api: Allow cfg attributes on rate_limited and authentication metadata fields
This commit is contained in:
parent
d8c5c326e6
commit
1c0dab5a47
@ -83,8 +83,32 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
|
||||
// with only the literal's value from here on.
|
||||
let name = &api.metadata.name.value();
|
||||
let path = &api.metadata.path;
|
||||
let rate_limited = &api.metadata.rate_limited;
|
||||
let authentication = &api.metadata.authentication;
|
||||
let rate_limited: TokenStream = api
|
||||
.metadata
|
||||
.rate_limited
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let attrs = &r.attrs;
|
||||
let value = &r.value;
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
rate_limited: #value,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let authentication: TokenStream = api
|
||||
.metadata
|
||||
.authentication
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let attrs = &r.attrs;
|
||||
let value = &r.value;
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
authentication: #ruma_api::AuthScheme::#value,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request_type = &api.request;
|
||||
let response_type = &api.response;
|
||||
@ -118,20 +142,24 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
|
||||
};
|
||||
|
||||
let mut header_kvs = api.request.append_header_kvs();
|
||||
if authentication == "AccessToken" {
|
||||
header_kvs.extend(quote! {
|
||||
req_headers.insert(
|
||||
#http::header::AUTHORIZATION,
|
||||
#http::header::HeaderValue::from_str(
|
||||
&::std::format!(
|
||||
"Bearer {}",
|
||||
access_token.ok_or(
|
||||
#ruma_api::error::IntoHttpError::NeedsAuthentication
|
||||
)?
|
||||
)
|
||||
)?
|
||||
);
|
||||
});
|
||||
for auth in &api.metadata.authentication {
|
||||
if auth.value == "AccessToken" {
|
||||
let attrs = &auth.attrs;
|
||||
header_kvs.extend(quote! {
|
||||
#( #attrs )*
|
||||
req_headers.insert(
|
||||
#http::header::AUTHORIZATION,
|
||||
#http::header::HeaderValue::from_str(
|
||||
&::std::format!(
|
||||
"Bearer {}",
|
||||
access_token.ok_or(
|
||||
#ruma_api::error::IntoHttpError::NeedsAuthentication
|
||||
)?
|
||||
)
|
||||
)?
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let extract_request_headers = if api.request.has_header_fields() {
|
||||
@ -227,19 +255,29 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
|
||||
let error = &api.error_ty;
|
||||
let request_lifetimes = api.request.combine_lifetimes();
|
||||
|
||||
let non_auth_endpoint_impls = if authentication != "None" {
|
||||
TokenStream::new()
|
||||
} else {
|
||||
quote! {
|
||||
#[automatically_derived]
|
||||
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest
|
||||
for Request #request_lifetimes
|
||||
{}
|
||||
let non_auth_endpoint_impls: TokenStream = api
|
||||
.metadata
|
||||
.authentication
|
||||
.iter()
|
||||
.map(|auth| {
|
||||
if auth.value != "None" {
|
||||
TokenStream::new()
|
||||
} else {
|
||||
let attrs = &auth.attrs;
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#[automatically_derived]
|
||||
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest
|
||||
for Request #request_lifetimes
|
||||
{}
|
||||
|
||||
#[automatically_derived]
|
||||
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
|
||||
}
|
||||
};
|
||||
#( #attrs )*
|
||||
#[automatically_derived]
|
||||
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(quote! {
|
||||
#[doc = #request_doc]
|
||||
@ -301,8 +339,8 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
|
||||
method: #http::Method::#method,
|
||||
name: #name,
|
||||
path: #path,
|
||||
rate_limited: #rate_limited,
|
||||
authentication: #ruma_api::AuthScheme::#authentication,
|
||||
#rate_limited
|
||||
#authentication
|
||||
};
|
||||
|
||||
#[automatically_derived]
|
||||
|
@ -4,7 +4,7 @@ use quote::ToTokens;
|
||||
use syn::{
|
||||
braced,
|
||||
parse::{Parse, ParseStream},
|
||||
Ident, LitBool, LitStr, Token,
|
||||
Attribute, Ident, LitBool, LitStr, Token,
|
||||
};
|
||||
|
||||
use crate::util;
|
||||
@ -19,6 +19,15 @@ mod kw {
|
||||
syn::custom_keyword!(authentication);
|
||||
}
|
||||
|
||||
/// A field of Metadata that contains attribute macros
|
||||
pub struct MetadataField<T> {
|
||||
/// attributes over the field
|
||||
pub attrs: Vec<Attribute>,
|
||||
|
||||
/// the field itself
|
||||
pub value: T,
|
||||
}
|
||||
|
||||
/// The result of processing the `metadata` section of the macro.
|
||||
pub struct Metadata {
|
||||
/// The description field.
|
||||
@ -34,10 +43,10 @@ pub struct Metadata {
|
||||
pub path: LitStr,
|
||||
|
||||
/// The rate_limited field.
|
||||
pub rate_limited: LitBool,
|
||||
pub rate_limited: Vec<MetadataField<LitBool>>,
|
||||
|
||||
/// The authentication field.
|
||||
pub authentication: Ident,
|
||||
pub authentication: Vec<MetadataField<Ident>>,
|
||||
}
|
||||
|
||||
fn set_field<T: ToTokens>(field: &mut Option<T>, value: T) -> syn::Result<()> {
|
||||
@ -69,8 +78,8 @@ impl Parse for Metadata {
|
||||
let mut method = None;
|
||||
let mut name = None;
|
||||
let mut path = None;
|
||||
let mut rate_limited = None;
|
||||
let mut authentication = None;
|
||||
let mut rate_limited = vec![];
|
||||
let mut authentication = vec![];
|
||||
|
||||
for field_value in field_values {
|
||||
match field_value {
|
||||
@ -78,8 +87,12 @@ impl Parse for Metadata {
|
||||
FieldValue::Method(m) => set_field(&mut method, m)?,
|
||||
FieldValue::Name(n) => set_field(&mut name, n)?,
|
||||
FieldValue::Path(p) => set_field(&mut path, p)?,
|
||||
FieldValue::RateLimited(rl) => set_field(&mut rate_limited, rl)?,
|
||||
FieldValue::Authentication(a) => set_field(&mut authentication, a)?,
|
||||
FieldValue::RateLimited(value, attrs) => {
|
||||
rate_limited.push(MetadataField { value, attrs })
|
||||
}
|
||||
FieldValue::Authentication(value, attrs) => {
|
||||
authentication.push(MetadataField { value, attrs })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,8 +104,16 @@ impl Parse for Metadata {
|
||||
method: method.ok_or_else(|| missing_field("method"))?,
|
||||
name: name.ok_or_else(|| missing_field("name"))?,
|
||||
path: path.ok_or_else(|| missing_field("path"))?,
|
||||
rate_limited: rate_limited.ok_or_else(|| missing_field("rate_limited"))?,
|
||||
authentication: authentication.ok_or_else(|| missing_field("authentication"))?,
|
||||
rate_limited: if rate_limited.is_empty() {
|
||||
return Err(missing_field("rate_limited"));
|
||||
} else {
|
||||
rate_limited
|
||||
},
|
||||
authentication: if authentication.is_empty() {
|
||||
return Err(missing_field("authentication"));
|
||||
} else {
|
||||
authentication
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -139,12 +160,21 @@ enum FieldValue {
|
||||
Method(Ident),
|
||||
Name(LitStr),
|
||||
Path(LitStr),
|
||||
RateLimited(LitBool),
|
||||
Authentication(Ident),
|
||||
RateLimited(LitBool, Vec<Attribute>),
|
||||
Authentication(Ident, Vec<Attribute>),
|
||||
}
|
||||
|
||||
impl Parse for FieldValue {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
let attrs: Vec<Attribute> = input.call(Attribute::parse_outer)?;
|
||||
for attr in attrs.iter() {
|
||||
if !util::is_cfg_attribute(attr) {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&attr,
|
||||
"only `cfg` attributes may appear here",
|
||||
));
|
||||
}
|
||||
}
|
||||
let field: Field = input.parse()?;
|
||||
let _: Token![:] = input.parse()?;
|
||||
|
||||
@ -164,8 +194,8 @@ impl Parse for FieldValue {
|
||||
|
||||
Self::Path(path)
|
||||
}
|
||||
Field::RateLimited => Self::RateLimited(input.parse()?),
|
||||
Field::Authentication => Self::Authentication(input.parse()?),
|
||||
Field::RateLimited => Self::RateLimited(input.parse()?, attrs),
|
||||
Field::Authentication => Self::Authentication(input.parse()?, attrs),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ use proc_macro_crate::crate_name;
|
||||
use quote::quote;
|
||||
use std::collections::BTreeSet;
|
||||
use syn::{
|
||||
AngleBracketedGenericArguments, GenericArgument, Ident, Lifetime,
|
||||
AngleBracketedGenericArguments, AttrStyle, Attribute, GenericArgument, Ident, Lifetime,
|
||||
ParenthesizedGenericArguments, PathArguments, Type, TypeArray, TypeBareFn, TypeGroup,
|
||||
TypeParen, TypePath, TypePtr, TypeReference, TypeSlice, TypeTuple,
|
||||
};
|
||||
@ -402,3 +402,7 @@ pub fn import_ruma_api() -> TokenStream {
|
||||
quote! { ::ruma_api }
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_cfg_attribute(attr: &Attribute) -> bool {
|
||||
attr.style == AttrStyle::Outer && attr.path.is_ident("cfg")
|
||||
}
|
||||
|
@ -9,7 +9,15 @@ pub mod some_endpoint {
|
||||
method: POST, // An `http::Method` constant. No imports required.
|
||||
name: "some_endpoint",
|
||||
path: "/_matrix/some/endpoint/:baz",
|
||||
|
||||
#[cfg(all())]
|
||||
rate_limited: true,
|
||||
#[cfg(any())]
|
||||
rate_limited: false,
|
||||
|
||||
#[cfg(all())]
|
||||
authentication: AccessToken,
|
||||
#[cfg(any())]
|
||||
authentication: None,
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user