diff --git a/ruma-common-macros/src/attr.rs b/ruma-common-macros/src/attr.rs new file mode 100644 index 00000000..c2d28fd8 --- /dev/null +++ b/ruma-common-macros/src/attr.rs @@ -0,0 +1,48 @@ +use syn::{ + parse::{Parse, ParseStream}, + LitStr, Token, +}; + +use crate::case::RenameRule; + +mod kw { + syn::custom_keyword!(rename); + syn::custom_keyword!(rename_all); +} + +pub struct RenameAttr(LitStr); + +impl RenameAttr { + pub fn into_inner(self) -> LitStr { + self.0 + } +} + +impl Parse for RenameAttr { + fn parse(input: ParseStream) -> syn::Result { + let _: kw::rename = input.parse()?; + let _: Token![=] = input.parse()?; + Ok(Self(input.parse()?)) + } +} + +pub struct RenameAllAttr(RenameRule); + +impl RenameAllAttr { + pub fn into_inner(self) -> RenameRule { + self.0 + } +} + +impl Parse for RenameAllAttr { + fn parse(input: ParseStream) -> syn::Result { + let _: kw::rename_all = input.parse()?; + let _: Token![=] = input.parse()?; + let s: LitStr = input.parse()?; + Ok(Self( + s.value() + .parse() + .map_err(|_| syn::Error::new_spanned(s, "invalid value for rename_all"))?, + )) + } +} diff --git a/ruma-common-macros/src/deserialize_from_cow_str.rs b/ruma-common-macros/src/deserialize_from_cow_str.rs new file mode 100644 index 00000000..68c27101 --- /dev/null +++ b/ruma-common-macros/src/deserialize_from_cow_str.rs @@ -0,0 +1,22 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::quote; + +use crate::util::import_ruma_common; + +pub fn expand_deserialize_from_cow_str(ident: &Ident) -> syn::Result { + let ruma_common = import_ruma_common(); + + Ok(quote! { + impl<'de> #ruma_common::exports::serde::de::Deserialize<'de> for #ident { + fn deserialize(deserializer: D) -> ::std::result::Result + where + D: #ruma_common::exports::serde::de::Deserializer<'de>, + { + type CowStr<'a> = ::std::borrow::Cow<'a, ::std::primitive::str>; + + let cow = #ruma_common::exports::ruma_serde::deserialize_cow_str(deserializer)?; + Ok(::std::convert::From::>::from(cow)) + } + } + }) +} diff --git a/ruma-common-macros/src/display_as_ref_str.rs b/ruma-common-macros/src/display_as_ref_str.rs new file mode 100644 index 00000000..67a1a9e0 --- /dev/null +++ b/ruma-common-macros/src/display_as_ref_str.rs @@ -0,0 +1,12 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::quote; + +pub fn expand_display_as_ref_str(ident: &Ident) -> syn::Result { + Ok(quote! { + impl ::std::fmt::Display for #ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + f.write_str(>::as_ref(self)) + } + } + }) +} diff --git a/ruma-common-macros/src/enum_as_ref_str.rs b/ruma-common-macros/src/enum_as_ref_str.rs new file mode 100644 index 00000000..4a2bbf97 --- /dev/null +++ b/ruma-common-macros/src/enum_as_ref_str.rs @@ -0,0 +1,58 @@ +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{Fields, FieldsNamed, FieldsUnnamed, ItemEnum}; + +use crate::util::{get_rename, get_rename_rule}; + +pub fn expand_enum_as_ref_str(input: &ItemEnum) -> syn::Result { + let enum_name = &input.ident; + let rename_rule = get_rename_rule(&input)?; + let branches: Vec<_> = input + .variants + .iter() + .map(|v| { + let variant_name = &v.ident; + let (field_capture, variant_str) = match (get_rename(v)?, &v.fields) { + (None, Fields::Unit) => ( + None, + rename_rule.apply_to_variant(&variant_name.to_string()).into_token_stream(), + ), + (Some(rename), Fields::Unit) => (None, rename.into_token_stream()), + (None, Fields::Named(FieldsNamed { named: fields, .. })) + | (None, Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. })) => { + if fields.len() != 1 { + return Err(syn::Error::new_spanned( + v, + "multiple data fields are not supported", + )); + } + + let capture = match &fields[0].ident { + Some(name) => quote! { { #name: inner } }, + None => quote! { (inner) }, + }; + + (Some(capture), quote! { inner }) + } + (Some(_), _) => { + return Err(syn::Error::new_spanned( + v, + "ruma_enum(rename) is only allowed on unit variants", + )); + } + }; + + Ok(quote! { + #enum_name :: #variant_name #field_capture => #variant_str + }) + }) + .collect::>()?; + + Ok(quote! { + impl ::std::convert::AsRef<::std::primitive::str> for #enum_name { + fn as_ref(&self) -> &::std::primitive::str { + match self { #(#branches),* } + } + } + }) +} diff --git a/ruma-common-macros/src/enum_from_string.rs b/ruma-common-macros/src/enum_from_string.rs new file mode 100644 index 00000000..c029e633 --- /dev/null +++ b/ruma-common-macros/src/enum_from_string.rs @@ -0,0 +1,84 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{Fields, FieldsNamed, FieldsUnnamed, ItemEnum}; + +use crate::util::{get_rename, get_rename_rule}; + +pub fn expand_enum_from_string(input: &ItemEnum) -> syn::Result { + let enum_name = &input.ident; + let rename_rule = get_rename_rule(&input)?; + let mut fallback = None; + let mut fallback_ty = None; + let branches: Vec<_> = input + .variants + .iter() + .map(|v| { + let variant_name = &v.ident; + let variant_str = match (get_rename(v)?, &v.fields) { + (None, Fields::Unit) => Some( + rename_rule.apply_to_variant(&variant_name.to_string()).into_token_stream(), + ), + (Some(rename), Fields::Unit) => Some(rename.into_token_stream()), + (None, Fields::Named(FieldsNamed { named: fields, .. })) + | (None, Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. })) => { + if fields.len() != 1 { + return Err(syn::Error::new_spanned( + v, + "multiple data fields are not supported", + )); + } + + if fallback.is_some() { + return Err(syn::Error::new_spanned( + v, + "multiple data-carrying variants are not supported", + )); + } + + let member = match &fields[0].ident { + Some(name) => name.into_token_stream(), + None => quote! { 0 }, + }; + + fallback = Some(quote! { + _ => #enum_name :: #variant_name { #member: s.into() } + }); + + fallback_ty = Some(&fields[0].ty); + + None + } + (Some(_), _) => { + return Err(syn::Error::new_spanned( + v, + "ruma_enum(rename) is only allowed on unit variants", + )); + } + }; + + Ok(variant_str.map(|s| quote! { #s => #enum_name :: #variant_name })) + }) + .collect::>()?; + + // Remove `None` from the iterator to avoid emitting consecutive commas in repetition + let branches = branches.iter().flatten(); + + if fallback.is_none() { + return Err(syn::Error::new(Span::call_site(), "required fallback variant not found")); + } + + Ok(quote! { + impl ::std::convert::From for #enum_name + where + T: ::std::convert::AsRef<::std::primitive::str> + + ::std::convert::Into<#fallback_ty> + { + fn from(s: T) -> Self { + match s.as_ref() { + #( #branches, )* + #fallback + } + } + } + }) +} diff --git a/ruma-common-macros/src/lib.rs b/ruma-common-macros/src/lib.rs index 408d89d3..7c6cdde8 100644 --- a/ruma-common-macros/src/lib.rs +++ b/ruma-common-macros/src/lib.rs @@ -1,9 +1,22 @@ use proc_macro::TokenStream; -use syn::{parse_macro_input, DeriveInput}; +use quote::quote; +use syn::{parse_macro_input, DeriveInput, ItemEnum}; +use deserialize_from_cow_str::expand_deserialize_from_cow_str; +use display_as_ref_str::expand_display_as_ref_str; +use enum_as_ref_str::expand_enum_as_ref_str; +use enum_from_string::expand_enum_from_string; use outgoing::expand_derive_outgoing; +use serialize_as_ref_str::expand_serialize_as_ref_str; +mod attr; +mod case; +mod deserialize_from_cow_str; +mod display_as_ref_str; +mod enum_as_ref_str; +mod enum_from_string; mod outgoing; +mod serialize_as_ref_str; mod util; /// Derive the `Outgoing` trait, possibly generating an 'Incoming' version of the struct this @@ -53,3 +66,62 @@ pub fn derive_outgoing(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_derive_outgoing(input).unwrap_or_else(|err| err.to_compile_error()).into() } + +#[proc_macro_derive(AsRefStr, attributes(ruma_enum))] +pub fn derive_enum_as_ref_str(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemEnum); + expand_enum_as_ref_str(&input).unwrap_or_else(|err| err.to_compile_error()).into() +} + +#[proc_macro_derive(FromString, attributes(ruma_enum))] +pub fn derive_enum_from_string(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemEnum); + expand_enum_from_string(&input).unwrap_or_else(|err| err.to_compile_error()).into() +} + +// FIXME: The following macros aren't actually interested in type details beyond name (and possibly +// generics in the future). They probably shouldn't use `DeriveInput`. + +#[proc_macro_derive(DisplayAsRefStr)] +pub fn derive_display_as_ref_str(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_display_as_ref_str(&input.ident).unwrap_or_else(|err| err.to_compile_error()).into() +} + +#[proc_macro_derive(SerializeAsRefStr)] +pub fn derive_serialize_as_ref_str(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_serialize_as_ref_str(&input.ident).unwrap_or_else(|err| err.to_compile_error()).into() +} + +#[proc_macro_derive(DeserializeFromCowStr)] +pub fn derive_deserialize_from_cow_str(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_deserialize_from_cow_str(&input.ident) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// Shorthand for the derives `AsRefStr`, `FromString`, `DisplayAsRefStr`, `SerializeAsRefStr` and +/// `DeserializeFromCowStr`. +#[proc_macro_derive(StringEnum, attributes(ruma_enum))] +pub fn derive_string_enum(input: TokenStream) -> TokenStream { + fn expand_all(input: ItemEnum) -> syn::Result { + let as_ref_str_impl = expand_enum_as_ref_str(&input)?; + let from_string_impl = expand_enum_from_string(&input)?; + let display_impl = expand_display_as_ref_str(&input.ident)?; + let serialize_impl = expand_serialize_as_ref_str(&input.ident)?; + let deserialize_impl = expand_deserialize_from_cow_str(&input.ident)?; + + Ok(quote! { + #as_ref_str_impl + #from_string_impl + #display_impl + #serialize_impl + #deserialize_impl + }) + } + + let input = parse_macro_input!(input as ItemEnum); + expand_all(input).unwrap_or_else(|err| err.to_compile_error()).into() +} diff --git a/ruma-common-macros/src/serialize_as_ref_str.rs b/ruma-common-macros/src/serialize_as_ref_str.rs new file mode 100644 index 00000000..bcaf5620 --- /dev/null +++ b/ruma-common-macros/src/serialize_as_ref_str.rs @@ -0,0 +1,20 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::quote; + +use crate::util::import_ruma_common; + +pub fn expand_serialize_as_ref_str(ident: &Ident) -> syn::Result { + let ruma_common = import_ruma_common(); + + Ok(quote! { + impl #ruma_common::exports::serde::ser::Serialize for #ident { + fn serialize(&self, serializer: S) -> ::std::result::Result + where + S: #ruma_common::exports::serde::ser::Serializer, + { + >::as_ref(self) + .serialize(serializer) + } + } + }) +} diff --git a/ruma-common-macros/src/util.rs b/ruma-common-macros/src/util.rs index 511db937..66013737 100644 --- a/ruma-common-macros/src/util.rs +++ b/ruma-common-macros/src/util.rs @@ -1,6 +1,12 @@ use proc_macro2::{Ident, Span, TokenStream}; use proc_macro_crate::crate_name; use quote::quote; +use syn::{ItemEnum, LitStr, Variant}; + +use crate::{ + attr::{RenameAllAttr, RenameAttr}, + case::RenameRule, +}; pub fn import_ruma_common() -> TokenStream { if let Ok(possibly_renamed) = crate_name("ruma-common") { @@ -13,3 +19,35 @@ pub fn import_ruma_common() -> TokenStream { quote! { ::ruma_common } } } + +pub fn get_rename_rule(input: &ItemEnum) -> syn::Result { + let rules: Vec<_> = input + .attrs + .iter() + .filter(|attr| attr.path.is_ident("ruma_enum")) + .map(|attr| attr.parse_args::().map(RenameAllAttr::into_inner)) + .collect::>()?; + + match rules.len() { + 0 => Ok(RenameRule::None), + 1 => Ok(rules[0]), + _ => Err(syn::Error::new( + Span::call_site(), + "found multiple ruma_enum(rename_all) attributes", + )), + } +} + +pub fn get_rename(input: &Variant) -> syn::Result> { + let renames: Vec<_> = input + .attrs + .iter() + .filter(|attr| attr.path.is_ident("ruma_enum")) + .map(|attr| attr.parse_args::().map(RenameAttr::into_inner)) + .collect::>()?; + + match renames.len() { + 0 | 1 => Ok(renames.into_iter().next()), + _ => Err(syn::Error::new(Span::call_site(), "found multiple ruma_enum(rename) attributes")), + } +} diff --git a/ruma-common/src/lib.rs b/ruma-common/src/lib.rs index 6318bd01..8b935416 100644 --- a/ruma-common/src/lib.rs +++ b/ruma-common/src/lib.rs @@ -9,7 +9,7 @@ pub mod push; mod raw; pub mod thirdparty; -pub use ruma_common_macros::Outgoing; +pub use ruma_common_macros::*; pub use self::raw::Raw; @@ -33,5 +33,6 @@ extern crate self as ruma_common; /// It is not considered part of ruma-common's public API. #[doc(hidden)] pub mod exports { + pub use ruma_serde; pub use serde; } diff --git a/ruma-common/tests/enum_derive.rs b/ruma-common/tests/enum_derive.rs new file mode 100644 index 00000000..3e339b5c --- /dev/null +++ b/ruma-common/tests/enum_derive.rs @@ -0,0 +1,57 @@ +use ruma_common::StringEnum; +use serde_json::{from_value as from_json_value, json, to_value as to_json_value}; + +#[derive(Debug, PartialEq, StringEnum)] +#[ruma_enum(rename_all = "snake_case")] +enum MyEnum { + First, + Second, + #[ruma_enum(rename = "m.third")] + Third, + HelloWorld, + _Custom(String), +} + +#[test] +fn as_ref_str() { + assert_eq!(MyEnum::First.as_ref(), "first"); + assert_eq!(MyEnum::Second.as_ref(), "second"); + assert_eq!(MyEnum::Third.as_ref(), "m.third"); + assert_eq!(MyEnum::HelloWorld.as_ref(), "hello_world"); + assert_eq!(MyEnum::_Custom("HelloWorld".into()).as_ref(), "HelloWorld"); +} + +#[test] +fn display() { + assert_eq!(MyEnum::First.to_string(), "first"); + assert_eq!(MyEnum::Second.to_string(), "second"); + assert_eq!(MyEnum::Third.to_string(), "m.third"); + assert_eq!(MyEnum::HelloWorld.to_string(), "hello_world"); + assert_eq!(MyEnum::_Custom("HelloWorld".into()).to_string(), "HelloWorld"); +} + +#[test] +fn from_string() { + assert_eq!(MyEnum::from("first"), MyEnum::First); + assert_eq!(MyEnum::from("second"), MyEnum::Second); + assert_eq!(MyEnum::from("m.third"), MyEnum::Third); + assert_eq!(MyEnum::from("hello_world"), MyEnum::HelloWorld); + assert_eq!(MyEnum::from("HelloWorld"), MyEnum::_Custom("HelloWorld".into())); +} + +#[test] +fn serialize() { + assert_eq!(to_json_value(MyEnum::First).unwrap(), json!("first")); + assert_eq!(to_json_value(MyEnum::HelloWorld).unwrap(), json!("hello_world")); + assert_eq!(to_json_value(MyEnum::_Custom("\\\n\\".into())).unwrap(), json!("\\\n\\")); +} + +#[test] +fn deserialize() { + assert_eq!(from_json_value::(json!("first")).unwrap(), MyEnum::First); + assert_eq!(from_json_value::(json!("hello_world")).unwrap(), MyEnum::HelloWorld); + assert_eq!( + from_json_value::(json!("\\\n\\")).unwrap(), + MyEnum::_Custom("\\\n\\".into()) + ); +}