events: Allow state key type to be customized by content type

This commit is contained in:
Jonas Platte 2022-04-12 22:07:46 +02:00 committed by Jonas Platte
parent da222a87c9
commit d8b7886382
10 changed files with 218 additions and 108 deletions

View File

@ -4,7 +4,7 @@ use serde_json::value::RawValue as RawJsonValue;
use super::{ use super::{
EphemeralRoomEventType, EventContent, GlobalAccountDataEventType, HasDeserializeFields, EphemeralRoomEventType, EventContent, GlobalAccountDataEventType, HasDeserializeFields,
MessageLikeEventType, RedactContent, RedactedEventContent, RoomAccountDataEventType, MessageLikeEventType, RedactContent, RedactedEventContent, RoomAccountDataEventType,
StateEventType, ToDeviceEventType, StateEventContent, StateEventType, ToDeviceEventType,
}; };
use crate::RoomVersionId; use crate::RoomVersionId;
@ -68,3 +68,7 @@ custom_event_content!(CustomEphemeralRoomEventContent, EphemeralRoomEventType);
custom_room_event_content!(CustomMessageLikeEventContent, MessageLikeEventType); custom_room_event_content!(CustomMessageLikeEventContent, MessageLikeEventType);
custom_room_event_content!(CustomStateEventContent, StateEventType); custom_room_event_content!(CustomStateEventContent, StateEventType);
custom_event_content!(CustomToDeviceEventContent, ToDeviceEventType); custom_event_content!(CustomToDeviceEventContent, ToDeviceEventType);
impl StateEventContent for CustomStateEventContent {
type StateKey = String;
}

View File

@ -1,6 +1,6 @@
use std::fmt; use std::fmt;
use serde::Serialize; use serde::{de::DeserializeOwned, Serialize};
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
use crate::serde::Raw; use crate::serde::Raw;
@ -164,17 +164,18 @@ trait_aliases! {
/// An alias for `EventContent<EventType = MessageLikeEventType>`. /// An alias for `EventContent<EventType = MessageLikeEventType>`.
trait MessageLikeEventContent = EventContent<EventType = MessageLikeEventType>; trait MessageLikeEventContent = EventContent<EventType = MessageLikeEventType>;
/// An alias for `EventContent<EventType = MessageLikeEventType> + RedactedEventContent`. /// An alias for `MessageLikeEventContent + RedactedEventContent`.
trait RedactedMessageLikeEventContent = trait RedactedMessageLikeEventContent = MessageLikeEventContent, RedactedEventContent;
EventContent<EventType = MessageLikeEventType>, RedactedEventContent;
/// An alias for `EventContent<EventType = StateEventType>`. /// An alias for `StateEventContent + RedactedEventContent`.
trait StateEventContent = EventContent<EventType = StateEventType>; trait RedactedStateEventContent = StateEventContent, RedactedEventContent;
/// An alias for `EventContent<EventType = StateEventType> + RedactedEventContent`.
trait RedactedStateEventContent =
EventContent<EventType = StateEventType>, RedactedEventContent;
/// An alias for `EventContent<EventType = ToDeviceEventType>`. /// An alias for `EventContent<EventType = ToDeviceEventType>`.
trait ToDeviceEventContent = EventContent<EventType = ToDeviceEventType>; trait ToDeviceEventContent = EventContent<EventType = ToDeviceEventType>;
} }
/// An alias for `EventContent<EventType = StateEventType>`.
pub trait StateEventContent: EventContent<EventType = StateEventType> {
/// The type of the event's `state_key` field.
type StateKey: AsRef<str> + Clone + fmt::Debug + DeserializeOwned + Serialize;
}

View File

