diff --git a/src/room/message.rs b/src/room/message.rs index 40622554..357ebf20 100644 --- a/src/room/message.rs +++ b/src/room/message.rs @@ -12,7 +12,7 @@ use serde::{ use serde_json::{from_value, Value}; use super::{EncryptedFile, ImageInfo, ThumbnailInfo}; -use crate::{Event, EventType, InnerInvalidEvent, InvalidEvent, RoomEvent}; +use crate::{Event, EventResult, EventType, InnerInvalidEvent, InvalidEvent, RoomEvent}; pub mod feedback; @@ -76,6 +76,53 @@ pub enum MessageEventContent { __Nonexhaustive, } +impl<'de> Deserialize<'de> for EventResult { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + + let raw: raw::MessageEvent = match serde_json::from_value(json.clone()) { + Ok(raw) => raw, + Err(error) => { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))); + } + }; + + Ok(EventResult::Ok(MessageEvent { + content: match raw.content { + raw::MessageEventContent::Audio(content) => MessageEventContent::Audio(content), + raw::MessageEventContent::Emote(content) => MessageEventContent::Emote(content), + raw::MessageEventContent::File(content) => MessageEventContent::File(content), + raw::MessageEventContent::Image(content) => MessageEventContent::Image(content), + raw::MessageEventContent::Location(content) => { + MessageEventContent::Location(content) + } + raw::MessageEventContent::Notice(content) => MessageEventContent::Notice(content), + raw::MessageEventContent::ServerNotice(content) => { + MessageEventContent::ServerNotice(content) + } + raw::MessageEventContent::Text(content) => MessageEventContent::Text(content), + raw::MessageEventContent::Video(content) => MessageEventContent::Video(content), + raw::MessageEventContent::__Nonexhaustive => { + panic!("__Nonexhaustive enum variant is not intended for use.") + } + }, + event_id: raw.event_id, + origin_server_ts: raw.origin_server_ts, + room_id: raw.room_id, + sender: raw.sender, + unsigned: raw.unsigned, + })) + } +} + impl FromStr for MessageEvent { type Err = InvalidEvent; @@ -193,6 +240,46 @@ impl Serialize for MessageEventContent { } } +impl<'de> Deserialize<'de> for EventResult { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + + let raw: raw::MessageEventContent = match serde_json::from_value(json.clone()) { + Ok(raw) => raw, + Err(error) => { + return Ok(EventResult::Err(InvalidEvent( + InnerInvalidEvent::Validation { + json, + message: error.to_string(), + }, + ))); + } + }; + + let content = match raw { + raw::MessageEventContent::Audio(content) => MessageEventContent::Audio(content), + raw::MessageEventContent::Emote(content) => MessageEventContent::Emote(content), + raw::MessageEventContent::File(content) => MessageEventContent::File(content), + raw::MessageEventContent::Image(content) => MessageEventContent::Image(content), + raw::MessageEventContent::Location(content) => MessageEventContent::Location(content), + raw::MessageEventContent::Notice(content) => MessageEventContent::Notice(content), + raw::MessageEventContent::ServerNotice(content) => { + MessageEventContent::ServerNotice(content) + } + raw::MessageEventContent::Text(content) => MessageEventContent::Text(content), + raw::MessageEventContent::Video(content) => MessageEventContent::Video(content), + raw::MessageEventContent::__Nonexhaustive => { + panic!("__Nonexhaustive enum variant is not intended for use.") + } + }; + + Ok(EventResult::Ok(content)) + } +} + impl FromStr for MessageEventContent { type Err = InvalidEvent; @@ -1107,7 +1194,7 @@ impl Serialize for VideoMessageEventContent { mod tests { use serde_json::to_string; - use super::{AudioMessageEventContent, MessageEventContent}; + use super::{AudioMessageEventContent, EventResult, MessageEventContent}; #[test] fn serialization() { @@ -1134,19 +1221,23 @@ mod tests { }); assert_eq!( - r#"{"body":"test","msgtype":"m.audio","url":"http://example.com/audio.mp3"}"# - .parse::() - .unwrap(), + serde_json::from_str::>( + r#"{"body":"test","msgtype":"m.audio","url":"http://example.com/audio.mp3"}"# + ) + .unwrap() + .into_result() + .unwrap(), message_event_content ); } #[test] fn deserialization_failure() { - assert!( + assert!(serde_json::from_str::>( r#"{"body":"test","msgtype":"m.location","url":"http://example.com/audio.mp3"}"# - .parse::() - .is_err() - ); + ) + .unwrap() + .into_result() + .is_err()); } }