api: Remove authentication from Request derive attributes

This commit is contained in:
Jonas Platte 2022-10-22 00:32:20 +02:00
parent c9bd9bf00b
commit dff84efb0c
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
5 changed files with 38 additions and 81 deletions

View File

@ -4,13 +4,16 @@ use std::{
};
use bytes::BufMut;
use http::Method;
use http::{
header::{self, HeaderName, HeaderValue},
Method,
};
use percent_encoding::utf8_percent_encode;
use tracing::warn;
use super::{
error::{IntoHttpError, UnknownVersionError},
AuthScheme,
AuthScheme, SendAccessToken,
};
use crate::{serde::slice_to_buf, RoomVersionId};
@ -53,6 +56,33 @@ impl Metadata {
}
}
/// Transform the `SendAccessToken` into an access token if the endpoint requires it, or if it
/// is `SendAccessToken::Force`.
///
/// Fails if the endpoint requires an access token but the parameter is `SendAccessToken::None`,
/// or if the access token can't be converted to a [`HeaderValue`].
pub fn authorization_header(
&self,
access_token: SendAccessToken<'_>,
) -> Result<Option<(HeaderName, HeaderValue)>, IntoHttpError> {
Ok(match self.authentication {
AuthScheme::None => match access_token.get_not_required_for_endpoint() {
Some(token) => Some((header::AUTHORIZATION, format!("Bearer {token}").try_into()?)),
None => None,
},
AuthScheme::AccessToken => {
let token = access_token
.get_required_for_endpoint()
.ok_or(IntoHttpError::NeedsAuthentication)?;
Some((header::AUTHORIZATION, format!("Bearer {token}").try_into()?))
}
AuthScheme::ServerSignatures => None,
})
}
/// Generate the endpoint URL for this endpoint.
pub fn make_endpoint_url(
&self,

View File

@ -77,8 +77,6 @@ impl Request {
);
let struct_attributes = &self.attributes;
let authentication = &metadata.authentication;
let request_ident = Ident::new("Request", self.request_kw.span());
let lifetimes = self.all_lifetimes();
let lifetimes = lifetimes.iter().map(|(lt, attr)| quote! { #attr #lt });
@ -95,7 +93,6 @@ impl Request {
)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)]
#[ruma_api(authentication = #authentication)]
#( #struct_attributes )*
pub struct #request_ident < #(#lifetimes),* > {
#fields

View File

@ -2,7 +2,7 @@
use syn::{
parse::{Parse, ParseStream},
Ident, Token, Type,
Ident, Token,
};
mod kw {
@ -12,7 +12,6 @@ mod kw {
syn::custom_keyword!(query);
syn::custom_keyword!(query_map);
syn::custom_keyword!(header);
syn::custom_keyword!(authentication);
syn::custom_keyword!(manual_body_serde);
}
@ -53,23 +52,6 @@ impl Parse for RequestMeta {
}
}
pub enum DeriveRequestMeta {
Authentication(Type),
}
impl Parse for DeriveRequestMeta {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(kw::authentication) {
let _: kw::authentication = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(Self::Authentication)
} else {
Err(lookahead.error())
}
}
}
pub enum ResponseMeta {
NewtypeBody,
RawBody,

View File

@ -4,16 +4,10 @@ use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
DeriveInput, Field, Generics, Ident, Lifetime, Token,
DeriveInput, Field, Generics, Ident, Lifetime,
};
use super::{
attribute::{DeriveRequestMeta, RequestMeta},
auth_scheme::AuthScheme,
util::collect_lifetime_idents,
};
use super::{attribute::RequestMeta, util::collect_lifetime_idents};
use crate::util::import_ruma_common;
mod incoming;
@ -46,29 +40,7 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
})
.collect::<syn::Result<_>>()?;
let mut authentication = None;
for attr in input.attrs {
if !attr.path.is_ident("ruma_api") {
continue;
}
let metas =
attr.parse_args_with(Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated)?;
for meta in metas {
match meta {
DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)),
}
}
}
let request = Request {
ident: input.ident,
generics: input.generics,
fields,
lifetimes,
authentication: authentication.expect("missing authentication attribute"),
};
let request = Request { ident: input.ident, generics: input.generics, fields, lifetimes };
let ruma_common = import_ruma_common();
let test = request.check(&ruma_common)?;
@ -93,8 +65,6 @@ struct Request {
generics: Generics,
lifetimes: RequestLifetimes,
fields: Vec<RequestField>,
authentication: AuthScheme,
}
impl Request {

View File

@ -3,7 +3,6 @@ use quote::quote;
use syn::Field;
use super::{Request, RequestField};
use crate::api::auth_scheme::AuthScheme;
impl Request {
pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream {
@ -100,29 +99,8 @@ impl Request {
}
}));
header_kvs.extend(match self.authentication {
AuthScheme::AccessToken(_) => quote! {
req_headers.insert(
#http::header::AUTHORIZATION,
::std::convert::TryFrom::<_>::try_from(::std::format!(
"Bearer {}",
access_token
.get_required_for_endpoint()
.ok_or(#ruma_common::api::error::IntoHttpError::NeedsAuthentication)?,
))?,
);
},
AuthScheme::None(_) => quote! {
if let Some(access_token) = access_token.get_not_required_for_endpoint() {
req_headers.insert(
#http::header::AUTHORIZATION,
::std::convert::TryFrom::<_>::try_from(
::std::format!("Bearer {}", access_token),
)?
);
}
},
AuthScheme::ServerSignatures(_) => quote! {},
header_kvs.extend(quote! {
req_headers.extend(METADATA.authorization_header(access_token)?);
});
let request_body = if let Some(field) = self.raw_body_field() {