From 5a791b3c6e749d5428e7e3efe3ad781cb693e8d5 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Tue, 22 Mar 2022 10:17:04 +0100 Subject: [PATCH] macros: Simplify Incoming derive implementation --- crates/ruma-macros/src/serde/incoming.rs | 141 +++++++---------------- 1 file changed, 43 insertions(+), 98 deletions(-) diff --git a/crates/ruma-macros/src/serde/incoming.rs b/crates/ruma-macros/src/serde/incoming.rs index b0ff42b1..d621c771 100644 --- a/crates/ruma-macros/src/serde/incoming.rs +++ b/crates/ruma-macros/src/serde/incoming.rs @@ -4,32 +4,49 @@ use syn::{ parse::{Parse, ParseStream}, parse_quote, punctuated::Punctuated, - AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Field, Fields, GenericArgument, - GenericParam, Generics, Ident, ImplGenerics, ParenthesizedGenericArguments, Path, - PathArguments, Token, Type, TypeGenerics, TypePath, TypeReference, TypeSlice, Variant, + AngleBracketedGenericArguments, Attribute, Data, DeriveInput, GenericArgument, GenericParam, + Generics, Ident, ParenthesizedGenericArguments, Path, PathArguments, Token, Type, TypePath, + TypeReference, TypeSlice, }; use crate::util::import_ruma_common; -enum StructKind { - Struct, - Tuple, -} - -enum DataKind { - Struct(Vec, StructKind), - Enum(Vec), - Unit, -} - -pub fn expand_derive_incoming(input: DeriveInput) -> syn::Result { +pub fn expand_derive_incoming(mut ty_def: DeriveInput) -> syn::Result { let ruma_common = import_ruma_common(); + let mut found_lifetime = false; + match &mut ty_def.data { + Data::Union(_) => panic!("#[derive(Incoming)] does not support Union types"), + Data::Enum(e) => { + for var in &mut e.variants { + for field in &mut var.fields { + if strip_lifetimes(&mut field.ty) { + found_lifetime = true; + } + } + } + } + Data::Struct(s) => { + for field in &mut s.fields { + if !matches!(field.vis, syn::Visibility::Public(_)) { + return Err(syn::Error::new_spanned(field, "All fields must be marked `pub`")); + } + if strip_lifetimes(&mut field.ty) { + found_lifetime = true; + } + } + } + } + + if !found_lifetime { + return Ok(TokenStream::new()); + } + let mut derives = vec![quote! { Debug }]; let mut derive_deserialize = true; derives.extend( - input + ty_def .attrs .iter() .filter(|attr| attr.path.is_ident("incoming_derive")) @@ -52,86 +69,17 @@ pub fn expand_derive_incoming(input: DeriveInput) -> syn::Result { quote! { #ruma_common::serde::_FakeDeriveSerde } }); - let input_attrs = - input.attrs.iter().filter(|attr| filter_input_attrs(attr)).collect::>(); + ty_def.attrs.retain(filter_input_attrs); + clean_generics(&mut ty_def.generics); - let data = match input.data.clone() { - Data::Union(_) => panic!("#[derive(Incoming)] does not support Union types"), - Data::Enum(e) => DataKind::Enum(e.variants.into_iter().collect()), - Data::Struct(s) => match s.fields { - Fields::Named(fs) => { - DataKind::Struct(fs.named.into_iter().collect(), StructKind::Struct) - } - Fields::Unnamed(fs) => { - DataKind::Struct(fs.unnamed.into_iter().collect(), StructKind::Tuple) - } - Fields::Unit => DataKind::Unit, - }, - }; + let doc = format!("'Incoming' variant of [{}].", &ty_def.ident); + ty_def.ident = format_ident!("Incoming{}", ty_def.ident, span = Span::call_site()); - match data { - DataKind::Unit => Ok(TokenStream::new()), - DataKind::Enum(mut vars) => { - let mut found_lifetime = false; - for var in &mut vars { - for field in &mut var.fields { - if strip_lifetimes(&mut field.ty) { - found_lifetime = true; - } - } - } - - if !found_lifetime { - return Ok(TokenStream::new()); - } - - let vis = input.vis; - let doc = format!("'Incoming' variant of [{ty}](enum.{ty}.html).", ty = &input.ident); - let incoming_ident = format_ident!("Incoming{}", input.ident, span = Span::call_site()); - let mut gen_copy = input.generics.clone(); - let (_, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy); - - Ok(quote! { - #[doc = #doc] - #[derive( #( #derives ),* )] - #( #input_attrs )* - #vis enum #incoming_ident #ty_gen { #( #vars, )* } - }) - } - DataKind::Struct(mut fields, struct_kind) => { - let mut found_lifetime = false; - for field in &mut fields { - if !matches!(field.vis, syn::Visibility::Public(_)) { - return Err(syn::Error::new_spanned(field, "All fields must be marked `pub`")); - } - if strip_lifetimes(&mut field.ty) { - found_lifetime = true; - } - } - - if !found_lifetime { - return Ok(TokenStream::new()); - } - - let vis = input.vis; - let doc = format!("'Incoming' variant of [{ty}](struct.{ty}.html).", ty = &input.ident); - let incoming_ident = format_ident!("Incoming{}", input.ident, span = Span::call_site()); - let mut gen_copy = input.generics.clone(); - let (_, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy); - - let struct_def = match struct_kind { - StructKind::Struct => quote! { { #(#fields,)* } }, - StructKind::Tuple => quote! { ( #(#fields,)* ); }, - }; - - Ok(quote! { - #[doc = #doc] - #[derive( #( #derives ),* )] - #( #input_attrs )* - #vis struct #incoming_ident #ty_gen #struct_def - }) - } - } + Ok(quote! { + #[doc = #doc] + #[derive( #( #derives ),* )] + #ty_def + }) } /// Keep any `cfg`, `cfg_attr`, `serde` or `non_exhaustive` attributes found and pass them to the @@ -144,16 +92,13 @@ fn filter_input_attrs(attr: &Attribute) -> bool { || attr.path.is_ident("allow") } -fn split_for_impl_lifetime_less(generics: &mut Generics) -> (ImplGenerics<'_>, TypeGenerics<'_>) { +fn clean_generics(generics: &mut Generics) { generics.params = generics .params .clone() .into_iter() .filter(|param| !matches!(param, GenericParam::Lifetime(_))) .collect(); - - let (impl_gen, ty_gen, _) = generics.split_for_impl(); - (impl_gen, ty_gen) } fn strip_lifetimes(field_type: &mut Type) -> bool {