Ensure validation logic for m.key.verification.start is run when deserializing the event, not just the content.

This commit is contained in:
Jimmy Cuadra 2019-08-06 01:25:29 -07:00
parent 4984868e21
commit 685a61954c

View File

@ -52,10 +52,15 @@ impl<'de> Deserialize<'de> for EventResult<StartEvent> {
}
};
let content = match raw.content {
raw::StartEventContent::MSasV1(content) => StartEventContent::MSasV1(content),
raw::StartEventContent::__Nonexhaustive => {
panic!("__Nonexhaustive enum variant is not intended for use.");
let content = match StartEventContent::from_raw(raw.content) {
Ok(content) => content,
Err(error) => {
return Ok(EventResult::Err(InvalidEvent(
InnerInvalidEvent::Validation {
json,
message: error.to_string(),
},
)));
}
};
@ -123,6 +128,44 @@ impl_event!(
EventType::KeyVerificationStart
);
impl StartEventContent {
fn from_raw(raw: raw::StartEventContent) -> Result<Self, &'static str> {
match raw {
raw::StartEventContent::MSasV1(content) => {
if !content
.key_agreement_protocols
.contains(&KeyAgreementProtocol::Curve25519)
{
return Err("`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`");
}
if !content.hashes.contains(&HashAlgorithm::Sha256) {
return Err("`hashes` must contain at least `HashAlgorithm::Sha256`");
}
if !content
.message_authentication_codes
.contains(&MessageAuthenticationCode::HkdfHmacSha256)
{
return Err("`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`");
}
if !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal)
{
return Err("`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`");
}
Ok(StartEventContent::MSasV1(content))
}
raw::StartEventContent::__Nonexhaustive => {
panic!("__Nonexhaustive enum variant is not intended for use.");
}
}
}
}
impl<'de> Deserialize<'de> for EventResult<StartEventContent> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
@ -142,53 +185,14 @@ impl<'de> Deserialize<'de> for EventResult<StartEventContent> {
}
};
match raw {
raw::StartEventContent::MSasV1(content) => {
if !content
.key_agreement_protocols
.contains(&KeyAgreementProtocol::Curve25519)
{
return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation {
json,
message: "`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`".to_string(),
})));
}
if !content.hashes.contains(&HashAlgorithm::Sha256) {
return Ok(EventResult::Err(InvalidEvent(
InnerInvalidEvent::Validation {
json,
message: "`hashes` must contain at least `HashAlgorithm::Sha256`"
.to_string(),
},
)));
}
if !content
.message_authentication_codes
.contains(&MessageAuthenticationCode::HkdfHmacSha256)
{
return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation {
json,
message: "`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`".to_string(),
})));
}
if !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal)
{
return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation {
json,
message: "`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`".to_string(),
})));
}
Ok(EventResult::Ok(StartEventContent::MSasV1(content)))
}
raw::StartEventContent::__Nonexhaustive => {
panic!("__Nonexhaustive enum variant is not intended for use.");
}
match StartEventContent::from_raw(raw) {
Ok(content) => Ok(EventResult::Ok(content)),
Err(error) => Ok(EventResult::Err(InvalidEvent(
InnerInvalidEvent::Validation {
json,
message: error.to_string(),
},
))),
}
}
}
@ -696,4 +700,19 @@ mod tests {
assert!(error.message().contains("short_authentication_string"));
assert!(error.json().is_some());
}
#[test]
fn deserialization_of_event_validates_content() {
// This JSON is missing the required value of "curve25519" for "key_agreement_protocols".
let error =
serde_json::from_str::<EventResult<StartEvent>>(
r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"#
)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("key_agreement_protocols"));
assert!(error.json().is_some());
}
}