Ensure validation logic for m.key.verification.start is run when deserializing the event, not just the content.

This commit is contained in:
Jimmy Cuadra 2019-08-06 01:25:29 -07:00
parent 4984868e21
commit 685a61954c

View File

@ -52,10 +52,15 @@ impl<'de> Deserialize<'de> for EventResult<StartEvent> {
} }
}; };
let content = match raw.content { let content = match StartEventContent::from_raw(raw.content) {
raw::StartEventContent::MSasV1(content) => StartEventContent::MSasV1(content), Ok(content) => content,
raw::StartEventContent::__Nonexhaustive => { Err(error) => {
panic!("__Nonexhaustive enum variant is not intended for use."); return Ok(EventResult::Err(InvalidEvent(
InnerInvalidEvent::Validation {
json,
message: error.to_string(),
},
)));
} }
}; };
@ -123,6 +128,44 @@ impl_event!(
EventType::KeyVerificationStart EventType::KeyVerificationStart
); );
impl StartEventContent {
fn from_raw(raw: raw::StartEventContent) -> Result<Self, &'static str> {
match raw {
raw::StartEventContent::MSasV1(content) => {
if !content
.key_agreement_protocols
.contains(&KeyAgreementProtocol::Curve25519)
{
return Err("`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`");
}
if !content.hashes.contains(&HashAlgorithm::Sha256) {
return Err("`hashes` must contain at least `HashAlgorithm::Sha256`");
}
if !content
.message_authentication_codes
.contains(&MessageAuthenticationCode::HkdfHmacSha256)
{
return Err("`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`");
}
if !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal)
{
return Err("`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`");
}
Ok(StartEventContent::MSasV1(content))
}
raw::StartEventContent::__Nonexhaustive => {
panic!("__Nonexhaustive enum variant is not intended for use.");
}
}
}
}
impl<'de> Deserialize<'de> for EventResult<StartEventContent> { impl<'de> Deserialize<'de> for EventResult<StartEventContent> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where
@ -142,53 +185,14 @@ impl<'de> Deserialize<'de> for EventResult<StartEventContent> {
} }
}; };
match raw { match StartEventContent::from_raw(raw) {
raw::StartEventContent::MSasV1(content) => { Ok(content) => Ok(EventResult::Ok(content)),
if !content Err(error) => Ok(EventResult::Err(InvalidEvent(
.key_agreement_protocols InnerInvalidEvent::Validation {
.contains(&KeyAgreementProtocol::Curve25519) json,
{ message: error.to_string(),
return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { },
json, ))),
message: "`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`".to_string(),
})));
}
if !content.hashes.contains(&HashAlgorithm::Sha256) {
return Ok(EventResult::Err(InvalidEvent(
InnerInvalidEvent::Validation {
json,
message: "`hashes` must contain at least `HashAlgorithm::Sha256`"
.to_string(),
},
)));
}
if !content
.message_authentication_codes
.contains(&MessageAuthenticationCode::HkdfHmacSha256)
{
return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation {
json,
message: "`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`".to_string(),
})));
}
if !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal)
{
return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation {
json,
message: "`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`".to_string(),
})));
}
Ok(EventResult::Ok(StartEventContent::MSasV1(content)))
}
raw::StartEventContent::__Nonexhaustive => {
panic!("__Nonexhaustive enum variant is not intended for use.");
}
} }
} }
} }
@ -696,4 +700,19 @@ mod tests {
assert!(error.message().contains("short_authentication_string")); assert!(error.message().contains("short_authentication_string"));
assert!(error.json().is_some()); assert!(error.json().is_some());
} }
#[test]
fn deserialization_of_event_validates_content() {
// This JSON is missing the required value of "curve25519" for "key_agreement_protocols".
let error =
serde_json::from_str::<EventResult<StartEvent>>(
r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"#
)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("key_agreement_protocols"));
assert!(error.json().is_some());
}
} }