diff --git a/crates/ruma-common/src/api/metadata.rs b/crates/ruma-common/src/api/metadata.rs index 5fca2476..7b599adc 100644 --- a/crates/ruma-common/src/api/metadata.rs +++ b/crates/ruma-common/src/api/metadata.rs @@ -3,6 +3,7 @@ use std::{ str::FromStr, }; +use bytes::BufMut; use http::Method; use percent_encoding::utf8_percent_encode; use tracing::warn; @@ -11,7 +12,7 @@ use super::{ error::{IntoHttpError, UnknownVersionError}, AuthScheme, }; -use crate::RoomVersionId; +use crate::{serde::slice_to_buf, RoomVersionId}; /// Metadata about an API endpoint. #[derive(Clone, Debug)] @@ -37,6 +38,21 @@ pub struct Metadata { } impl Metadata { + /// Returns an empty request body for this Matrix request. + /// + /// For `GET` requests, it returns an entirely empty buffer, for others it returns an empty JSON + /// object (`{}`). + pub fn empty_request_body(&self) -> B + where + B: Default + BufMut, + { + if self.method == Method::GET { + Default::default() + } else { + slice_to_buf(b"{}") + } + } + /// Generate the endpoint URL for this endpoint. pub fn make_endpoint_url( &self, diff --git a/crates/ruma-macros/src/api/api_request.rs b/crates/ruma-macros/src/api/api_request.rs index 88aed2bb..e3a86793 100644 --- a/crates/ruma-macros/src/api/api_request.rs +++ b/crates/ruma-macros/src/api/api_request.rs @@ -82,7 +82,6 @@ impl Request { ); let struct_attributes = &self.attributes; - let method = &metadata.method; let authentication = &metadata.authentication; let request_ident = Ident::new("Request", self.request_kw.span()); @@ -101,11 +100,7 @@ impl Request { )] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] - #[ruma_api( - method = #method, - authentication = #authentication, - error_ty = #error_ty, - )] + #[ruma_api(authentication = #authentication, error_ty = #error_ty)] #( #struct_attributes )* pub struct #request_ident < #(#lifetimes),* > { #fields diff --git a/crates/ruma-macros/src/api/attribute.rs b/crates/ruma-macros/src/api/attribute.rs index f88125c3..ad30764d 100644 --- a/crates/ruma-macros/src/api/attribute.rs +++ b/crates/ruma-macros/src/api/attribute.rs @@ -13,7 +13,6 @@ mod kw { syn::custom_keyword!(query_map); syn::custom_keyword!(header); syn::custom_keyword!(authentication); - syn::custom_keyword!(method); syn::custom_keyword!(error_ty); syn::custom_keyword!(manual_body_serde); } @@ -57,7 +56,6 @@ impl Parse for RequestMeta { pub enum DeriveRequestMeta { Authentication(Type), - Method(Type), ErrorTy(Type), } @@ -68,10 +66,6 @@ impl Parse for DeriveRequestMeta { let _: kw::authentication = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::Authentication) - } else if lookahead.peek(kw::method) { - let _: kw::method = input.parse()?; - let _: Token![=] = input.parse()?; - input.parse().map(Self::Method) } else if lookahead.peek(kw::error_ty) { let _: kw::error_ty = input.parse()?; let _: Token![=] = input.parse()?; diff --git a/crates/ruma-macros/src/api/request.rs b/crates/ruma-macros/src/api/request.rs index bccd3c23..15e4bc5e 100644 --- a/crates/ruma-macros/src/api/request.rs +++ b/crates/ruma-macros/src/api/request.rs @@ -48,7 +48,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { let mut authentication = None; let mut error_ty = None; - let mut method = None; for attr in input.attrs { if !attr.path.is_ident("ruma_api") { @@ -60,7 +59,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { for meta in metas { match meta { DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)), - DeriveRequestMeta::Method(t) => method = Some(parse_quote!(#t)), DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t), } } @@ -72,12 +70,17 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { fields, lifetimes, authentication: authentication.expect("missing authentication attribute"), - method: method.expect("missing method attribute"), error_ty: error_ty.expect("missing error_ty attribute"), }; - request.check()?; - Ok(request.expand_all()) + let ruma_common = import_ruma_common(); + let test = request.check(&ruma_common)?; + let types_impls = request.expand_all(&ruma_common); + + Ok(quote! { + #types_impls + #test + }) } #[derive(Default)] @@ -95,7 +98,6 @@ struct Request { fields: Vec, authentication: AuthScheme, - method: Ident, error_ty: Type, } @@ -149,8 +151,7 @@ impl Request { self.fields.iter().find_map(RequestField::as_query_map_field) } - fn expand_all(&self) -> TokenStream { - let ruma_common = import_ruma_common(); + fn expand_all(&self, ruma_common: &TokenStream) -> TokenStream { let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let serde = quote! { #ruma_common::exports::serde }; @@ -206,8 +207,8 @@ impl Request { } }); - let outgoing_request_impl = self.expand_outgoing(&ruma_common); - let incoming_request_impl = self.expand_incoming(&ruma_common); + let outgoing_request_impl = self.expand_outgoing(ruma_common); + let incoming_request_impl = self.expand_incoming(ruma_common); quote! { #request_body_struct @@ -218,7 +219,9 @@ impl Request { } } - pub(super) fn check(&self) -> syn::Result<()> { + pub(super) fn check(&self, ruma_common: &TokenStream) -> syn::Result> { + let http = quote! { #ruma_common::exports::http }; + // TODO: highlight problematic fields let newtype_body_fields = self.fields.iter().filter(|f| { @@ -275,14 +278,17 @@ impl Request { )); } - if self.method == "GET" && (has_body_fields || has_newtype_body_field) { - return Err(syn::Error::new_spanned( - &self.ident, - "GET endpoints can't have body fields", - )); - } - - Ok(()) + Ok((has_body_fields || has_newtype_body_field).then(|| { + quote! { + #[::std::prelude::v1::test] + fn request_is_not_get() { + ::std::assert_ne!( + METADATA.method, #http::Method::GET, + "GET endpoints can't have body fields", + ); + } + } + })) } } diff --git a/crates/ruma-macros/src/api/request/incoming.rs b/crates/ruma-macros/src/api/request/incoming.rs index f6f16072..a4966357 100644 --- a/crates/ruma-macros/src/api/request/incoming.rs +++ b/crates/ruma-macros/src/api/request/incoming.rs @@ -11,7 +11,6 @@ impl Request { let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; - let method = &self.method; let error_ty = &self.error_ty; let incoming_request_type = if self.has_lifetimes() { @@ -203,9 +202,9 @@ impl Request { B: ::std::convert::AsRef<[::std::primitive::u8]>, S: ::std::convert::AsRef<::std::primitive::str>, { - if request.method() != #http::Method::#method { + if request.method() != METADATA.method { return Err(#ruma_common::api::error::FromHttpRequestError::MethodMismatch { - expected: #http::Method::#method, + expected: METADATA.method, received: request.method().clone(), }); } diff --git a/crates/ruma-macros/src/api/request/outgoing.rs b/crates/ruma-macros/src/api/request/outgoing.rs index 624f85bc..b4c68c16 100644 --- a/crates/ruma-macros/src/api/request/outgoing.rs +++ b/crates/ruma-macros/src/api/request/outgoing.rs @@ -10,7 +10,6 @@ impl Request { let bytes = quote! { #ruma_common::exports::bytes }; let http = quote! { #ruma_common::exports::http }; - let method = &self.method; let error_ty = &self.error_ty; let path_fields = @@ -137,10 +136,8 @@ impl Request { quote! { #ruma_common::serde::json_to_buf(&RequestBody { #initializers })? } - } else if method == "GET" { - quote! { ::default() } } else { - quote! { #ruma_common::serde::slice_to_buf(b"{}") } + quote! { METADATA.empty_request_body::() } }; let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); @@ -170,7 +167,7 @@ impl Request { considering_versions: &'_ [#ruma_common::api::MatrixVersion], ) -> ::std::result::Result<#http::Request, #ruma_common::api::error::IntoHttpError> { let mut req_builder = #http::Request::builder() - .method(#http::Method::#method) + .method(METADATA.method) .uri(METADATA.make_endpoint_url( considering_versions, base_url,