Handle enums and nested types properly

This commit is contained in:
Devin Ragotzy 2020-07-31 21:16:11 -04:00 committed by Jonas Platte
parent 03288c2140
commit f0fb9a69c0
2 changed files with 148 additions and 60 deletions

View File

@ -1,9 +1,9 @@
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{ use syn::{
parse_quote, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Fields, parse_quote, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Field, Fields,
GenericArgument, GenericParam, Generics, ImplGenerics, PathArguments, Type, TypeGenerics, GenericArgument, GenericParam, Generics, ImplGenerics, ParenthesizedGenericArguments,
TypePath, TypeReference, TypeSlice, PathArguments, Type, TypeGenerics, TypePath, TypeReference, TypeSlice, Variant,
}; };
enum StructKind { enum StructKind {
@ -11,6 +11,11 @@ enum StructKind {
Tuple, Tuple,
} }
enum DataKind {
Struct(Vec<Field>, StructKind),
Enum(Vec<Variant>),
}
pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> { pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
let derive_deserialize = if no_deserialize_in_attrs(&input.attrs) { let derive_deserialize = if no_deserialize_in_attrs(&input.attrs) {
TokenStream::new() TokenStream::new()
@ -18,35 +23,82 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
quote!(::ruma_api::exports::serde::Deserialize) quote!(::ruma_api::exports::serde::Deserialize)
}; };
let (mut fields, struct_kind): (Vec<_>, _) = match input.data { let data = match input.data {
Data::Enum(_) | Data::Union(_) => { Data::Union(_) => panic!("#[derive(Outgoing)] does not support Union types"),
panic!("#[derive(Outgoing)] is only supported for structs") Data::Enum(e) => DataKind::Enum(e.variants.into_iter().collect()),
}
Data::Struct(s) => match s.fields { Data::Struct(s) => match s.fields {
Fields::Named(fs) => (fs.named.into_iter().collect(), StructKind::Struct), Fields::Named(fs) => {
Fields::Unnamed(fs) => (fs.unnamed.into_iter().collect(), StructKind::Tuple), DataKind::Struct(fs.named.into_iter().collect(), StructKind::Struct)
Fields::Unit => return Ok(impl_outgoing_with_incoming_self(input.ident)), }
Fields::Unnamed(fs) => {
DataKind::Struct(fs.unnamed.into_iter().collect(), StructKind::Tuple)
}
Fields::Unit => return Ok(impl_outgoing_with_incoming_self(&input.ident, None, None)),
}, },
}; };
match data {
DataKind::Enum(mut vars) => {
let mut any_attribute = false; let mut any_attribute = false;
for var in &mut vars {
for field in &mut var.fields {
if strip_lifetimes(&mut field.ty) {
any_attribute = true;
}
}
}
let original_ident = &input.ident;
let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl();
if !any_attribute {
return Ok(impl_outgoing_with_incoming_self(
original_ident,
Some(original_impl_gen),
Some(original_ty_gen),
));
}
let vis = input.vis;
let doc = format!("'Incoming' variant of [{ty}](enum.{ty}.html).", ty = &input.ident);
let incoming_ident =
format_ident!("Incoming{}", original_ident, span = Span::call_site());
let mut gen_copy = input.generics.clone();
let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy);
Ok(quote! {
#[doc = #doc]
#[derive(Debug, #derive_deserialize)]
#vis enum #incoming_ident #ty_gen { #( #vars, )* }
impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen {
type Incoming = #incoming_ident #impl_gen;
}
})
}
DataKind::Struct(mut fields, struct_kind) => {
let mut any_attribute = false;
for field in &mut fields { for field in &mut fields {
if strip_lifetimes(&mut field.ty) { if strip_lifetimes(&mut field.ty) {
any_attribute = true; any_attribute = true;
} }
} }
if !any_attribute {
return Ok(impl_outgoing_with_incoming_self(input.ident));
}
let original_ident = &input.ident; let original_ident = &input.ident;
let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl(); let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl();
if !any_attribute {
return Ok(impl_outgoing_with_incoming_self(
original_ident,
Some(original_impl_gen),
Some(original_ty_gen),
));
}
let vis = input.vis; let vis = input.vis;
let doc = format!("'Incoming' variant of [{ty}](struct.{ty}.html).", ty = &input.ident); let doc = format!("'Incoming' variant of [{ty}](struct.{ty}.html).", ty = &input.ident);
let incoming_ident = format_ident!("Incoming{}", original_ident, span = Span::call_site()); let incoming_ident =
format_ident!("Incoming{}", original_ident, span = Span::call_site());
let mut gen_copy = input.generics.clone(); let mut gen_copy = input.generics.clone();
let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy); let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy);
@ -65,14 +117,20 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
} }
}) })
} }
}
}
fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool { fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| attr.path.is_ident("incoming_no_deserialize")) attrs.iter().any(|attr| attr.path.is_ident("incoming_no_deserialize"))
} }
fn impl_outgoing_with_incoming_self(ident: Ident) -> TokenStream { fn impl_outgoing_with_incoming_self(
ident: &Ident,
impl_gen: Option<ImplGenerics>,
ty_gen: Option<TypeGenerics>,
) -> TokenStream {
quote! { quote! {
impl ::ruma_api::Outgoing for #ident { impl #impl_gen ::ruma_api::Outgoing for #ident #ty_gen {
type Incoming = Self; type Incoming = Self;
} }
} }
@ -98,13 +156,19 @@ fn strip_lifetimes(field_type: &mut Type) -> bool {
let mut has_lifetimes = false; let mut has_lifetimes = false;
for seg in &mut path.segments { for seg in &mut path.segments {
// strip generic lifetimes // strip generic lifetimes
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { match &mut seg.arguments {
PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, .. args, ..
}) = &mut seg.arguments }) => {
{
*args = args *args = args
.clone() .clone()
.into_iter() .into_iter()
.map(|mut ty| {
if let GenericArgument::Type(ty) = &mut ty {
strip_lifetimes(ty);
}
ty
})
.filter(|arg| { .filter(|arg| {
if let GenericArgument::Lifetime(_) = arg { if let GenericArgument::Lifetime(_) = arg {
has_lifetimes = true; has_lifetimes = true;
@ -115,6 +179,20 @@ fn strip_lifetimes(field_type: &mut Type) -> bool {
}) })
.collect(); .collect();
} }
PathArguments::Parenthesized(ParenthesizedGenericArguments {
inputs, ..
}) => {
*inputs = inputs
.clone()
.into_iter()
.map(|mut ty| {
strip_lifetimes(&mut ty);
ty
})
.collect();
}
_ => {}
}
} }
if has_lifetimes { if has_lifetimes {

View File

@ -22,4 +22,14 @@ pub struct Request<'a, T> {
pub user_id: &'a UserId, pub user_id: &'a UserId,
pub bytes: &'a [u8], pub bytes: &'a [u8],
pub recursive: &'a [Thing<'a, T>], pub recursive: &'a [Thing<'a, T>],
pub option: Option<&'a [u8]>,
}
#[derive(Outgoing)]
#[incoming_no_deserialize]
pub enum EnumThing<'a, T> {
Abc(&'a str),
Stuff(Thing<'a, T>),
Boxy(&'a ::ruma_identifiers::DeviceId),
Other(Option<&'a str>),
} }