use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse::{Parse, ParseStream}, parse_quote, punctuated::Punctuated, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, GenericArgument, GenericParam, Generics, Ident, ItemType, ParenthesizedGenericArguments, Path, PathArguments, Token, Type, TypePath, TypeReference, TypeSlice, }; use crate::util::import_ruma_common; pub fn expand_derive_incoming(mut ty_def: DeriveInput) -> syn::Result { let ruma_common = import_ruma_common(); let mut found_lifetime = false; match &mut ty_def.data { Data::Union(_) => panic!("#[derive(Incoming)] does not support Union types"), Data::Enum(e) => { for var in &mut e.variants { for field in &mut var.fields { if strip_lifetimes(&mut field.ty, &ruma_common) { found_lifetime = true; } } } } Data::Struct(s) => { for field in &mut s.fields { if !matches!(field.vis, syn::Visibility::Public(_)) { return Err(syn::Error::new_spanned(field, "All fields must be marked `pub`")); } if strip_lifetimes(&mut field.ty, &ruma_common) { found_lifetime = true; } } } } let ident = format_ident!("Incoming{}", ty_def.ident, span = Span::call_site()); if !found_lifetime { let doc = format!( "Convenience type alias for [{}], for consistency with other [{}] types.", &ty_def.ident, ident ); let mut type_alias: ItemType = parse_quote! { type X = Y; }; type_alias.vis = ty_def.vis.clone(); type_alias.ident = ident; type_alias.generics = ty_def.generics.clone(); type_alias.ty = Box::new(TypePath { qself: None, path: ty_def.ident.clone().into() }.into()); return Ok(quote! { #[doc = #doc] #type_alias }); } let meta: Vec = ty_def .attrs .iter() .filter(|attr| attr.path.is_ident("incoming_derive")) .map(|attr| attr.parse_args()) .collect::>()?; let mut derive_debug = true; let mut derive_deserialize = true; let mut derives: Vec<_> = meta .into_iter() .flat_map(|m| m.derive_macs) .filter_map(|derive_mac| match derive_mac { DeriveMac::Regular(id) => Some(quote! { #id }), DeriveMac::NegativeDebug => { derive_debug = false; None } DeriveMac::NegativeDeserialize => { derive_deserialize = false; None } }) .collect(); if derive_debug { derives.push(quote! { ::std::fmt::Debug }); } derives.push(if derive_deserialize { quote! { #ruma_common::exports::serde::Deserialize } } else { quote! { #ruma_common::serde::_FakeDeriveSerde } }); ty_def.attrs.retain(filter_input_attrs); clean_generics(&mut ty_def.generics); let doc = format!("'Incoming' variant of [{}].", &ty_def.ident); ty_def.ident = ident; Ok(quote! { #[doc = #doc] #[derive( #( #derives ),* )] #ty_def }) } /// Keep any `cfg`, `cfg_attr`, `serde` or `non_exhaustive` attributes found and pass them to the /// Incoming variant. fn filter_input_attrs(attr: &Attribute) -> bool { attr.path.is_ident("cfg") || attr.path.is_ident("cfg_attr") || attr.path.is_ident("serde") || attr.path.is_ident("non_exhaustive") || attr.path.is_ident("allow") } fn clean_generics(generics: &mut Generics) { generics.params = generics .params .clone() .into_iter() .filter(|param| !matches!(param, GenericParam::Lifetime(_))) .collect(); } fn strip_lifetimes(field_type: &mut Type, ruma_common: &TokenStream) -> bool { match field_type { // T<'a> -> IncomingT // The IncomingT has to be declared by the user of this derive macro. Type::Path(TypePath { path, .. }) => { let mut has_lifetimes = false; let mut is_lifetime_generic = false; for seg in &mut path.segments { // strip generic lifetimes match &mut seg.arguments { PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => { *args = args .clone() .into_iter() .map(|mut ty| { if let GenericArgument::Type(ty) = &mut ty { if strip_lifetimes(ty, ruma_common) { has_lifetimes = true; }; } ty }) .filter(|arg| { if let GenericArgument::Lifetime(_) = arg { is_lifetime_generic = true; false } else { true } }) .collect(); } PathArguments::Parenthesized(ParenthesizedGenericArguments { inputs, .. }) => { *inputs = inputs .clone() .into_iter() .map(|mut ty| { if strip_lifetimes(&mut ty, ruma_common) { has_lifetimes = true; }; ty }) .collect(); } _ => {} } } // If a type has a generic lifetime parameter there must be an `Incoming` variant of // that type. if is_lifetime_generic { if let Some(name) = path.segments.last_mut() { let incoming_ty_ident = format_ident!("Incoming{}", name.ident); name.ident = incoming_ty_ident; } } has_lifetimes || is_lifetime_generic } Type::Reference(TypeReference { elem, .. }) => { let special_replacement = match &mut **elem { Type::Path(ty) => { let path = &ty.path; let last_seg = path.segments.last().unwrap(); if last_seg.ident == "str" { // &str -> String Some(parse_quote! { ::std::string::String }) } else if last_seg.ident == "RawJsonValue" { Some(parse_quote! { ::std::boxed::Box<#path> }) } else if last_seg.ident == "ClientSecret" || last_seg.ident == "DeviceId" || last_seg.ident == "DeviceKeyId" || last_seg.ident == "DeviceSigningKeyId" || last_seg.ident == "EventId" || last_seg.ident == "KeyId" || last_seg.ident == "MxcUri" || last_seg.ident == "ServerName" || last_seg.ident == "SessionId" || last_seg.ident == "RoomAliasId" || last_seg.ident == "RoomId" || last_seg.ident == "RoomOrAliasId" || last_seg.ident == "RoomName" || last_seg.ident == "ServerSigningKeyId" || last_seg.ident == "SigningKeyId" || last_seg.ident == "TransactionId" || last_seg.ident == "UserId" { let ident = format_ident!("Owned{}", last_seg.ident); Some(parse_quote! { #ruma_common::#ident }) } else { None } } // &[T] -> Vec Type::Slice(TypeSlice { elem, .. }) => { // Recursively strip the lifetimes of the slice's elements. strip_lifetimes(&mut *elem, ruma_common); Some(parse_quote! { Vec<#elem> }) } _ => None, }; *field_type = match special_replacement { Some(ty) => ty, None => { // Strip lifetimes of `elem`. strip_lifetimes(elem, ruma_common); // Replace reference with `elem`. (**elem).clone() } }; true } Type::Tuple(syn::TypeTuple { elems, .. }) => { let mut has_lifetime = false; for elem in elems { if strip_lifetimes(elem, ruma_common) { has_lifetime = true; } } has_lifetime } _ => false, } } pub struct Meta { derive_macs: Vec, } impl Parse for Meta { fn parse(input: ParseStream<'_>) -> syn::Result { Ok(Self { derive_macs: Punctuated::<_, Token![,]>::parse_terminated(input)?.into_iter().collect(), }) } } pub enum DeriveMac { Regular(Path), NegativeDebug, NegativeDeserialize, } impl Parse for DeriveMac { fn parse(input: ParseStream<'_>) -> syn::Result { if input.peek(Token![!]) { let _: Token![!] = input.parse()?; let mac: Ident = input.parse()?; if mac == "Debug" { Ok(Self::NegativeDebug) } else if mac == "Deserialize" { Ok(Self::NegativeDeserialize) } else { Err(syn::Error::new_spanned( mac, "Negative incoming_derive can only be used for Debug and Deserialize", )) } } else { let mac = input.parse()?; Ok(Self::Regular(mac)) } } }