From 2c8d609095aec25c73d029e2eabf97c8ef377a6a Mon Sep 17 00:00:00 2001 From: "Ragotzy.devin" Date: Thu, 4 Jun 2020 15:19:54 -0400 Subject: [PATCH] Add Event derive macro --- Cargo.toml | 2 +- ruma-events-macros/src/event.rs | 311 ++++++++++++++++++ ruma-events-macros/src/event_content.rs | 10 +- ruma-events-macros/src/lib.rs | 17 +- src/message.rs | 226 +------------ src/state.rs | 259 +-------------- tests/event.rs | 10 + tests/event_content.rs | 2 +- ...ty-check.rs => 01-content-sanity-check.rs} | 0 tests/ui/04-event-sanity-check.rs | 19 ++ tests/ui/05-named-fields.rs | 10 + tests/ui/05-named-fields.stderr | 5 + tests/ui/06-no-content-field.rs | 13 + tests/ui/06-no-content-field.stderr | 7 + 14 files changed, 404 insertions(+), 487 deletions(-) create mode 100644 ruma-events-macros/src/event.rs create mode 100644 tests/event.rs rename tests/ui/{01-sanity-check.rs => 01-content-sanity-check.rs} (100%) create mode 100644 tests/ui/04-event-sanity-check.rs create mode 100644 tests/ui/05-named-fields.rs create mode 100644 tests/ui/05-named-fields.stderr create mode 100644 tests/ui/06-no-content-field.rs create mode 100644 tests/ui/06-no-content-field.stderr diff --git a/Cargo.toml b/Cargo.toml index ea2a5c32..633d0056 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ strum = { version = "0.18.0", features = ["derive"] } maplit = "1.0.2" matches = "0.1.8" ruma-identifiers = { version = "0.16.2", features = ["rand"] } -trybuild = "1.0.27" +trybuild = "1.0.28" [workspace] members = [ diff --git a/ruma-events-macros/src/event.rs b/ruma-events-macros/src/event.rs new file mode 100644 index 00000000..265715fb --- /dev/null +++ b/ruma-events-macros/src/event.rs @@ -0,0 +1,311 @@ +//! Implementation of the top level `*Event` derive macro. + +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::{ + Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam, Ident, TypeParam, +}; + +/// Derive `Event` macro code generation. +pub fn expand_event(input: DeriveInput) -> syn::Result { + let ident = &input.ident; + let fields = if let Data::Struct(DataStruct { fields, .. }) = input.data.clone() { + if let Fields::Named(FieldsNamed { named, .. }) = fields { + if !named.iter().any(|f| f.ident.as_ref().unwrap() == "content") { + return Err(syn::Error::new( + Span::call_site(), + "struct must contain a `content` field", + )); + } + + named.into_iter().collect::>() + } else { + return Err(syn::Error::new_spanned( + fields, + "the `Event` derive only supports named fields", + )); + } + } else { + return Err(syn::Error::new_spanned( + input.ident, + "the `Event` derive only supports structs with named fields", + )); + }; + + let content_trait = Ident::new(&format!("{}Content", ident), input.ident.span()); + let try_from_raw_fields = fields + .iter() + .map(|field| { + let name = field.ident.as_ref().unwrap(); + if name == "content" { + quote! { content: C::try_from_raw(raw.content)? } + } else if name == "prev_content" { + quote! { prev_content: raw.prev_content.map(C::try_from_raw).transpose()? } + } else { + quote! { #name: raw.#name } + } + }) + .collect::>(); + + let try_from_raw_impl = quote! { + impl ::ruma_events::TryFromRaw for #ident + where + C: ::ruma_events::#content_trait + ::ruma_events::TryFromRaw, + C::Raw: ::ruma_events::RawEventContent, + { + type Raw = raw_event::#ident; + type Err = C::Err; + + fn try_from_raw(raw: Self::Raw) -> Result { + Ok(Self { + #( #try_from_raw_fields ),* + }) + } + } + }; + + let serialize_fields = fields + .iter() + .map(|field| { + let name = field.ident.as_ref().unwrap(); + if name == "prev_content" { + quote! { + if let Some(content) = self.prev_content.as_ref() { + state.serialize_field("prev_content", content)?; + } + } + } else if name == "origin_server_ts" { + quote! { + let time_since_epoch = + self.origin_server_ts.duration_since(::std::time::UNIX_EPOCH).unwrap(); + + let timestamp = ::js_int::UInt::try_from(time_since_epoch.as_millis()) + .map_err(S::Error::custom)?; + + state.serialize_field("origin_server_ts", ×tamp)?; + } + } else if name == "unsigned" { + quote! { + if !self.unsigned.is_empty() { + state.serialize_field("unsigned", &self.unsigned)?; + } + } + } else { + quote! { + state.serialize_field(stringify!(#name), &self.#name)?; + } + } + }) + .collect::>(); + + let serialize_impl = quote! { + impl ::serde::ser::Serialize for #ident + where + C::Raw: ::ruma_events::RawEventContent, + { + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::ser::Serializer, + { + use ::serde::ser::SerializeStruct as _; + + let event_type = self.content.event_type(); + + let mut state = serializer.serialize_struct("StateEvent", 7)?; + + state.serialize_field("type", event_type)?; + #( #serialize_fields )* + state.end() + } + } + }; + + let raw_mod = expand_raw_state_event(&input, fields)?; + + Ok(quote! { + #try_from_raw_impl + + #serialize_impl + + #raw_mod + }) +} + +fn expand_raw_state_event(input: &DeriveInput, fields: Vec) -> syn::Result { + let ident = &input.ident; + let content_ident = Ident::new(&format!("{}Content", ident), input.ident.span()); + + // the raw version has no bounds on its type param + let generics = { + let mut gen = input.generics.clone(); + for p in &mut gen.params { + if let GenericParam::Type(TypeParam { bounds, .. }) = p { + bounds.clear(); + } + } + gen + }; + + let enum_variants = fields + .iter() + .map(|field| { + let name = field.ident.as_ref().unwrap(); + to_camel_case(name) + }) + .collect::>(); + + let deserialize_var_types = fields + .iter() + .map(|field| { + let name = field.ident.as_ref().unwrap(); + let ty = &field.ty; + if name == "content" || name == "prev_content" { + quote! { Box<::serde_json::value::RawValue> } + } else if name == "origin_server_ts" { + quote! { ::js_int::UInt } + } else { + quote! { #ty } + } + }) + .collect::>(); + + let ok_or_else_fields = fields + .iter() + .map(|field| { + let name = field.ident.as_ref().unwrap(); + if name == "content" { + quote! { + let raw = content.ok_or_else(|| ::serde::de::Error::missing_field("content"))?; + let content = C::from_parts(&event_type, raw).map_err(A::Error::custom)?; + } + } else if name == "prev_content" { + quote! { + let prev_content = if let Some(raw) = prev_content { + Some(C::from_parts(&event_type, raw).map_err(A::Error::custom)?) + } else { + None + }; + } + } else if name == "origin_server_ts" { + quote! { + let origin_server_ts = origin_server_ts + .map(|time| { + let t = time.into(); + ::std::time::UNIX_EPOCH + ::std::time::Duration::from_millis(t) + }) + .ok_or_else(|| ::serde::de::Error::missing_field("origin_server_ts"))?; + } + } else if name == "unsigned" { + quote! { let unsigned = unsigned.unwrap_or_default(); } + } else { + quote! { + let #name = #name.ok_or_else(|| { + ::serde::de::Error::missing_field(stringify!(#name)) + })?; + } + } + }) + .collect::>(); + + let field_names = fields.iter().flat_map(|f| &f.ident).collect::>(); + + let deserialize_impl = quote! { + impl<'de, C> ::serde::de::Deserialize<'de> for #ident + where + C: ::ruma_events::RawEventContent, + { + fn deserialize(deserializer: D) -> Result + where + D: ::serde::de::Deserializer<'de>, + { + #[derive(serde::Deserialize)] + #[serde(field_identifier, rename_all = "snake_case")] + enum Field { + // since this is represented as an enum we have to add it so the JSON picks it up + Type, + #( #enum_variants ),* + } + + /// Visits the fields of an event struct to handle deserialization of + /// the `content` and `prev_content` fields. + struct EventVisitor(::std::marker::PhantomData); + + impl<'de, C> ::serde::de::Visitor<'de> for EventVisitor + where + C: ::ruma_events::RawEventContent, + { + type Value = #ident; + + fn expecting(&self, formatter: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(formatter, "struct implementing {}", stringify!(#content_ident)) + } + + fn visit_map(self, mut map: A) -> Result + where + A: ::serde::de::MapAccess<'de>, + { + use ::serde::de::Error as _; + + let mut event_type: Option = None; + #( let mut #field_names: Option<#deserialize_var_types> = None; )* + + while let Some(key) = map.next_key()? { + match key { + Field::Type => { + if event_type.is_some() { + return Err(::serde::de::Error::duplicate_field("type")); + } + event_type = Some(map.next_value()?); + } + #( + Field::#enum_variants => { + if #field_names.is_some() { + return Err(::serde::de::Error::duplicate_field(stringify!(#field_names))); + } + #field_names = Some(map.next_value()?); + } + )* + } + } + + let event_type = event_type.ok_or_else(|| ::serde::de::Error::missing_field("type"))?; + #( #ok_or_else_fields )* + + Ok(#ident { + #( #field_names ),* + }) + } + } + + deserializer.deserialize_map(EventVisitor(::std::marker::PhantomData)) + } + } + }; + + let raw_docs = format!("The raw version of {}, allows for deserialization.", ident); + Ok(quote! { + #[doc = #raw_docs] + mod raw_event { + use super::*; + + #[derive(Clone, Debug)] + pub struct #ident #generics { + #( #fields ),* + } + + #deserialize_impl + } + }) +} + +/// CamelCase's a field ident like "foo_bar" to "FooBar". +fn to_camel_case(name: &Ident) -> Ident { + let span = name.span(); + let name = name.to_string(); + + let s = name + .split('_') + .map(|s| s.chars().next().unwrap().to_uppercase().to_string() + &s[1..]) + .collect::(); + Ident::new(&s, span) +} diff --git a/ruma-events-macros/src/event_content.rs b/ruma-events-macros/src/event_content.rs index 5c6f202c..619ea128 100644 --- a/ruma-events-macros/src/event_content.rs +++ b/ruma-events-macros/src/event_content.rs @@ -26,7 +26,7 @@ impl Parse for EventMeta { /// Create a `RoomEventContent` implementation for a struct. /// /// This is used internally for code sharing as `RoomEventContent` is not derivable. -fn expand_room_event(input: DeriveInput) -> syn::Result { +fn expand_room_event_content(input: DeriveInput) -> syn::Result { let ident = &input.ident; let event_type_attr = input @@ -76,9 +76,9 @@ fn expand_room_event(input: DeriveInput) -> syn::Result { } /// Create a `MessageEventContent` implementation for a struct -pub fn expand_message_event(input: DeriveInput) -> syn::Result { +pub fn expand_message_event_content(input: DeriveInput) -> syn::Result { let ident = input.ident.clone(); - let room_ev_content = expand_room_event(input)?; + let room_ev_content = expand_room_event_content(input)?; Ok(quote! { #room_ev_content @@ -88,9 +88,9 @@ pub fn expand_message_event(input: DeriveInput) -> syn::Result { } /// Create a `MessageEventContent` implementation for a struct -pub fn expand_state_event(input: DeriveInput) -> syn::Result { +pub fn expand_state_event_content(input: DeriveInput) -> syn::Result { let ident = input.ident.clone(); - let room_ev_content = expand_room_event(input)?; + let room_ev_content = expand_room_event_content(input)?; Ok(quote! { #room_ev_content diff --git a/ruma-events-macros/src/lib.rs b/ruma-events-macros/src/lib.rs index 585dcdad..3e2a143d 100644 --- a/ruma-events-macros/src/lib.rs +++ b/ruma-events-macros/src/lib.rs @@ -16,13 +16,15 @@ use syn::{parse_macro_input, DeriveInput}; use self::{ collection::{expand_collection, parse::RumaCollectionInput}, - event_content::{expand_message_event, expand_state_event}, + event::expand_event, + event_content::{expand_message_event_content, expand_state_event_content}, from_raw::expand_from_raw, gen::RumaEvent, parse::RumaEventInput, }; mod collection; +mod event; mod event_content; mod from_raw; mod gen; @@ -148,7 +150,7 @@ pub fn derive_from_raw(input: TokenStream) -> TokenStream { #[proc_macro_derive(MessageEventContent, attributes(ruma_event))] pub fn derive_message_event_content(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - expand_message_event(input) + expand_message_event_content(input) .unwrap_or_else(|err| err.to_compile_error()) .into() } @@ -157,7 +159,16 @@ pub fn derive_message_event_content(input: TokenStream) -> TokenStream { #[proc_macro_derive(StateEventContent, attributes(ruma_event))] pub fn derive_state_event_content(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - expand_state_event(input) + expand_state_event_content(input) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// Generates implementations needed to serialize and deserialize Matrix events. +#[proc_macro_derive(Event, attributes(ruma_event))] +pub fn derive_state_event(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_event(input) .unwrap_or_else(|err| err.to_compile_error()) .into() } diff --git a/src/message.rs b/src/message.rs index 9ed05695..95bf3ca8 100644 --- a/src/message.rs +++ b/src/message.rs @@ -14,7 +14,7 @@ use serde::{ }; use crate::{MessageEventContent, RawEventContent, RoomEventContent, TryFromRaw, UnsignedData}; -use ruma_events_macros::event_content_collection; +use ruma_events_macros::{event_content_collection, Event}; event_content_collection! { /// A message event. @@ -29,7 +29,7 @@ event_content_collection! { } /// Message event. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Event)] pub struct MessageEvent where C::Raw: RawEventContent, @@ -53,220 +53,6 @@ where pub unsigned: UnsignedData, } -impl TryFromRaw for MessageEvent -where - C: MessageEventContent + TryFromRaw, - C::Raw: RawEventContent, -{ - type Raw = raw_message_event::MessageEvent; - type Err = C::Err; - - fn try_from_raw(raw: Self::Raw) -> Result { - Ok(Self { - content: C::try_from_raw(raw.content)?, - event_id: raw.event_id, - sender: raw.sender, - origin_server_ts: raw.origin_server_ts, - room_id: raw.room_id, - unsigned: raw.unsigned, - }) - } -} - -impl Serialize for MessageEvent -where - C::Raw: RawEventContent, -{ - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let event_type = self.content.event_type(); - - let time_since_epoch = self.origin_server_ts.duration_since(UNIX_EPOCH).unwrap(); - let timestamp = match UInt::try_from(time_since_epoch.as_millis()) { - Ok(uint) => uint, - Err(err) => return Err(S::Error::custom(err)), - }; - - let mut message = serializer.serialize_struct("MessageEvent", 7)?; - - message.serialize_field("type", event_type)?; - message.serialize_field("content", &self.content)?; - message.serialize_field("event_id", &self.event_id)?; - message.serialize_field("sender", &self.sender)?; - message.serialize_field("origin_server_ts", ×tamp)?; - message.serialize_field("room_id", &self.room_id)?; - - if !self.unsigned.is_empty() { - message.serialize_field("unsigned", &self.unsigned)?; - } - - message.end() - } -} - -mod raw_message_event { - use std::{ - fmt, - marker::PhantomData, - time::{Duration, SystemTime, UNIX_EPOCH}, - }; - - use js_int::UInt; - use ruma_identifiers::{EventId, RoomId, UserId}; - use serde::de::{self, Deserialize, Deserializer, Error as _, MapAccess, Visitor}; - use serde_json::value::RawValue as RawJsonValue; - - use crate::{RawEventContent, UnsignedData}; - - /// The raw half of a message event. - #[derive(Clone, Debug)] - pub struct MessageEvent { - /// Data specific to the event type. - pub content: C, - - /// The globally unique event identifier for the user who sent the event. - pub event_id: EventId, - - /// Contains the fully-qualified ID of the user who sent this event. - pub sender: UserId, - - /// Timestamp in milliseconds on originating homeserver when this event was sent. - pub origin_server_ts: SystemTime, - - /// The ID of the room associated with this event. - pub room_id: RoomId, - - /// Additional key-value pairs not signed by the homeserver. - pub unsigned: UnsignedData, - } - - impl<'de, C> Deserialize<'de> for MessageEvent - where - C: RawEventContent, - { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_map(MessageEventVisitor(std::marker::PhantomData)) - } - } - - #[derive(serde::Deserialize)] - #[serde(field_identifier, rename_all = "snake_case")] - enum Field { - Type, - Content, - EventId, - Sender, - OriginServerTs, - RoomId, - Unsigned, - } - - /// Visits the fields of a MessageEvent to handle deserialization of - /// the `content` and `prev_content` fields. - struct MessageEventVisitor(PhantomData); - - impl<'de, C> Visitor<'de> for MessageEventVisitor - where - C: RawEventContent, - { - type Value = MessageEvent; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "struct implementing MessageEventContent") - } - - fn visit_map(self, mut map: A) -> Result - where - A: MapAccess<'de>, - { - let mut content: Option> = None; - let mut event_type: Option = None; - let mut event_id: Option = None; - let mut sender: Option = None; - let mut origin_server_ts: Option = None; - let mut room_id: Option = None; - let mut unsigned: Option = None; - - while let Some(key) = map.next_key()? { - match key { - Field::Content => { - if content.is_some() { - return Err(de::Error::duplicate_field("content")); - } - content = Some(map.next_value()?); - } - Field::EventId => { - if event_id.is_some() { - return Err(de::Error::duplicate_field("event_id")); - } - event_id = Some(map.next_value()?); - } - Field::Sender => { - if sender.is_some() { - return Err(de::Error::duplicate_field("sender")); - } - sender = Some(map.next_value()?); - } - Field::OriginServerTs => { - if origin_server_ts.is_some() { - return Err(de::Error::duplicate_field("origin_server_ts")); - } - origin_server_ts = Some(map.next_value()?); - } - Field::RoomId => { - if room_id.is_some() { - return Err(de::Error::duplicate_field("room_id")); - } - room_id = Some(map.next_value()?); - } - Field::Type => { - if event_type.is_some() { - return Err(de::Error::duplicate_field("type")); - } - event_type = Some(map.next_value()?); - } - Field::Unsigned => { - if unsigned.is_some() { - return Err(de::Error::duplicate_field("unsigned")); - } - unsigned = Some(map.next_value()?); - } - } - } - - let event_type = event_type.ok_or_else(|| de::Error::missing_field("type"))?; - - let raw = content.ok_or_else(|| de::Error::missing_field("content"))?; - let content = C::from_parts(&event_type, raw).map_err(A::Error::custom)?; - - let event_id = event_id.ok_or_else(|| de::Error::missing_field("event_id"))?; - let sender = sender.ok_or_else(|| de::Error::missing_field("sender"))?; - - let origin_server_ts = origin_server_ts - .map(|time| UNIX_EPOCH + Duration::from_millis(time.into())) - .ok_or_else(|| de::Error::missing_field("origin_server_ts"))?; - - let room_id = room_id.ok_or_else(|| de::Error::missing_field("room_id"))?; - - let unsigned = unsigned.unwrap_or_default(); - - Ok(MessageEvent { - content, - event_id, - sender, - origin_server_ts, - room_id, - unsigned, - }) - } - } -} - #[cfg(test)] mod tests { use std::{ @@ -288,7 +74,7 @@ mod tests { }; #[test] - fn message_serialize_aliases() { + fn message_serialize_sticker() { let aliases_event = MessageEvent { content: AnyMessageEventContent::Sticker(StickerEventContent { body: "Hello".into(), @@ -345,7 +131,7 @@ mod tests { } #[test] - fn deserialize_message_aliases_content() { + fn deserialize_message_call_answer_content() { let json_data = json!({ "answer": { "type": "answer", @@ -372,7 +158,7 @@ mod tests { } #[test] - fn deserialize_message_aliases() { + fn deserialize_message_call_answer() { let json_data = json!({ "content": { "answer": { @@ -418,7 +204,7 @@ mod tests { } #[test] - fn deserialize_message_avatar() { + fn deserialize_message_sticker() { let json_data = json!({ "content": { "body": "Hello", diff --git a/src/state.rs b/src/state.rs index 4bfe7c74..c3f2e5f1 100644 --- a/src/state.rs +++ b/src/state.rs @@ -14,7 +14,7 @@ use serde::{ }; use crate::{RawEventContent, RoomEventContent, StateEventContent, TryFromRaw, UnsignedData}; -use ruma_events_macros::event_content_collection; +use ruma_events_macros::{event_content_collection, Event}; event_content_collection! { /// A state event. @@ -23,7 +23,7 @@ event_content_collection! { } /// State event. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Event)] pub struct StateEvent where C::Raw: RawEventContent, @@ -56,261 +56,6 @@ where pub unsigned: UnsignedData, } -impl TryFromRaw for StateEvent -where - C: StateEventContent + TryFromRaw, - C::Raw: RawEventContent, -{ - type Raw = raw_state_event::StateEvent; - type Err = C::Err; - - fn try_from_raw(raw: Self::Raw) -> Result { - Ok(Self { - content: C::try_from_raw(raw.content)?, - event_id: raw.event_id, - sender: raw.sender, - origin_server_ts: raw.origin_server_ts, - room_id: raw.room_id, - state_key: raw.state_key, - prev_content: raw.prev_content.map(C::try_from_raw).transpose()?, - unsigned: raw.unsigned, - }) - } -} - -impl Serialize for StateEvent -where - C::Raw: RawEventContent, -{ - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let event_type = self.content.event_type(); - - let time_since_epoch = self.origin_server_ts.duration_since(UNIX_EPOCH).unwrap(); - let timestamp = match UInt::try_from(time_since_epoch.as_millis()) { - Ok(uint) => uint, - Err(err) => return Err(S::Error::custom(err)), - }; - - let mut state = serializer.serialize_struct("StateEvent", 7)?; - - state.serialize_field("type", event_type)?; - state.serialize_field("content", &self.content)?; - state.serialize_field("event_id", &self.event_id)?; - state.serialize_field("sender", &self.sender)?; - state.serialize_field("origin_server_ts", ×tamp)?; - state.serialize_field("room_id", &self.room_id)?; - state.serialize_field("state_key", &self.state_key)?; - - if let Some(content) = self.prev_content.as_ref() { - state.serialize_field("prev_content", content)?; - } - - if !self.unsigned.is_empty() { - state.serialize_field("unsigned", &self.unsigned)?; - } - - state.end() - } -} - -mod raw_state_event { - use std::{ - fmt, - marker::PhantomData, - time::{Duration, SystemTime, UNIX_EPOCH}, - }; - - use js_int::UInt; - use ruma_identifiers::{EventId, RoomId, UserId}; - use serde::de::{self, Deserialize, Deserializer, Error as _, MapAccess, Visitor}; - use serde_json::value::RawValue as RawJsonValue; - - use crate::{RawEventContent, UnsignedData}; - - /// State event. - #[derive(Clone, Debug)] - pub struct StateEvent { - /// Data specific to the event type. - pub content: C, - - /// The globally unique event identifier for the user who sent the event. - pub event_id: EventId, - - /// Contains the fully-qualified ID of the user who sent this event. - pub sender: UserId, - - /// Timestamp in milliseconds on originating homeserver when this event was sent. - pub origin_server_ts: SystemTime, - - /// The ID of the room associated with this event. - pub room_id: RoomId, - - /// A unique key which defines the overwriting semantics for this piece of room state. - /// - /// This is often an empty string, but some events send a `UserId` to show - /// which user the event affects. - pub state_key: String, - - /// Optional previous content for this event. - pub prev_content: Option, - - /// Additional key-value pairs not signed by the homeserver. - pub unsigned: UnsignedData, - } - - impl<'de, C> Deserialize<'de> for StateEvent - where - C: RawEventContent, - { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_map(StateEventVisitor(std::marker::PhantomData)) - } - } - - #[derive(serde::Deserialize)] - #[serde(field_identifier, rename_all = "snake_case")] - enum Field { - Type, - Content, - EventId, - Sender, - OriginServerTs, - RoomId, - StateKey, - PrevContent, - Unsigned, - } - - /// Visits the fields of a StateEvent to handle deserialization of - /// the `content` and `prev_content` fields. - struct StateEventVisitor(PhantomData); - - impl<'de, C> Visitor<'de> for StateEventVisitor - where - C: RawEventContent, - { - type Value = StateEvent; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "struct implementing StateEventContent") - } - - fn visit_map(self, mut map: A) -> Result - where - A: MapAccess<'de>, - { - let mut content: Option> = None; - let mut event_type: Option = None; - let mut event_id: Option = None; - let mut sender: Option = None; - let mut origin_server_ts: Option = None; - let mut room_id: Option = None; - let mut state_key: Option = None; - let mut prev_content: Option> = None; - let mut unsigned: Option = None; - - while let Some(key) = map.next_key()? { - match key { - Field::Content => { - if content.is_some() { - return Err(de::Error::duplicate_field("content")); - } - content = Some(map.next_value()?); - } - Field::EventId => { - if event_id.is_some() { - return Err(de::Error::duplicate_field("event_id")); - } - event_id = Some(map.next_value()?); - } - Field::Sender => { - if sender.is_some() { - return Err(de::Error::duplicate_field("sender")); - } - sender = Some(map.next_value()?); - } - Field::OriginServerTs => { - if origin_server_ts.is_some() { - return Err(de::Error::duplicate_field("origin_server_ts")); - } - origin_server_ts = Some(map.next_value()?); - } - Field::RoomId => { - if room_id.is_some() { - return Err(de::Error::duplicate_field("room_id")); - } - room_id = Some(map.next_value()?); - } - Field::StateKey => { - if state_key.is_some() { - return Err(de::Error::duplicate_field("state_key")); - } - state_key = Some(map.next_value()?); - } - Field::PrevContent => { - if prev_content.is_some() { - return Err(de::Error::duplicate_field("prev_content")); - } - prev_content = Some(map.next_value()?); - } - Field::Type => { - if event_type.is_some() { - return Err(de::Error::duplicate_field("type")); - } - event_type = Some(map.next_value()?); - } - Field::Unsigned => { - if unsigned.is_some() { - return Err(de::Error::duplicate_field("unsigned")); - } - unsigned = Some(map.next_value()?); - } - } - } - - let event_type = event_type.ok_or_else(|| de::Error::missing_field("type"))?; - - let raw = content.ok_or_else(|| de::Error::missing_field("content"))?; - let content = C::from_parts(&event_type, raw).map_err(A::Error::custom)?; - - let event_id = event_id.ok_or_else(|| de::Error::missing_field("event_id"))?; - let sender = sender.ok_or_else(|| de::Error::missing_field("sender"))?; - - let origin_server_ts = origin_server_ts - .map(|time| UNIX_EPOCH + Duration::from_millis(time.into())) - .ok_or_else(|| de::Error::missing_field("origin_server_ts"))?; - - let room_id = room_id.ok_or_else(|| de::Error::missing_field("room_id"))?; - let state_key = state_key.ok_or_else(|| de::Error::missing_field("state_key"))?; - - let prev_content = if let Some(raw) = prev_content { - Some(C::from_parts(&event_type, raw).map_err(A::Error::custom)?) - } else { - None - }; - - let unsigned = unsigned.unwrap_or_default(); - - Ok(StateEvent { - content, - event_id, - sender, - origin_server_ts, - room_id, - state_key, - prev_content, - unsigned, - }) - } - } -} - #[cfg(test)] mod tests { use std::{ diff --git a/tests/event.rs b/tests/event.rs new file mode 100644 index 00000000..d9dfa4fa --- /dev/null +++ b/tests/event.rs @@ -0,0 +1,10 @@ +#[test] +fn ui() { + let t = trybuild::TestCases::new(); + // rustc overflows when compiling this see: + // https://github.com/rust-lang/rust/issues/55779 + // there is a workaround in the file. + t.pass("tests/ui/04-event-sanity-check.rs"); + t.compile_fail("tests/ui/05-named-fields.rs"); + t.compile_fail("tests/ui/06-no-content-field.rs"); +} diff --git a/tests/event_content.rs b/tests/event_content.rs index aacac007..ba0426f1 100644 --- a/tests/event_content.rs +++ b/tests/event_content.rs @@ -1,7 +1,7 @@ #[test] fn ui() { let t = trybuild::TestCases::new(); - t.pass("tests/ui/01-sanity-check.rs"); + t.pass("tests/ui/01-content-sanity-check.rs"); t.compile_fail("tests/ui/02-no-event-type.rs"); t.compile_fail("tests/ui/03-invalid-event-type.rs"); } diff --git a/tests/ui/01-sanity-check.rs b/tests/ui/01-content-sanity-check.rs similarity index 100% rename from tests/ui/01-sanity-check.rs rename to tests/ui/01-content-sanity-check.rs diff --git a/tests/ui/04-event-sanity-check.rs b/tests/ui/04-event-sanity-check.rs new file mode 100644 index 00000000..e2cb3b71 --- /dev/null +++ b/tests/ui/04-event-sanity-check.rs @@ -0,0 +1,19 @@ +// rustc overflows when compiling this see: +// https://github.com/rust-lang/rust/issues/55779 +extern crate serde; + +use ruma_events::{RawEventContent, StateEventContent}; +use ruma_events_macros::Event; + +/// State event. +#[derive(Clone, Debug, Event)] +pub struct StateEvent +where + C::Raw: RawEventContent, +{ + pub content: C, + pub state_key: String, + pub prev_content: Option, +} + +fn main() {} diff --git a/tests/ui/05-named-fields.rs b/tests/ui/05-named-fields.rs new file mode 100644 index 00000000..c059c235 --- /dev/null +++ b/tests/ui/05-named-fields.rs @@ -0,0 +1,10 @@ +use ruma_events::{RawEventContent, StateEventContent}; +use ruma_events_macros::Event; + +/// State event. +#[derive(Clone, Debug, Event)] +pub struct StateEvent(C) +where + C::Raw: RawEventContent; + +fn main() {} diff --git a/tests/ui/05-named-fields.stderr b/tests/ui/05-named-fields.stderr new file mode 100644 index 00000000..2d11539b --- /dev/null +++ b/tests/ui/05-named-fields.stderr @@ -0,0 +1,5 @@ +error: the `Event` derive only supports named fields + --> $DIR/05-named-fields.rs:6:44 + | +6 | pub struct StateEvent(C) + | ^^^ diff --git a/tests/ui/06-no-content-field.rs b/tests/ui/06-no-content-field.rs new file mode 100644 index 00000000..b991dc9d --- /dev/null +++ b/tests/ui/06-no-content-field.rs @@ -0,0 +1,13 @@ +use ruma_events::{RawEventContent, StateEventContent}; +use ruma_events_macros::Event; + +/// State event. +#[derive(Clone, Debug, Event)] +pub struct StateEvent +where + C::Raw: RawEventContent +{ + pub not_content: C, +} + +fn main() {} diff --git a/tests/ui/06-no-content-field.stderr b/tests/ui/06-no-content-field.stderr new file mode 100644 index 00000000..361f94a0 --- /dev/null +++ b/tests/ui/06-no-content-field.stderr @@ -0,0 +1,7 @@ +error: struct must contain a `content` field + --> $DIR/06-no-content-field.rs:5:24 + | +5 | #[derive(Clone, Debug, Event)] + | ^^^^^ + | + = note: this error originates in a derive macro (in Nightly builds, run with -Z macro-backtrace for more info)