From 30c1ef07dcd5fae831d6a71b74069e86fda4b544 Mon Sep 17 00:00:00 2001 From: Jimmy Cuadra Date: Mon, 5 Aug 2019 15:55:25 -0700 Subject: [PATCH] impl Deserialize m.key.verification.start --- src/key/verification/start.rs | 171 ++++++++++++++++++++++++++++------ 1 file changed, 143 insertions(+), 28 deletions(-) diff --git a/src/key/verification/start.rs b/src/key/verification/start.rs index 505cff8a..4ca6690f 100644 --- a/src/key/verification/start.rs +++ b/src/key/verification/start.rs @@ -10,7 +10,7 @@ use super::{ HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString, VerificationMethod, }; -use crate::{Event, EventType, InnerInvalidEvent, InvalidEvent, InvalidInput}; +use crate::{Event, EventResult, EventType, InnerInvalidEvent, InvalidEvent, InvalidInput}; /// Begins an SAS key verification process. /// @@ -33,6 +33,36 @@ pub enum StartEventContent { __Nonexhaustive, } +impl<'de> Deserialize<'de> for EventResult { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + + let raw: raw::StartEvent = match serde_json::from_value(json.clone()) { + Ok(raw) => raw, + Err(error) => { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))); + } + }; + + let content = match raw.content { + raw::StartEventContent::MSasV1(content) => StartEventContent::MSasV1(content), + raw::StartEventContent::__Nonexhaustive => { + panic!("__Nonexhaustive enum variant is not intended for use."); + } + }; + + Ok(EventResult::Ok(StartEvent { content })) + } +} + impl FromStr for StartEvent { type Err = InvalidEvent; @@ -93,6 +123,76 @@ impl_event!( EventType::KeyVerificationStart ); +impl<'de> Deserialize<'de> for EventResult { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + + let raw: raw::StartEventContent = match serde_json::from_value(json.clone()) { + Ok(raw) => raw, + Err(error) => { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))); + } + }; + + match raw { + raw::StartEventContent::MSasV1(content) => { + if !content + .key_agreement_protocols + .contains(&KeyAgreementProtocol::Curve25519) + { + return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { + json, + message: "`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`".to_string(), + }))); + } + + if !content.hashes.contains(&HashAlgorithm::Sha256) { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: "`hashes` must contain at least `HashAlgorithm::Sha256`" + .to_string(), + }, + ))); + } + + if !content + .message_authentication_codes + .contains(&MessageAuthenticationCode::HkdfHmacSha256) + { + return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { + json, + message: "`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`".to_string(), + }))); + } + + if !content + .short_authentication_string + .contains(&ShortAuthenticationString::Decimal) + { + return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { + json, + message: "`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`".to_string(), + }))); + } + + Ok(EventResult::Ok(StartEventContent::MSasV1(content))) + } + raw::StartEventContent::__Nonexhaustive => { + panic!("__Nonexhaustive enum variant is not intended for use."); + } + } + } +} + impl FromStr for StartEventContent { type Err = InvalidEvent; @@ -390,7 +490,7 @@ mod tests { use serde_json::to_string; use super::{ - HashAlgorithm, KeyAgreementProtocol, MSasV1Content, MSasV1ContentOptions, + EventResult, HashAlgorithm, KeyAgreementProtocol, MSasV1Content, MSasV1ContentOptions, MessageAuthenticationCode, ShortAuthenticationString, StartEvent, StartEventContent, }; @@ -498,8 +598,11 @@ mod tests { // Deserialize the content struct separately to verify `FromStr` is implemented for it. assert_eq!( - r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","hashes":["sha256"],"key_agreement_protocols":["curve25519"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"# - .parse::() + serde_json::from_str::>( + r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","hashes":["sha256"],"key_agreement_protocols":["curve25519"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"# + ) + .unwrap() + .into() .unwrap(), key_verification_start_content ); @@ -509,8 +612,11 @@ mod tests { }; assert_eq!( - r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"# - .parse::() + serde_json::from_str::>( + r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"# + ) + .unwrap() + .into() .unwrap(), key_verification_start ) @@ -518,17 +624,18 @@ mod tests { #[test] fn deserialization_failure() { - // Invalid JSON - let error = "{".parse::().err().unwrap(); - - // No `serde_json::Value` available if deserialization failed. - assert!(error.json().is_none()); + // Ensure that invalid JSON creates a `serde_json::Error` and not `InvalidEvent` + assert!(serde_json::from_str::>("{").is_err()); } #[test] fn deserialization_structure_mismatch() { // Missing several required fields. - let error = r#"{"from_device":"123"}"#.parse::().err().unwrap(); + let error = + serde_json::from_str::>(r#"{"from_device":"123"}"#) + .unwrap() + .into() + .unwrap_err(); assert!(error.message().contains("missing field")); assert!(error.json().is_some()); @@ -537,10 +644,12 @@ mod tests { #[test] fn deserialization_validation_missing_required_key_agreement_protocols() { let error = - r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"# - .parse::() - .err() - .unwrap(); + serde_json::from_str::>( + r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"# + ) + .unwrap() + .into() + .unwrap_err(); assert!(error.message().contains("key_agreement_protocols")); assert!(error.json().is_some()); @@ -549,10 +658,12 @@ mod tests { #[test] fn deserialization_validation_missing_required_hashes() { let error = - r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":[],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"# - .parse::() - .err() - .unwrap(); + serde_json::from_str::>( + r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":[],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"# + ) + .unwrap() + .into() + .unwrap_err(); assert!(error.message().contains("hashes")); assert!(error.json().is_some()); @@ -561,10 +672,12 @@ mod tests { #[test] fn deserialization_validation_missing_required_message_authentication_codes() { let error = - r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":[],"short_authentication_string":["decimal"]}"# - .parse::() - .err() - .unwrap(); + serde_json::from_str::>( + r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":[],"short_authentication_string":["decimal"]}"# + ) + .unwrap() + .into() + .unwrap_err(); assert!(error.message().contains("message_authentication_codes")); assert!(error.json().is_some()); @@ -573,10 +686,12 @@ mod tests { #[test] fn deserialization_validation_missing_required_short_authentication_string() { let error = - r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":[]}"# - .parse::() - .err() - .unwrap(); + serde_json::from_str::>( + r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":[]}"# + ) + .unwrap() + .into() + .unwrap_err(); assert!(error.message().contains("short_authentication_string")); assert!(error.json().is_some());