diff --git a/crates/ruma-events-macros/src/event.rs b/crates/ruma-events-macros/src/event.rs index f9fc8487..5394b276 100644 --- a/crates/ruma-events-macros/src/event.rs +++ b/crates/ruma-events-macros/src/event.rs @@ -10,6 +10,7 @@ use syn::{ use crate::{ event_parse::{to_kind_variation, EventKind, EventKindVariation}, import_ruma_events, + util::is_non_stripped_room_event, }; /// Derive `Event` macro code generation. @@ -41,19 +42,26 @@ pub fn expand_event(input: DeriveInput) -> syn::Result { )); }; - let conversion_impl = expand_from_into(&input, kind, var, &fields, &ruma_events); - let serialize_impl = expand_serialize_event(&input, var, &fields, &ruma_events); - let deserialize_impl = expand_deserialize_event(&input, kind, var, &fields, &ruma_events)?; - let redact_impl = expand_redact_event(&input, kind, var, &fields, &ruma_events); - let eq_impl = expand_eq_ord_event(&input, &fields); + let mut res = TokenStream::new(); - Ok(quote! { - #conversion_impl - #serialize_impl - #deserialize_impl - #redact_impl - #eq_impl - }) + res.extend(expand_serialize_event(&input, var, &fields, &ruma_events)); + res.extend(expand_deserialize_event(&input, kind, var, &fields, &ruma_events)?); + + if var.is_sync() { + res.extend(expand_sync_from_into_full(&input, kind, var, &fields, &ruma_events)); + } + + if matches!(kind, EventKind::Message | EventKind::State) + && matches!(var, EventKindVariation::Full | EventKindVariation::Sync) + { + res.extend(expand_redact_event(&input, kind, var, &fields, &ruma_events)); + } + + if is_non_stripped_room_event(kind, var) { + res.extend(expand_eq_ord_event(&input)); + } + + Ok(res) } fn expand_serialize_event( @@ -405,17 +413,17 @@ fn expand_redact_event( var: EventKindVariation, fields: &[Field], ruma_events: &TokenStream, -) -> Option { +) -> TokenStream { let ruma_identifiers = quote! { #ruma_events::exports::ruma_identifiers }; - let redacted_type = kind.to_event_ident(var.to_redacted()?)?; + let redacted_type = kind.to_event_ident(var.to_redacted().unwrap()).unwrap(); let redacted_content_trait = format_ident!("{}Content", kind.to_event_ident(EventKindVariation::Redacted).unwrap()); let ident = &input.ident; let mut generics = input.generics.clone(); if generics.params.is_empty() { - return None; + return TokenStream::new(); } assert_eq!(generics.params.len(), 1, "expected one generic parameter"); @@ -450,7 +458,7 @@ fn expand_redact_event( } }); - Some(quote! { + quote! { #[automatically_derived] impl #impl_generics #ruma_events::Redact for #ident #ty_gen #where_clause { type Redacted = @@ -468,91 +476,83 @@ fn expand_redact_event( } } } - }) + } } -fn expand_from_into( +fn expand_sync_from_into_full( input: &DeriveInput, kind: EventKind, var: EventKindVariation, fields: &[Field], ruma_events: &TokenStream, -) -> Option { +) -> TokenStream { let ruma_identifiers = quote! { #ruma_events::exports::ruma_identifiers }; let ident = &input.ident; - + let full_struct = kind.to_event_ident(var.to_full().unwrap()).unwrap(); let (impl_generics, ty_gen, where_clause) = input.generics.split_for_impl(); - let fields: Vec<_> = fields.iter().flat_map(|f| &f.ident).collect(); - if let EventKindVariation::Sync | EventKindVariation::RedactedSync = var { - let full_struct = kind.to_event_ident(var.to_full().unwrap()).unwrap(); - Some(quote! { - #[automatically_derived] - impl #impl_generics ::std::convert::From<#full_struct #ty_gen> - for #ident #ty_gen #where_clause - { - fn from(event: #full_struct #ty_gen) -> Self { - let #full_struct { #( #fields, )* .. } = event; - Self { #( #fields, )* } - } + quote! { + #[automatically_derived] + impl #impl_generics ::std::convert::From<#full_struct #ty_gen> + for #ident #ty_gen #where_clause + { + fn from(event: #full_struct #ty_gen) -> Self { + let #full_struct { #( #fields, )* .. } = event; + Self { #( #fields, )* } } + } - #[automatically_derived] - impl #impl_generics #ident #ty_gen #where_clause { - /// Convert this sync event into a full event, one with a room_id field. - pub fn into_full_event( - self, - room_id: #ruma_identifiers::RoomId, - ) -> #full_struct #ty_gen { - let Self { #( #fields, )* } = self; - #full_struct { - #( #fields, )* - room_id, - } - } - } - }) - } else { - None - } -} - -fn expand_eq_ord_event(input: &DeriveInput, fields: &[Field]) -> Option { - fields.iter().flat_map(|f| f.ident.as_ref()).any(|f| f == "event_id").then(|| { - let ident = &input.ident; - let (impl_gen, ty_gen, where_clause) = input.generics.split_for_impl(); - - quote! { - #[automatically_derived] - impl #impl_gen ::std::cmp::PartialEq for #ident #ty_gen #where_clause { - /// Checks if two `EventId`s are equal. - fn eq(&self, other: &Self) -> ::std::primitive::bool { - self.event_id == other.event_id - } - } - - #[automatically_derived] - impl #impl_gen ::std::cmp::Eq for #ident #ty_gen #where_clause {} - - #[automatically_derived] - impl #impl_gen ::std::cmp::PartialOrd for #ident #ty_gen #where_clause { - /// Compares `EventId`s and orders them lexicographically. - fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> { - self.event_id.partial_cmp(&other.event_id) - } - } - - #[automatically_derived] - impl #impl_gen ::std::cmp::Ord for #ident #ty_gen #where_clause { - /// Compares `EventId`s and orders them lexicographically. - fn cmp(&self, other: &Self) -> ::std::cmp::Ordering { - self.event_id.cmp(&other.event_id) + #[automatically_derived] + impl #impl_generics #ident #ty_gen #where_clause { + /// Convert this sync event into a full event, one with a room_id field. + pub fn into_full_event( + self, + room_id: #ruma_identifiers::RoomId, + ) -> #full_struct #ty_gen { + let Self { #( #fields, )* } = self; + #full_struct { + #( #fields, )* + room_id, } } } - }) + } +} + +fn expand_eq_ord_event(input: &DeriveInput) -> TokenStream { + let ident = &input.ident; + let (impl_gen, ty_gen, where_clause) = input.generics.split_for_impl(); + + quote! { + #[automatically_derived] + impl #impl_gen ::std::cmp::PartialEq for #ident #ty_gen #where_clause { + /// Checks if two `EventId`s are equal. + fn eq(&self, other: &Self) -> ::std::primitive::bool { + self.event_id == other.event_id + } + } + + #[automatically_derived] + impl #impl_gen ::std::cmp::Eq for #ident #ty_gen #where_clause {} + + #[automatically_derived] + impl #impl_gen ::std::cmp::PartialOrd for #ident #ty_gen #where_clause { + /// Compares `EventId`s and orders them lexicographically. + fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> { + self.event_id.partial_cmp(&other.event_id) + } + } + + #[automatically_derived] + impl #impl_gen ::std::cmp::Ord for #ident #ty_gen #where_clause { + /// Compares `EventId`s and orders them lexicographically. + fn cmp(&self, other: &Self) -> ::std::cmp::Ordering { + self.event_id.cmp(&other.event_id) + } + } + } } /// CamelCase's a field ident like "foo_bar" to "FooBar". diff --git a/crates/ruma-events-macros/src/event_enum.rs b/crates/ruma-events-macros/src/event_enum.rs index bfd193b6..05cdac1e 100644 --- a/crates/ruma-events-macros/src/event_enum.rs +++ b/crates/ruma-events-macros/src/event_enum.rs @@ -4,43 +4,10 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use syn::{Attribute, Data, DataEnum, DeriveInput, Ident, LitStr}; -use crate::event_parse::{EventEnumDecl, EventEnumEntry, EventKind, EventKindVariation}; - -fn is_non_stripped_room_event(kind: EventKind, var: EventKindVariation) -> bool { - matches!(kind, EventKind::Message | EventKind::State) - && matches!( - var, - EventKindVariation::Full - | EventKindVariation::Sync - | EventKindVariation::Redacted - | EventKindVariation::RedactedSync - ) -} - -fn has_prev_content_field(kind: EventKind, var: EventKindVariation) -> bool { - matches!(kind, EventKind::State) - && matches!(var, EventKindVariation::Full | EventKindVariation::Sync) -} - -type EventKindFn = fn(EventKind, EventKindVariation) -> bool; - -/// This const is used to generate the accessor methods for the `Any*Event` enums. -/// -/// DO NOT alter the field names unless the structs in `ruma_events::event_kinds` have changed. -const EVENT_FIELDS: &[(&str, EventKindFn)] = &[ - ("origin_server_ts", is_non_stripped_room_event), - ("room_id", |kind, var| { - matches!(kind, EventKind::Message | EventKind::State | EventKind::Ephemeral) - && matches!(var, EventKindVariation::Full | EventKindVariation::Redacted) - }), - ("event_id", is_non_stripped_room_event), - ("sender", |kind, var| { - matches!(kind, EventKind::Message | EventKind::State | EventKind::ToDevice) - && var != EventKindVariation::Initial - }), - ("state_key", |kind, _| matches!(kind, EventKind::State)), - ("unsigned", is_non_stripped_room_event), -]; +use crate::{ + event_parse::{EventEnumDecl, EventEnumEntry, EventKind, EventKindVariation}, + util::{has_prev_content_field, EVENT_FIELDS}, +}; /// Create a content enum from `EventEnumInput`. pub fn expand_event_enums(input: &EventEnumDecl) -> syn::Result { diff --git a/crates/ruma-events-macros/src/event_parse.rs b/crates/ruma-events-macros/src/event_parse.rs index 82e2fb81..fce5433f 100644 --- a/crates/ruma-events-macros/src/event_parse.rs +++ b/crates/ruma-events-macros/src/event_parse.rs @@ -45,6 +45,10 @@ impl EventKindVariation { matches!(self, Self::Redacted | Self::RedactedSync) } + pub fn is_sync(self) -> bool { + matches!(self, Self::Sync | Self::RedactedSync) + } + pub fn to_redacted(self) -> Option { match self { EventKindVariation::Full => Some(EventKindVariation::Redacted), diff --git a/crates/ruma-events-macros/src/lib.rs b/crates/ruma-events-macros/src/lib.rs index a777067f..37b43f48 100644 --- a/crates/ruma-events-macros/src/lib.rs +++ b/crates/ruma-events-macros/src/lib.rs @@ -27,6 +27,7 @@ mod event_content; mod event_enum; mod event_parse; mod event_type; +mod util; /// Generates an enum to represent the various Matrix event types. /// diff --git a/crates/ruma-events-macros/src/util.rs b/crates/ruma-events-macros/src/util.rs new file mode 100644 index 00000000..9bf3b86f --- /dev/null +++ b/crates/ruma-events-macros/src/util.rs @@ -0,0 +1,37 @@ +use crate::event_parse::{EventKind, EventKindVariation}; + +pub(crate) fn is_non_stripped_room_event(kind: EventKind, var: EventKindVariation) -> bool { + matches!(kind, EventKind::Message | EventKind::State) + && matches!( + var, + EventKindVariation::Full + | EventKindVariation::Sync + | EventKindVariation::Redacted + | EventKindVariation::RedactedSync + ) +} + +pub(crate) fn has_prev_content_field(kind: EventKind, var: EventKindVariation) -> bool { + matches!(kind, EventKind::State) + && matches!(var, EventKindVariation::Full | EventKindVariation::Sync) +} + +pub(crate) type EventKindFn = fn(EventKind, EventKindVariation) -> bool; + +/// This const is used to generate the accessor methods for the `Any*Event` enums. +/// +/// DO NOT alter the field names unless the structs in `ruma_events::event_kinds` have changed. +pub(crate) const EVENT_FIELDS: &[(&str, EventKindFn)] = &[ + ("origin_server_ts", is_non_stripped_room_event), + ("room_id", |kind, var| { + matches!(kind, EventKind::Message | EventKind::State | EventKind::Ephemeral) + && matches!(var, EventKindVariation::Full | EventKindVariation::Redacted) + }), + ("event_id", is_non_stripped_room_event), + ("sender", |kind, var| { + matches!(kind, EventKind::Message | EventKind::State | EventKind::ToDevice) + && var != EventKindVariation::Initial + }), + ("state_key", |kind, _| matches!(kind, EventKind::State)), + ("unsigned", is_non_stripped_room_event), +];