events: Allow state key type to be customized by content type
This commit is contained in:
		
							parent
							
								
									da222a87c9
								
							
						
					
					
						commit
						d8b7886382
					
				| @ -4,7 +4,7 @@ use serde_json::value::RawValue as RawJsonValue; | ||||
| use super::{ | ||||
|     EphemeralRoomEventType, EventContent, GlobalAccountDataEventType, HasDeserializeFields, | ||||
|     MessageLikeEventType, RedactContent, RedactedEventContent, RoomAccountDataEventType, | ||||
|     StateEventType, ToDeviceEventType, | ||||
|     StateEventContent, StateEventType, ToDeviceEventType, | ||||
| }; | ||||
| use crate::RoomVersionId; | ||||
| 
 | ||||
| @ -68,3 +68,7 @@ custom_event_content!(CustomEphemeralRoomEventContent, EphemeralRoomEventType); | ||||
| custom_room_event_content!(CustomMessageLikeEventContent, MessageLikeEventType); | ||||
| custom_room_event_content!(CustomStateEventContent, StateEventType); | ||||
| custom_event_content!(CustomToDeviceEventContent, ToDeviceEventType); | ||||
| 
 | ||||
| impl StateEventContent for CustomStateEventContent { | ||||
|     type StateKey = String; | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| use std::fmt; | ||||
| 
 | ||||
| use serde::Serialize; | ||||
| use serde::{de::DeserializeOwned, Serialize}; | ||||
| use serde_json::value::RawValue as RawJsonValue; | ||||
| 
 | ||||
| use crate::serde::Raw; | ||||
| @ -164,17 +164,18 @@ trait_aliases! { | ||||
|     /// An alias for `EventContent<EventType = MessageLikeEventType>`.
 | ||||
|     trait MessageLikeEventContent = EventContent<EventType = MessageLikeEventType>; | ||||
| 
 | ||||
|     /// An alias for `EventContent<EventType = MessageLikeEventType> + RedactedEventContent`.
 | ||||
|     trait RedactedMessageLikeEventContent = | ||||
|         EventContent<EventType = MessageLikeEventType>, RedactedEventContent; | ||||
|     /// An alias for `MessageLikeEventContent + RedactedEventContent`.
 | ||||
|     trait RedactedMessageLikeEventContent = MessageLikeEventContent, RedactedEventContent; | ||||
| 
 | ||||
|     /// An alias for `EventContent<EventType = StateEventType>`.
 | ||||
|     trait StateEventContent = EventContent<EventType = StateEventType>; | ||||
| 
 | ||||
|     /// An alias for `EventContent<EventType = StateEventType> + RedactedEventContent`.
 | ||||
|     trait RedactedStateEventContent = | ||||
|         EventContent<EventType = StateEventType>, RedactedEventContent; | ||||
|     /// An alias for `StateEventContent + RedactedEventContent`.
 | ||||
|     trait RedactedStateEventContent = StateEventContent, RedactedEventContent; | ||||
| 
 | ||||
|     /// An alias for `EventContent<EventType = ToDeviceEventType>`.
 | ||||
|     trait ToDeviceEventContent = EventContent<EventType = ToDeviceEventType>; | ||||
| } | ||||
| 
 | ||||
| /// An alias for `EventContent<EventType = StateEventType>`.
 | ||||
| pub trait StateEventContent: EventContent<EventType = StateEventType> { | ||||
|     /// The type of the event's `state_key` field.
 | ||||
|     type StateKey: AsRef<str> + Clone + fmt::Debug + DeserializeOwned + Serialize; | ||||
| } | ||||
|  | ||||
| @ -203,7 +203,7 @@ pub struct OriginalStateEvent<C: StateEventContent> { | ||||
|     ///
 | ||||
|     /// This is often an empty string, but some events send a `UserId` to show which user the event
 | ||||
|     /// affects.
 | ||||
|     pub state_key: String, | ||||
|     pub state_key: C::StateKey, | ||||
| 
 | ||||
|     /// Additional key-value pairs not signed by the homeserver.
 | ||||
|     pub unsigned: StateUnsigned<C>, | ||||
| @ -231,7 +231,7 @@ pub struct OriginalSyncStateEvent<C: StateEventContent> { | ||||
|     ///
 | ||||
|     /// This is often an empty string, but some events send a `UserId` to show which user the event
 | ||||
|     /// affects.
 | ||||
|     pub state_key: String, | ||||
|     pub state_key: C::StateKey, | ||||
| 
 | ||||
|     /// Additional key-value pairs not signed by the homeserver.
 | ||||
|     pub unsigned: StateUnsigned<C>, | ||||
| @ -250,7 +250,7 @@ pub struct StrippedStateEvent<C: StateEventContent> { | ||||
|     ///
 | ||||
|     /// This is often an empty string, but some events send a `UserId` to show which user the event
 | ||||
|     /// affects.
 | ||||
|     pub state_key: String, | ||||
|     pub state_key: C::StateKey, | ||||
| } | ||||
| 
 | ||||
| /// A minimal state event, used for creating a new room.
 | ||||
| @ -265,8 +265,7 @@ pub struct InitialStateEvent<C: StateEventContent> { | ||||
|     /// affects.
 | ||||
|     ///
 | ||||
|     /// Defaults to the empty string.
 | ||||
|     #[ruma_event(default)] | ||||
|     pub state_key: String, | ||||
|     pub state_key: C::StateKey, | ||||
| } | ||||
| 
 | ||||
| /// A redacted state event.
 | ||||
| @ -294,7 +293,7 @@ pub struct RedactedStateEvent<C: RedactedStateEventContent> { | ||||
|     ///
 | ||||
|     /// This is often an empty string, but some events send a `UserId` to show which user the event
 | ||||
|     /// affects.
 | ||||
|     pub state_key: String, | ||||
|     pub state_key: C::StateKey, | ||||
| 
 | ||||
|     /// Additional key-value pairs not signed by the homeserver.
 | ||||
|     pub unsigned: RedactedUnsigned, | ||||
| @ -322,7 +321,7 @@ pub struct RedactedSyncStateEvent<C: RedactedStateEventContent> { | ||||
|     ///
 | ||||
|     /// This is often an empty string, but some events send a `UserId` to show which user the event
 | ||||
|     /// affects.
 | ||||
|     pub state_key: String, | ||||
|     pub state_key: C::StateKey, | ||||
| 
 | ||||
|     /// Additional key-value pairs not signed by the homeserver.
 | ||||
|     pub unsigned: RedactedUnsigned, | ||||
| @ -411,11 +410,16 @@ pub struct DecryptedMegolmV1Event<C: MessageLikeEventContent> { | ||||
| } | ||||
| 
 | ||||
| macro_rules! impl_possibly_redacted_event { | ||||
|     ($ty:ident ( $content_trait:ident, $event_type:ident ) { $($extra:tt)* }) => { | ||||
|     ( | ||||
|         $ty:ident ( $content_trait:ident, $event_type:ident ) | ||||
|         $( where C::Redacted: $trait:ident<StateKey = C::StateKey>, )? | ||||
|         { $($extra:tt)* } | ||||
|     ) => { | ||||
|         impl<C> $ty<C> | ||||
|         where | ||||
|             C: $content_trait + RedactContent, | ||||
|             C::Redacted: $content_trait + RedactedEventContent, | ||||
|             $( C::Redacted: $trait<StateKey = C::StateKey>, )? | ||||
|         { | ||||
|             /// Returns the `type` of this event.
 | ||||
|             pub fn event_type(&self) -> $event_type { | ||||
| @ -457,6 +461,7 @@ macro_rules! impl_possibly_redacted_event { | ||||
|         where | ||||
|             C: $content_trait + RedactContent, | ||||
|             C::Redacted: $content_trait + RedactedEventContent, | ||||
|             $( C::Redacted: $trait<StateKey = C::StateKey>, )? | ||||
|         { | ||||
|             type Redacted = Self; | ||||
| 
 | ||||
| @ -472,6 +477,7 @@ macro_rules! impl_possibly_redacted_event { | ||||
|         where | ||||
|             C: $content_trait + RedactContent, | ||||
|             C::Redacted: $content_trait + RedactedEventContent, | ||||
|             $( C::Redacted: $trait<StateKey = C::StateKey>, )? | ||||
|         { | ||||
|             fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||||
|             where | ||||
| @ -526,57 +532,67 @@ impl_possibly_redacted_event!(SyncMessageLikeEvent(MessageLikeEventContent, Mess | ||||
|     } | ||||
| }); | ||||
| 
 | ||||
| impl_possibly_redacted_event!(StateEvent(StateEventContent, StateEventType) { | ||||
|     /// Returns this event's `room_id` field.
 | ||||
|     pub fn room_id(&self) -> &RoomId { | ||||
|         match self { | ||||
|             Self::Original(ev) => &ev.room_id, | ||||
|             Self::Redacted(ev) => &ev.room_id, | ||||
| impl_possibly_redacted_event!( | ||||
|     StateEvent(StateEventContent, StateEventType) | ||||
|     where | ||||
|         C::Redacted: StateEventContent<StateKey = C::StateKey>, | ||||
|     { | ||||
|         /// Returns this event's `room_id` field.
 | ||||
|         pub fn room_id(&self) -> &RoomId { | ||||
|             match self { | ||||
|                 Self::Original(ev) => &ev.room_id, | ||||
|                 Self::Redacted(ev) => &ev.room_id, | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Returns this event's `state_key` field.
 | ||||
|     pub fn state_key(&self) -> &str { | ||||
|         match self { | ||||
|             Self::Original(ev) => &ev.state_key, | ||||
|             Self::Redacted(ev) => &ev.state_key, | ||||
|         /// Returns this event's `state_key` field.
 | ||||
|         pub fn state_key(&self) -> &C::StateKey { | ||||
|             match self { | ||||
|                 Self::Original(ev) => &ev.state_key, | ||||
|                 Self::Redacted(ev) => &ev.state_key, | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Get the inner `OriginalStateEvent` if this is an unredacted event.
 | ||||
|     pub fn as_original(&self) -> Option<&OriginalStateEvent<C>> { | ||||
|         match self { | ||||
|             Self::Original(v) => Some(v), | ||||
|             _ => None, | ||||
|         /// Get the inner `OriginalStateEvent` if this is an unredacted event.
 | ||||
|         pub fn as_original(&self) -> Option<&OriginalStateEvent<C>> { | ||||
|             match self { | ||||
|                 Self::Original(v) => Some(v), | ||||
|                 _ => None, | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| }); | ||||
| ); | ||||
| 
 | ||||
| impl_possibly_redacted_event!(SyncStateEvent(StateEventContent, StateEventType) { | ||||
|     /// Returns this event's `state_key` field.
 | ||||
|     pub fn state_key(&self) -> &str { | ||||
|         match self { | ||||
|             Self::Original(ev) => &ev.state_key, | ||||
|             Self::Redacted(ev) => &ev.state_key, | ||||
| impl_possibly_redacted_event!( | ||||
|     SyncStateEvent(StateEventContent, StateEventType) | ||||
|     where | ||||
|         C::Redacted: StateEventContent<StateKey = C::StateKey>, | ||||
|     { | ||||
|         /// Returns this event's `state_key` field.
 | ||||
|         pub fn state_key(&self) -> &C::StateKey { | ||||
|             match self { | ||||
|                 Self::Original(ev) => &ev.state_key, | ||||
|                 Self::Redacted(ev) => &ev.state_key, | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Get the inner `OriginalSyncStateEvent` if this is an unredacted event.
 | ||||
|     pub fn as_original(&self) -> Option<&OriginalSyncStateEvent<C>> { | ||||
|         match self { | ||||
|             Self::Original(v) => Some(v), | ||||
|             _ => None, | ||||
|         /// Get the inner `OriginalSyncStateEvent` if this is an unredacted event.
 | ||||
|         pub fn as_original(&self) -> Option<&OriginalSyncStateEvent<C>> { | ||||
|             match self { | ||||
|                 Self::Original(v) => Some(v), | ||||
|                 _ => None, | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Convert this sync event into a full event (one with a `room_id` field).
 | ||||
|     pub fn into_full_event(self, room_id: OwnedRoomId) -> StateEvent<C> { | ||||
|         match self { | ||||
|             Self::Original(ev) => StateEvent::Original(ev.into_full_event(room_id)), | ||||
|             Self::Redacted(ev) => StateEvent::Redacted(ev.into_full_event(room_id)), | ||||
|         /// Convert this sync event into a full event (one with a `room_id` field).
 | ||||
|         pub fn into_full_event(self, room_id: OwnedRoomId) -> StateEvent<C> { | ||||
|             match self { | ||||
|                 Self::Original(ev) => StateEvent::Original(ev.into_full_event(room_id)), | ||||
|                 Self::Redacted(ev) => StateEvent::Redacted(ev.into_full_event(room_id)), | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| }); | ||||
| ); | ||||
| 
 | ||||
| macro_rules! impl_sync_from_full { | ||||
|     ($ty:ident, $full:ident, $content_trait:ident) => { | ||||
|  | ||||
| @ -6,7 +6,8 @@ use serde_json::value::RawValue as RawJsonValue; | ||||
| 
 | ||||
| use crate::{ | ||||
|     events::{ | ||||
|         EventContent, HasDeserializeFields, RedactContent, RedactedEventContent, StateEventType, | ||||
|         EventContent, HasDeserializeFields, RedactContent, RedactedEventContent, StateEventContent, | ||||
|         StateEventType, | ||||
|     }, | ||||
|     OwnedRoomAliasId, RoomVersionId, | ||||
| }; | ||||
| @ -94,6 +95,10 @@ impl EventContent for RedactedRoomAliasesEventContent { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl StateEventContent for RedactedRoomAliasesEventContent { | ||||
|     type StateKey = String; // Box<ServerName>
 | ||||
| } | ||||
| 
 | ||||
| // Since this redacted event has fields we leave the default `empty` method
 | ||||
| // that will error if called.
 | ||||
| impl RedactedEventContent for RedactedRoomAliasesEventContent { | ||||
|  | ||||
| @ -11,7 +11,7 @@ use serde_json::value::RawValue as RawJsonValue; | ||||
| use crate::{ | ||||
|     events::{ | ||||
|         EventContent, HasDeserializeFields, OriginalSyncStateEvent, RedactContent, | ||||
|         RedactedEventContent, StateEventType, StrippedStateEvent, | ||||
|         RedactedEventContent, StateEventContent, StateEventType, StrippedStateEvent, | ||||
|     }, | ||||
|     serde::StringEnum, | ||||
|     MxcUri, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, PrivOwnedStr, RoomVersionId, | ||||
| @ -175,6 +175,10 @@ impl EventContent for RedactedRoomMemberEventContent { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl StateEventContent for RedactedRoomMemberEventContent { | ||||
|     type StateKey = String; // Box<UserId>
 | ||||
| } | ||||
| 
 | ||||
| // Since this redacted event has fields we leave the default `empty` method
 | ||||
| // that will error if called.
 | ||||
| impl RedactedEventContent for RedactedRoomMemberEventContent { | ||||
|  | ||||
| @ -6,7 +6,7 @@ error: no event type attribute found, add `#[ruma_event(type = "any.room.event", | ||||
|   | | ||||
|   = note: this error originates in the derive macro `EventContent` (in Nightly builds, run with -Z macro-backtrace for more info) | ||||
| 
 | ||||
| error: expected one of: `type`, `kind`, `skip_redaction`, `custom_redacted`, `type_fragment` | ||||
| error: expected one of: `type`, `kind`, `skip_redaction`, `custom_redacted`, `type_fragment`, `state_key_type` | ||||
|   --> tests/events/ui/03-invalid-event-type.rs:11:14 | ||||
|    | | ||||
| 11 | #[ruma_event(event = "m.macro.test", kind = State)] | ||||
|  | ||||
| @ -17,7 +17,7 @@ pub struct OriginalStateEvent<C: StateEventContent> { | ||||
|     pub sender: OwnedUserId, | ||||
|     pub origin_server_ts: MilliSecondsSinceUnixEpoch, | ||||
|     pub room_id: OwnedRoomId, | ||||
|     pub state_key: String, | ||||
|     pub state_key: C::StateKey, | ||||
|     pub unsigned: StateUnsigned<C>, | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -2,10 +2,7 @@ | ||||
| 
 | ||||
| use proc_macro2::{Span, TokenStream}; | ||||
| use quote::quote; | ||||
| use syn::{ | ||||
|     parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam, Meta, | ||||
|     MetaList, NestedMeta, | ||||
| }; | ||||
| use syn::{parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam}; | ||||
| 
 | ||||
| use super::{ | ||||
|     event_parse::{to_kind_variation, EventKind, EventKindVariation}, | ||||
| @ -169,14 +166,16 @@ fn expand_deserialize_event( | ||||
|         .iter() | ||||
|         .map(|field| { | ||||
|             let name = field.ident.as_ref().unwrap(); | ||||
|             let ty = &field.ty; | ||||
|             if name == "content" || (name == "unsigned" && has_prev_content(kind, var)) { | ||||
|                 if is_generic { | ||||
|                     quote! { ::std::boxed::Box<#serde_json::value::RawValue> } | ||||
|                 } else { | ||||
|                     quote! { #content_type } | ||||
|                 } | ||||
|             } else if name == "state_key" && var == EventKindVariation::Initial { | ||||
|                 quote! { ::std::string::String } | ||||
|             } else { | ||||
|                 let ty = &field.ty; | ||||
|                 quote! { #ty } | ||||
|             } | ||||
|         }) | ||||
| @ -225,42 +224,33 @@ fn expand_deserialize_event( | ||||
|                         )?; | ||||
|                     } | ||||
|                 } | ||||
|             } else if name == "unsigned" && has_prev_content(kind, var) { | ||||
|                 quote! { | ||||
|                     let unsigned = unsigned.map(|json| { | ||||
|                         #ruma_common::events::StateUnsigned::_from_parts(&event_type, &json) | ||||
|                             .map_err(#serde::de::Error::custom) | ||||
|                     }).transpose()?.unwrap_or_default(); | ||||
|                 } | ||||
|             } else { | ||||
|                 let attrs: Vec<_> = field | ||||
|                     .attrs | ||||
|                     .iter() | ||||
|                     .filter(|a| a.path.is_ident("ruma_event")) | ||||
|                     .map(|a| a.parse_meta()) | ||||
|                     .collect::<syn::Result<_>>()?; | ||||
| 
 | ||||
|                 let has_default_attr = attrs.iter().any(|a| { | ||||
|                     matches!( | ||||
|                         a, | ||||
|                         Meta::List(MetaList { nested, .. }) | ||||
|                         if nested.iter().any(|n| { | ||||
|                             matches!(n, NestedMeta::Meta(Meta::Path(p)) if p.is_ident("default")) | ||||
|                         }) | ||||
|                     ) | ||||
|                 }); | ||||
| 
 | ||||
|                 if has_default_attr || name == "unsigned" { | ||||
|             } else if name == "unsigned" { | ||||
|                 if has_prev_content(kind, var) { | ||||
|                     quote! { | ||||
|                         let #name = #name.unwrap_or_default(); | ||||
|                         let unsigned = unsigned.map(|json| { | ||||
|                             #ruma_common::events::StateUnsigned::_from_parts(&event_type, &json) | ||||
|                                 .map_err(#serde::de::Error::custom) | ||||
|                         }).transpose()?.unwrap_or_default(); | ||||
|                     } | ||||
|                 } else { | ||||
|                     quote! { | ||||
|                         let #name = #name.ok_or_else(|| { | ||||
|                             #serde::de::Error::missing_field(stringify!(#name)) | ||||
|                         })?; | ||||
|                         let unsigned = unsigned.unwrap_or_default(); | ||||
|                     } | ||||
|                 } | ||||
|             } else if name == "state_key" && var == EventKindVariation::Initial { | ||||
|                 let ty = &field.ty; | ||||
|                 quote! { | ||||
|                     let state_key: ::std::string::String = state_key.unwrap_or_default(); | ||||
|                     let state_key: #ty = <#ty as #serde::de::Deserialize>::deserialize( | ||||
|                         #serde::de::IntoDeserializer::<A::Error>::into_deserializer(state_key), | ||||
|                     )?; | ||||
|                 } | ||||
|             } else { | ||||
|                 quote! { | ||||
|                     let #name = #name.ok_or_else(|| { | ||||
|                         #serde::de::Error::missing_field(stringify!(#name)) | ||||
|                     })?; | ||||
|                 } | ||||
|             }) | ||||
|         }) | ||||
|         .collect::<syn::Result<_>>()?; | ||||
| @ -385,12 +375,21 @@ fn expand_redact_event( | ||||
| 
 | ||||
|     let where_clause = generics.make_where_clause(); | ||||
|     where_clause.predicates.push(parse_quote! { #ty_param: #ruma_common::events::RedactContent }); | ||||
|     where_clause.predicates.push(parse_quote! { | ||||
|         <#ty_param as #ruma_common::events::RedactContent>::Redacted: | ||||
| 
 | ||||
|     let redacted_event_content_bound = if kind == EventKind::State { | ||||
|         quote! { | ||||
|             #ruma_common::events::StateEventContent<StateKey = #ty_param::StateKey> | ||||
|         } | ||||
|     } else { | ||||
|         quote! { | ||||
|             #ruma_common::events::EventContent< | ||||
|                 EventType = #ruma_common::events::#redacted_event_type_enum | ||||
|             > | ||||
|                 + #ruma_common::events::RedactedEventContent | ||||
|         } | ||||
|     }; | ||||
|     where_clause.predicates.push(parse_quote! { | ||||
|         <#ty_param as #ruma_common::events::RedactContent>::Redacted: | ||||
|             #redacted_event_content_bound + #ruma_common::events::RedactedEventContent | ||||
|     }); | ||||
| 
 | ||||
|     let (impl_generics, ty_gen, where_clause) = generics.split_for_impl(); | ||||
|  | ||||
| @ -6,7 +6,7 @@ use proc_macro2::{Span, TokenStream}; | ||||
| use quote::{format_ident, quote}; | ||||
| use syn::{ | ||||
|     parse::{Parse, ParseStream}, | ||||
|     DeriveInput, Field, Ident, LitStr, Token, | ||||
|     DeriveInput, Field, Ident, LitStr, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| use crate::util::m_prefix_name_to_type_name; | ||||
| @ -21,6 +21,8 @@ mod kw { | ||||
|     // The kind of event content this is.
 | ||||
|     syn::custom_keyword!(kind); | ||||
|     syn::custom_keyword!(type_fragment); | ||||
|     // The type to use for a state events' `state_key` field.
 | ||||
|     syn::custom_keyword!(state_key_type); | ||||
| } | ||||
| 
 | ||||
| /// Parses attributes for `*EventContent` derives.
 | ||||
| @ -43,6 +45,8 @@ enum EventMeta { | ||||
|     /// The given field holds a part of the event type (replaces the `*` in a `m.foo.*` event
 | ||||
|     /// type).
 | ||||
|     TypeFragment, | ||||
| 
 | ||||
|     StateKeyType(Box<Type>), | ||||
| } | ||||
| 
 | ||||
| impl EventMeta { | ||||
| @ -59,6 +63,13 @@ impl EventMeta { | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn get_state_key_type(&self) -> Option<&Type> { | ||||
|         match self { | ||||
|             Self::StateKeyType(ty) => Some(ty), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Parse for EventMeta { | ||||
| @ -71,7 +82,7 @@ impl Parse for EventMeta { | ||||
|         } else if lookahead.peek(kw::kind) { | ||||
|             let _: kw::kind = input.parse()?; | ||||
|             let _: Token![=] = input.parse()?; | ||||
|             EventKind::parse(input).map(EventMeta::Kind) | ||||
|             input.parse().map(EventMeta::Kind) | ||||
|         } else if lookahead.peek(kw::skip_redaction) { | ||||
|             let _: kw::skip_redaction = input.parse()?; | ||||
|             Ok(EventMeta::SkipRedaction) | ||||
| @ -81,6 +92,10 @@ impl Parse for EventMeta { | ||||
|         } else if lookahead.peek(kw::type_fragment) { | ||||
|             let _: kw::type_fragment = input.parse()?; | ||||
|             Ok(EventMeta::TypeFragment) | ||||
|         } else if lookahead.peek(kw::state_key_type) { | ||||
|             let _: kw::state_key_type = input.parse()?; | ||||
|             let _: Token![=] = input.parse()?; | ||||
|             input.parse().map(EventMeta::StateKeyType) | ||||
|         } else { | ||||
|             Err(lookahead.error()) | ||||
|         } | ||||
| @ -101,6 +116,10 @@ impl MetaAttrs { | ||||
|     fn get_event_kind(&self) -> Option<EventKind> { | ||||
|         self.0.iter().find_map(|a| a.get_event_kind()) | ||||
|     } | ||||
| 
 | ||||
|     fn get_state_key_type(&self) -> Option<&Type> { | ||||
|         self.0.iter().find_map(|a| a.get_state_key_type()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Parse for MetaAttrs { | ||||
| @ -155,6 +174,26 @@ pub fn expand_event_content( | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|     let state_key_types: Vec<_> = | ||||
|         content_attr.iter().filter_map(|attrs| attrs.get_state_key_type()).collect(); | ||||
|     let state_key_type = match (event_kind, state_key_types.as_slice()) { | ||||
|         (Some(EventKind::State), []) => Some(quote! { ::std::string::String }), | ||||
|         (Some(EventKind::State), [ty]) => Some(quote! { #ty }), | ||||
|         (Some(EventKind::State), _) => { | ||||
|             return Err(syn::Error::new( | ||||
|                 Span::call_site(), | ||||
|                 "multiple state_key_type attribute found, there can only be one", | ||||
|             )); | ||||
|         } | ||||
|         (_, []) => None, | ||||
|         (_, [ty, ..]) => { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 ty, | ||||
|                 "state_key_type attribute is not valid for non-state event kinds", | ||||
|             )); | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|     let ident = &input.ident; | ||||
|     let fields = match &input.data { | ||||
|         syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.iter(), | ||||
| @ -185,13 +224,26 @@ pub fn expand_event_content( | ||||
| 
 | ||||
|     // We only generate redacted content structs for state and message-like events
 | ||||
|     let redacted_event_content = needs_redacted(&content_attr, event_kind).then(|| { | ||||
|         generate_redacted_event_content(ident, fields.clone(), event_type, event_kind, ruma_common) | ||||
|             .unwrap_or_else(syn::Error::into_compile_error) | ||||
|         generate_redacted_event_content( | ||||
|             ident, | ||||
|             fields.clone(), | ||||
|             event_type, | ||||
|             event_kind, | ||||
|             state_key_type.as_ref(), | ||||
|             ruma_common, | ||||
|         ) | ||||
|         .unwrap_or_else(syn::Error::into_compile_error) | ||||
|     }); | ||||
| 
 | ||||
|     let event_content_impl = | ||||
|         generate_event_content_impl(ident, fields, event_type, event_kind, ruma_common) | ||||
|             .unwrap_or_else(syn::Error::into_compile_error); | ||||
|     let event_content_impl = generate_event_content_impl( | ||||
|         ident, | ||||
|         fields, | ||||
|         event_type, | ||||
|         event_kind, | ||||
|         state_key_type.as_ref(), | ||||
|         ruma_common, | ||||
|     ) | ||||
|     .unwrap_or_else(syn::Error::into_compile_error); | ||||
|     let static_event_content_impl = event_kind | ||||
|         .map(|k| generate_static_event_content_impl(ident, k, false, event_type, ruma_common)); | ||||
|     let type_aliases = event_kind.map(|k| { | ||||
| @ -212,6 +264,7 @@ fn generate_redacted_event_content<'a>( | ||||
|     fields: impl Iterator<Item = &'a Field>, | ||||
|     event_type: &LitStr, | ||||
|     event_kind: Option<EventKind>, | ||||
|     state_key_type: Option<&TokenStream>, | ||||
|     ruma_common: &TokenStream, | ||||
| ) -> syn::Result<TokenStream> { | ||||
|     assert!( | ||||
| @ -295,6 +348,7 @@ fn generate_redacted_event_content<'a>( | ||||
|         kept_redacted_fields.iter(), | ||||
|         event_type, | ||||
|         event_kind, | ||||
|         state_key_type, | ||||
|         ruma_common, | ||||
|     ) | ||||
|     .unwrap_or_else(syn::Error::into_compile_error); | ||||
| @ -416,6 +470,7 @@ fn generate_event_content_impl<'a>( | ||||
|     mut fields: impl Iterator<Item = &'a Field>, | ||||
|     event_type: &LitStr, | ||||
|     event_kind: Option<EventKind>, | ||||
|     state_key_type: Option<&TokenStream>, | ||||
|     ruma_common: &TokenStream, | ||||
| ) -> syn::Result<TokenStream> { | ||||
|     let serde = quote! { #ruma_common::exports::serde }; | ||||
| @ -489,6 +544,16 @@ fn generate_event_content_impl<'a>( | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     let state_event_content_impl = (event_kind == Some(EventKind::State)).then(|| { | ||||
|         assert!(state_key_type.is_some()); | ||||
|         quote! { | ||||
|             #[automatically_derived] | ||||
|             impl #ruma_common::events::StateEventContent for #ident { | ||||
|                 type StateKey = #state_key_type; | ||||
|             } | ||||
|         } | ||||
|     }); | ||||
| 
 | ||||
|     Ok(quote! { | ||||
|         #event_type_ty_decl | ||||
| 
 | ||||
| @ -513,6 +578,8 @@ fn generate_event_content_impl<'a>( | ||||
|                 #serde_json::from_str(content.get()) | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         #state_event_content_impl | ||||
|     }) | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -37,7 +37,6 @@ const EVENT_FIELDS: &[(&str, EventKindFn)] = &[ | ||||
|         matches!(kind, EventKind::MessageLike | EventKind::State | EventKind::ToDevice) | ||||
|             && var != EventEnumVariation::Initial | ||||
|     }), | ||||
|     ("state_key", |kind, _| matches!(kind, EventKind::State)), | ||||
| ]; | ||||
| 
 | ||||
| /// Create a content enum from `EventEnumInput`.
 | ||||
| @ -484,6 +483,21 @@ fn expand_accessor_methods( | ||||
|         }) | ||||
|     }); | ||||
| 
 | ||||
|     let state_key_accessor = (kind == EventKind::State).then(|| { | ||||
|         let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); | ||||
|         let call_parens = maybe_redacted.then(|| quote! { () }); | ||||
| 
 | ||||
|         quote! { | ||||
|             /// Returns this event's `state_key` field.
 | ||||
|             pub fn state_key(&self) -> &::std::primitive::str { | ||||
|                 match self { | ||||
|                     #( #variants(event) => &event.state_key #call_parens .as_ref(), )* | ||||
|                     Self::_Custom(event) => &event.state_key #call_parens .as_ref(), | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     }); | ||||
| 
 | ||||
|     let txn_id_accessor = maybe_redacted.then(|| { | ||||
|         let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); | ||||
|         quote! { | ||||
| @ -513,6 +527,7 @@ fn expand_accessor_methods( | ||||
| 
 | ||||
|             #content_accessors | ||||
|             #( #methods )* | ||||
|             #state_key_accessor | ||||
|             #txn_id_accessor | ||||
|         } | ||||
|     }) | ||||
| @ -572,7 +587,6 @@ fn field_return_type(name: &str, ruma_common: &TokenStream) -> TokenStream { | ||||
|         "room_id" => quote! { #ruma_common::RoomId }, | ||||
|         "event_id" => quote! { #ruma_common::EventId }, | ||||
|         "sender" => quote! { #ruma_common::UserId }, | ||||
|         "state_key" => quote! { ::std::primitive::str }, | ||||
|         _ => panic!("the `ruma_macros::event_enum::EVENT_FIELD` const was changed"), | ||||
|     } | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user