diff --git a/ruma-events/CHANGELOG.md b/ruma-events/CHANGELOG.md index 976ce1a9..cf85ecd6 100644 --- a/ruma-events/CHANGELOG.md +++ b/ruma-events/CHANGELOG.md @@ -4,6 +4,8 @@ Breaking changes: +* Change the structure of `StartEventContent` so that we can access transaction + ids without the need to understand the concrete method. * Change `get_message_events` limit field type from `Option` to `UInt` * Add `alt_aliases` to `CanonicalAliasEventContent` * Replace `format` and `formatted_body` fields in `TextMessagEventContent`, diff --git a/ruma-events/src/key/verification/start.rs b/ruma-events/src/key/verification/start.rs index b29b236a..54c0ef8d 100644 --- a/ruma-events/src/key/verification/start.rs +++ b/ruma-events/src/key/verification/start.rs @@ -1,8 +1,11 @@ //! Types for the *m.key.verification.start* event. +use std::collections::BTreeMap; + use ruma_events_macros::BasicEventContent; use ruma_identifiers::DeviceId; use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; use super::{ HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString, @@ -16,19 +19,8 @@ pub type StartEvent = BasicEvent; /// The payload of an *m.key.verification.start* event. #[derive(Clone, Debug, Deserialize, Serialize, BasicEventContent)] -#[non_exhaustive] #[ruma_event(type = "m.key.verification.start")] -#[serde(tag = "method")] -pub enum StartEventContent { - /// The *m.sas.v1* verification method. - #[serde(rename = "m.sas.v1")] - MSasV1(MSasV1Content), -} - -/// The payload of an *m.key.verification.start* event using the *m.sas.v1* method. -#[derive(Clone, Debug, Deserialize, Serialize)] -#[non_exhaustive] -pub struct MSasV1Content { +pub struct StartEventContent { /// The device ID which is initiating the process. pub from_device: Box, @@ -39,6 +31,40 @@ pub struct MSasV1Content { /// from a request. pub transaction_id: String, + /// Method specific content. + #[serde(flatten)] + pub method: StartMethod, +} + +/// An enum representing the different method specific +/// *m.key.verification.start* content. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[non_exhaustive] +#[serde(untagged)] +pub enum StartMethod { + /// The *m.sas.v1* verification method. + MSasV1(MSasV1Content), + + /// Any unknown start method. + Custom(CustomContent), +} + +/// Method specific content of a unknown key verification method. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct CustomContent { + /// The name of the method. + pub method: String, + + /// The additional fields that the method contains. + #[serde(flatten)] + pub fields: BTreeMap, +} + +/// The payload of an *m.key.verification.start* event using the *m.sas.v1* method. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[non_exhaustive] +#[serde(rename = "m.sas.v1", tag = "method")] +pub struct MSasV1Content { /// The key agreement protocols the sending device understands. /// /// Must include at least `Curve25519` or `Curve25519HkdfSha256`. @@ -63,16 +89,6 @@ pub struct MSasV1Content { /// Options for creating an `MSasV1Content` with `MSasV1Content::new`. #[derive(Clone, Debug, Deserialize)] pub struct MSasV1ContentOptions { - /// The device ID which is initiating the process. - pub from_device: Box, - - /// An opaque identifier for the verification process. - /// - /// Must be unique with respect to the devices involved. Must be the same as the - /// `transaction_id` given in the *m.key.verification.request* if this process is originating - /// from a request. - pub transaction_id: String, - /// The key agreement protocols the sending device understands. /// /// Must include at least `curve25519`. @@ -146,8 +162,6 @@ impl MSasV1Content { } Ok(Self { - from_device: options.from_device, - transaction_id: options.transaction_id, key_agreement_protocols: options.key_agreement_protocols, hashes: options.hashes, message_authentication_codes: options.message_authentication_codes, @@ -158,20 +172,23 @@ impl MSasV1Content { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use matches::assert_matches; - use serde_json::{from_value as from_json_value, json, to_value as to_json_value}; + use serde_json::{ + from_value as from_json_value, json, to_value as to_json_value, Value as JsonValue, + }; use super::{ - HashAlgorithm, KeyAgreementProtocol, MSasV1Content, MSasV1ContentOptions, + CustomContent, HashAlgorithm, KeyAgreementProtocol, MSasV1Content, MSasV1ContentOptions, MessageAuthenticationCode, ShortAuthenticationString, StartEvent, StartEventContent, + StartMethod, }; use ruma_common::Raw; #[test] fn invalid_m_sas_v1_content_missing_required_key_agreement_protocols() { let error = MSasV1Content::new(MSasV1ContentOptions { - from_device: "123".into(), - transaction_id: "456".into(), hashes: vec![HashAlgorithm::Sha256], key_agreement_protocols: vec![], message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256], @@ -186,8 +203,6 @@ mod tests { #[test] fn invalid_m_sas_v1_content_missing_required_hashes() { let error = MSasV1Content::new(MSasV1ContentOptions { - from_device: "123".into(), - transaction_id: "456".into(), hashes: vec![], key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519], message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256], @@ -202,8 +217,6 @@ mod tests { #[test] fn invalid_m_sas_v1_content_missing_required_message_authentication_codes() { let error = MSasV1Content::new(MSasV1ContentOptions { - from_device: "123".into(), - transaction_id: "456".into(), hashes: vec![HashAlgorithm::Sha256], key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519], message_authentication_codes: vec![], @@ -218,8 +231,6 @@ mod tests { #[test] fn invalid_m_sas_v1_content_missing_required_short_authentication_string() { let error = MSasV1Content::new(MSasV1ContentOptions { - from_device: "123".into(), - transaction_id: "456".into(), hashes: vec![HashAlgorithm::Sha256], key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519], message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256], @@ -233,17 +244,19 @@ mod tests { #[test] fn serialization() { - let key_verification_start_content = StartEventContent::MSasV1( - MSasV1Content::new(MSasV1ContentOptions { - from_device: "123".into(), - transaction_id: "456".into(), - hashes: vec![HashAlgorithm::Sha256], - key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519], - message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256], - short_authentication_string: vec![ShortAuthenticationString::Decimal], - }) - .unwrap(), - ); + let key_verification_start_content = StartEventContent { + from_device: "123".into(), + transaction_id: "456".into(), + method: StartMethod::MSasV1( + MSasV1Content::new(MSasV1ContentOptions { + hashes: vec![HashAlgorithm::Sha256], + key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519], + message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256], + short_authentication_string: vec![ShortAuthenticationString::Decimal], + }) + .unwrap(), + ), + }; let key_verification_start = StartEvent { content: key_verification_start_content }; @@ -261,6 +274,31 @@ mod tests { }); assert_eq!(to_json_value(&key_verification_start).unwrap(), json_data); + + let json_data = json!({ + "content": { + "from_device": "123", + "transaction_id": "456", + "method": "m.sas.custom", + "test": "field", + }, + "type": "m.key.verification.start" + }); + + let key_verification_start_content = StartEventContent { + from_device: "123".into(), + transaction_id: "456".into(), + method: StartMethod::Custom(CustomContent { + method: "m.sas.custom".to_owned(), + fields: vec![("test".to_string(), JsonValue::from("field"))] + .into_iter() + .collect::>(), + }), + }; + + let key_verification_start = StartEvent { content: key_verification_start_content }; + + assert_eq!(to_json_value(&key_verification_start).unwrap(), json_data); } #[test] @@ -281,14 +319,16 @@ mod tests { .unwrap() .deserialize() .unwrap(), - StartEventContent::MSasV1(MSasV1Content { + StartEventContent { from_device, transaction_id, - hashes, - key_agreement_protocols, - message_authentication_codes, - short_authentication_string, - }) if from_device == "123" + method: StartMethod::MSasV1(MSasV1Content { + hashes, + key_agreement_protocols, + message_authentication_codes, + short_authentication_string, + }) + } if from_device == "123" && transaction_id == "456" && hashes == vec![HashAlgorithm::Sha256] && key_agreement_protocols == vec![KeyAgreementProtocol::Curve25519] @@ -315,21 +355,53 @@ mod tests { .deserialize() .unwrap(), StartEvent { - content: StartEventContent::MSasV1(MSasV1Content { + content: StartEventContent { from_device, transaction_id, - hashes, - key_agreement_protocols, - message_authentication_codes, - short_authentication_string, - }) + method: StartMethod::MSasV1(MSasV1Content { + hashes, + key_agreement_protocols, + message_authentication_codes, + short_authentication_string, + }) + } } if from_device == "123" && transaction_id == "456" && hashes == vec![HashAlgorithm::Sha256] && key_agreement_protocols == vec![KeyAgreementProtocol::Curve25519] && message_authentication_codes == vec![MessageAuthenticationCode::HkdfHmacSha256] && short_authentication_string == vec![ShortAuthenticationString::Decimal] - ) + ); + + let json = json!({ + "content": { + "from_device": "123", + "transaction_id": "456", + "method": "m.sas.custom", + "test": "field", + }, + "type": "m.key.verification.start" + }); + + assert_matches!( + from_json_value::>(json) + .unwrap() + .deserialize() + .unwrap(), + StartEvent { + content: StartEventContent { + from_device, + transaction_id, + method: StartMethod::Custom(CustomContent { + method, + fields, + }) + } + } if from_device == "123" + && transaction_id == "456" + && method == "m.sas.custom" + && fields.get("test").unwrap() == &JsonValue::from("field") + ); } #[test]