macros: Simplify Incoming derive implementation

This commit is contained in:
Jonas Platte 2022-03-22 10:17:04 +01:00 committed by Jonas Platte
parent a6a530dcc8
commit 5a791b3c6e

View File

@ -4,32 +4,49 @@ use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Field, Fields, GenericArgument,
GenericParam, Generics, Ident, ImplGenerics, ParenthesizedGenericArguments, Path,
PathArguments, Token, Type, TypeGenerics, TypePath, TypeReference, TypeSlice, Variant,
AngleBracketedGenericArguments, Attribute, Data, DeriveInput, GenericArgument, GenericParam,
Generics, Ident, ParenthesizedGenericArguments, Path, PathArguments, Token, Type, TypePath,
TypeReference, TypeSlice,
};
use crate::util::import_ruma_common;
enum StructKind {
Struct,
Tuple,
}
enum DataKind {
Struct(Vec<Field>, StructKind),
Enum(Vec<Variant>),
Unit,
}
pub fn expand_derive_incoming(input: DeriveInput) -> syn::Result<TokenStream> {
pub fn expand_derive_incoming(mut ty_def: DeriveInput) -> syn::Result<TokenStream> {
let ruma_common = import_ruma_common();
let mut found_lifetime = false;
match &mut ty_def.data {
Data::Union(_) => panic!("#[derive(Incoming)] does not support Union types"),
Data::Enum(e) => {
for var in &mut e.variants {
for field in &mut var.fields {
if strip_lifetimes(&mut field.ty) {
found_lifetime = true;
}
}
}
}
Data::Struct(s) => {
for field in &mut s.fields {
if !matches!(field.vis, syn::Visibility::Public(_)) {
return Err(syn::Error::new_spanned(field, "All fields must be marked `pub`"));
}
if strip_lifetimes(&mut field.ty) {
found_lifetime = true;
}
}
}
}
if !found_lifetime {
return Ok(TokenStream::new());
}
let mut derives = vec![quote! { Debug }];
let mut derive_deserialize = true;
derives.extend(
input
ty_def
.attrs
.iter()
.filter(|attr| attr.path.is_ident("incoming_derive"))
@ -52,86 +69,17 @@ pub fn expand_derive_incoming(input: DeriveInput) -> syn::Result<TokenStream> {
quote! { #ruma_common::serde::_FakeDeriveSerde }
});
let input_attrs =
input.attrs.iter().filter(|attr| filter_input_attrs(attr)).collect::<Vec<_>>();
ty_def.attrs.retain(filter_input_attrs);
clean_generics(&mut ty_def.generics);
let data = match input.data.clone() {
Data::Union(_) => panic!("#[derive(Incoming)] does not support Union types"),
Data::Enum(e) => DataKind::Enum(e.variants.into_iter().collect()),
Data::Struct(s) => match s.fields {
Fields::Named(fs) => {
DataKind::Struct(fs.named.into_iter().collect(), StructKind::Struct)
}
Fields::Unnamed(fs) => {
DataKind::Struct(fs.unnamed.into_iter().collect(), StructKind::Tuple)
}
Fields::Unit => DataKind::Unit,
},
};
let doc = format!("'Incoming' variant of [{}].", &ty_def.ident);
ty_def.ident = format_ident!("Incoming{}", ty_def.ident, span = Span::call_site());
match data {
DataKind::Unit => Ok(TokenStream::new()),
DataKind::Enum(mut vars) => {
let mut found_lifetime = false;
for var in &mut vars {
for field in &mut var.fields {
if strip_lifetimes(&mut field.ty) {
found_lifetime = true;
}
}
}
if !found_lifetime {
return Ok(TokenStream::new());
}
let vis = input.vis;
let doc = format!("'Incoming' variant of [{ty}](enum.{ty}.html).", ty = &input.ident);
let incoming_ident = format_ident!("Incoming{}", input.ident, span = Span::call_site());
let mut gen_copy = input.generics.clone();
let (_, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy);
Ok(quote! {
#[doc = #doc]
#[derive( #( #derives ),* )]
#( #input_attrs )*
#vis enum #incoming_ident #ty_gen { #( #vars, )* }
})
}
DataKind::Struct(mut fields, struct_kind) => {
let mut found_lifetime = false;
for field in &mut fields {
if !matches!(field.vis, syn::Visibility::Public(_)) {
return Err(syn::Error::new_spanned(field, "All fields must be marked `pub`"));
}
if strip_lifetimes(&mut field.ty) {
found_lifetime = true;
}
}
if !found_lifetime {
return Ok(TokenStream::new());
}
let vis = input.vis;
let doc = format!("'Incoming' variant of [{ty}](struct.{ty}.html).", ty = &input.ident);
let incoming_ident = format_ident!("Incoming{}", input.ident, span = Span::call_site());
let mut gen_copy = input.generics.clone();
let (_, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy);
let struct_def = match struct_kind {
StructKind::Struct => quote! { { #(#fields,)* } },
StructKind::Tuple => quote! { ( #(#fields,)* ); },
};
Ok(quote! {
#[doc = #doc]
#[derive( #( #derives ),* )]
#( #input_attrs )*
#vis struct #incoming_ident #ty_gen #struct_def
})
}
}
Ok(quote! {
#[doc = #doc]
#[derive( #( #derives ),* )]
#ty_def
})
}
/// Keep any `cfg`, `cfg_attr`, `serde` or `non_exhaustive` attributes found and pass them to the
@ -144,16 +92,13 @@ fn filter_input_attrs(attr: &Attribute) -> bool {
|| attr.path.is_ident("allow")
}
fn split_for_impl_lifetime_less(generics: &mut Generics) -> (ImplGenerics<'_>, TypeGenerics<'_>) {
fn clean_generics(generics: &mut Generics) {
generics.params = generics
.params
.clone()
.into_iter()
.filter(|param| !matches!(param, GenericParam::Lifetime(_)))
.collect();
let (impl_gen, ty_gen, _) = generics.split_for_impl();
(impl_gen, ty_gen)
}
fn strip_lifetimes(field_type: &mut Type) -> bool {