api: Allow cfg attributes on rate_limited and authentication metadata fields

This commit is contained in:
Akshay 2021-02-05 17:17:43 +05:30 committed by GitHub
parent d8c5c326e6
commit 1c0dab5a47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 124 additions and 44 deletions

View File

@ -83,8 +83,32 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
// with only the literal's value from here on.
let name = &api.metadata.name.value();
let path = &api.metadata.path;
let rate_limited = &api.metadata.rate_limited;
let authentication = &api.metadata.authentication;
let rate_limited: TokenStream = api
.metadata
.rate_limited
.iter()
.map(|r| {
let attrs = &r.attrs;
let value = &r.value;
quote! {
#( #attrs )*
rate_limited: #value,
}
})
.collect();
let authentication: TokenStream = api
.metadata
.authentication
.iter()
.map(|r| {
let attrs = &r.attrs;
let value = &r.value;
quote! {
#( #attrs )*
authentication: #ruma_api::AuthScheme::#value,
}
})
.collect();
let request_type = &api.request;
let response_type = &api.response;
@ -118,20 +142,24 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
};
let mut header_kvs = api.request.append_header_kvs();
if authentication == "AccessToken" {
header_kvs.extend(quote! {
req_headers.insert(
#http::header::AUTHORIZATION,
#http::header::HeaderValue::from_str(
&::std::format!(
"Bearer {}",
access_token.ok_or(
#ruma_api::error::IntoHttpError::NeedsAuthentication
)?
)
)?
);
});
for auth in &api.metadata.authentication {
if auth.value == "AccessToken" {
let attrs = &auth.attrs;
header_kvs.extend(quote! {
#( #attrs )*
req_headers.insert(
#http::header::AUTHORIZATION,
#http::header::HeaderValue::from_str(
&::std::format!(
"Bearer {}",
access_token.ok_or(
#ruma_api::error::IntoHttpError::NeedsAuthentication
)?
)
)?
);
});
}
}
let extract_request_headers = if api.request.has_header_fields() {
@ -227,19 +255,29 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
let error = &api.error_ty;
let request_lifetimes = api.request.combine_lifetimes();
let non_auth_endpoint_impls = if authentication != "None" {
TokenStream::new()
} else {
quote! {
#[automatically_derived]
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest
for Request #request_lifetimes
{}
let non_auth_endpoint_impls: TokenStream = api
.metadata
.authentication
.iter()
.map(|auth| {
if auth.value != "None" {
TokenStream::new()
} else {
let attrs = &auth.attrs;
quote! {
#( #attrs )*
#[automatically_derived]
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest
for Request #request_lifetimes
{}
#[automatically_derived]
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
}
};
#( #attrs )*
#[automatically_derived]
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
}
}
})
.collect();
Ok(quote! {
#[doc = #request_doc]
@ -301,8 +339,8 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
method: #http::Method::#method,
name: #name,
path: #path,
rate_limited: #rate_limited,
authentication: #ruma_api::AuthScheme::#authentication,
#rate_limited
#authentication
};
#[automatically_derived]

View File

@ -4,7 +4,7 @@ use quote::ToTokens;
use syn::{
braced,
parse::{Parse, ParseStream},
Ident, LitBool, LitStr, Token,
Attribute, Ident, LitBool, LitStr, Token,
};
use crate::util;
@ -19,6 +19,15 @@ mod kw {
syn::custom_keyword!(authentication);
}
/// A field of Metadata that contains attribute macros
pub struct MetadataField<T> {
/// attributes over the field
pub attrs: Vec<Attribute>,
/// the field itself
pub value: T,
}
/// The result of processing the `metadata` section of the macro.
pub struct Metadata {
/// The description field.
@ -34,10 +43,10 @@ pub struct Metadata {
pub path: LitStr,
/// The rate_limited field.
pub rate_limited: LitBool,
pub rate_limited: Vec<MetadataField<LitBool>>,
/// The authentication field.
pub authentication: Ident,
pub authentication: Vec<MetadataField<Ident>>,
}
fn set_field<T: ToTokens>(field: &mut Option<T>, value: T) -> syn::Result<()> {
@ -69,8 +78,8 @@ impl Parse for Metadata {
let mut method = None;
let mut name = None;
let mut path = None;
let mut rate_limited = None;
let mut authentication = None;
let mut rate_limited = vec![];
let mut authentication = vec![];
for field_value in field_values {
match field_value {
@ -78,8 +87,12 @@ impl Parse for Metadata {
FieldValue::Method(m) => set_field(&mut method, m)?,
FieldValue::Name(n) => set_field(&mut name, n)?,
FieldValue::Path(p) => set_field(&mut path, p)?,
FieldValue::RateLimited(rl) => set_field(&mut rate_limited, rl)?,
FieldValue::Authentication(a) => set_field(&mut authentication, a)?,
FieldValue::RateLimited(value, attrs) => {
rate_limited.push(MetadataField { value, attrs })
}
FieldValue::Authentication(value, attrs) => {
authentication.push(MetadataField { value, attrs })
}
}
}
@ -91,8 +104,16 @@ impl Parse for Metadata {
method: method.ok_or_else(|| missing_field("method"))?,
name: name.ok_or_else(|| missing_field("name"))?,
path: path.ok_or_else(|| missing_field("path"))?,
rate_limited: rate_limited.ok_or_else(|| missing_field("rate_limited"))?,
authentication: authentication.ok_or_else(|| missing_field("authentication"))?,
rate_limited: if rate_limited.is_empty() {
return Err(missing_field("rate_limited"));
} else {
rate_limited
},
authentication: if authentication.is_empty() {
return Err(missing_field("authentication"));
} else {
authentication
},
})
}
}
@ -139,12 +160,21 @@ enum FieldValue {
Method(Ident),
Name(LitStr),
Path(LitStr),
RateLimited(LitBool),
Authentication(Ident),
RateLimited(LitBool, Vec<Attribute>),
Authentication(Ident, Vec<Attribute>),
}
impl Parse for FieldValue {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs: Vec<Attribute> = input.call(Attribute::parse_outer)?;
for attr in attrs.iter() {
if !util::is_cfg_attribute(attr) {
return Err(syn::Error::new_spanned(
&attr,
"only `cfg` attributes may appear here",
));
}
}
let field: Field = input.parse()?;
let _: Token![:] = input.parse()?;
@ -164,8 +194,8 @@ impl Parse for FieldValue {
Self::Path(path)
}
Field::RateLimited => Self::RateLimited(input.parse()?),
Field::Authentication => Self::Authentication(input.parse()?),
Field::RateLimited => Self::RateLimited(input.parse()?, attrs),
Field::Authentication => Self::Authentication(input.parse()?, attrs),
})
}
}

View File

@ -5,7 +5,7 @@ use proc_macro_crate::crate_name;
use quote::quote;
use std::collections::BTreeSet;
use syn::{
AngleBracketedGenericArguments, GenericArgument, Ident, Lifetime,
AngleBracketedGenericArguments, AttrStyle, Attribute, GenericArgument, Ident, Lifetime,
ParenthesizedGenericArguments, PathArguments, Type, TypeArray, TypeBareFn, TypeGroup,
TypeParen, TypePath, TypePtr, TypeReference, TypeSlice, TypeTuple,
};
@ -402,3 +402,7 @@ pub fn import_ruma_api() -> TokenStream {
quote! { ::ruma_api }
}
}
pub(crate) fn is_cfg_attribute(attr: &Attribute) -> bool {
attr.style == AttrStyle::Outer && attr.path.is_ident("cfg")
}

View File

@ -9,7 +9,15 @@ pub mod some_endpoint {
method: POST, // An `http::Method` constant. No imports required.
name: "some_endpoint",
path: "/_matrix/some/endpoint/:baz",
#[cfg(all())]
rate_limited: true,
#[cfg(any())]
rate_limited: false,
#[cfg(all())]
authentication: AccessToken,
#[cfg(any())]
authentication: None,
}