api: Allow cfg attributes on rate_limited and authentication metadata fields

This commit is contained in:
Akshay 2021-02-05 17:17:43 +05:30 committed by GitHub
parent d8c5c326e6
commit 1c0dab5a47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 124 additions and 44 deletions

View File

@ -83,8 +83,32 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
// with only the literal's value from here on. // with only the literal's value from here on.
let name = &api.metadata.name.value(); let name = &api.metadata.name.value();
let path = &api.metadata.path; let path = &api.metadata.path;
let rate_limited = &api.metadata.rate_limited; let rate_limited: TokenStream = api
let authentication = &api.metadata.authentication; .metadata
.rate_limited
.iter()
.map(|r| {
let attrs = &r.attrs;
let value = &r.value;
quote! {
#( #attrs )*
rate_limited: #value,
}
})
.collect();
let authentication: TokenStream = api
.metadata
.authentication
.iter()
.map(|r| {
let attrs = &r.attrs;
let value = &r.value;
quote! {
#( #attrs )*
authentication: #ruma_api::AuthScheme::#value,
}
})
.collect();
let request_type = &api.request; let request_type = &api.request;
let response_type = &api.response; let response_type = &api.response;
@ -118,20 +142,24 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
}; };
let mut header_kvs = api.request.append_header_kvs(); let mut header_kvs = api.request.append_header_kvs();
if authentication == "AccessToken" { for auth in &api.metadata.authentication {
header_kvs.extend(quote! { if auth.value == "AccessToken" {
req_headers.insert( let attrs = &auth.attrs;
#http::header::AUTHORIZATION, header_kvs.extend(quote! {
#http::header::HeaderValue::from_str( #( #attrs )*
&::std::format!( req_headers.insert(
"Bearer {}", #http::header::AUTHORIZATION,
access_token.ok_or( #http::header::HeaderValue::from_str(
#ruma_api::error::IntoHttpError::NeedsAuthentication &::std::format!(
)? "Bearer {}",
) access_token.ok_or(
)? #ruma_api::error::IntoHttpError::NeedsAuthentication
); )?
}); )
)?
);
});
}
} }
let extract_request_headers = if api.request.has_header_fields() { let extract_request_headers = if api.request.has_header_fields() {
@ -227,19 +255,29 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
let error = &api.error_ty; let error = &api.error_ty;
let request_lifetimes = api.request.combine_lifetimes(); let request_lifetimes = api.request.combine_lifetimes();
let non_auth_endpoint_impls = if authentication != "None" { let non_auth_endpoint_impls: TokenStream = api
TokenStream::new() .metadata
} else { .authentication
quote! { .iter()
#[automatically_derived] .map(|auth| {
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest if auth.value != "None" {
for Request #request_lifetimes TokenStream::new()
{} } else {
let attrs = &auth.attrs;
quote! {
#( #attrs )*
#[automatically_derived]
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest
for Request #request_lifetimes
{}
#[automatically_derived] #( #attrs )*
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} #[automatically_derived]
} impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
}; }
}
})
.collect();
Ok(quote! { Ok(quote! {
#[doc = #request_doc] #[doc = #request_doc]
@ -301,8 +339,8 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
method: #http::Method::#method, method: #http::Method::#method,
name: #name, name: #name,
path: #path, path: #path,
rate_limited: #rate_limited, #rate_limited
authentication: #ruma_api::AuthScheme::#authentication, #authentication
}; };
#[automatically_derived] #[automatically_derived]

View File

@ -4,7 +4,7 @@ use quote::ToTokens;
use syn::{ use syn::{
braced, braced,
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
Ident, LitBool, LitStr, Token, Attribute, Ident, LitBool, LitStr, Token,
}; };
use crate::util; use crate::util;
@ -19,6 +19,15 @@ mod kw {
syn::custom_keyword!(authentication); syn::custom_keyword!(authentication);
} }
/// A field of Metadata that contains attribute macros
pub struct MetadataField<T> {
/// attributes over the field
pub attrs: Vec<Attribute>,
/// the field itself
pub value: T,
}
/// The result of processing the `metadata` section of the macro. /// The result of processing the `metadata` section of the macro.
pub struct Metadata { pub struct Metadata {
/// The description field. /// The description field.
@ -34,10 +43,10 @@ pub struct Metadata {
pub path: LitStr, pub path: LitStr,
/// The rate_limited field. /// The rate_limited field.
pub rate_limited: LitBool, pub rate_limited: Vec<MetadataField<LitBool>>,
/// The authentication field. /// The authentication field.
pub authentication: Ident, pub authentication: Vec<MetadataField<Ident>>,
} }
fn set_field<T: ToTokens>(field: &mut Option<T>, value: T) -> syn::Result<()> { fn set_field<T: ToTokens>(field: &mut Option<T>, value: T) -> syn::Result<()> {
@ -69,8 +78,8 @@ impl Parse for Metadata {
let mut method = None; let mut method = None;
let mut name = None; let mut name = None;
let mut path = None; let mut path = None;
let mut rate_limited = None; let mut rate_limited = vec![];
let mut authentication = None; let mut authentication = vec![];
for field_value in field_values { for field_value in field_values {
match field_value { match field_value {
@ -78,8 +87,12 @@ impl Parse for Metadata {
FieldValue::Method(m) => set_field(&mut method, m)?, FieldValue::Method(m) => set_field(&mut method, m)?,
FieldValue::Name(n) => set_field(&mut name, n)?, FieldValue::Name(n) => set_field(&mut name, n)?,
FieldValue::Path(p) => set_field(&mut path, p)?, FieldValue::Path(p) => set_field(&mut path, p)?,
FieldValue::RateLimited(rl) => set_field(&mut rate_limited, rl)?, FieldValue::RateLimited(value, attrs) => {
FieldValue::Authentication(a) => set_field(&mut authentication, a)?, rate_limited.push(MetadataField { value, attrs })
}
FieldValue::Authentication(value, attrs) => {
authentication.push(MetadataField { value, attrs })
}
} }
} }
@ -91,8 +104,16 @@ impl Parse for Metadata {
method: method.ok_or_else(|| missing_field("method"))?, method: method.ok_or_else(|| missing_field("method"))?,
name: name.ok_or_else(|| missing_field("name"))?, name: name.ok_or_else(|| missing_field("name"))?,
path: path.ok_or_else(|| missing_field("path"))?, path: path.ok_or_else(|| missing_field("path"))?,
rate_limited: rate_limited.ok_or_else(|| missing_field("rate_limited"))?, rate_limited: if rate_limited.is_empty() {
authentication: authentication.ok_or_else(|| missing_field("authentication"))?, return Err(missing_field("rate_limited"));
} else {
rate_limited
},
authentication: if authentication.is_empty() {
return Err(missing_field("authentication"));
} else {
authentication
},
}) })
} }
} }
@ -139,12 +160,21 @@ enum FieldValue {
Method(Ident), Method(Ident),
Name(LitStr), Name(LitStr),
Path(LitStr), Path(LitStr),
RateLimited(LitBool), RateLimited(LitBool, Vec<Attribute>),
Authentication(Ident), Authentication(Ident, Vec<Attribute>),
} }
impl Parse for FieldValue { impl Parse for FieldValue {
fn parse(input: ParseStream) -> syn::Result<Self> { fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs: Vec<Attribute> = input.call(Attribute::parse_outer)?;
for attr in attrs.iter() {
if !util::is_cfg_attribute(attr) {
return Err(syn::Error::new_spanned(
&attr,
"only `cfg` attributes may appear here",
));
}
}
let field: Field = input.parse()?; let field: Field = input.parse()?;
let _: Token![:] = input.parse()?; let _: Token![:] = input.parse()?;
@ -164,8 +194,8 @@ impl Parse for FieldValue {
Self::Path(path) Self::Path(path)
} }
Field::RateLimited => Self::RateLimited(input.parse()?), Field::RateLimited => Self::RateLimited(input.parse()?, attrs),
Field::Authentication => Self::Authentication(input.parse()?), Field::Authentication => Self::Authentication(input.parse()?, attrs),
}) })
} }
} }

