api-macros: Refactor access token sending logic

This commit is contained in:
Jonas Platte 2021-04-23 13:22:20 +02:00
parent 527007b957
commit ae26be88c5
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
3 changed files with 55 additions and 32 deletions

View File

@ -133,23 +133,22 @@ impl Request {
for auth in &metadata.authentication { for auth in &metadata.authentication {
let attrs = &auth.attrs; let attrs = &auth.attrs;
let (if_required_expr, none_expr) = if auth.value == "AccessToken" { let hdr_kv = if auth.value == "AccessToken" {
( quote! {
quote! { Some(access_token) }, #( #attrs )*
quote! { return Err(#ruma_api::error::IntoHttpError::NeedsAuthentication) }, req_headers.insert(
) #http::header::AUTHORIZATION,
#http::header::HeaderValue::from_str(&::std::format!(
"Bearer {}",
access_token
.get_required()
.ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?,
))?
);
}
} else { } else {
(quote! { None }, quote! { None }) quote! {
}; if let Some(access_token) = access_token.get_optional() {
header_kvs.extend(quote! {{
let access_token = match access_token {
#ruma_api::SendAccessToken::IfRequired(access_token) => #if_required_expr,
#ruma_api::SendAccessToken::Always(access_token) => Some(access_token),
#ruma_api::SendAccessToken::None => #none_expr,
};
if let Some(access_token) = access_token {
#( #attrs )* #( #attrs )*
req_headers.insert( req_headers.insert(
#http::header::AUTHORIZATION, #http::header::AUTHORIZATION,
@ -158,7 +157,10 @@ impl Request {
)? )?
); );
} }
}}); }
};
header_kvs.extend(hdr_kv);
} }
let request_body = if let Some(field) = self.newtype_raw_body_field() { let request_body = if let Some(field) = self.newtype_raw_body_field() {

View File

@ -213,7 +213,7 @@ pub mod exports {
use error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError}; use error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError};
/// An enum to control whether an access token should be added to outgoing requests /// An enum to control whether an access token should be added to outgoing requests
#[derive(Debug, Clone)] #[derive(Clone, Copy, Debug)]
pub enum SendAccessToken<'a> { pub enum SendAccessToken<'a> {
/// Add the given access token to the request only if the `METADATA` on the request requires it /// Add the given access token to the request only if the `METADATA` on the request requires it
IfRequired(&'a str), IfRequired(&'a str),
@ -226,6 +226,28 @@ pub enum SendAccessToken<'a> {
None, None,
} }
impl<'a> SendAccessToken<'a> {
/// Get the access token for an endpoint that should not require one.
///
/// Returns `Some(_)` only if `self` is `SendAccessToken::Always(_)`.
pub fn get_optional(self) -> Option<&'a str> {
match self {
Self::Always(tok) => Some(tok),
Self::IfRequired(_) | Self::None => None,
}
}
/// Get the access token for an endpoint that requires one.
///
/// Returns `Some(_)` if `self` contains an access token.
pub fn get_required(self) -> Option<&'a str> {
match self {
Self::IfRequired(tok) | Self::Always(tok) => Some(tok),
Self::None => None,
}
}
}
/// A request type for a Matrix API endpoint, used for sending requests. /// A request type for a Matrix API endpoint, used for sending requests.
pub trait OutgoingRequest: Sized { pub trait OutgoingRequest: Sized {
/// A type capturing the expected error conditions the server can return. /// A type capturing the expected error conditions the server can return.

View File

@ -87,20 +87,19 @@ impl<'a> ruma_api::OutgoingRequest for Request<'a> {
url.push_str(&Cow::from(utf8_percent_encode(&self.state_key, NON_ALPHANUMERIC))); url.push_str(&Cow::from(utf8_percent_encode(&self.state_key, NON_ALPHANUMERIC)));
} }
let access_token = match access_token {
SendAccessToken::IfRequired(access_token) | SendAccessToken::Always(access_token) => {
access_token
}
SendAccessToken::None => {
return Err(ruma_api::error::IntoHttpError::NeedsAuthentication)
}
};
http::Request::builder() http::Request::builder()
.method(http::Method::GET) .method(http::Method::GET)
.uri(url) .uri(url)
.header(CONTENT_TYPE, "application/json") .header(CONTENT_TYPE, "application/json")
.header(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", access_token))?) .header(
AUTHORIZATION,
HeaderValue::from_str(&format!(
"Bearer {}",
access_token
.get_required()
.ok_or(ruma_api::error::IntoHttpError::NeedsAuthentication)?,
))?,
)
.body(Vec::new()) .body(Vec::new())
.map_err(Into::into) .map_err(Into::into)
} }