From 2acca3e3ef270e677f121b9d05367d6b1db0312b Mon Sep 17 00:00:00 2001 From: Jimmy Cuadra Date: Tue, 6 Aug 2019 14:54:25 -0700 Subject: [PATCH] impl Deserialize for m.room.encrypted --- src/room/encrypted.rs | 95 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 88 insertions(+), 7 deletions(-) diff --git a/src/room/encrypted.rs b/src/room/encrypted.rs index f05da06c..88ff077b 100644 --- a/src/room/encrypted.rs +++ b/src/room/encrypted.rs @@ -7,7 +7,7 @@ use ruma_identifiers::{DeviceId, EventId, RoomId, UserId}; use serde::{de::Error, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{from_value, Value}; -use crate::{Algorithm, Event, EventType, InnerInvalidEvent, InvalidEvent, RoomEvent}; +use crate::{Algorithm, Event, EventResult, EventType, InnerInvalidEvent, InvalidEvent, RoomEvent}; /// This event type is used when sending encrypted events. /// @@ -50,6 +50,48 @@ pub enum EncryptedEventContent { __Nonexhaustive, } +impl<'de> Deserialize<'de> for EventResult { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + + let raw: raw::EncryptedEvent = match serde_json::from_value(json.clone()) { + Ok(raw) => raw, + Err(error) => { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))); + } + }; + + let content = match raw.content { + raw::EncryptedEventContent::OlmV1Curve25519AesSha2(content) => { + EncryptedEventContent::OlmV1Curve25519AesSha2(content) + } + raw::EncryptedEventContent::MegolmV1AesSha2(content) => { + EncryptedEventContent::MegolmV1AesSha2(content) + } + raw::EncryptedEventContent::__Nonexhaustive => { + panic!("__Nonexhaustive enum variant is not intended for use."); + } + }; + + Ok(EventResult::Ok(EncryptedEvent { + content, + event_id: raw.event_id, + origin_server_ts: raw.origin_server_ts, + room_id: raw.room_id, + sender: raw.sender, + unsigned: raw.unsigned, + })) + } +} + impl FromStr for EncryptedEvent { type Err = InvalidEvent; @@ -144,6 +186,39 @@ impl_room_event!( EventType::RoomEncrypted ); +impl<'de> Deserialize<'de> for EventResult { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + + let raw: raw::EncryptedEventContent = match serde_json::from_value(json.clone()) { + Ok(raw) => raw, + Err(error) => { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))); + } + }; + + match raw { + raw::EncryptedEventContent::OlmV1Curve25519AesSha2(content) => Ok(EventResult::Ok( + EncryptedEventContent::OlmV1Curve25519AesSha2(content), + )), + raw::EncryptedEventContent::MegolmV1AesSha2(content) => Ok(EventResult::Ok( + EncryptedEventContent::MegolmV1AesSha2(content), + )), + raw::EncryptedEventContent::__Nonexhaustive => { + panic!("__Nonexhaustive enum variant is not intended for use."); + } + } + } +} + impl FromStr for EncryptedEventContent { type Err = InvalidEvent; @@ -340,7 +415,7 @@ pub struct MegolmV1AesSha2Content { mod tests { use serde_json::to_string; - use super::{Algorithm, EncryptedEventContent, MegolmV1AesSha2Content}; + use super::{Algorithm, EncryptedEventContent, EventResult, MegolmV1AesSha2Content}; #[test] fn serializtion() { @@ -371,8 +446,11 @@ mod tests { }); assert_eq!( - r#"{"algorithm":"m.megolm.v1.aes-sha2","ciphertext":"ciphertext","sender_key":"sender_key","device_id":"device_id","session_id":"session_id"}"# - .parse::() + serde_json::from_str::>( + r#"{"algorithm":"m.megolm.v1.aes-sha2","ciphertext":"ciphertext","sender_key":"sender_key","device_id":"device_id","session_id":"session_id"}"# + ) + .unwrap() + .into_result() .unwrap(), key_verification_start_content ); @@ -380,8 +458,11 @@ mod tests { #[test] fn deserialization_failure() { - assert!( - r#"{"algorithm":"m.megolm.v1.aes-sha2"}"#.parse::().is_err() - ); + assert!(serde_json::from_str::>( + r#"{"algorithm":"m.megolm.v1.aes-sha2"}"# + ) + .unwrap() + .into_result() + .is_err()); } }