api: Remove authentication from Request derive attributes

This commit is contained in:
Jonas Platte 2022-10-22 00:32:20 +02:00
parent c9bd9bf00b
commit dff84efb0c
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
5 changed files with 38 additions and 81 deletions

View File

@ -4,13 +4,16 @@ use std::{
}; };
use bytes::BufMut; use bytes::BufMut;
use http::Method; use http::{
header::{self, HeaderName, HeaderValue},
Method,
};
use percent_encoding::utf8_percent_encode; use percent_encoding::utf8_percent_encode;
use tracing::warn; use tracing::warn;
use super::{ use super::{
error::{IntoHttpError, UnknownVersionError}, error::{IntoHttpError, UnknownVersionError},
AuthScheme, AuthScheme, SendAccessToken,
}; };
use crate::{serde::slice_to_buf, RoomVersionId}; use crate::{serde::slice_to_buf, RoomVersionId};
@ -53,6 +56,33 @@ impl Metadata {
} }
} }
/// Transform the `SendAccessToken` into an access token if the endpoint requires it, or if it
/// is `SendAccessToken::Force`.
///
/// Fails if the endpoint requires an access token but the parameter is `SendAccessToken::None`,
/// or if the access token can't be converted to a [`HeaderValue`].
pub fn authorization_header(
&self,
access_token: SendAccessToken<'_>,
) -> Result<Option<(HeaderName, HeaderValue)>, IntoHttpError> {
Ok(match self.authentication {
AuthScheme::None => match access_token.get_not_required_for_endpoint() {
Some(token) => Some((header::AUTHORIZATION, format!("Bearer {token}").try_into()?)),
None => None,
},
AuthScheme::AccessToken => {
let token = access_token
.get_required_for_endpoint()
.ok_or(IntoHttpError::NeedsAuthentication)?;
Some((header::AUTHORIZATION, format!("Bearer {token}").try_into()?))
}
AuthScheme::ServerSignatures => None,
})
}
/// Generate the endpoint URL for this endpoint. /// Generate the endpoint URL for this endpoint.
pub fn make_endpoint_url( pub fn make_endpoint_url(
&self, &self,

View File

@ -77,8 +77,6 @@ impl Request {
); );
let struct_attributes = &self.attributes; let struct_attributes = &self.attributes;
let authentication = &metadata.authentication;
let request_ident = Ident::new("Request", self.request_kw.span()); let request_ident = Ident::new("Request", self.request_kw.span());
let lifetimes = self.all_lifetimes(); let lifetimes = self.all_lifetimes();
let lifetimes = lifetimes.iter().map(|(lt, attr)| quote! { #attr #lt }); let lifetimes = lifetimes.iter().map(|(lt, attr)| quote! { #attr #lt });
@ -95,7 +93,6 @@ impl Request {
)] )]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)]
#[ruma_api(authentication = #authentication)]
#( #struct_attributes )* #( #struct_attributes )*
pub struct #request_ident < #(#lifetimes),* > { pub struct #request_ident < #(#lifetimes),* > {
#fields #fields

View File

@ -2,7 +2,7 @@
use syn::{ use syn::{
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
Ident, Token, Type, Ident, Token,
}; };
mod kw { mod kw {
@ -12,7 +12,6 @@ mod kw {
syn::custom_keyword!(query); syn::custom_keyword!(query);
syn::custom_keyword!(query_map); syn::custom_keyword!(query_map);
syn::custom_keyword!(header); syn::custom_keyword!(header);
syn::custom_keyword!(authentication);
syn::custom_keyword!(manual_body_serde); syn::custom_keyword!(manual_body_serde);
} }
@ -53,23 +52,6 @@ impl Parse for RequestMeta {
} }
} }
pub enum DeriveRequestMeta {
Authentication(Type),
}
impl Parse for DeriveRequestMeta {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(kw::authentication) {
let _: kw::authentication = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(Self::Authentication)
} else {
Err(lookahead.error())
}
}
}
pub enum ResponseMeta { pub enum ResponseMeta {
NewtypeBody, NewtypeBody,
RawBody, RawBody,

View File

@ -4,16 +4,10 @@ use proc_macro2::TokenStream;
use quote::{quote, ToTokens}; use quote::{quote, ToTokens};
use syn::{ use syn::{
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
parse_quote, DeriveInput, Field, Generics, Ident, Lifetime,
punctuated::Punctuated,
DeriveInput, Field, Generics, Ident, Lifetime, Token,
}; };
use super::{ use super::{attribute::RequestMeta, util::collect_lifetime_idents};
attribute::{DeriveRequestMeta, RequestMeta},
auth_scheme::AuthScheme,
util::collect_lifetime_idents,
};
use crate::util::import_ruma_common; use crate::util::import_ruma_common;
mod incoming; mod incoming;
@ -46,29 +40,7 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
}) })
.collect::<syn::Result<_>>()?; .collect::<syn::Result<_>>()?;
let mut authentication = None; let request = Request { ident: input.ident, generics: input.generics, fields, lifetimes };
for attr in input.attrs {
if !attr.path.is_ident("ruma_api") {
continue;
}
let metas =
attr.parse_args_with(Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated)?;
for meta in metas {
match meta {
DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)),
}
}
}
let request = Request {
ident: input.ident,
generics: input.generics,
fields,
lifetimes,
authentication: authentication.expect("missing authentication attribute"),
};
let ruma_common = import_ruma_common(); let ruma_common = import_ruma_common();
let test = request.check(&ruma_common)?; let test = request.check(&ruma_common)?;
@ -93,8 +65,6 @@ struct Request {
generics: Generics, generics: Generics,
lifetimes: RequestLifetimes, lifetimes: RequestLifetimes,
fields: Vec<RequestField>, fields: Vec<RequestField>,
authentication: AuthScheme,
} }
impl Request { impl Request {

View File

@ -3,7 +3,6 @@ use quote::quote;
use syn::Field; use syn::Field;
use super::{Request, RequestField}; use super::{Request, RequestField};
use crate::api::auth_scheme::AuthScheme;
impl Request { impl Request {
pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream { pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream {
@ -100,29 +99,8 @@ impl Request {
} }
})); }));
header_kvs.extend(match self.authentication { header_kvs.extend(quote! {
AuthScheme::AccessToken(_) => quote! { req_headers.extend(METADATA.authorization_header(access_token)?);
req_headers.insert(
#http::header::AUTHORIZATION,
::std::convert::TryFrom::<_>::try_from(::std::format!(
"Bearer {}",
access_token
.get_required_for_endpoint()
.ok_or(#ruma_common::api::error::IntoHttpError::NeedsAuthentication)?,
))?,
);
},
AuthScheme::None(_) => quote! {
if let Some(access_token) = access_token.get_not_required_for_endpoint() {
req_headers.insert(
#http::header::AUTHORIZATION,
::std::convert::TryFrom::<_>::try_from(
::std::format!("Bearer {}", access_token),
)?
);
}
},
AuthScheme::ServerSignatures(_) => quote! {},
}); });
let request_body = if let Some(field) = self.raw_body_field() { let request_body = if let Some(field) = self.raw_body_field() {