diff --git a/crates/ruma-common/src/api/metadata.rs b/crates/ruma-common/src/api/metadata.rs index 7b599adc..bc46a162 100644 --- a/crates/ruma-common/src/api/metadata.rs +++ b/crates/ruma-common/src/api/metadata.rs @@ -4,13 +4,16 @@ use std::{ }; use bytes::BufMut; -use http::Method; +use http::{ + header::{self, HeaderName, HeaderValue}, + Method, +}; use percent_encoding::utf8_percent_encode; use tracing::warn; use super::{ error::{IntoHttpError, UnknownVersionError}, - AuthScheme, + AuthScheme, SendAccessToken, }; use crate::{serde::slice_to_buf, RoomVersionId}; @@ -53,6 +56,33 @@ impl Metadata { } } + /// Transform the `SendAccessToken` into an access token if the endpoint requires it, or if it + /// is `SendAccessToken::Force`. + /// + /// Fails if the endpoint requires an access token but the parameter is `SendAccessToken::None`, + /// or if the access token can't be converted to a [`HeaderValue`]. + pub fn authorization_header( + &self, + access_token: SendAccessToken<'_>, + ) -> Result, IntoHttpError> { + Ok(match self.authentication { + AuthScheme::None => match access_token.get_not_required_for_endpoint() { + Some(token) => Some((header::AUTHORIZATION, format!("Bearer {token}").try_into()?)), + None => None, + }, + + AuthScheme::AccessToken => { + let token = access_token + .get_required_for_endpoint() + .ok_or(IntoHttpError::NeedsAuthentication)?; + + Some((header::AUTHORIZATION, format!("Bearer {token}").try_into()?)) + } + + AuthScheme::ServerSignatures => None, + }) + } + /// Generate the endpoint URL for this endpoint. pub fn make_endpoint_url( &self, diff --git a/crates/ruma-macros/src/api/api_request.rs b/crates/ruma-macros/src/api/api_request.rs index 60754146..8618c5b9 100644 --- a/crates/ruma-macros/src/api/api_request.rs +++ b/crates/ruma-macros/src/api/api_request.rs @@ -77,8 +77,6 @@ impl Request { ); let struct_attributes = &self.attributes; - let authentication = &metadata.authentication; - let request_ident = Ident::new("Request", self.request_kw.span()); let lifetimes = self.all_lifetimes(); let lifetimes = lifetimes.iter().map(|(lt, attr)| quote! { #attr #lt }); @@ -95,7 +93,6 @@ impl Request { )] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] - #[ruma_api(authentication = #authentication)] #( #struct_attributes )* pub struct #request_ident < #(#lifetimes),* > { #fields diff --git a/crates/ruma-macros/src/api/attribute.rs b/crates/ruma-macros/src/api/attribute.rs index 1e6ded1f..10c9915f 100644 --- a/crates/ruma-macros/src/api/attribute.rs +++ b/crates/ruma-macros/src/api/attribute.rs @@ -2,7 +2,7 @@ use syn::{ parse::{Parse, ParseStream}, - Ident, Token, Type, + Ident, Token, }; mod kw { @@ -12,7 +12,6 @@ mod kw { syn::custom_keyword!(query); syn::custom_keyword!(query_map); syn::custom_keyword!(header); - syn::custom_keyword!(authentication); syn::custom_keyword!(manual_body_serde); } @@ -53,23 +52,6 @@ impl Parse for RequestMeta { } } -pub enum DeriveRequestMeta { - Authentication(Type), -} - -impl Parse for DeriveRequestMeta { - fn parse(input: ParseStream<'_>) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(kw::authentication) { - let _: kw::authentication = input.parse()?; - let _: Token![=] = input.parse()?; - input.parse().map(Self::Authentication) - } else { - Err(lookahead.error()) - } - } -} - pub enum ResponseMeta { NewtypeBody, RawBody, diff --git a/crates/ruma-macros/src/api/request.rs b/crates/ruma-macros/src/api/request.rs index ce62292f..96b87b7d 100644 --- a/crates/ruma-macros/src/api/request.rs +++ b/crates/ruma-macros/src/api/request.rs @@ -4,16 +4,10 @@ use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, - parse_quote, - punctuated::Punctuated, - DeriveInput, Field, Generics, Ident, Lifetime, Token, + DeriveInput, Field, Generics, Ident, Lifetime, }; -use super::{ - attribute::{DeriveRequestMeta, RequestMeta}, - auth_scheme::AuthScheme, - util::collect_lifetime_idents, -}; +use super::{attribute::RequestMeta, util::collect_lifetime_idents}; use crate::util::import_ruma_common; mod incoming; @@ -46,29 +40,7 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { }) .collect::>()?; - let mut authentication = None; - - for attr in input.attrs { - if !attr.path.is_ident("ruma_api") { - continue; - } - - let metas = - attr.parse_args_with(Punctuated::::parse_terminated)?; - for meta in metas { - match meta { - DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)), - } - } - } - - let request = Request { - ident: input.ident, - generics: input.generics, - fields, - lifetimes, - authentication: authentication.expect("missing authentication attribute"), - }; + let request = Request { ident: input.ident, generics: input.generics, fields, lifetimes }; let ruma_common = import_ruma_common(); let test = request.check(&ruma_common)?; @@ -93,8 +65,6 @@ struct Request { generics: Generics, lifetimes: RequestLifetimes, fields: Vec, - - authentication: AuthScheme, } impl Request { diff --git a/crates/ruma-macros/src/api/request/outgoing.rs b/crates/ruma-macros/src/api/request/outgoing.rs index 93ef47aa..0e5361ea 100644 --- a/crates/ruma-macros/src/api/request/outgoing.rs +++ b/crates/ruma-macros/src/api/request/outgoing.rs @@ -3,7 +3,6 @@ use quote::quote; use syn::Field; use super::{Request, RequestField}; -use crate::api::auth_scheme::AuthScheme; impl Request { pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream { @@ -100,29 +99,8 @@ impl Request { } })); - header_kvs.extend(match self.authentication { - AuthScheme::AccessToken(_) => quote! { - req_headers.insert( - #http::header::AUTHORIZATION, - ::std::convert::TryFrom::<_>::try_from(::std::format!( - "Bearer {}", - access_token - .get_required_for_endpoint() - .ok_or(#ruma_common::api::error::IntoHttpError::NeedsAuthentication)?, - ))?, - ); - }, - AuthScheme::None(_) => quote! { - if let Some(access_token) = access_token.get_not_required_for_endpoint() { - req_headers.insert( - #http::header::AUTHORIZATION, - ::std::convert::TryFrom::<_>::try_from( - ::std::format!("Bearer {}", access_token), - )? - ); - } - }, - AuthScheme::ServerSignatures(_) => quote! {}, + header_kvs.extend(quote! { + req_headers.extend(METADATA.authorization_header(access_token)?); }); let request_body = if let Some(field) = self.raw_body_field() {