diff --git a/ruma-events-macros/src/event_enum.rs b/ruma-events-macros/src/event_enum.rs index 82e26a85..953c7014 100644 --- a/ruma-events-macros/src/event_enum.rs +++ b/ruma-events-macros/src/event_enum.rs @@ -1,28 +1,67 @@ //! Implementation of event enum and event content enum macros. -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ parse::{self, Parse, ParseStream}, Attribute, Expr, ExprLit, Ident, Lit, LitStr, Token, }; +use crate::event_names::{ + ANY_BASIC_EVENT, ANY_EPHEMERAL_EVENT, ANY_MESSAGE_EVENT, ANY_STATE_EVENT, + ANY_STRIPPED_STATE_EVENT, ANY_SYNC_MESSAGE_EVENT, ANY_SYNC_STATE_EVENT, ANY_TO_DEVICE_EVENT, +}; + +// Arrays of event enum names grouped by a field they share in common. +const ROOM_EVENT_KIND: &[&str] = + &[ANY_MESSAGE_EVENT, ANY_SYNC_MESSAGE_EVENT, ANY_STATE_EVENT, ANY_SYNC_STATE_EVENT]; + +const ROOM_ID_KIND: &[&str] = &[ANY_MESSAGE_EVENT, ANY_STATE_EVENT, ANY_EPHEMERAL_EVENT]; + +const EVENT_ID_KIND: &[&str] = + &[ANY_MESSAGE_EVENT, ANY_SYNC_MESSAGE_EVENT, ANY_STATE_EVENT, ANY_SYNC_STATE_EVENT]; + +const SENDER_KIND: &[&str] = &[ + ANY_MESSAGE_EVENT, + ANY_STATE_EVENT, + ANY_SYNC_STATE_EVENT, + ANY_TO_DEVICE_EVENT, + ANY_SYNC_MESSAGE_EVENT, + ANY_STRIPPED_STATE_EVENT, +]; + +const PREV_CONTENT_KIND: &[&str] = &[ANY_STATE_EVENT, ANY_SYNC_STATE_EVENT]; + +const STATE_KEY_KIND: &[&str] = &[ANY_STATE_EVENT, ANY_SYNC_STATE_EVENT, ANY_STRIPPED_STATE_EVENT]; + +/// This const is used to generate the accessor methods for the `Any*Event` enums. +/// +/// DO NOT alter the field names unless the structs in `ruma_events::event_kinds` have changed. +const EVENT_FIELDS: &[(&str, &[&str])] = &[ + ("origin_server_ts", ROOM_EVENT_KIND), + ("room_id", ROOM_ID_KIND), + ("event_id", EVENT_ID_KIND), + ("sender", SENDER_KIND), + ("state_key", STATE_KEY_KIND), + ("unsigned", ROOM_EVENT_KIND), +]; + /// Create a content enum from `EventEnumInput`. pub fn expand_event_enum(input: EventEnumInput) -> syn::Result { let ident = &input.name; let event_enum = expand_any_enum_with_deserialize(&input, ident)?; - let needs_event_content = ident == "AnyStateEvent" - || ident == "AnyMessageEvent" - || ident == "AnyToDeviceEvent" - || ident == "AnyEphemeralRoomEvent" - || ident == "AnyBasicEvent"; + let needs_event_content = ident == ANY_STATE_EVENT + || ident == ANY_MESSAGE_EVENT + || ident == ANY_TO_DEVICE_EVENT + || ident == ANY_EPHEMERAL_EVENT + || ident == ANY_BASIC_EVENT; let needs_event_stub = - ident == "AnyStateEvent" || ident == "AnyMessageEvent" || ident == "AnyEphemeralRoomEvent"; + ident == ANY_STATE_EVENT || ident == ANY_MESSAGE_EVENT || ident == ANY_EPHEMERAL_EVENT; - let needs_stripped_event = ident == "AnyStateEvent"; + let needs_stripped_event = ident == ANY_STATE_EVENT; let event_stub_enum = if needs_event_stub { expand_stub_enum(&input)? } else { TokenStream::new() }; @@ -175,9 +214,13 @@ fn expand_any_enum_with_deserialize( } }; + let field_accessor_impl = accessor_methods(ident, &variants); + Ok(quote! { #any_enum + #field_accessor_impl + #event_deserialize_impl }) } @@ -202,6 +245,57 @@ fn marker_traits(ident: &Ident) -> TokenStream { } } +fn accessor_methods(ident: &Ident, variants: &[Ident]) -> TokenStream { + let fields = EVENT_FIELDS + .iter() + .map(|(name, has_field)| generate_accessor(name, ident, *has_field, variants)); + + let any_content = ident.to_string().replace("Stub", "").replace("Stripped", ""); + let content_enum = Ident::new(&format!("{}Content", any_content), ident.span()); + + let content = quote! { + /// Returns the any content enum for this event. + pub fn content(&self) -> #content_enum { + match self { + #( + Self::#variants(event) => #content_enum::#variants(event.content.clone()), + )* + Self::Custom(event) => #content_enum::Custom(event.content.clone()), + } + } + }; + + let prev_content = if PREV_CONTENT_KIND.contains(&ident.to_string().as_str()) { + quote! { + /// Returns the any content enum for this events prev_content. + pub fn prev_content(&self) -> Option<#content_enum> { + match self { + #( + Self::#variants(event) => { + event.prev_content.as_ref().map(|c| #content_enum::#variants(c.clone())) + }, + )* + Self::Custom(event) => { + event.prev_content.as_ref().map(|c| #content_enum::Custom(c.clone())) + }, + } + } + } + } else { + TokenStream::new() + }; + + quote! { + impl #ident { + #content + + #prev_content + + #( #fields )* + } + } +} + fn to_event_path(name: &LitStr, struct_name: &Ident) -> TokenStream { let span = name.span(); let name = name.value(); @@ -280,6 +374,45 @@ pub(crate) fn to_camel_case(name: &LitStr) -> syn::Result { Ok(Ident::new(&s, span)) } +fn generate_accessor( + name: &str, + ident: &Ident, + event_kind_list: &[&str], + variants: &[Ident], +) -> TokenStream { + if event_kind_list.contains(&ident.to_string().as_str()) { + let field_type = field_return_type(name); + + let name = Ident::new(name, Span::call_site()); + let docs = format!("Returns this events {} field.", name); + quote! { + #[doc = #docs] + pub fn #name(&self) -> &#field_type { + match self { + #( + Self::#variants(event) => &event.#name, + )* + Self::Custom(event) => &event.#name, + } + } + } + } else { + TokenStream::new() + } +} + +fn field_return_type(name: &str) -> TokenStream { + match name { + "origin_server_ts" => quote! { ::std::time::SystemTime }, + "room_id" => quote! { ::ruma_identifiers::RoomId }, + "event_id" => quote! { ::ruma_identifiers::EventId }, + "sender" => quote! { ::ruma_identifiers::UserId }, + "state_key" => quote! { str }, + "unsigned" => quote! { ::ruma_events::UnsignedData }, + _ => panic!("the `ruma_events_macros::event_enum::EVENT_FIELD` const was changed"), + } +} + /// Custom keywords for the `event_enum!` macro mod kw { syn::custom_keyword!(name); diff --git a/ruma-events-macros/src/event_names.rs b/ruma-events-macros/src/event_names.rs new file mode 100644 index 00000000..9148b5c6 --- /dev/null +++ b/ruma-events-macros/src/event_names.rs @@ -0,0 +1,28 @@ +//! The names of the `Any*Event` enums. The event_enum! macro uses these names to generate +//! certain code for certain enums. If the names change this is the one source of truth, +//! most comparisons and branching uses these constants. + +// State events +pub const ANY_STATE_EVENT: &str = "AnyStateEvent"; + +pub const ANY_SYNC_STATE_EVENT: &str = "AnyStateEventStub"; + +pub const ANY_STRIPPED_STATE_EVENT: &str = "AnyStrippedStateEventStub"; + +// Message events +pub const ANY_MESSAGE_EVENT: &str = "AnyMessageEvent"; + +pub const ANY_SYNC_MESSAGE_EVENT: &str = "AnyMessageEventStub"; + +// Ephemeral events +pub const ANY_EPHEMERAL_EVENT: &str = "AnyEphemeralRoomEvent"; + +#[allow(dead_code)] +// This is currently not used but, left for completeness sake. +pub const ANY_SYNC_EPHEMERAL_EVENT: &str = "AnyEphemeralRoomEventStub"; + +// Basic event +pub const ANY_BASIC_EVENT: &str = "AnyBasicEvent"; + +// To device event +pub const ANY_TO_DEVICE_EVENT: &str = "AnyToDeviceEvent"; diff --git a/ruma-events-macros/src/lib.rs b/ruma-events-macros/src/lib.rs index aca3c288..657897ab 100644 --- a/ruma-events-macros/src/lib.rs +++ b/ruma-events-macros/src/lib.rs @@ -22,6 +22,7 @@ use self::{ mod event; mod event_content; mod event_enum; +mod event_names; /// Generates an enum to represent the various Matrix event types. /// diff --git a/ruma-events/tests/enums.rs b/ruma-events/tests/enums.rs index 58b3f124..e8a04f8f 100644 --- a/ruma-events/tests/enums.rs +++ b/ruma-events/tests/enums.rs @@ -1,7 +1,7 @@ use std::convert::TryFrom; use matches::assert_matches; -use ruma_identifiers::RoomAliasId; +use ruma_identifiers::{EventId, RoomAliasId, RoomId, UserId}; use serde_json::{from_value as from_json_value, json, Value as JsonValue}; use ruma_events::{ @@ -11,7 +11,8 @@ use ruma_events::{ power_levels::PowerLevelsEventContent, }, AnyEvent, AnyMessageEvent, AnyMessageEventStub, AnyRoomEvent, AnyRoomEventStub, AnyStateEvent, - AnyStateEventStub, MessageEvent, MessageEventStub, StateEvent, StateEventStub, + AnyStateEventContent, AnyStateEventStub, MessageEvent, MessageEventStub, StateEvent, + StateEventStub, }; fn message_event() -> JsonValue { @@ -243,3 +244,24 @@ fn alias_event_deserialization() { if aliases == vec![ RoomAliasId::try_from("#somewhere:localhost").unwrap() ] ); } + +#[test] +fn alias_event_field_access() { + let json_data = aliases_event(); + + assert_matches!( + from_json_value::(json_data.clone()), + Ok(AnyEvent::State(state_event)) + if state_event.state_key() == "" + && state_event.room_id() == &RoomId::try_from("!room:room.com").unwrap() + && state_event.event_id() == &EventId::try_from("$152037280074GZeOm:localhost").unwrap() + && state_event.sender() == &UserId::try_from("@example:localhost").unwrap() + ); + + let deser = from_json_value::(json_data).unwrap(); + if let AnyStateEventContent::RoomAliases(AliasesEventContent { aliases }) = deser.content() { + assert_eq!(aliases, vec![RoomAliasId::try_from("#somewhere:localhost").unwrap()]) + } else { + panic!("the `Any*Event` enum's accessor methods may have been altered") + } +}