From 1c0dab5a479737bc2ae83ff710321c191dd502f7 Mon Sep 17 00:00:00 2001 From: Akshay Date: Fri, 5 Feb 2021 17:17:43 +0530 Subject: [PATCH] api: Allow cfg attributes on rate_limited and authentication metadata fields --- ruma-api-macros/src/api.rs | 98 ++++++++++++++++++++--------- ruma-api-macros/src/api/metadata.rs | 56 +++++++++++++---- ruma-api-macros/src/util.rs | 6 +- ruma-api/tests/ruma_api_macros.rs | 8 +++ 4 files changed, 124 insertions(+), 44 deletions(-) diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index dea6d079..69fc1d51 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -83,8 +83,32 @@ pub fn expand_all(api: Api) -> syn::Result { // with only the literal's value from here on. let name = &api.metadata.name.value(); let path = &api.metadata.path; - let rate_limited = &api.metadata.rate_limited; - let authentication = &api.metadata.authentication; + let rate_limited: TokenStream = api + .metadata + .rate_limited + .iter() + .map(|r| { + let attrs = &r.attrs; + let value = &r.value; + quote! { + #( #attrs )* + rate_limited: #value, + } + }) + .collect(); + let authentication: TokenStream = api + .metadata + .authentication + .iter() + .map(|r| { + let attrs = &r.attrs; + let value = &r.value; + quote! { + #( #attrs )* + authentication: #ruma_api::AuthScheme::#value, + } + }) + .collect(); let request_type = &api.request; let response_type = &api.response; @@ -118,20 +142,24 @@ pub fn expand_all(api: Api) -> syn::Result { }; let mut header_kvs = api.request.append_header_kvs(); - if authentication == "AccessToken" { - header_kvs.extend(quote! { - req_headers.insert( - #http::header::AUTHORIZATION, - #http::header::HeaderValue::from_str( - &::std::format!( - "Bearer {}", - access_token.ok_or( - #ruma_api::error::IntoHttpError::NeedsAuthentication - )? - ) - )? - ); - }); + for auth in &api.metadata.authentication { + if auth.value == "AccessToken" { + let attrs = &auth.attrs; + header_kvs.extend(quote! { + #( #attrs )* + req_headers.insert( + #http::header::AUTHORIZATION, + #http::header::HeaderValue::from_str( + &::std::format!( + "Bearer {}", + access_token.ok_or( + #ruma_api::error::IntoHttpError::NeedsAuthentication + )? + ) + )? + ); + }); + } } let extract_request_headers = if api.request.has_header_fields() { @@ -227,19 +255,29 @@ pub fn expand_all(api: Api) -> syn::Result { let error = &api.error_ty; let request_lifetimes = api.request.combine_lifetimes(); - let non_auth_endpoint_impls = if authentication != "None" { - TokenStream::new() - } else { - quote! { - #[automatically_derived] - impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest - for Request #request_lifetimes - {} + let non_auth_endpoint_impls: TokenStream = api + .metadata + .authentication + .iter() + .map(|auth| { + if auth.value != "None" { + TokenStream::new() + } else { + let attrs = &auth.attrs; + quote! { + #( #attrs )* + #[automatically_derived] + impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest + for Request #request_lifetimes + {} - #[automatically_derived] - impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} - } - }; + #( #attrs )* + #[automatically_derived] + impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} + } + } + }) + .collect(); Ok(quote! { #[doc = #request_doc] @@ -301,8 +339,8 @@ pub fn expand_all(api: Api) -> syn::Result { method: #http::Method::#method, name: #name, path: #path, - rate_limited: #rate_limited, - authentication: #ruma_api::AuthScheme::#authentication, + #rate_limited + #authentication }; #[automatically_derived] diff --git a/ruma-api-macros/src/api/metadata.rs b/ruma-api-macros/src/api/metadata.rs index a6df2bc8..acc1435f 100644 --- a/ruma-api-macros/src/api/metadata.rs +++ b/ruma-api-macros/src/api/metadata.rs @@ -4,7 +4,7 @@ use quote::ToTokens; use syn::{ braced, parse::{Parse, ParseStream}, - Ident, LitBool, LitStr, Token, + Attribute, Ident, LitBool, LitStr, Token, }; use crate::util; @@ -19,6 +19,15 @@ mod kw { syn::custom_keyword!(authentication); } +/// A field of Metadata that contains attribute macros +pub struct MetadataField { + /// attributes over the field + pub attrs: Vec, + + /// the field itself + pub value: T, +} + /// The result of processing the `metadata` section of the macro. pub struct Metadata { /// The description field. @@ -34,10 +43,10 @@ pub struct Metadata { pub path: LitStr, /// The rate_limited field. - pub rate_limited: LitBool, + pub rate_limited: Vec>, /// The authentication field. - pub authentication: Ident, + pub authentication: Vec>, } fn set_field(field: &mut Option, value: T) -> syn::Result<()> { @@ -69,8 +78,8 @@ impl Parse for Metadata { let mut method = None; let mut name = None; let mut path = None; - let mut rate_limited = None; - let mut authentication = None; + let mut rate_limited = vec![]; + let mut authentication = vec![]; for field_value in field_values { match field_value { @@ -78,8 +87,12 @@ impl Parse for Metadata { FieldValue::Method(m) => set_field(&mut method, m)?, FieldValue::Name(n) => set_field(&mut name, n)?, FieldValue::Path(p) => set_field(&mut path, p)?, - FieldValue::RateLimited(rl) => set_field(&mut rate_limited, rl)?, - FieldValue::Authentication(a) => set_field(&mut authentication, a)?, + FieldValue::RateLimited(value, attrs) => { + rate_limited.push(MetadataField { value, attrs }) + } + FieldValue::Authentication(value, attrs) => { + authentication.push(MetadataField { value, attrs }) + } } } @@ -91,8 +104,16 @@ impl Parse for Metadata { method: method.ok_or_else(|| missing_field("method"))?, name: name.ok_or_else(|| missing_field("name"))?, path: path.ok_or_else(|| missing_field("path"))?, - rate_limited: rate_limited.ok_or_else(|| missing_field("rate_limited"))?, - authentication: authentication.ok_or_else(|| missing_field("authentication"))?, + rate_limited: if rate_limited.is_empty() { + return Err(missing_field("rate_limited")); + } else { + rate_limited + }, + authentication: if authentication.is_empty() { + return Err(missing_field("authentication")); + } else { + authentication + }, }) } } @@ -139,12 +160,21 @@ enum FieldValue { Method(Ident), Name(LitStr), Path(LitStr), - RateLimited(LitBool), - Authentication(Ident), + RateLimited(LitBool, Vec), + Authentication(Ident, Vec), } impl Parse for FieldValue { fn parse(input: ParseStream) -> syn::Result { + let attrs: Vec = input.call(Attribute::parse_outer)?; + for attr in attrs.iter() { + if !util::is_cfg_attribute(attr) { + return Err(syn::Error::new_spanned( + &attr, + "only `cfg` attributes may appear here", + )); + } + } let field: Field = input.parse()?; let _: Token![:] = input.parse()?; @@ -164,8 +194,8 @@ impl Parse for FieldValue { Self::Path(path) } - Field::RateLimited => Self::RateLimited(input.parse()?), - Field::Authentication => Self::Authentication(input.parse()?), + Field::RateLimited => Self::RateLimited(input.parse()?, attrs), + Field::Authentication => Self::Authentication(input.parse()?, attrs), }) } } diff --git a/ruma-api-macros/src/util.rs b/ruma-api-macros/src/util.rs index 928ece93..93dca2cb 100644 --- a/ruma-api-macros/src/util.rs +++ b/ruma-api-macros/src/util.rs @@ -5,7 +5,7 @@ use proc_macro_crate::crate_name; use quote::quote; use std::collections::BTreeSet; use syn::{ - AngleBracketedGenericArguments, GenericArgument, Ident, Lifetime, + AngleBracketedGenericArguments, AttrStyle, Attribute, GenericArgument, Ident, Lifetime, ParenthesizedGenericArguments, PathArguments, Type, TypeArray, TypeBareFn, TypeGroup, TypeParen, TypePath, TypePtr, TypeReference, TypeSlice, TypeTuple, }; @@ -402,3 +402,7 @@ pub fn import_ruma_api() -> TokenStream { quote! { ::ruma_api } } } + +pub(crate) fn is_cfg_attribute(attr: &Attribute) -> bool { + attr.style == AttrStyle::Outer && attr.path.is_ident("cfg") +} diff --git a/ruma-api/tests/ruma_api_macros.rs b/ruma-api/tests/ruma_api_macros.rs index f8c2b20b..d8966979 100644 --- a/ruma-api/tests/ruma_api_macros.rs +++ b/ruma-api/tests/ruma_api_macros.rs @@ -9,7 +9,15 @@ pub mod some_endpoint { method: POST, // An `http::Method` constant. No imports required. name: "some_endpoint", path: "/_matrix/some/endpoint/:baz", + + #[cfg(all())] + rate_limited: true, + #[cfg(any())] rate_limited: false, + + #[cfg(all())] + authentication: AccessToken, + #[cfg(any())] authentication: None, }