diff --git a/crates/ruma-api-macros/src/api/metadata.rs b/crates/ruma-api-macros/src/api/metadata.rs index 5fa28d68..d3e07997 100644 --- a/crates/ruma-api-macros/src/api/metadata.rs +++ b/crates/ruma-api-macros/src/api/metadata.rs @@ -1,5 +1,6 @@ //! Details of the `metadata` section of the procedural macro. +use proc_macro2::TokenStream; use quote::ToTokens; use syn::{ braced, @@ -17,6 +18,11 @@ mod kw { syn::custom_keyword!(path); syn::custom_keyword!(rate_limited); syn::custom_keyword!(authentication); + + syn::custom_keyword!(None); + syn::custom_keyword!(AccessToken); + syn::custom_keyword!(ServerSignatures); + syn::custom_keyword!(QueryOnlyAccessToken); } /// A field of Metadata that contains attribute macros @@ -46,7 +52,7 @@ pub struct Metadata { pub rate_limited: Vec>, /// The authentication field. - pub authentication: Vec>, + pub authentication: Vec>, } fn set_field(field: &mut Option, value: T) -> syn::Result<()> { @@ -118,6 +124,43 @@ impl Parse for Metadata { } } +#[derive(PartialEq)] +pub enum AuthScheme { + None(kw::None), + AccessToken(kw::AccessToken), + ServerSignatures(kw::ServerSignatures), + QueryOnlyAccessToken(kw::QueryOnlyAccessToken), +} + +impl Parse for AuthScheme { + fn parse(input: ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + + if lookahead.peek(kw::None) { + input.parse().map(Self::None) + } else if lookahead.peek(kw::AccessToken) { + input.parse().map(Self::AccessToken) + } else if lookahead.peek(kw::ServerSignatures) { + input.parse().map(Self::ServerSignatures) + } else if lookahead.peek(kw::QueryOnlyAccessToken) { + input.parse().map(Self::QueryOnlyAccessToken) + } else { + Err(lookahead.error()) + } + } +} + +impl ToTokens for AuthScheme { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + AuthScheme::None(kw) => kw.to_tokens(tokens), + AuthScheme::AccessToken(kw) => kw.to_tokens(tokens), + AuthScheme::ServerSignatures(kw) => kw.to_tokens(tokens), + AuthScheme::QueryOnlyAccessToken(kw) => kw.to_tokens(tokens), + } + } +} + enum Field { Description, Method, @@ -161,7 +204,7 @@ enum FieldValue { Name(LitStr), Path(LitStr), RateLimited(LitBool, Vec), - Authentication(Ident, Vec), + Authentication(AuthScheme, Vec), } impl Parse for FieldValue { diff --git a/crates/ruma-api-macros/src/api/request/incoming.rs b/crates/ruma-api-macros/src/api/request/incoming.rs index b1d56c34..e0c33c10 100644 --- a/crates/ruma-api-macros/src/api/request/incoming.rs +++ b/crates/ruma-api-macros/src/api/request/incoming.rs @@ -1,7 +1,8 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use super::{Metadata, Request, RequestField, RequestFieldKind}; +use super::{Request, RequestField, RequestFieldKind}; +use crate::api::metadata::{AuthScheme, Metadata}; impl Request { pub fn expand_incoming( @@ -193,7 +194,7 @@ impl Request { }; let non_auth_impls = metadata.authentication.iter().filter_map(|auth| { - (auth.value == "None").then(|| { + matches!(auth.value, AuthScheme::None(_)).then(|| { let attrs = &auth.attrs; quote! { #( #attrs )* diff --git a/crates/ruma-api-macros/src/api/request/outgoing.rs b/crates/ruma-api-macros/src/api/request/outgoing.rs index da363302..cfaee60e 100644 --- a/crates/ruma-api-macros/src/api/request/outgoing.rs +++ b/crates/ruma-api-macros/src/api/request/outgoing.rs @@ -1,7 +1,9 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use super::{Metadata, Request, RequestField, RequestFieldKind}; +use crate::api::metadata::{AuthScheme, Metadata}; + +use super::{Request, RequestField, RequestFieldKind}; impl Request { pub fn expand_outgoing( @@ -133,8 +135,8 @@ impl Request { for auth in &metadata.authentication { let attrs = &auth.attrs; - let hdr_kv = if auth.value == "AccessToken" { - quote! { + let hdr_kv = match auth.value { + AuthScheme::AccessToken(_) => quote! { #( #attrs )* req_headers.insert( #http::header::AUTHORIZATION, @@ -145,9 +147,8 @@ impl Request { .ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?, ))?, ); - } - } else { - quote! { + }, + AuthScheme::None(_) => quote! { if let Some(access_token) = access_token.get_not_required_for_endpoint() { #( #attrs )* req_headers.insert( @@ -157,7 +158,8 @@ impl Request { )? ); } - } + }, + AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {}, }; header_kvs.extend(hdr_kv); @@ -184,7 +186,7 @@ impl Request { }; let non_auth_impls = metadata.authentication.iter().filter_map(|auth| { - (auth.value == "None").then(|| { + matches!(auth.value, AuthScheme::None(_)).then(|| { let attrs = &auth.attrs; quote! { #( #attrs )*