Pass the attributes of any type deriving Outgoing to the Incoming type

This commit is contained in:
Devin Ragotzy 2020-08-08 21:01:40 -04:00 committed by GitHub
parent f455d4c8ab
commit ddb1b48e71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -24,6 +24,9 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
quote!(::ruma_api::exports::serde::Deserialize) quote!(::ruma_api::exports::serde::Deserialize)
}; };
let input_attrs =
input.attrs.clone().into_iter().filter(filter_input_attrs).collect::<Vec<_>>();
let data = match input.data.clone() { let data = match input.data.clone() {
Data::Union(_) => panic!("#[derive(Outgoing)] does not support Union types"), Data::Union(_) => panic!("#[derive(Outgoing)] does not support Union types"),
Data::Enum(e) => DataKind::Enum(e.variants.into_iter().collect()), Data::Enum(e) => DataKind::Enum(e.variants.into_iter().collect()),
@ -67,6 +70,7 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
Ok(quote! { Ok(quote! {
#[doc = #doc] #[doc = #doc]
#[derive(Debug, #derive_deserialize)] #[derive(Debug, #derive_deserialize)]
#( #input_attrs )*
#vis enum #incoming_ident #ty_gen { #( #vars, )* } #vis enum #incoming_ident #ty_gen { #( #vars, )* }
impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen { impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen {
@ -104,6 +108,7 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
Ok(quote! { Ok(quote! {
#[doc = #doc] #[doc = #doc]
#[derive(Debug, #derive_deserialize)] #[derive(Debug, #derive_deserialize)]
#( #input_attrs )*
#vis struct #incoming_ident #ty_gen #struct_def #vis struct #incoming_ident #ty_gen #struct_def
impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen { impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen {
@ -114,6 +119,12 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
} }
} }
/// Keep any `serde` or `non_exhaustive` attributes found and
/// pass them to the Incoming variant.
fn filter_input_attrs(attr: &Attribute) -> bool {
attr.path.is_ident("serde") || attr.path.is_ident("non_exhaustive")
}
fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool { fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| attr.path.is_ident("incoming_no_deserialize")) attrs.iter().any(|attr| attr.path.is_ident("incoming_no_deserialize"))
} }