api: Remove method from Request derive attributes

This commit is contained in:
Jonas Platte 2022-10-21 23:28:34 +02:00
parent 70c5e84107
commit e7e39a6af1
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
6 changed files with 47 additions and 40 deletions

View File

@ -3,6 +3,7 @@ use std::{
str::FromStr, str::FromStr,
}; };
use bytes::BufMut;
use http::Method; use http::Method;
use percent_encoding::utf8_percent_encode; use percent_encoding::utf8_percent_encode;
use tracing::warn; use tracing::warn;
@ -11,7 +12,7 @@ use super::{
error::{IntoHttpError, UnknownVersionError}, error::{IntoHttpError, UnknownVersionError},
AuthScheme, AuthScheme,
}; };
use crate::RoomVersionId; use crate::{serde::slice_to_buf, RoomVersionId};
/// Metadata about an API endpoint. /// Metadata about an API endpoint.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -37,6 +38,21 @@ pub struct Metadata {
} }
impl Metadata { impl Metadata {
/// Returns an empty request body for this Matrix request.
///
/// For `GET` requests, it returns an entirely empty buffer, for others it returns an empty JSON
/// object (`{}`).
pub fn empty_request_body<B>(&self) -> B
where
B: Default + BufMut,
{
if self.method == Method::GET {
Default::default()
} else {
slice_to_buf(b"{}")
}
}
/// Generate the endpoint URL for this endpoint. /// Generate the endpoint URL for this endpoint.
pub fn make_endpoint_url( pub fn make_endpoint_url(
&self, &self,

View File

@ -82,7 +82,6 @@ impl Request {
); );
let struct_attributes = &self.attributes; let struct_attributes = &self.attributes;
let method = &metadata.method;
let authentication = &metadata.authentication; let authentication = &metadata.authentication;
let request_ident = Ident::new("Request", self.request_kw.span()); let request_ident = Ident::new("Request", self.request_kw.span());
@ -101,11 +100,7 @@ impl Request {
)] )]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)]
#[ruma_api( #[ruma_api(authentication = #authentication, error_ty = #error_ty)]
method = #method,
authentication = #authentication,
error_ty = #error_ty,
)]
#( #struct_attributes )* #( #struct_attributes )*
pub struct #request_ident < #(#lifetimes),* > { pub struct #request_ident < #(#lifetimes),* > {
#fields #fields

View File

@ -13,7 +13,6 @@ mod kw {
syn::custom_keyword!(query_map); syn::custom_keyword!(query_map);
syn::custom_keyword!(header); syn::custom_keyword!(header);
syn::custom_keyword!(authentication); syn::custom_keyword!(authentication);
syn::custom_keyword!(method);
syn::custom_keyword!(error_ty); syn::custom_keyword!(error_ty);
syn::custom_keyword!(manual_body_serde); syn::custom_keyword!(manual_body_serde);
} }
@ -57,7 +56,6 @@ impl Parse for RequestMeta {
pub enum DeriveRequestMeta { pub enum DeriveRequestMeta {
Authentication(Type), Authentication(Type),
Method(Type),
ErrorTy(Type), ErrorTy(Type),
} }
@ -68,10 +66,6 @@ impl Parse for DeriveRequestMeta {
let _: kw::authentication = input.parse()?; let _: kw::authentication = input.parse()?;
let _: Token![=] = input.parse()?; let _: Token![=] = input.parse()?;
input.parse().map(Self::Authentication) input.parse().map(Self::Authentication)
} else if lookahead.peek(kw::method) {
let _: kw::method = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(Self::Method)
} else if lookahead.peek(kw::error_ty) { } else if lookahead.peek(kw::error_ty) {
let _: kw::error_ty = input.parse()?; let _: kw::error_ty = input.parse()?;
let _: Token![=] = input.parse()?; let _: Token![=] = input.parse()?;

View File

@ -48,7 +48,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
let mut authentication = None; let mut authentication = None;
let mut error_ty = None; let mut error_ty = None;
let mut method = None;
for attr in input.attrs { for attr in input.attrs {
if !attr.path.is_ident("ruma_api") { if !attr.path.is_ident("ruma_api") {
@ -60,7 +59,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
for meta in metas { for meta in metas {
match meta { match meta {
DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)), DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)),
DeriveRequestMeta::Method(t) => method = Some(parse_quote!(#t)),
DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t), DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t),
} }
} }
@ -72,12 +70,17 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
fields, fields,
lifetimes, lifetimes,
authentication: authentication.expect("missing authentication attribute"), authentication: authentication.expect("missing authentication attribute"),
method: method.expect("missing method attribute"),
error_ty: error_ty.expect("missing error_ty attribute"), error_ty: error_ty.expect("missing error_ty attribute"),
}; };
request.check()?; let ruma_common = import_ruma_common();
Ok(request.expand_all()) let test = request.check(&ruma_common)?;
let types_impls = request.expand_all(&ruma_common);
Ok(quote! {
#types_impls
#test
})
} }
#[derive(Default)] #[derive(Default)]
@ -95,7 +98,6 @@ struct Request {
fields: Vec<RequestField>, fields: Vec<RequestField>,
authentication: AuthScheme, authentication: AuthScheme,
method: Ident,
error_ty: Type, error_ty: Type,
} }
@ -149,8 +151,7 @@ impl Request {
self.fields.iter().find_map(RequestField::as_query_map_field) self.fields.iter().find_map(RequestField::as_query_map_field)
} }
fn expand_all(&self) -> TokenStream { fn expand_all(&self, ruma_common: &TokenStream) -> TokenStream {
let ruma_common = import_ruma_common();
let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
let serde = quote! { #ruma_common::exports::serde }; let serde = quote! { #ruma_common::exports::serde };
@ -206,8 +207,8 @@ impl Request {
} }
}); });
let outgoing_request_impl = self.expand_outgoing(&ruma_common); let outgoing_request_impl = self.expand_outgoing(ruma_common);
let incoming_request_impl = self.expand_incoming(&ruma_common); let incoming_request_impl = self.expand_incoming(ruma_common);
quote! { quote! {
#request_body_struct #request_body_struct
@ -218,7 +219,9 @@ impl Request {
} }
} }
pub(super) fn check(&self) -> syn::Result<()> { pub(super) fn check(&self, ruma_common: &TokenStream) -> syn::Result<Option<TokenStream>> {
let http = quote! { #ruma_common::exports::http };
// TODO: highlight problematic fields // TODO: highlight problematic fields
let newtype_body_fields = self.fields.iter().filter(|f| { let newtype_body_fields = self.fields.iter().filter(|f| {
@ -275,14 +278,17 @@ impl Request {
)); ));
} }
if self.method == "GET" && (has_body_fields || has_newtype_body_field) { Ok((has_body_fields || has_newtype_body_field).then(|| {
return Err(syn::Error::new_spanned( quote! {
&self.ident, #[::std::prelude::v1::test]
fn request_is_not_get() {
::std::assert_ne!(
METADATA.method, #http::Method::GET,
"GET endpoints can't have body fields", "GET endpoints can't have body fields",
)); );
} }
}
Ok(()) }))
} }
} }

