diff --git a/crates/ruma-common/src/events/_custom.rs b/crates/ruma-common/src/events/_custom.rs index cb79d21f..f3104ad8 100644 --- a/crates/ruma-common/src/events/_custom.rs +++ b/crates/ruma-common/src/events/_custom.rs @@ -4,7 +4,7 @@ use serde_json::value::RawValue as RawJsonValue; use super::{ EphemeralRoomEventType, EventContent, GlobalAccountDataEventType, HasDeserializeFields, MessageLikeEventType, RedactContent, RedactedEventContent, RoomAccountDataEventType, - StateEventType, ToDeviceEventType, + StateEventContent, StateEventType, ToDeviceEventType, }; use crate::RoomVersionId; @@ -68,3 +68,7 @@ custom_event_content!(CustomEphemeralRoomEventContent, EphemeralRoomEventType); custom_room_event_content!(CustomMessageLikeEventContent, MessageLikeEventType); custom_room_event_content!(CustomStateEventContent, StateEventType); custom_event_content!(CustomToDeviceEventContent, ToDeviceEventType); + +impl StateEventContent for CustomStateEventContent { + type StateKey = String; +} diff --git a/crates/ruma-common/src/events/content.rs b/crates/ruma-common/src/events/content.rs index 61a03ec1..d6fd59f1 100644 --- a/crates/ruma-common/src/events/content.rs +++ b/crates/ruma-common/src/events/content.rs @@ -1,6 +1,6 @@ use std::fmt; -use serde::Serialize; +use serde::{de::DeserializeOwned, Serialize}; use serde_json::value::RawValue as RawJsonValue; use crate::serde::Raw; @@ -164,17 +164,18 @@ trait_aliases! { /// An alias for `EventContent`. trait MessageLikeEventContent = EventContent; - /// An alias for `EventContent + RedactedEventContent`. - trait RedactedMessageLikeEventContent = - EventContent, RedactedEventContent; + /// An alias for `MessageLikeEventContent + RedactedEventContent`. + trait RedactedMessageLikeEventContent = MessageLikeEventContent, RedactedEventContent; - /// An alias for `EventContent`. - trait StateEventContent = EventContent; - - /// An alias for `EventContent + RedactedEventContent`. - trait RedactedStateEventContent = - EventContent, RedactedEventContent; + /// An alias for `StateEventContent + RedactedEventContent`. + trait RedactedStateEventContent = StateEventContent, RedactedEventContent; /// An alias for `EventContent`. trait ToDeviceEventContent = EventContent; } + +/// An alias for `EventContent`. +pub trait StateEventContent: EventContent { + /// The type of the event's `state_key` field. + type StateKey: AsRef + Clone + fmt::Debug + DeserializeOwned + Serialize; +} diff --git a/crates/ruma-common/src/events/kinds.rs b/crates/ruma-common/src/events/kinds.rs index a47bb36e..b9467a85 100644 --- a/crates/ruma-common/src/events/kinds.rs +++ b/crates/ruma-common/src/events/kinds.rs @@ -203,7 +203,7 @@ pub struct OriginalStateEvent { /// /// This is often an empty string, but some events send a `UserId` to show which user the event /// affects. - pub state_key: String, + pub state_key: C::StateKey, /// Additional key-value pairs not signed by the homeserver. pub unsigned: StateUnsigned, @@ -231,7 +231,7 @@ pub struct OriginalSyncStateEvent { /// /// This is often an empty string, but some events send a `UserId` to show which user the event /// affects. - pub state_key: String, + pub state_key: C::StateKey, /// Additional key-value pairs not signed by the homeserver. pub unsigned: StateUnsigned, @@ -250,7 +250,7 @@ pub struct StrippedStateEvent { /// /// This is often an empty string, but some events send a `UserId` to show which user the event /// affects. - pub state_key: String, + pub state_key: C::StateKey, } /// A minimal state event, used for creating a new room. @@ -265,8 +265,7 @@ pub struct InitialStateEvent { /// affects. /// /// Defaults to the empty string. - #[ruma_event(default)] - pub state_key: String, + pub state_key: C::StateKey, } /// A redacted state event. @@ -294,7 +293,7 @@ pub struct RedactedStateEvent { /// /// This is often an empty string, but some events send a `UserId` to show which user the event /// affects. - pub state_key: String, + pub state_key: C::StateKey, /// Additional key-value pairs not signed by the homeserver. pub unsigned: RedactedUnsigned, @@ -322,7 +321,7 @@ pub struct RedactedSyncStateEvent { /// /// This is often an empty string, but some events send a `UserId` to show which user the event /// affects. - pub state_key: String, + pub state_key: C::StateKey, /// Additional key-value pairs not signed by the homeserver. pub unsigned: RedactedUnsigned, @@ -411,11 +410,16 @@ pub struct DecryptedMegolmV1Event { } macro_rules! impl_possibly_redacted_event { - ($ty:ident ( $content_trait:ident, $event_type:ident ) { $($extra:tt)* }) => { + ( + $ty:ident ( $content_trait:ident, $event_type:ident ) + $( where C::Redacted: $trait:ident, )? + { $($extra:tt)* } + ) => { impl $ty where C: $content_trait + RedactContent, C::Redacted: $content_trait + RedactedEventContent, + $( C::Redacted: $trait, )? { /// Returns the `type` of this event. pub fn event_type(&self) -> $event_type { @@ -457,6 +461,7 @@ macro_rules! impl_possibly_redacted_event { where C: $content_trait + RedactContent, C::Redacted: $content_trait + RedactedEventContent, + $( C::Redacted: $trait, )? { type Redacted = Self; @@ -472,6 +477,7 @@ macro_rules! impl_possibly_redacted_event { where C: $content_trait + RedactContent, C::Redacted: $content_trait + RedactedEventContent, + $( C::Redacted: $trait, )? { fn deserialize(deserializer: D) -> Result where @@ -526,57 +532,67 @@ impl_possibly_redacted_event!(SyncMessageLikeEvent(MessageLikeEventContent, Mess } }); -impl_possibly_redacted_event!(StateEvent(StateEventContent, StateEventType) { - /// Returns this event's `room_id` field. - pub fn room_id(&self) -> &RoomId { - match self { - Self::Original(ev) => &ev.room_id, - Self::Redacted(ev) => &ev.room_id, +impl_possibly_redacted_event!( + StateEvent(StateEventContent, StateEventType) + where + C::Redacted: StateEventContent, + { + /// Returns this event's `room_id` field. + pub fn room_id(&self) -> &RoomId { + match self { + Self::Original(ev) => &ev.room_id, + Self::Redacted(ev) => &ev.room_id, + } } - } - /// Returns this event's `state_key` field. - pub fn state_key(&self) -> &str { - match self { - Self::Original(ev) => &ev.state_key, - Self::Redacted(ev) => &ev.state_key, + /// Returns this event's `state_key` field. + pub fn state_key(&self) -> &C::StateKey { + match self { + Self::Original(ev) => &ev.state_key, + Self::Redacted(ev) => &ev.state_key, + } } - } - /// Get the inner `OriginalStateEvent` if this is an unredacted event. - pub fn as_original(&self) -> Option<&OriginalStateEvent> { - match self { - Self::Original(v) => Some(v), - _ => None, + /// Get the inner `OriginalStateEvent` if this is an unredacted event. + pub fn as_original(&self) -> Option<&OriginalStateEvent> { + match self { + Self::Original(v) => Some(v), + _ => None, + } } } -}); +); -impl_possibly_redacted_event!(SyncStateEvent(StateEventContent, StateEventType) { - /// Returns this event's `state_key` field. - pub fn state_key(&self) -> &str { - match self { - Self::Original(ev) => &ev.state_key, - Self::Redacted(ev) => &ev.state_key, +impl_possibly_redacted_event!( + SyncStateEvent(StateEventContent, StateEventType) + where + C::Redacted: StateEventContent, + { + /// Returns this event's `state_key` field. + pub fn state_key(&self) -> &C::StateKey { + match self { + Self::Original(ev) => &ev.state_key, + Self::Redacted(ev) => &ev.state_key, + } } - } - /// Get the inner `OriginalSyncStateEvent` if this is an unredacted event. - pub fn as_original(&self) -> Option<&OriginalSyncStateEvent> { - match self { - Self::Original(v) => Some(v), - _ => None, + /// Get the inner `OriginalSyncStateEvent` if this is an unredacted event. + pub fn as_original(&self) -> Option<&OriginalSyncStateEvent> { + match self { + Self::Original(v) => Some(v), + _ => None, + } } - } - /// Convert this sync event into a full event (one with a `room_id` field). - pub fn into_full_event(self, room_id: OwnedRoomId) -> StateEvent { - match self { - Self::Original(ev) => StateEvent::Original(ev.into_full_event(room_id)), - Self::Redacted(ev) => StateEvent::Redacted(ev.into_full_event(room_id)), + /// Convert this sync event into a full event (one with a `room_id` field). + pub fn into_full_event(self, room_id: OwnedRoomId) -> StateEvent { + match self { + Self::Original(ev) => StateEvent::Original(ev.into_full_event(room_id)), + Self::Redacted(ev) => StateEvent::Redacted(ev.into_full_event(room_id)), + } } } -}); +); macro_rules! impl_sync_from_full { ($ty:ident, $full:ident, $content_trait:ident) => { diff --git a/crates/ruma-common/src/events/room/aliases.rs b/crates/ruma-common/src/events/room/aliases.rs index cb31571d..888d38b3 100644 --- a/crates/ruma-common/src/events/room/aliases.rs +++ b/crates/ruma-common/src/events/room/aliases.rs @@ -6,7 +6,8 @@ use serde_json::value::RawValue as RawJsonValue; use crate::{ events::{ - EventContent, HasDeserializeFields, RedactContent, RedactedEventContent, StateEventType, + EventContent, HasDeserializeFields, RedactContent, RedactedEventContent, StateEventContent, + StateEventType, }, OwnedRoomAliasId, RoomVersionId, }; @@ -94,6 +95,10 @@ impl EventContent for RedactedRoomAliasesEventContent { } } +impl StateEventContent for RedactedRoomAliasesEventContent { + type StateKey = String; // Box +} + // Since this redacted event has fields we leave the default `empty` method // that will error if called. impl RedactedEventContent for RedactedRoomAliasesEventContent { diff --git a/crates/ruma-common/src/events/room/member.rs b/crates/ruma-common/src/events/room/member.rs index aad78b51..3ec25e66 100644 --- a/crates/ruma-common/src/events/room/member.rs +++ b/crates/ruma-common/src/events/room/member.rs @@ -11,7 +11,7 @@ use serde_json::value::RawValue as RawJsonValue; use crate::{ events::{ EventContent, HasDeserializeFields, OriginalSyncStateEvent, RedactContent, - RedactedEventContent, StateEventType, StrippedStateEvent, + RedactedEventContent, StateEventContent, StateEventType, StrippedStateEvent, }, serde::StringEnum, MxcUri, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, PrivOwnedStr, RoomVersionId, @@ -175,6 +175,10 @@ impl EventContent for RedactedRoomMemberEventContent { } } +impl StateEventContent for RedactedRoomMemberEventContent { + type StateKey = String; // Box +} + // Since this redacted event has fields we leave the default `empty` method // that will error if called. impl RedactedEventContent for RedactedRoomMemberEventContent { diff --git a/crates/ruma-common/tests/events/ui/03-invalid-event-type.stderr b/crates/ruma-common/tests/events/ui/03-invalid-event-type.stderr index e25bfccf..d8fc4543 100644 --- a/crates/ruma-common/tests/events/ui/03-invalid-event-type.stderr +++ b/crates/ruma-common/tests/events/ui/03-invalid-event-type.stderr @@ -6,7 +6,7 @@ error: no event type attribute found, add `#[ruma_event(type = "any.room.event", | = note: this error originates in the derive macro `EventContent` (in Nightly builds, run with -Z macro-backtrace for more info) -error: expected one of: `type`, `kind`, `skip_redaction`, `custom_redacted`, `type_fragment` +error: expected one of: `type`, `kind`, `skip_redaction`, `custom_redacted`, `type_fragment`, `state_key_type` --> tests/events/ui/03-invalid-event-type.rs:11:14 | 11 | #[ruma_event(event = "m.macro.test", kind = State)] diff --git a/crates/ruma-common/tests/events/ui/04-event-sanity-check.rs b/crates/ruma-common/tests/events/ui/04-event-sanity-check.rs index 9fd874e2..7eb11ff4 100644 --- a/crates/ruma-common/tests/events/ui/04-event-sanity-check.rs +++ b/crates/ruma-common/tests/events/ui/04-event-sanity-check.rs @@ -17,7 +17,7 @@ pub struct OriginalStateEvent { pub sender: OwnedUserId, pub origin_server_ts: MilliSecondsSinceUnixEpoch, pub room_id: OwnedRoomId, - pub state_key: String, + pub state_key: C::StateKey, pub unsigned: StateUnsigned, } diff --git a/crates/ruma-macros/src/events/event.rs b/crates/ruma-macros/src/events/event.rs index ad7d445b..344a2166 100644 --- a/crates/ruma-macros/src/events/event.rs +++ b/crates/ruma-macros/src/events/event.rs @@ -2,10 +2,7 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::{ - parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam, Meta, - MetaList, NestedMeta, -}; +use syn::{parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam}; use super::{ event_parse::{to_kind_variation, EventKind, EventKindVariation}, @@ -169,14 +166,16 @@ fn expand_deserialize_event( .iter() .map(|field| { let name = field.ident.as_ref().unwrap(); - let ty = &field.ty; if name == "content" || (name == "unsigned" && has_prev_content(kind, var)) { if is_generic { quote! { ::std::boxed::Box<#serde_json::value::RawValue> } } else { quote! { #content_type } } + } else if name == "state_key" && var == EventKindVariation::Initial { + quote! { ::std::string::String } } else { + let ty = &field.ty; quote! { #ty } } }) @@ -225,42 +224,33 @@ fn expand_deserialize_event( )?; } } - } else if name == "unsigned" && has_prev_content(kind, var) { - quote! { - let unsigned = unsigned.map(|json| { - #ruma_common::events::StateUnsigned::_from_parts(&event_type, &json) - .map_err(#serde::de::Error::custom) - }).transpose()?.unwrap_or_default(); - } - } else { - let attrs: Vec<_> = field - .attrs - .iter() - .filter(|a| a.path.is_ident("ruma_event")) - .map(|a| a.parse_meta()) - .collect::>()?; - - let has_default_attr = attrs.iter().any(|a| { - matches!( - a, - Meta::List(MetaList { nested, .. }) - if nested.iter().any(|n| { - matches!(n, NestedMeta::Meta(Meta::Path(p)) if p.is_ident("default")) - }) - ) - }); - - if has_default_attr || name == "unsigned" { + } else if name == "unsigned" { + if has_prev_content(kind, var) { quote! { - let #name = #name.unwrap_or_default(); + let unsigned = unsigned.map(|json| { + #ruma_common::events::StateUnsigned::_from_parts(&event_type, &json) + .map_err(#serde::de::Error::custom) + }).transpose()?.unwrap_or_default(); } } else { quote! { - let #name = #name.ok_or_else(|| { - #serde::de::Error::missing_field(stringify!(#name)) - })?; + let unsigned = unsigned.unwrap_or_default(); } } + } else if name == "state_key" && var == EventKindVariation::Initial { + let ty = &field.ty; + quote! { + let state_key: ::std::string::String = state_key.unwrap_or_default(); + let state_key: #ty = <#ty as #serde::de::Deserialize>::deserialize( + #serde::de::IntoDeserializer::::into_deserializer(state_key), + )?; + } + } else { + quote! { + let #name = #name.ok_or_else(|| { + #serde::de::Error::missing_field(stringify!(#name)) + })?; + } }) }) .collect::>()?; @@ -385,12 +375,21 @@ fn expand_redact_event( let where_clause = generics.make_where_clause(); where_clause.predicates.push(parse_quote! { #ty_param: #ruma_common::events::RedactContent }); - where_clause.predicates.push(parse_quote! { - <#ty_param as #ruma_common::events::RedactContent>::Redacted: + + let redacted_event_content_bound = if kind == EventKind::State { + quote! { + #ruma_common::events::StateEventContent + } + } else { + quote! { #ruma_common::events::EventContent< EventType = #ruma_common::events::#redacted_event_type_enum > - + #ruma_common::events::RedactedEventContent + } + }; + where_clause.predicates.push(parse_quote! { + <#ty_param as #ruma_common::events::RedactContent>::Redacted: + #redacted_event_content_bound + #ruma_common::events::RedactedEventContent }); let (impl_generics, ty_gen, where_clause) = generics.split_for_impl(); diff --git a/crates/ruma-macros/src/events/event_content.rs b/crates/ruma-macros/src/events/event_content.rs index 253f3d91..98bc5290 100644 --- a/crates/ruma-macros/src/events/event_content.rs +++ b/crates/ruma-macros/src/events/event_content.rs @@ -6,7 +6,7 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse::{Parse, ParseStream}, - DeriveInput, Field, Ident, LitStr, Token, + DeriveInput, Field, Ident, LitStr, Token, Type, }; use crate::util::m_prefix_name_to_type_name; @@ -21,6 +21,8 @@ mod kw { // The kind of event content this is. syn::custom_keyword!(kind); syn::custom_keyword!(type_fragment); + // The type to use for a state events' `state_key` field. + syn::custom_keyword!(state_key_type); } /// Parses attributes for `*EventContent` derives. @@ -43,6 +45,8 @@ enum EventMeta { /// The given field holds a part of the event type (replaces the `*` in a `m.foo.*` event /// type). TypeFragment, + + StateKeyType(Box), } impl EventMeta { @@ -59,6 +63,13 @@ impl EventMeta { _ => None, } } + + fn get_state_key_type(&self) -> Option<&Type> { + match self { + Self::StateKeyType(ty) => Some(ty), + _ => None, + } + } } impl Parse for EventMeta { @@ -71,7 +82,7 @@ impl Parse for EventMeta { } else if lookahead.peek(kw::kind) { let _: kw::kind = input.parse()?; let _: Token![=] = input.parse()?; - EventKind::parse(input).map(EventMeta::Kind) + input.parse().map(EventMeta::Kind) } else if lookahead.peek(kw::skip_redaction) { let _: kw::skip_redaction = input.parse()?; Ok(EventMeta::SkipRedaction) @@ -81,6 +92,10 @@ impl Parse for EventMeta { } else if lookahead.peek(kw::type_fragment) { let _: kw::type_fragment = input.parse()?; Ok(EventMeta::TypeFragment) + } else if lookahead.peek(kw::state_key_type) { + let _: kw::state_key_type = input.parse()?; + let _: Token![=] = input.parse()?; + input.parse().map(EventMeta::StateKeyType) } else { Err(lookahead.error()) } @@ -101,6 +116,10 @@ impl MetaAttrs { fn get_event_kind(&self) -> Option { self.0.iter().find_map(|a| a.get_event_kind()) } + + fn get_state_key_type(&self) -> Option<&Type> { + self.0.iter().find_map(|a| a.get_state_key_type()) + } } impl Parse for MetaAttrs { @@ -155,6 +174,26 @@ pub fn expand_event_content( } }; + let state_key_types: Vec<_> = + content_attr.iter().filter_map(|attrs| attrs.get_state_key_type()).collect(); + let state_key_type = match (event_kind, state_key_types.as_slice()) { + (Some(EventKind::State), []) => Some(quote! { ::std::string::String }), + (Some(EventKind::State), [ty]) => Some(quote! { #ty }), + (Some(EventKind::State), _) => { + return Err(syn::Error::new( + Span::call_site(), + "multiple state_key_type attribute found, there can only be one", + )); + } + (_, []) => None, + (_, [ty, ..]) => { + return Err(syn::Error::new_spanned( + ty, + "state_key_type attribute is not valid for non-state event kinds", + )); + } + }; + let ident = &input.ident; let fields = match &input.data { syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.iter(), @@ -185,13 +224,26 @@ pub fn expand_event_content( // We only generate redacted content structs for state and message-like events let redacted_event_content = needs_redacted(&content_attr, event_kind).then(|| { - generate_redacted_event_content(ident, fields.clone(), event_type, event_kind, ruma_common) - .unwrap_or_else(syn::Error::into_compile_error) + generate_redacted_event_content( + ident, + fields.clone(), + event_type, + event_kind, + state_key_type.as_ref(), + ruma_common, + ) + .unwrap_or_else(syn::Error::into_compile_error) }); - let event_content_impl = - generate_event_content_impl(ident, fields, event_type, event_kind, ruma_common) - .unwrap_or_else(syn::Error::into_compile_error); + let event_content_impl = generate_event_content_impl( + ident, + fields, + event_type, + event_kind, + state_key_type.as_ref(), + ruma_common, + ) + .unwrap_or_else(syn::Error::into_compile_error); let static_event_content_impl = event_kind .map(|k| generate_static_event_content_impl(ident, k, false, event_type, ruma_common)); let type_aliases = event_kind.map(|k| { @@ -212,6 +264,7 @@ fn generate_redacted_event_content<'a>( fields: impl Iterator, event_type: &LitStr, event_kind: Option, + state_key_type: Option<&TokenStream>, ruma_common: &TokenStream, ) -> syn::Result { assert!( @@ -295,6 +348,7 @@ fn generate_redacted_event_content<'a>( kept_redacted_fields.iter(), event_type, event_kind, + state_key_type, ruma_common, ) .unwrap_or_else(syn::Error::into_compile_error); @@ -416,6 +470,7 @@ fn generate_event_content_impl<'a>( mut fields: impl Iterator, event_type: &LitStr, event_kind: Option, + state_key_type: Option<&TokenStream>, ruma_common: &TokenStream, ) -> syn::Result { let serde = quote! { #ruma_common::exports::serde }; @@ -489,6 +544,16 @@ fn generate_event_content_impl<'a>( } } + let state_event_content_impl = (event_kind == Some(EventKind::State)).then(|| { + assert!(state_key_type.is_some()); + quote! { + #[automatically_derived] + impl #ruma_common::events::StateEventContent for #ident { + type StateKey = #state_key_type; + } + } + }); + Ok(quote! { #event_type_ty_decl @@ -513,6 +578,8 @@ fn generate_event_content_impl<'a>( #serde_json::from_str(content.get()) } } + + #state_event_content_impl }) } diff --git a/crates/ruma-macros/src/events/event_enum.rs b/crates/ruma-macros/src/events/event_enum.rs index 49010209..26d7443b 100644 --- a/crates/ruma-macros/src/events/event_enum.rs +++ b/crates/ruma-macros/src/events/event_enum.rs @@ -37,7 +37,6 @@ const EVENT_FIELDS: &[(&str, EventKindFn)] = &[ matches!(kind, EventKind::MessageLike | EventKind::State | EventKind::ToDevice) && var != EventEnumVariation::Initial }), - ("state_key", |kind, _| matches!(kind, EventKind::State)), ]; /// Create a content enum from `EventEnumInput`. @@ -484,6 +483,21 @@ fn expand_accessor_methods( }) }); + let state_key_accessor = (kind == EventKind::State).then(|| { + let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); + let call_parens = maybe_redacted.then(|| quote! { () }); + + quote! { + /// Returns this event's `state_key` field. + pub fn state_key(&self) -> &::std::primitive::str { + match self { + #( #variants(event) => &event.state_key #call_parens .as_ref(), )* + Self::_Custom(event) => &event.state_key #call_parens .as_ref(), + } + } + } + }); + let txn_id_accessor = maybe_redacted.then(|| { let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); quote! { @@ -513,6 +527,7 @@ fn expand_accessor_methods( #content_accessors #( #methods )* + #state_key_accessor #txn_id_accessor } }) @@ -572,7 +587,6 @@ fn field_return_type(name: &str, ruma_common: &TokenStream) -> TokenStream { "room_id" => quote! { #ruma_common::RoomId }, "event_id" => quote! { #ruma_common::EventId }, "sender" => quote! { #ruma_common::UserId }, - "state_key" => quote! { ::std::primitive::str }, _ => panic!("the `ruma_macros::event_enum::EVENT_FIELD` const was changed"), } }