Handle enums and nested types properly

This commit is contained in:
Devin Ragotzy 2020-07-31 21:16:11 -04:00 committed by Jonas Platte
parent 03288c2140
commit f0fb9a69c0
2 changed files with 148 additions and 60 deletions

View File

@ -1,9 +1,9 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote};
use syn::{
parse_quote, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Fields,
GenericArgument, GenericParam, Generics, ImplGenerics, PathArguments, Type, TypeGenerics,
TypePath, TypeReference, TypeSlice,
parse_quote, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, Field, Fields,
GenericArgument, GenericParam, Generics, ImplGenerics, ParenthesizedGenericArguments,
PathArguments, Type, TypeGenerics, TypePath, TypeReference, TypeSlice, Variant,
};
enum StructKind {
@ -11,6 +11,11 @@ enum StructKind {
Tuple,
}
enum DataKind {
Struct(Vec<Field>, StructKind),
Enum(Vec<Variant>),
}
pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
let derive_deserialize = if no_deserialize_in_attrs(&input.attrs) {
TokenStream::new()
@ -18,35 +23,82 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
quote!(::ruma_api::exports::serde::Deserialize)
};
let (mut fields, struct_kind): (Vec<_>, _) = match input.data {
Data::Enum(_) | Data::Union(_) => {
panic!("#[derive(Outgoing)] is only supported for structs")
}
let data = match input.data {
Data::Union(_) => panic!("#[derive(Outgoing)] does not support Union types"),
Data::Enum(e) => DataKind::Enum(e.variants.into_iter().collect()),
Data::Struct(s) => match s.fields {
Fields::Named(fs) => (fs.named.into_iter().collect(), StructKind::Struct),
Fields::Unnamed(fs) => (fs.unnamed.into_iter().collect(), StructKind::Tuple),
Fields::Unit => return Ok(impl_outgoing_with_incoming_self(input.ident)),
Fields::Named(fs) => {
DataKind::Struct(fs.named.into_iter().collect(), StructKind::Struct)
}
Fields::Unnamed(fs) => {
DataKind::Struct(fs.unnamed.into_iter().collect(), StructKind::Tuple)
}
Fields::Unit => return Ok(impl_outgoing_with_incoming_self(&input.ident, None, None)),
},
};
match data {
DataKind::Enum(mut vars) => {
let mut any_attribute = false;
for var in &mut vars {
for field in &mut var.fields {
if strip_lifetimes(&mut field.ty) {
any_attribute = true;
}
}
}
let original_ident = &input.ident;
let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl();
if !any_attribute {
return Ok(impl_outgoing_with_incoming_self(
original_ident,
Some(original_impl_gen),
Some(original_ty_gen),
));
}
let vis = input.vis;
let doc = format!("'Incoming' variant of [{ty}](enum.{ty}.html).", ty = &input.ident);
let incoming_ident =
format_ident!("Incoming{}", original_ident, span = Span::call_site());
let mut gen_copy = input.generics.clone();
let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy);
Ok(quote! {
#[doc = #doc]
#[derive(Debug, #derive_deserialize)]
#vis enum #incoming_ident #ty_gen { #( #vars, )* }
impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen {
type Incoming = #incoming_ident #impl_gen;
}
})
}
DataKind::Struct(mut fields, struct_kind) => {
let mut any_attribute = false;
for field in &mut fields {
if strip_lifetimes(&mut field.ty) {
any_attribute = true;
}
}
if !any_attribute {
return Ok(impl_outgoing_with_incoming_self(input.ident));
}
let original_ident = &input.ident;
let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl();
if !any_attribute {
return Ok(impl_outgoing_with_incoming_self(
original_ident,
Some(original_impl_gen),
Some(original_ty_gen),
));
}
let vis = input.vis;
let doc = format!("'Incoming' variant of [{ty}](struct.{ty}.html).", ty = &input.ident);
let incoming_ident = format_ident!("Incoming{}", original_ident, span = Span::call_site());
let incoming_ident =
format_ident!("Incoming{}", original_ident, span = Span::call_site());
let mut gen_copy = input.generics.clone();
let (impl_gen, ty_gen) = split_for_impl_lifetime_less(&mut gen_copy);
@ -64,15 +116,21 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result<TokenStream> {
type Incoming = #incoming_ident #impl_gen;
}
})
}
}
}
fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| attr.path.is_ident("incoming_no_deserialize"))
}
fn impl_outgoing_with_incoming_self(ident: Ident) -> TokenStream {
fn impl_outgoing_with_incoming_self(
ident: &Ident,
impl_gen: Option<ImplGenerics>,
ty_gen: Option<TypeGenerics>,
) -> TokenStream {
quote! {
impl ::ruma_api::Outgoing for #ident {
impl #impl_gen ::ruma_api::Outgoing for #ident #ty_gen {
type Incoming = Self;
}
}
@ -98,13 +156,19 @@ fn strip_lifetimes(field_type: &mut Type) -> bool {
let mut has_lifetimes = false;
for seg in &mut path.segments {
// strip generic lifetimes
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
match &mut seg.arguments {
PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, ..
}) = &mut seg.arguments
{
}) => {
*args = args
.clone()
.into_iter()
.map(|mut ty| {
if let GenericArgument::Type(ty) = &mut ty {
strip_lifetimes(ty);
}
ty
})
.filter(|arg| {
if let GenericArgument::Lifetime(_) = arg {
has_lifetimes = true;
@ -115,6 +179,20 @@ fn strip_lifetimes(field_type: &mut Type) -> bool {
})
.collect();
}
PathArguments::Parenthesized(ParenthesizedGenericArguments {
inputs, ..
}) => {
*inputs = inputs
.clone()
.into_iter()
.map(|mut ty| {
strip_lifetimes(&mut ty);
ty
})
.collect();
}
_ => {}
}
}
if has_lifetimes {

View File

@ -22,4 +22,14 @@ pub struct Request<'a, T> {
pub user_id: &'a UserId,
pub bytes: &'a [u8],
pub recursive: &'a [Thing<'a, T>],
pub option: Option<&'a [u8]>,
}
#[derive(Outgoing)]
#[incoming_no_deserialize]
pub enum EnumThing<'a, T> {
Abc(&'a str),
Stuff(Thing<'a, T>),
Boxy(&'a ::ruma_identifiers::DeviceId),
Other(Option<&'a str>),
}