View File

@ -11,7 +11,6 @@ impl Request {
let serde = quote! { #ruma_common::exports::serde }; let serde = quote! { #ruma_common::exports::serde };
let serde_json = quote! { #ruma_common::exports::serde_json }; let serde_json = quote! { #ruma_common::exports::serde_json };
let method = &self.method;
let error_ty = &self.error_ty; let error_ty = &self.error_ty;
let incoming_request_type = if self.has_lifetimes() { let incoming_request_type = if self.has_lifetimes() {
@ -203,9 +202,9 @@ impl Request {
B: ::std::convert::AsRef<[::std::primitive::u8]>, B: ::std::convert::AsRef<[::std::primitive::u8]>,
S: ::std::convert::AsRef<::std::primitive::str>, S: ::std::convert::AsRef<::std::primitive::str>,
{ {
if request.method() != #http::Method::#method { if request.method() != METADATA.method {
return Err(#ruma_common::api::error::FromHttpRequestError::MethodMismatch { return Err(#ruma_common::api::error::FromHttpRequestError::MethodMismatch {
expected: #http::Method::#method, expected: METADATA.method,
received: request.method().clone(), received: request.method().clone(),
}); });
} }

View File

@ -10,7 +10,6 @@ impl Request {
let bytes = quote! { #ruma_common::exports::bytes }; let bytes = quote! { #ruma_common::exports::bytes };
let http = quote! { #ruma_common::exports::http }; let http = quote! { #ruma_common::exports::http };
let method = &self.method;
let error_ty = &self.error_ty; let error_ty = &self.error_ty;
let path_fields = let path_fields =
@ -137,10 +136,8 @@ impl Request {
quote! { quote! {
#ruma_common::serde::json_to_buf(&RequestBody { #initializers })? #ruma_common::serde::json_to_buf(&RequestBody { #initializers })?
} }
} else if method == "GET" {
quote! { <T as ::std::default::Default>::default() }
} else { } else {
quote! { #ruma_common::serde::slice_to_buf(b"{}") } quote! { METADATA.empty_request_body::<T>() }
}; };
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
@ -170,7 +167,7 @@ impl Request {
considering_versions: &'_ [#ruma_common::api::MatrixVersion], considering_versions: &'_ [#ruma_common::api::MatrixVersion],
) -> ::std::result::Result<#http::Request<T>, #ruma_common::api::error::IntoHttpError> { ) -> ::std::result::Result<#http::Request<T>, #ruma_common::api::error::IntoHttpError> {
let mut req_builder = #http::Request::builder() let mut req_builder = #http::Request::builder()
.method(#http::Method::#method) .method(METADATA.method)
.uri(METADATA.make_endpoint_url( .uri(METADATA.make_endpoint_url(
considering_versions, considering_versions,
base_url, base_url,