diff --git a/crates/ruma-api-macros/src/request.rs b/crates/ruma-api-macros/src/request.rs index 09e243b5..a0d3d3e8 100644 --- a/crates/ruma-api-macros/src/request.rs +++ b/crates/ruma-api-macros/src/request.rs @@ -124,7 +124,13 @@ impl Request { } fn has_body_fields(&self) -> bool { - self.fields.iter().any(|f| matches!(f, RequestField::Body(..))) + self.fields + .iter() + .any(|f| matches!(f, RequestField::Body(_) | RequestField::NewtypeBody(_))) + } + + fn has_newtype_body(&self) -> bool { + self.fields.iter().any(|f| matches!(f, RequestField::NewtypeBody(_))) } fn has_header_fields(&self) -> bool { @@ -132,11 +138,11 @@ impl Request { } fn has_path_fields(&self) -> bool { - self.fields.iter().any(|f| matches!(f, RequestField::Path(..))) + self.fields.iter().any(|f| matches!(f, RequestField::Path(_))) } fn has_query_fields(&self) -> bool { - self.fields.iter().any(|f| matches!(f, RequestField::Query(..))) + self.fields.iter().any(|f| matches!(f, RequestField::Query(_))) } fn has_lifetimes(&self) -> bool { @@ -154,10 +160,6 @@ impl Request { self.fields.iter().filter(|f| matches!(f, RequestField::Path(..))).count() } - fn newtype_body_field(&self) -> Option<&Field> { - self.fields.iter().find_map(RequestField::as_newtype_body_field) - } - fn raw_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(RequestField::as_raw_body_field) } @@ -172,17 +174,10 @@ impl Request { let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde = quote! { #ruma_api::exports::serde }; - let request_body_def = if let Some(body_field) = self.newtype_body_field() { - let field = Field { ident: None, colon_token: None, ..body_field.clone() }; - Some(quote! { (#field); }) - } else if self.has_body_fields() { + let request_body_struct = self.has_body_fields().then(|| { + let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] }); let fields = self.fields.iter().filter_map(RequestField::as_body_field); - Some(quote! { { #(#fields),* } }) - } else { - None - }; - let request_body_struct = request_body_def.map(|def| { // Though we don't track the difference between newtype body and body // for lifetimes, the outer check and the macro failing if it encounters // an illegal combination of field attributes, is enough to guarantee @@ -199,7 +194,8 @@ impl Request { #serde::Serialize, #derive_deserialize )] - struct RequestBody< #(#lifetimes),* > #def + #serde_attr + struct RequestBody< #(#lifetimes),* > { #(#fields),* } } }); @@ -351,12 +347,10 @@ impl RequestField { /// Return the contained field if this request field is a body kind. pub fn as_body_field(&self) -> Option<&Field> { - self.field_of_kind(RequestFieldKind::Body) - } - - /// Return the contained field if this request field is a body kind. - pub fn as_newtype_body_field(&self) -> Option<&Field> { - self.field_of_kind(RequestFieldKind::NewtypeBody) + match self { + RequestField::Body(field) | RequestField::NewtypeBody(field) => Some(field), + _ => None, + } } /// Return the contained field if this request field is a raw body kind. diff --git a/crates/ruma-api-macros/src/request/incoming.rs b/crates/ruma-api-macros/src/request/incoming.rs index 3fa8a263..e2349852 100644 --- a/crates/ruma-api-macros/src/request/incoming.rs +++ b/crates/ruma-api-macros/src/request/incoming.rs @@ -1,5 +1,6 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; +use syn::Field; use super::{Request, RequestField, RequestFieldKind}; use crate::auth_scheme::AuthScheme; @@ -166,43 +167,34 @@ impl Request { (TokenStream::new(), TokenStream::new()) }; - let extract_body = - (self.has_body_fields() || self.newtype_body_field().is_some()).then(|| { - let body_lifetimes = (!self.lifetimes.body.is_empty()).then(|| { - // duplicate the anonymous lifetime as many times as needed - let lifetimes = - std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len()); - quote! { < #( #lifetimes, )* > } - }); - - quote! { - let request_body: < - RequestBody #body_lifetimes - as #ruma_serde::Outgoing - >::Incoming = { - let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( - request.body(), - ); - - #serde_json::from_slice(match body { - // If the request body is completely empty, pretend it is an empty JSON - // object instead. This allows requests with only optional body parameters - // to be deserialized in that case. - [] => b"{}", - b => b, - })? - }; - } + let extract_body = self.has_body_fields().then(|| { + let body_lifetimes = (!self.lifetimes.body.is_empty()).then(|| { + // duplicate the anonymous lifetime as many times as needed + let lifetimes = std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len()); + quote! { < #( #lifetimes, )* > } }); - let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - let parse = quote! { - let #field_name = request_body.0; - }; + quote! { + let request_body: < + RequestBody #body_lifetimes + as #ruma_serde::Outgoing + >::Incoming = { + let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( + request.body(), + ); - (parse, quote! { #field_name, }) - } else if let Some(field) = self.raw_body_field() { + #serde_json::from_slice(match body { + // If the request body is completely empty, pretend it is an empty JSON + // object instead. This allows requests with only optional body parameters + // to be deserialized in that case. + [] => b"{}", + b => b, + })? + }; + } + }); + + let (parse_body, body_vars) = if let Some(field) = self.raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let parse = quote! { let #field_name = @@ -211,7 +203,7 @@ impl Request { (parse, quote! { #field_name, }) } else { - self.vars(RequestFieldKind::Body, quote! { request_body }) + vars(self.body_fields(), quote! { request_body }) }; let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { @@ -266,28 +258,33 @@ impl Request { request_field_kind: RequestFieldKind, src: TokenStream, ) -> (TokenStream, TokenStream) { - self.fields - .iter() - .filter_map(|f| f.field_of_kind(request_field_kind)) - .map(|field| { - let field_name = - field.ident.as_ref().expect("expected field to have an identifier"); - let cfg_attrs = - field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); - - let decl = quote! { - #( #cfg_attrs )* - let #field_name = #src.#field_name; - }; - - ( - decl, - quote! { - #( #cfg_attrs )* - #field_name, - }, - ) - }) - .unzip() + vars(self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)), src) } } + +fn vars<'a>( + fields: impl IntoIterator, + src: TokenStream, +) -> (TokenStream, TokenStream) { + fields + .into_iter() + .map(|field| { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + let cfg_attrs = + field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); + + let decl = quote! { + #( #cfg_attrs )* + let #field_name = #src.#field_name; + }; + + ( + decl, + quote! { + #( #cfg_attrs )* + #field_name, + }, + ) + }) + .unzip() +} diff --git a/crates/ruma-api-macros/src/request/outgoing.rs b/crates/ruma-api-macros/src/request/outgoing.rs index 874a960b..40240c12 100644 --- a/crates/ruma-api-macros/src/request/outgoing.rs +++ b/crates/ruma-api-macros/src/request/outgoing.rs @@ -1,5 +1,6 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; +use syn::Field; use crate::auth_scheme::AuthScheme; @@ -156,18 +157,11 @@ impl Request { let request_body = if let Some(field) = self.raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { #ruma_serde::slice_to_buf(&self.#field_name) } - } else if self.has_body_fields() || self.newtype_body_field().is_some() { - let request_body_initializers = if let Some(field) = self.newtype_body_field() { - let field_name = - field.ident.as_ref().expect("expected field to have an identifier"); - quote! { (self.#field_name) } - } else { - let initializers = self.struct_init_fields(RequestFieldKind::Body, quote! { self }); - quote! { { #initializers } } - }; + } else if self.has_body_fields() { + let initializers = struct_init_fields(self.body_fields(), quote! { self }); quote! { - #ruma_serde::json_to_buf(&RequestBody #request_body_initializers)? + #ruma_serde::json_to_buf(&RequestBody { #initializers })? } } else { quote! { ::default() } @@ -227,27 +221,35 @@ impl Request { } } - /// Produces code for a struct initializer for the given field kind to be accessed through the - /// given variable name. fn struct_init_fields( &self, request_field_kind: RequestFieldKind, src: TokenStream, ) -> TokenStream { - self.fields - .iter() - .filter_map(|f| f.field_of_kind(request_field_kind)) - .map(|field| { - let field_name = - field.ident.as_ref().expect("expected field to have an identifier"); - let cfg_attrs = - field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); - - quote! { - #( #cfg_attrs )* - #field_name: #src.#field_name, - } - }) - .collect() + struct_init_fields( + self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)), + src, + ) } } + +/// Produces code for a struct initializer for the given field kind to be accessed through the +/// given variable name. +fn struct_init_fields<'a>( + fields: impl IntoIterator, + src: TokenStream, +) -> TokenStream { + fields + .into_iter() + .map(|field| { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + let cfg_attrs = + field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); + + quote! { + #( #cfg_attrs )* + #field_name: #src.#field_name, + } + }) + .collect() +} diff --git a/crates/ruma-api-macros/src/response.rs b/crates/ruma-api-macros/src/response.rs index c7473028..9660ef1c 100644 --- a/crates/ruma-api-macros/src/response.rs +++ b/crates/ruma-api-macros/src/response.rs @@ -65,17 +65,19 @@ struct Response { impl Response { /// Whether or not this request has any data in the HTTP body. fn has_body_fields(&self) -> bool { - self.fields.iter().any(|f| matches!(f, ResponseField::Body(_))) + self.fields + .iter() + .any(|f| matches!(f, ResponseField::Body(_) | &ResponseField::NewtypeBody(_))) } - /// Returns the body field. - fn newtype_body_field(&self) -> Option<&Field> { - self.fields.iter().find_map(ResponseField::as_newtype_body_field) + /// Whether or not this request has a single newtype body field. + fn has_newtype_body(&self) -> bool { + self.fields.iter().any(|f| matches!(f, ResponseField::NewtypeBody(_))) } - /// Returns the body field. - fn raw_body_field(&self) -> Option<&Field> { - self.fields.iter().find_map(ResponseField::as_raw_body_field) + /// Whether or not this request has a single raw body field. + fn has_raw_body(&self) -> bool { + self.fields.iter().any(|f| matches!(f, ResponseField::RawBody(_))) } /// Whether or not this request has any data in the URL path. @@ -89,31 +91,23 @@ impl Response { let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde = quote! { #ruma_api::exports::serde }; - let response_body_struct = - self.fields.iter().all(|f| !matches!(f, ResponseField::RawBody(_))).then(|| { - let newtype_body_field = - self.fields.iter().find(|f| matches!(f, ResponseField::NewtypeBody(_))); - let def = if let Some(body_field) = newtype_body_field { - let field = - Field { ident: None, colon_token: None, ..body_field.field().clone() }; - quote! { (#field); } - } else { - let fields = self.fields.iter().filter_map(|f| f.as_body_field()); - quote! { { #(#fields),* } } - }; + let response_body_struct = (!self.has_raw_body()).then(|| { + let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] }); + let fields = self.fields.iter().filter_map(ResponseField::as_body_field); - quote! { - /// Data in the response body. - #[derive( - Debug, - #ruma_api_macros::_FakeDeriveRumaApi, - #ruma_serde::Outgoing, - #serde::Deserialize, - #serde::Serialize, - )] - struct ResponseBody #def - } - }); + quote! { + /// Data in the response body. + #[derive( + Debug, + #ruma_api_macros::_FakeDeriveRumaApi, + #ruma_serde::Outgoing, + #serde::Deserialize, + #serde::Serialize, + )] + #serde_attr + struct ResponseBody { #(#fields),* } + } + }); let outgoing_response_impl = self.expand_outgoing(&ruma_api); let incoming_response_impl = self.expand_incoming(&self.error_ty, &ruma_api); @@ -190,23 +184,7 @@ impl ResponseField { /// Return the contained field if this response field is a body kind. fn as_body_field(&self) -> Option<&Field> { match self { - ResponseField::Body(field) => Some(field), - _ => None, - } - } - - /// Return the contained field and HTTP header ident if this repsonse field is a header kind. - fn as_header_field(&self) -> Option<(&Field, &Ident)> { - match self { - ResponseField::Header(field, ident) => Some((field, ident)), - _ => None, - } - } - - /// Return the contained field if this response field is a newtype body kind. - fn as_newtype_body_field(&self) -> Option<&Field> { - match self { - ResponseField::NewtypeBody(field) => Some(field), + ResponseField::Body(field) | ResponseField::NewtypeBody(field) => Some(field), _ => None, } } @@ -218,6 +196,14 @@ impl ResponseField { _ => None, } } + + /// Return the contained field and HTTP header ident if this repsonse field is a header kind. + fn as_header_field(&self) -> Option<(&Field, &Ident)> { + match self { + ResponseField::Header(field, ident) => Some((field, ident)), + _ => None, + } + } } impl TryFrom for ResponseField { diff --git a/crates/ruma-api-macros/src/response/incoming.rs b/crates/ruma-api-macros/src/response/incoming.rs index 5fa16ae2..39ee8a3c 100644 --- a/crates/ruma-api-macros/src/response/incoming.rs +++ b/crates/ruma-api-macros/src/response/incoming.rs @@ -16,27 +16,26 @@ impl Response { } }); - let typed_response_body_decl = - (self.has_body_fields() || self.newtype_body_field().is_some()).then(|| { - quote! { - let response_body: < - ResponseBody - as #ruma_serde::Outgoing - >::Incoming = { - let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( - response.body(), - ); + let typed_response_body_decl = self.has_body_fields().then(|| { + quote! { + let response_body: < + ResponseBody + as #ruma_serde::Outgoing + >::Incoming = { + let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( + response.body(), + ); - #serde_json::from_slice(match body { - // If the response body is completely empty, pretend it is an empty - // JSON object instead. This allows responses with only optional body - // parameters to be deserialized in that case. - [] => b"{}", - b => b, - })? - }; - } - }); + #serde_json::from_slice(match body { + // If the response body is completely empty, pretend it is an empty + // JSON object instead. This allows responses with only optional body + // parameters to be deserialized in that case. + [] => b"{}", + b => b, + })? + }; + } + }); let response_init_fields = { let mut fields = vec![]; @@ -50,7 +49,7 @@ impl Response { field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); fields.push(match response_field { - ResponseField::Body(_) => { + ResponseField::Body(_) | ResponseField::NewtypeBody(_) => { quote! { #( #cfg_attrs )* #field_name: response_body.#field_name @@ -83,12 +82,6 @@ impl Response { }; quote! { #optional_header } } - ResponseField::NewtypeBody(_) => { - quote! { - #( #cfg_attrs )* - #field_name: response_body.0 - } - } // This field must be instantiated last to avoid `use of move value` error. // We are guaranteed only one new body field because of a check in // `parse_response`. diff --git a/crates/ruma-api-macros/src/response/outgoing.rs b/crates/ruma-api-macros/src/response/outgoing.rs index 08e46b0d..f09b3615 100644 --- a/crates/ruma-api-macros/src/response/outgoing.rs +++ b/crates/ruma-api-macros/src/response/outgoing.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream; use quote::quote; -use super::Response; +use super::{Response, ResponseField}; impl Response { pub fn expand_outgoing(&self, ruma_api: &TokenStream) -> TokenStream { @@ -37,15 +37,12 @@ impl Response { }) }); - let body = if let Some(field) = self.raw_body_field() { + let body = if let Some(field) = + self.fields.iter().find_map(ResponseField::as_raw_body_field) + { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { #ruma_serde::slice_to_buf(&self.#field_name) } - } else if let Some(field) = self.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #ruma_serde::json_to_buf(&self.#field_name)? - } - } else { + } else if self.has_body_fields() { let fields = self.fields.iter().filter_map(|response_field| { response_field.as_body_field().map(|field| { let field_name = @@ -62,6 +59,8 @@ impl Response { quote! { #ruma_serde::json_to_buf(&ResponseBody { #(#fields)* })? } + } else { + quote! { ::default() } }; quote! {