@ -203,7 +203,7 @@ pub struct OriginalStateEvent<C: StateEventContent> {
/// ///
/// This is often an empty string, but some events send a `UserId` to show which user the event /// This is often an empty string, but some events send a `UserId` to show which user the event
/// affects. /// affects.
pub state_key: String, pub state_key: C::StateKey,
/// Additional key-value pairs not signed by the homeserver. /// Additional key-value pairs not signed by the homeserver.
pub unsigned: StateUnsigned<C>, pub unsigned: StateUnsigned<C>,
@ -231,7 +231,7 @@ pub struct OriginalSyncStateEvent<C: StateEventContent> {
/// ///
/// This is often an empty string, but some events send a `UserId` to show which user the event /// This is often an empty string, but some events send a `UserId` to show which user the event
/// affects. /// affects.
pub state_key: String, pub state_key: C::StateKey,
/// Additional key-value pairs not signed by the homeserver. /// Additional key-value pairs not signed by the homeserver.
pub unsigned: StateUnsigned<C>, pub unsigned: StateUnsigned<C>,
@ -250,7 +250,7 @@ pub struct StrippedStateEvent<C: StateEventContent> {
/// ///
/// This is often an empty string, but some events send a `UserId` to show which user the event /// This is often an empty string, but some events send a `UserId` to show which user the event
/// affects. /// affects.
pub state_key: String, pub state_key: C::StateKey,
} }
/// A minimal state event, used for creating a new room. /// A minimal state event, used for creating a new room.
@ -265,8 +265,7 @@ pub struct InitialStateEvent<C: StateEventContent> {
/// affects. /// affects.
/// ///
/// Defaults to the empty string. /// Defaults to the empty string.
#[ruma_event(default)] pub state_key: C::StateKey,
pub state_key: String,
} }
/// A redacted state event. /// A redacted state event.
@ -294,7 +293,7 @@ pub struct RedactedStateEvent<C: RedactedStateEventContent> {
/// ///
/// This is often an empty string, but some events send a `UserId` to show which user the event /// This is often an empty string, but some events send a `UserId` to show which user the event
/// affects. /// affects.
pub state_key: String, pub state_key: C::StateKey,
/// Additional key-value pairs not signed by the homeserver. /// Additional key-value pairs not signed by the homeserver.
pub unsigned: RedactedUnsigned, pub unsigned: RedactedUnsigned,
@ -322,7 +321,7 @@ pub struct RedactedSyncStateEvent<C: RedactedStateEventContent> {
/// ///
/// This is often an empty string, but some events send a `UserId` to show which user the event /// This is often an empty string, but some events send a `UserId` to show which user the event
/// affects. /// affects.
pub state_key: String, pub state_key: C::StateKey,
/// Additional key-value pairs not signed by the homeserver. /// Additional key-value pairs not signed by the homeserver.
pub unsigned: RedactedUnsigned, pub unsigned: RedactedUnsigned,
@ -411,11 +410,16 @@ pub struct DecryptedMegolmV1Event<C: MessageLikeEventContent> {
} }
macro_rules! impl_possibly_redacted_event { macro_rules! impl_possibly_redacted_event {
($ty:ident ( $content_trait:ident, $event_type:ident ) { $($extra:tt)* }) => { (
$ty:ident ( $content_trait:ident, $event_type:ident )
$( where C::Redacted: $trait:ident<StateKey = C::StateKey>, )?
{ $($extra:tt)* }
) => {
impl<C> $ty<C> impl<C> $ty<C>
where where
C: $content_trait + RedactContent, C: $content_trait + RedactContent,
C::Redacted: $content_trait + RedactedEventContent, C::Redacted: $content_trait + RedactedEventContent,
$( C::Redacted: $trait<StateKey = C::StateKey>, )?
{ {
/// Returns the `type` of this event. /// Returns the `type` of this event.
pub fn event_type(&self) -> $event_type { pub fn event_type(&self) -> $event_type {
@ -457,6 +461,7 @@ macro_rules! impl_possibly_redacted_event {
where where
C: $content_trait + RedactContent, C: $content_trait + RedactContent,
C::Redacted: $content_trait + RedactedEventContent, C::Redacted: $content_trait + RedactedEventContent,
$( C::Redacted: $trait<StateKey = C::StateKey>, )?
{ {
type Redacted = Self; type Redacted = Self;
@ -472,6 +477,7 @@ macro_rules! impl_possibly_redacted_event {
where where
C: $content_trait + RedactContent, C: $content_trait + RedactContent,
C::Redacted: $content_trait + RedactedEventContent, C::Redacted: $content_trait + RedactedEventContent,
$( C::Redacted: $trait<StateKey = C::StateKey>, )?
{ {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where
@ -526,7 +532,11 @@ impl_possibly_redacted_event!(SyncMessageLikeEvent(MessageLikeEventContent, Mess
} }
}); });
impl_possibly_redacted_event!(StateEvent(StateEventContent, StateEventType) { impl_possibly_redacted_event!(
StateEvent(StateEventContent, StateEventType)
where
C::Redacted: StateEventContent<StateKey = C::StateKey>,
{
/// Returns this event's `room_id` field. /// Returns this event's `room_id` field.
pub fn room_id(&self) -> &RoomId { pub fn room_id(&self) -> &RoomId {
match self { match self {
@ -536,7 +546,7 @@ impl_possibly_redacted_event!(StateEvent(StateEventContent, StateEventType) {
} }
/// Returns this event's `state_key` field. /// Returns this event's `state_key` field.
pub fn state_key(&self) -> &str { pub fn state_key(&self) -> &C::StateKey {
match self { match self {
Self::Original(ev) => &ev.state_key, Self::Original(ev) => &ev.state_key,
Self::Redacted(ev) => &ev.state_key, Self::Redacted(ev) => &ev.state_key,
@ -550,11 +560,16 @@ impl_possibly_redacted_event!(StateEvent(StateEventContent, StateEventType) {
_ => None, _ => None,
} }
} }
}); }
);
impl_possibly_redacted_event!(SyncStateEvent(StateEventContent, StateEventType) { impl_possibly_redacted_event!(
SyncStateEvent(StateEventContent, StateEventType)
where
C::Redacted: StateEventContent<StateKey = C::StateKey>,
{
/// Returns this event's `state_key` field. /// Returns this event's `state_key` field.
pub fn state_key(&self) -> &str { pub fn state_key(&self) -> &C::StateKey {
match self { match self {
Self::Original(ev) => &ev.state_key, Self::Original(ev) => &ev.state_key,
Self::Redacted(ev) => &ev.state_key, Self::Redacted(ev) => &ev.state_key,
@ -576,7 +591,8 @@ impl_possibly_redacted_event!(SyncStateEvent(StateEventContent, StateEventType)
Self::Redacted(ev) => StateEvent::Redacted(ev.into_full_event(room_id)), Self::Redacted(ev) => StateEvent::Redacted(ev.into_full_event(room_id)),
} }
} }
}); }
);
macro_rules! impl_sync_from_full { macro_rules! impl_sync_from_full {
($ty:ident, $full:ident, $content_trait:ident) => { ($ty:ident, $full:ident, $content_trait:ident) => {

View File

@ -6,7 +6,8 @@ use serde_json::value::RawValue as RawJsonValue;
use crate::{ use crate::{
events::{ events::{
EventContent, HasDeserializeFields, RedactContent, RedactedEventContent, StateEventType, EventContent, HasDeserializeFields, RedactContent, RedactedEventContent, StateEventContent,
StateEventType,
}, },
OwnedRoomAliasId, RoomVersionId, OwnedRoomAliasId, RoomVersionId,
}; };
@ -94,6 +95,10 @@ impl EventContent for RedactedRoomAliasesEventContent {
} }
} }
impl StateEventContent for RedactedRoomAliasesEventContent {
type StateKey = String; // Box<ServerName>
}
// Since this redacted event has fields we leave the default `empty` method // Since this redacted event has fields we leave the default `empty` method
// that will error if called. // that will error if called.
impl RedactedEventContent for RedactedRoomAliasesEventContent { impl RedactedEventContent for RedactedRoomAliasesEventContent {

View File

@ -11,7 +11,7 @@ use serde_json::value::RawValue as RawJsonValue;
use crate::{ use crate::{
events::{ events::{
EventContent, HasDeserializeFields, OriginalSyncStateEvent, RedactContent, EventContent, HasDeserializeFields, OriginalSyncStateEvent, RedactContent,
RedactedEventContent, StateEventType, StrippedStateEvent, RedactedEventContent, StateEventContent, StateEventType, StrippedStateEvent,
}, },
serde::StringEnum, serde::StringEnum,
MxcUri, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, PrivOwnedStr, RoomVersionId, MxcUri, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, PrivOwnedStr, RoomVersionId,
@ -175,6 +175,10 @@ impl EventContent for RedactedRoomMemberEventContent {
} }
} }
impl StateEventContent for RedactedRoomMemberEventContent {
type StateKey = String; // Box<UserId>
}
// Since this redacted event has fields we leave the default `empty` method // Since this redacted event has fields we leave the default `empty` method
// that will error if called. // that will error if called.
impl RedactedEventContent for RedactedRoomMemberEventContent { impl RedactedEventContent for RedactedRoomMemberEventContent {

View File

@ -6,7 +6,7 @@ error: no event type attribute found, add `#[ruma_event(type = "any.room.event",
| |
= note: this error originates in the derive macro `EventContent` (in Nightly builds, run with -Z macro-backtrace for more info) = note: this error originates in the derive macro `EventContent` (in Nightly builds, run with -Z macro-backtrace for more info)
error: expected one of: `type`, `kind`, `skip_redaction`, `custom_redacted`, `type_fragment` error: expected one of: `type`, `kind`, `skip_redaction`, `custom_redacted`, `type_fragment`, `state_key_type`
--> tests/events/ui/03-invalid-event-type.rs:11:14 --> tests/events/ui/03-invalid-event-type.rs:11:14
| |
11 | #[ruma_event(event = "m.macro.test", kind = State)] 11 | #[ruma_event(event = "m.macro.test", kind = State)]

View File

@ -17,7 +17,7 @@ pub struct OriginalStateEvent<C: StateEventContent> {
pub sender: OwnedUserId, pub sender: OwnedUserId,
pub origin_server_ts: MilliSecondsSinceUnixEpoch, pub origin_server_ts: MilliSecondsSinceUnixEpoch,
pub room_id: OwnedRoomId, pub room_id: OwnedRoomId,
pub state_key: String, pub state_key: C::StateKey,
pub unsigned: StateUnsigned<C>, pub unsigned: StateUnsigned<C>,
} }

View File

@ -2,10 +2,7 @@
use proc_macro2::{Span, TokenStream}; use proc_macro2::{Span, TokenStream};
use quote::quote; use quote::quote;
use syn::{ use syn::{parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam};
parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam, Meta,
MetaList, NestedMeta,
};
use super::{ use super::{
event_parse::{to_kind_variation, EventKind, EventKindVariation}, event_parse::{to_kind_variation, EventKind, EventKindVariation},
@ -169,14 +166,16 @@ fn expand_deserialize_event(
.iter() .iter()
.map(|field| { .map(|field| {
let name = field.ident.as_ref().unwrap(); let name = field.ident.as_ref().unwrap();
let ty = &field.ty;
if name == "content" || (name == "unsigned" && has_prev_content(kind, var)) { if name == "content" || (name == "unsigned" && has_prev_content(kind, var)) {
if is_generic { if is_generic {
quote! { ::std::boxed::Box<#serde_json::value::RawValue> } quote! { ::std::boxed::Box<#serde_json::value::RawValue> }
} else { } else {
quote! { #content_type } quote! { #content_type }
} }
} else if name == "state_key" && var == EventKindVariation::Initial {
quote! { ::std::string::String }
} else { } else {
let ty = &field.ty;
quote! { #ty } quote! { #ty }
} }
}) })
@ -225,7 +224,8 @@ fn expand_deserialize_event(
)?; )?;
} }
} }
} else if name == "unsigned" && has_prev_content(kind, var) { } else if name == "unsigned" {
if has_prev_content(kind, var) {
quote! { quote! {
let unsigned = unsigned.map(|json| { let unsigned = unsigned.map(|json| {
#ruma_common::events::StateUnsigned::_from_parts(&event_type, &json) #ruma_common::events::StateUnsigned::_from_parts(&event_type, &json)
@ -233,26 +233,17 @@ fn expand_deserialize_event(
}).transpose()?.unwrap_or_default(); }).transpose()?.unwrap_or_default();
} }
} else { } else {
let attrs: Vec<_> = field
.attrs
.iter()
.filter(|a| a.path.is_ident("ruma_event"))
.map(|a| a.parse_meta())
.collect::<syn::Result<_>>()?;
let has_default_attr = attrs.iter().any(|a| {
matches!(
a,
Meta::List(MetaList { nested, .. })
if nested.iter().any(|n| {
matches!(n, NestedMeta::Meta(Meta::Path(p)) if p.is_ident("default"))
})
)
});
if has_default_attr || name == "unsigned" {
quote! { quote! {
let #name = #name.unwrap_or_default(); let unsigned = unsigned.unwrap_or_default();
}
}
} else if name == "state_key" && var == EventKindVariation::Initial {
let ty = &field.ty;
quote! {
let state_key: ::std::string::String = state_key.unwrap_or_default();
let state_key: #ty = <#ty as #serde::de::Deserialize>::deserialize(
#serde::de::IntoDeserializer::<A::Error>::into_deserializer(state_key),
)?;
} }
} else { } else {
quote! { quote! {
@ -260,7 +251,6 @@ fn expand_deserialize_event(
#serde::de::Error::missing_field(stringify!(#name)) #serde::de::Error::missing_field(stringify!(#name))
})?; })?;
} }
}
}) })
}) })
.collect::<syn::Result<_>>()?; .collect::<syn::Result<_>>()?;
@ -385,12 +375,21 @@ fn expand_redact_event(
let where_clause = generics.make_where_clause(); let where_clause = generics.make_where_clause();
where_clause.predicates.push(parse_quote! { #ty_param: #ruma_common::events::RedactContent }); where_clause.predicates.push(parse_quote! { #ty_param: #ruma_common::events::RedactContent });
where_clause.predicates.push(parse_quote! {
<#ty_param as #ruma_common::events::RedactContent>::Redacted: let redacted_event_content_bound = if kind == EventKind::State {
quote! {
#ruma_common::events::StateEventContent<StateKey = #ty_param::StateKey>
}
} else {
quote! {
#ruma_common::events::EventContent< #ruma_common::events::EventContent<
EventType = #ruma_common::events::#redacted_event_type_enum EventType = #ruma_common::events::#redacted_event_type_enum
> >
+ #ruma_common::events::RedactedEventContent }
};
where_clause.predicates.push(parse_quote! {
<#ty_param as #ruma_common::events::RedactContent>::Redacted:
#redacted_event_content_bound + #ruma_common::events::RedactedEventContent
}); });
let (impl_generics, ty_gen, where_clause) = generics.split_for_impl(); let (impl_generics, ty_gen, where_clause) = generics.split_for_impl();

View File

@ -6,7 +6,7 @@ use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{ use syn::{
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
DeriveInput, Field, Ident, LitStr, Token, DeriveInput, Field, Ident, LitStr, Token, Type,
}; };
use crate::util::m_prefix_name_to_type_name; use crate::util::m_prefix_name_to_type_name;
@ -21,6 +21,8 @@ mod kw {
// The kind of event content this is. // The kind of event content this is.
syn::custom_keyword!(kind); syn::custom_keyword!(kind);
syn::custom_keyword!(type_fragment); syn::custom_keyword!(type_fragment);
// The type to use for a state events' `state_key` field.
syn::custom_keyword!(state_key_type);
} }
/// Parses attributes for `*EventContent` derives. /// Parses attributes for `*EventContent` derives.
@ -43,6 +45,8 @@ enum EventMeta {
/// The given field holds a part of the event type (replaces the `*` in a `m.foo.*` event /// The given field holds a part of the event type (replaces the `*` in a `m.foo.*` event
/// type). /// type).
TypeFragment, TypeFragment,
StateKeyType(Box<Type>),
} }
impl EventMeta { impl EventMeta {
@ -59,6 +63,13 @@ impl EventMeta {
_ => None, _ => None,
} }
} }
fn get_state_key_type(&self) -> Option<&Type> {
match self {
Self::StateKeyType(ty) => Some(ty),
_ => None,
}
}
} }
impl Parse for EventMeta { impl Parse for EventMeta {
@ -71,7 +82,7 @@ impl Parse for EventMeta {
} else if lookahead.peek(kw::kind) { } else if lookahead.peek(kw::kind) {
let _: kw::kind = input.parse()?; let _: kw::kind = input.parse()?;
let _: Token![=] = input.parse()?; let _: Token![=] = input.parse()?;
EventKind::parse(input).map(EventMeta::Kind) input.parse().map(EventMeta::Kind)
} else if lookahead.peek(kw::skip_redaction) { } else if lookahead.peek(kw::skip_redaction) {
let _: kw::skip_redaction = input.parse()?; let _: kw::skip_redaction = input.parse()?;
Ok(EventMeta::SkipRedaction) Ok(EventMeta::SkipRedaction)
@ -81,6 +92,10 @@ impl Parse for EventMeta {
} else if lookahead.peek(kw::type_fragment) { } else if lookahead.peek(kw::type_fragment) {
let _: kw::type_fragment = input.parse()?; let _: kw::type_fragment = input.parse()?;
Ok(EventMeta::TypeFragment) Ok(EventMeta::TypeFragment)
} else if lookahead.peek(kw::state_key_type) {
let _: kw::state_key_type = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(EventMeta::StateKeyType)
} else { } else {
Err(lookahead.error()) Err(lookahead.error())
} }
@ -101,6 +116,10 @@ impl MetaAttrs {
fn get_event_kind(&self) -> Option<EventKind> { fn get_event_kind(&self) -> Option<EventKind> {
self.0.iter().find_map(|a| a.get_event_kind()) self.0.iter().find_map(|a| a.get_event_kind())
} }
fn get_state_key_type(&self) -> Option<&Type> {
self.0.iter().find_map(|a| a.get_state_key_type())
}
} }
impl Parse for MetaAttrs { impl Parse for MetaAttrs {
@ -155,6 +174,26 @@ pub fn expand_event_content(
} }
}; };
let state_key_types: Vec<_> =
content_attr.iter().filter_map(|attrs| attrs.get_state_key_type()).collect();
let state_key_type = match (event_kind, state_key_types.as_slice()) {
(Some(EventKind::State), []) => Some(quote! { ::std::string::String }),
(Some(EventKind::State), [ty]) => Some(quote! { #ty }),
(Some(EventKind::State), _) => {
return Err(syn::Error::new(
Span::call_site(),
"multiple state_key_type attribute found, there can only be one",
));
}
(_, []) => None,
(_, [ty, ..]) => {
return Err(syn::Error::new_spanned(
ty,
"state_key_type attribute is not valid for non-state event kinds",
));
}
};
let ident = &input.ident; let ident = &input.ident;
let fields = match &input.data { let fields = match &input.data {
syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.iter(), syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.iter(),
@ -185,12 +224,25 @@ pub fn expand_event_content(
// We only generate redacted content structs for state and message-like events // We only generate redacted content structs for state and message-like events
let redacted_event_content = needs_redacted(&content_attr, event_kind).then(|| { let redacted_event_content = needs_redacted(&content_attr, event_kind).then(|| {
generate_redacted_event_content(ident, fields.clone(), event_type, event_kind, ruma_common) generate_redacted_event_content(
ident,
fields.clone(),
event_type,
event_kind,
state_key_type.as_ref(),
ruma_common,
)
.unwrap_or_else(syn::Error::into_compile_error) .unwrap_or_else(syn::Error::into_compile_error)
}); });
let event_content_impl = let event_content_impl = generate_event_content_impl(
generate_event_content_impl(ident, fields, event_type, event_kind, ruma_common) ident,
fields,
event_type,
event_kind,
state_key_type.as_ref(),
ruma_common,
)
.unwrap_or_else(syn::Error::into_compile_error); .unwrap_or_else(syn::Error::into_compile_error);
let static_event_content_impl = event_kind let static_event_content_impl = event_kind
.map(|k| generate_static_event_content_impl(ident, k, false, event_type, ruma_common)); .map(|k| generate_static_event_content_impl(ident, k, false, event_type, ruma_common));
@ -212,6 +264,7 @@ fn generate_redacted_event_content<'a>(
fields: impl Iterator<Item = &'a Field>, fields: impl Iterator<Item = &'a Field>,
event_type: &LitStr, event_type: &LitStr,
event_kind: Option<EventKind>, event_kind: Option<EventKind>,
state_key_type: Option<&TokenStream>,
ruma_common: &TokenStream, ruma_common: &TokenStream,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
assert!( assert!(
@ -295,6 +348,7 @@ fn generate_redacted_event_content<'a>(
kept_redacted_fields.iter(), kept_redacted_fields.iter(),
event_type, event_type,
event_kind, event_kind,
state_key_type,
ruma_common, ruma_common,
) )
.unwrap_or_else(syn::Error::into_compile_error); .unwrap_or_else(syn::Error::into_compile_error);
@ -416,6 +470,7 @@ fn generate_event_content_impl<'a>(
mut fields: impl Iterator<Item = &'a Field>, mut fields: impl Iterator<Item = &'a Field>,
event_type: &LitStr, event_type: &LitStr,
event_kind: Option<EventKind>, event_kind: Option<EventKind>,
state_key_type: Option<&TokenStream>,
ruma_common: &TokenStream, ruma_common: &TokenStream,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let serde = quote! { #ruma_common::exports::serde }; let serde = quote! { #ruma_common::exports::serde };
@ -489,6 +544,16 @@ fn generate_event_content_impl<'a>(
} }
} }
let state_event_content_impl = (event_kind == Some(EventKind::State)).then(|| {
assert!(state_key_type.is_some());
quote! {
#[automatically_derived]
impl #ruma_common::events::StateEventContent for #ident {
type StateKey = #state_key_type;
}
}
});
Ok(quote! { Ok(quote! {
#event_type_ty_decl #event_type_ty_decl
@ -513,6 +578,8 @@ fn generate_event_content_impl<'a>(
#serde_json::from_str(content.get()) #serde_json::from_str(content.get())
} }
} }
#state_event_content_impl
}) })
} }

View File

@ -37,7 +37,6 @@ const EVENT_FIELDS: &[(&str, EventKindFn)] = &[
matches!(kind, EventKind::MessageLike | EventKind::State | EventKind::ToDevice) matches!(kind, EventKind::MessageLike | EventKind::State | EventKind::ToDevice)
&& var != EventEnumVariation::Initial && var != EventEnumVariation::Initial
}), }),
("state_key", |kind, _| matches!(kind, EventKind::State)),
]; ];
/// Create a content enum from `EventEnumInput`. /// Create a content enum from `EventEnumInput`.
@ -484,6 +483,21 @@ fn expand_accessor_methods(
}) })
}); });
let state_key_accessor = (kind == EventKind::State).then(|| {
let variants = variants.iter().map(|v| v.match_arm(quote! { Self }));
let call_parens = maybe_redacted.then(|| quote! { () });
quote! {
/// Returns this event's `state_key` field.
pub fn state_key(&self) -> &::std::primitive::str {
match self {
#( #variants(event) => &event.state_key #call_parens .as_ref(), )*
Self::_Custom(event) => &event.state_key #call_parens .as_ref(),
}
}
}
});
let txn_id_accessor = maybe_redacted.then(|| { let txn_id_accessor = maybe_redacted.then(|| {
let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); let variants = variants.iter().map(|v| v.match_arm(quote! { Self }));
quote! { quote! {
@ -513,6 +527,7 @@ fn expand_accessor_methods(
#content_accessors #content_accessors
#( #methods )* #( #methods )*
#state_key_accessor
#txn_id_accessor #txn_id_accessor
} }
}) })
@ -572,7 +587,6 @@ fn field_return_type(name: &str, ruma_common: &TokenStream) -> TokenStream {
"room_id" => quote! { #ruma_common::RoomId }, "room_id" => quote! { #ruma_common::RoomId },
"event_id" => quote! { #ruma_common::EventId }, "event_id" => quote! { #ruma_common::EventId },
"sender" => quote! { #ruma_common::UserId }, "sender" => quote! { #ruma_common::UserId },
"state_key" => quote! { ::std::primitive::str },
_ => panic!("the `ruma_macros::event_enum::EVENT_FIELD` const was changed"), _ => panic!("the `ruma_macros::event_enum::EVENT_FIELD` const was changed"),
} }
} }