diff --git a/src/key/verification/start.rs b/src/key/verification/start.rs index be40dc86..f7c4c019 100644 --- a/src/key/verification/start.rs +++ b/src/key/verification/start.rs @@ -52,10 +52,15 @@ impl<'de> Deserialize<'de> for EventResult { } }; - let content = match raw.content { - raw::StartEventContent::MSasV1(content) => StartEventContent::MSasV1(content), - raw::StartEventContent::__Nonexhaustive => { - panic!("__Nonexhaustive enum variant is not intended for use."); + let content = match StartEventContent::from_raw(raw.content) { + Ok(content) => content, + Err(error) => { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))); } }; @@ -123,6 +128,44 @@ impl_event!( EventType::KeyVerificationStart ); +impl StartEventContent { + fn from_raw(raw: raw::StartEventContent) -> Result { + match raw { + raw::StartEventContent::MSasV1(content) => { + if !content + .key_agreement_protocols + .contains(&KeyAgreementProtocol::Curve25519) + { + return Err("`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`"); + } + + if !content.hashes.contains(&HashAlgorithm::Sha256) { + return Err("`hashes` must contain at least `HashAlgorithm::Sha256`"); + } + + if !content + .message_authentication_codes + .contains(&MessageAuthenticationCode::HkdfHmacSha256) + { + return Err("`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`"); + } + + if !content + .short_authentication_string + .contains(&ShortAuthenticationString::Decimal) + { + return Err("`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`"); + } + + Ok(StartEventContent::MSasV1(content)) + } + raw::StartEventContent::__Nonexhaustive => { + panic!("__Nonexhaustive enum variant is not intended for use."); + } + } + } +} + impl<'de> Deserialize<'de> for EventResult { fn deserialize(deserializer: D) -> Result where @@ -142,53 +185,14 @@ impl<'de> Deserialize<'de> for EventResult { } }; - match raw { - raw::StartEventContent::MSasV1(content) => { - if !content - .key_agreement_protocols - .contains(&KeyAgreementProtocol::Curve25519) - { - return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { - json, - message: "`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`".to_string(), - }))); - } - - if !content.hashes.contains(&HashAlgorithm::Sha256) { - return Ok(EventResult::Err(InvalidEvent( - InnerInvalidEvent::Validation { - json, - message: "`hashes` must contain at least `HashAlgorithm::Sha256`" - .to_string(), - }, - ))); - } - - if !content - .message_authentication_codes - .contains(&MessageAuthenticationCode::HkdfHmacSha256) - { - return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { - json, - message: "`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`".to_string(), - }))); - } - - if !content - .short_authentication_string - .contains(&ShortAuthenticationString::Decimal) - { - return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { - json, - message: "`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`".to_string(), - }))); - } - - Ok(EventResult::Ok(StartEventContent::MSasV1(content))) - } - raw::StartEventContent::__Nonexhaustive => { - panic!("__Nonexhaustive enum variant is not intended for use."); - } + match StartEventContent::from_raw(raw) { + Ok(content) => Ok(EventResult::Ok(content)), + Err(error) => Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))), } } } @@ -696,4 +700,19 @@ mod tests { assert!(error.message().contains("short_authentication_string")); assert!(error.json().is_some()); } + + #[test] + fn deserialization_of_event_validates_content() { + // This JSON is missing the required value of "curve25519" for "key_agreement_protocols". + let error = + serde_json::from_str::>( + r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"# + ) + .unwrap() + .into_result() + .unwrap_err(); + + assert!(error.message().contains("key_agreement_protocols")); + assert!(error.json().is_some()); + } }