View File

@ -5,7 +5,7 @@ use proc_macro_crate::crate_name;
use quote::quote; use quote::quote;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use syn::{ use syn::{
AngleBracketedGenericArguments, GenericArgument, Ident, Lifetime, AngleBracketedGenericArguments, AttrStyle, Attribute, GenericArgument, Ident, Lifetime,
ParenthesizedGenericArguments, PathArguments, Type, TypeArray, TypeBareFn, TypeGroup, ParenthesizedGenericArguments, PathArguments, Type, TypeArray, TypeBareFn, TypeGroup,
TypeParen, TypePath, TypePtr, TypeReference, TypeSlice, TypeTuple, TypeParen, TypePath, TypePtr, TypeReference, TypeSlice, TypeTuple,
}; };
@ -402,3 +402,7 @@ pub fn import_ruma_api() -> TokenStream {
quote! { ::ruma_api } quote! { ::ruma_api }
} }
} }
pub(crate) fn is_cfg_attribute(attr: &Attribute) -> bool {
attr.style == AttrStyle::Outer && attr.path.is_ident("cfg")
}

View File

@ -9,7 +9,15 @@ pub mod some_endpoint {
method: POST, // An `http::Method` constant. No imports required. method: POST, // An `http::Method` constant. No imports required.
name: "some_endpoint", name: "some_endpoint",
path: "/_matrix/some/endpoint/:baz", path: "/_matrix/some/endpoint/:baz",
#[cfg(all())]
rate_limited: true,
#[cfg(any())]
rate_limited: false, rate_limited: false,
#[cfg(all())]
authentication: AccessToken,
#[cfg(any())]
authentication: None, authentication: None,
} }