diff --git a/ruma-api-macros/src/api/request/outgoing.rs b/ruma-api-macros/src/api/request/outgoing.rs index 12289bc5..46aed9b6 100644 --- a/ruma-api-macros/src/api/request/outgoing.rs +++ b/ruma-api-macros/src/api/request/outgoing.rs @@ -133,32 +133,34 @@ impl Request { for auth in &metadata.authentication { let attrs = &auth.attrs; - let (if_required_expr, none_expr) = if auth.value == "AccessToken" { - ( - quote! { Some(access_token) }, - quote! { return Err(#ruma_api::error::IntoHttpError::NeedsAuthentication) }, - ) - } else { - (quote! { None }, quote! { None }) - }; - - header_kvs.extend(quote! {{ - let access_token = match access_token { - #ruma_api::SendAccessToken::IfRequired(access_token) => #if_required_expr, - #ruma_api::SendAccessToken::Always(access_token) => Some(access_token), - #ruma_api::SendAccessToken::None => #none_expr, - }; - - if let Some(access_token) = access_token { + let hdr_kv = if auth.value == "AccessToken" { + quote! { #( #attrs )* req_headers.insert( #http::header::AUTHORIZATION, - #http::header::HeaderValue::from_str( - &::std::format!("Bearer {}", access_token) - )? + #http::header::HeaderValue::from_str(&::std::format!( + "Bearer {}", + access_token + .get_required() + .ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?, + ))? ); } - }}); + } else { + quote! { + if let Some(access_token) = access_token.get_optional() { + #( #attrs )* + req_headers.insert( + #http::header::AUTHORIZATION, + #http::header::HeaderValue::from_str( + &::std::format!("Bearer {}", access_token) + )? + ); + } + } + }; + + header_kvs.extend(hdr_kv); } let request_body = if let Some(field) = self.newtype_raw_body_field() { diff --git a/ruma-api/src/lib.rs b/ruma-api/src/lib.rs index 1d62a9fe..5a474533 100644 --- a/ruma-api/src/lib.rs +++ b/ruma-api/src/lib.rs @@ -213,7 +213,7 @@ pub mod exports { use error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError}; /// An enum to control whether an access token should be added to outgoing requests -#[derive(Debug, Clone)] +#[derive(Clone, Copy, Debug)] pub enum SendAccessToken<'a> { /// Add the given access token to the request only if the `METADATA` on the request requires it IfRequired(&'a str), @@ -226,6 +226,28 @@ pub enum SendAccessToken<'a> { None, } +impl<'a> SendAccessToken<'a> { + /// Get the access token for an endpoint that should not require one. + /// + /// Returns `Some(_)` only if `self` is `SendAccessToken::Always(_)`. + pub fn get_optional(self) -> Option<&'a str> { + match self { + Self::Always(tok) => Some(tok), + Self::IfRequired(_) | Self::None => None, + } + } + + /// Get the access token for an endpoint that requires one. + /// + /// Returns `Some(_)` if `self` contains an access token. + pub fn get_required(self) -> Option<&'a str> { + match self { + Self::IfRequired(tok) | Self::Always(tok) => Some(tok), + Self::None => None, + } + } +} + /// A request type for a Matrix API endpoint, used for sending requests. pub trait OutgoingRequest: Sized { /// A type capturing the expected error conditions the server can return. diff --git a/ruma-client-api/src/r0/state/get_state_events_for_key.rs b/ruma-client-api/src/r0/state/get_state_events_for_key.rs index 5abaaa12..f4c341db 100644 --- a/ruma-client-api/src/r0/state/get_state_events_for_key.rs +++ b/ruma-client-api/src/r0/state/get_state_events_for_key.rs @@ -87,20 +87,19 @@ impl<'a> ruma_api::OutgoingRequest for Request<'a> { url.push_str(&Cow::from(utf8_percent_encode(&self.state_key, NON_ALPHANUMERIC))); } - let access_token = match access_token { - SendAccessToken::IfRequired(access_token) | SendAccessToken::Always(access_token) => { - access_token - } - SendAccessToken::None => { - return Err(ruma_api::error::IntoHttpError::NeedsAuthentication) - } - }; - http::Request::builder() .method(http::Method::GET) .uri(url) .header(CONTENT_TYPE, "application/json") - .header(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", access_token))?) + .header( + AUTHORIZATION, + HeaderValue::from_str(&format!( + "Bearer {}", + access_token + .get_required() + .ok_or(ruma_api::error::IntoHttpError::NeedsAuthentication)?, + ))?, + ) .body(Vec::new()) .map_err(Into::into) }