Ensure validation logic for m.key.verification.start is run when deserializing the event, not just the content.
This commit is contained in:
		
							parent
							
								
									4984868e21
								
							
						
					
					
						commit
						685a61954c
					
				| @ -52,10 +52,15 @@ impl<'de> Deserialize<'de> for EventResult<StartEvent> { | |||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let content = match raw.content { |         let content = match StartEventContent::from_raw(raw.content) { | ||||||
|             raw::StartEventContent::MSasV1(content) => StartEventContent::MSasV1(content), |             Ok(content) => content, | ||||||
|             raw::StartEventContent::__Nonexhaustive => { |             Err(error) => { | ||||||
|                 panic!("__Nonexhaustive enum variant is not intended for use."); |                 return Ok(EventResult::Err(InvalidEvent( | ||||||
|  |                     InnerInvalidEvent::Validation { | ||||||
|  |                         json, | ||||||
|  |                         message: error.to_string(), | ||||||
|  |                     }, | ||||||
|  |                 ))); | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
| @ -123,6 +128,44 @@ impl_event!( | |||||||
|     EventType::KeyVerificationStart |     EventType::KeyVerificationStart | ||||||
| ); | ); | ||||||
| 
 | 
 | ||||||
|  | impl StartEventContent { | ||||||
|  |     fn from_raw(raw: raw::StartEventContent) -> Result<Self, &'static str> { | ||||||
|  |         match raw { | ||||||
|  |             raw::StartEventContent::MSasV1(content) => { | ||||||
|  |                 if !content | ||||||
|  |                     .key_agreement_protocols | ||||||
|  |                     .contains(&KeyAgreementProtocol::Curve25519) | ||||||
|  |                 { | ||||||
|  |                     return Err("`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`"); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 if !content.hashes.contains(&HashAlgorithm::Sha256) { | ||||||
|  |                     return Err("`hashes` must contain at least `HashAlgorithm::Sha256`"); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 if !content | ||||||
|  |                     .message_authentication_codes | ||||||
|  |                     .contains(&MessageAuthenticationCode::HkdfHmacSha256) | ||||||
|  |                 { | ||||||
|  |                     return Err("`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`"); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 if !content | ||||||
|  |                     .short_authentication_string | ||||||
|  |                     .contains(&ShortAuthenticationString::Decimal) | ||||||
|  |                 { | ||||||
|  |                     return Err("`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`"); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 Ok(StartEventContent::MSasV1(content)) | ||||||
|  |             } | ||||||
|  |             raw::StartEventContent::__Nonexhaustive => { | ||||||
|  |                 panic!("__Nonexhaustive enum variant is not intended for use."); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| impl<'de> Deserialize<'de> for EventResult<StartEventContent> { | impl<'de> Deserialize<'de> for EventResult<StartEventContent> { | ||||||
|     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> |     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||||||
|     where |     where | ||||||
| @ -142,53 +185,14 @@ impl<'de> Deserialize<'de> for EventResult<StartEventContent> { | |||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         match raw { |         match StartEventContent::from_raw(raw) { | ||||||
|             raw::StartEventContent::MSasV1(content) => { |             Ok(content) => Ok(EventResult::Ok(content)), | ||||||
|                 if !content |             Err(error) => Ok(EventResult::Err(InvalidEvent( | ||||||
|                     .key_agreement_protocols |                 InnerInvalidEvent::Validation { | ||||||
|                     .contains(&KeyAgreementProtocol::Curve25519) |                     json, | ||||||
|                 { |                     message: error.to_string(), | ||||||
|                     return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { |                 }, | ||||||
|                         json, |             ))), | ||||||
|                         message: "`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`".to_string(), |  | ||||||
|                     }))); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if !content.hashes.contains(&HashAlgorithm::Sha256) { |  | ||||||
|                     return Ok(EventResult::Err(InvalidEvent( |  | ||||||
|                         InnerInvalidEvent::Validation { |  | ||||||
|                             json, |  | ||||||
|                             message: "`hashes` must contain at least `HashAlgorithm::Sha256`" |  | ||||||
|                                 .to_string(), |  | ||||||
|                         }, |  | ||||||
|                     ))); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if !content |  | ||||||
|                     .message_authentication_codes |  | ||||||
|                     .contains(&MessageAuthenticationCode::HkdfHmacSha256) |  | ||||||
|                 { |  | ||||||
|                     return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { |  | ||||||
|                         json, |  | ||||||
|                         message: "`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`".to_string(), |  | ||||||
|                     }))); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if !content |  | ||||||
|                     .short_authentication_string |  | ||||||
|                     .contains(&ShortAuthenticationString::Decimal) |  | ||||||
|                 { |  | ||||||
|                     return Ok(EventResult::Err(InvalidEvent(InnerInvalidEvent::Validation { |  | ||||||
|                         json, |  | ||||||
|                         message: "`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`".to_string(), |  | ||||||
|                     }))); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 Ok(EventResult::Ok(StartEventContent::MSasV1(content))) |  | ||||||
|             } |  | ||||||
|             raw::StartEventContent::__Nonexhaustive => { |  | ||||||
|                 panic!("__Nonexhaustive enum variant is not intended for use."); |  | ||||||
|             } |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @ -696,4 +700,19 @@ mod tests { | |||||||
|         assert!(error.message().contains("short_authentication_string")); |         assert!(error.message().contains("short_authentication_string")); | ||||||
|         assert!(error.json().is_some()); |         assert!(error.json().is_some()); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     #[test] | ||||||
|  |     fn deserialization_of_event_validates_content() { | ||||||
|  |         // This JSON is missing the required value of "curve25519" for "key_agreement_protocols".
 | ||||||
|  |         let error = | ||||||
|  |             serde_json::from_str::<EventResult<StartEvent>>( | ||||||
|  |                 r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"# | ||||||
|  |             ) | ||||||
|  |                 .unwrap() | ||||||
|  |                 .into_result() | ||||||
|  |                 .unwrap_err(); | ||||||
|  | 
 | ||||||
|  |         assert!(error.message().contains("key_agreement_protocols")); | ||||||
|  |         assert!(error.json().is_some()); | ||||||
|  |     } | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user