diff --git a/ruma-api-macros/src/derive_outgoing.rs b/ruma-api-macros/src/derive_outgoing.rs index 0014bad4..c5a11967 100644 --- a/ruma-api-macros/src/derive_outgoing.rs +++ b/ruma-api-macros/src/derive_outgoing.rs @@ -1,9 +1,9 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ - parse_quote, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Fields, - GenericArgument, GenericParam, Generics, ImplGenerics, PathArguments, Type, TypeGenerics, - TypePath, TypeReference, TypeSlice, + parse_quote, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Field, Fields, + GenericArgument, GenericParam, Generics, ImplGenerics, ParenthesizedGenericArguments, + PathArguments, Type, TypeGenerics, TypePath, TypeReference, TypeSlice, Variant, }; enum StructKind { @@ -11,6 +11,11 @@ enum StructKind { Tuple, } +enum DataKind { + Struct(Vec, StructKind), + Enum(Vec), +} + pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { let derive_deserialize = if no_deserialize_in_attrs(&input.attrs) { TokenStream::new() @@ -18,61 +23,114 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { quote!(::ruma_api::exports::serde::Deserialize) }; - let (mut fields, struct_kind): (Vec<_>, _) = match input.data { - Data::Enum(_) | Data::Union(_) => { - panic!("#[derive(Outgoing)] is only supported for structs") - } + let data = match input.data { + Data::Union(_) => panic!("#[derive(Outgoing)] does not support Union types"), + Data::Enum(e) => DataKind::Enum(e.variants.into_iter().collect()), Data::Struct(s) => match s.fields { - Fields::Named(fs) => (fs.named.into_iter().collect(), StructKind::Struct), - Fields::Unnamed(fs) => (fs.unnamed.into_iter().collect(), StructKind::Tuple), - Fields::Unit => return Ok(impl_outgoing_with_incoming_self(input.ident)), + Fields::Named(fs) => { + DataKind::Struct(fs.named.into_iter().collect(), StructKind::Struct) + } + Fields::Unnamed(fs) => { + DataKind::Struct(fs.unnamed.into_iter().collect(), StructKind::Tuple) + } + Fields::Unit => return Ok(impl_outgoing_with_incoming_self(&input.ident, None, None)), }, }; - let mut any_attribute = false; + match data { + DataKind::Enum(mut vars) => { + let mut any_attribute = false; + for var in &mut vars { + for field in &mut var.fields { + if strip_lifetimes(&mut field.ty) { + any_attribute = true; + } + } + } - for field in &mut fields { - if strip_lifetimes(&mut field.ty) { - any_attribute = true; + let original_ident = &input.ident; + let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl(); + + if !any_attribute { + return Ok(impl_outgoing_with_incoming_self( + original_ident, + Some(original_impl_gen), + Some(original_ty_gen), + )); + } + + let vis = input.vis; + let doc = format!("'Incoming' variant of [{ty}](enum.{ty}.html).", ty = &input.ident); + let incoming_ident = + format_ident!("Incoming{}", original_ident, span = Span::call_site()); + let mut gen_copy = input.generics.clone(); + let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy); + + Ok(quote! { + #[doc = #doc] + #[derive(Debug, #derive_deserialize)] + #vis enum #incoming_ident #ty_gen { #( #vars, )* } + + impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen { + type Incoming = #incoming_ident #impl_gen; + } + }) + } + DataKind::Struct(mut fields, struct_kind) => { + let mut any_attribute = false; + for field in &mut fields { + if strip_lifetimes(&mut field.ty) { + any_attribute = true; + } + } + + let original_ident = &input.ident; + let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl(); + + if !any_attribute { + return Ok(impl_outgoing_with_incoming_self( + original_ident, + Some(original_impl_gen), + Some(original_ty_gen), + )); + } + + let vis = input.vis; + let doc = format!("'Incoming' variant of [{ty}](struct.{ty}.html).", ty = &input.ident); + let incoming_ident = + format_ident!("Incoming{}", original_ident, span = Span::call_site()); + let mut gen_copy = input.generics.clone(); + let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy); + + let struct_def = match struct_kind { + StructKind::Struct => quote! { { #(#fields,)* } }, + StructKind::Tuple => quote! { ( #(#fields,)* ); }, + }; + + Ok(quote! { + #[doc = #doc] + #[derive(Debug, #derive_deserialize)] + #vis struct #incoming_ident #ty_gen #struct_def + + impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen { + type Incoming = #incoming_ident #impl_gen; + } + }) } } - - if !any_attribute { - return Ok(impl_outgoing_with_incoming_self(input.ident)); - } - - let original_ident = &input.ident; - let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl(); - - let vis = input.vis; - let doc = format!("'Incoming' variant of [{ty}](struct.{ty}.html).", ty = &input.ident); - let incoming_ident = format_ident!("Incoming{}", original_ident, span = Span::call_site()); - let mut gen_copy = input.generics.clone(); - let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy); - - let struct_def = match struct_kind { - StructKind::Struct => quote! { { #(#fields,)* } }, - StructKind::Tuple => quote! { ( #(#fields,)* ); }, - }; - - Ok(quote! { - #[doc = #doc] - #[derive(Debug, #derive_deserialize)] - #vis struct #incoming_ident #ty_gen #struct_def - - impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen { - type Incoming = #incoming_ident #impl_gen; - } - }) } fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| attr.path.is_ident("incoming_no_deserialize")) } -fn impl_outgoing_with_incoming_self(ident: Ident) -> TokenStream { +fn impl_outgoing_with_incoming_self( + ident: &Ident, + impl_gen: Option, + ty_gen: Option, +) -> TokenStream { quote! { - impl ::ruma_api::Outgoing for #ident { + impl #impl_gen ::ruma_api::Outgoing for #ident #ty_gen { type Incoming = Self; } } @@ -98,22 +156,42 @@ fn strip_lifetimes(field_type: &mut Type) -> bool { let mut has_lifetimes = false; for seg in &mut path.segments { // strip generic lifetimes - if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { - args, .. - }) = &mut seg.arguments - { - *args = args - .clone() - .into_iter() - .filter(|arg| { - if let GenericArgument::Lifetime(_) = arg { - has_lifetimes = true; - false - } else { - true - } - }) - .collect(); + match &mut seg.arguments { + PathArguments::AngleBracketed(AngleBracketedGenericArguments { + args, .. + }) => { + *args = args + .clone() + .into_iter() + .map(|mut ty| { + if let GenericArgument::Type(ty) = &mut ty { + strip_lifetimes(ty); + } + ty + }) + .filter(|arg| { + if let GenericArgument::Lifetime(_) = arg { + has_lifetimes = true; + false + } else { + true + } + }) + .collect(); + } + PathArguments::Parenthesized(ParenthesizedGenericArguments { + inputs, .. + }) => { + *inputs = inputs + .clone() + .into_iter() + .map(|mut ty| { + strip_lifetimes(&mut ty); + ty + }) + .collect(); + } + _ => {} } } diff --git a/ruma-api/tests/outgoing.rs b/ruma-api/tests/outgoing.rs index f793f042..3805693d 100644 --- a/ruma-api/tests/outgoing.rs +++ b/ruma-api/tests/outgoing.rs @@ -22,4 +22,14 @@ pub struct Request<'a, T> { pub user_id: &'a UserId, pub bytes: &'a [u8], pub recursive: &'a [Thing<'a, T>], + pub option: Option<&'a [u8]>, +} + +#[derive(Outgoing)] +#[incoming_no_deserialize] +pub enum EnumThing<'a, T> { + Abc(&'a str), + Stuff(Thing<'a, T>), + Boxy(&'a ::ruma_identifiers::DeviceId), + Other(Option<&'a str>), }