api-macros: Use serde(transparent) for newtype bodies

… rather than special casing them in many places.
This commit is contained in:
Jonas Platte 2021-08-16 22:15:19 +02:00
parent 42fda7c89f
commit 96ab1674af
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
6 changed files with 162 additions and 191 deletions

View File

@ -124,7 +124,13 @@ impl Request {
} }
fn has_body_fields(&self) -> bool { fn has_body_fields(&self) -> bool {
self.fields.iter().any(|f| matches!(f, RequestField::Body(..))) self.fields
.iter()
.any(|f| matches!(f, RequestField::Body(_) | RequestField::NewtypeBody(_)))
}
fn has_newtype_body(&self) -> bool {
self.fields.iter().any(|f| matches!(f, RequestField::NewtypeBody(_)))
} }
fn has_header_fields(&self) -> bool { fn has_header_fields(&self) -> bool {
@ -132,11 +138,11 @@ impl Request {
} }
fn has_path_fields(&self) -> bool { fn has_path_fields(&self) -> bool {
self.fields.iter().any(|f| matches!(f, RequestField::Path(..))) self.fields.iter().any(|f| matches!(f, RequestField::Path(_)))
} }
fn has_query_fields(&self) -> bool { fn has_query_fields(&self) -> bool {
self.fields.iter().any(|f| matches!(f, RequestField::Query(..))) self.fields.iter().any(|f| matches!(f, RequestField::Query(_)))
} }
fn has_lifetimes(&self) -> bool { fn has_lifetimes(&self) -> bool {
@ -154,10 +160,6 @@ impl Request {
self.fields.iter().filter(|f| matches!(f, RequestField::Path(..))).count() self.fields.iter().filter(|f| matches!(f, RequestField::Path(..))).count()
} }
fn newtype_body_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_newtype_body_field)
}
fn raw_body_field(&self) -> Option<&Field> { fn raw_body_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_raw_body_field) self.fields.iter().find_map(RequestField::as_raw_body_field)
} }
@ -172,17 +174,10 @@ impl Request {
let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
let serde = quote! { #ruma_api::exports::serde }; let serde = quote! { #ruma_api::exports::serde };
let request_body_def = if let Some(body_field) = self.newtype_body_field() { let request_body_struct = self.has_body_fields().then(|| {
let field = Field { ident: None, colon_token: None, ..body_field.clone() }; let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] });
Some(quote! { (#field); })
} else if self.has_body_fields() {
let fields = self.fields.iter().filter_map(RequestField::as_body_field); let fields = self.fields.iter().filter_map(RequestField::as_body_field);
Some(quote! { { #(#fields),* } })
} else {
None
};
let request_body_struct = request_body_def.map(|def| {
// Though we don't track the difference between newtype body and body // Though we don't track the difference between newtype body and body
// for lifetimes, the outer check and the macro failing if it encounters // for lifetimes, the outer check and the macro failing if it encounters
// an illegal combination of field attributes, is enough to guarantee // an illegal combination of field attributes, is enough to guarantee
@ -199,7 +194,8 @@ impl Request {
#serde::Serialize, #serde::Serialize,
#derive_deserialize #derive_deserialize
)] )]
struct RequestBody< #(#lifetimes),* > #def #serde_attr
struct RequestBody< #(#lifetimes),* > { #(#fields),* }
} }
}); });
@ -351,12 +347,10 @@ impl RequestField {
/// Return the contained field if this request field is a body kind. /// Return the contained field if this request field is a body kind.
pub fn as_body_field(&self) -> Option<&Field> { pub fn as_body_field(&self) -> Option<&Field> {
self.field_of_kind(RequestFieldKind::Body) match self {
RequestField::Body(field) | RequestField::NewtypeBody(field) => Some(field),
_ => None,
} }
/// Return the contained field if this request field is a body kind.
pub fn as_newtype_body_field(&self) -> Option<&Field> {
self.field_of_kind(RequestFieldKind::NewtypeBody)
} }
/// Return the contained field if this request field is a raw body kind. /// Return the contained field if this request field is a raw body kind.

View File

@ -1,5 +1,6 @@
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::{Ident, Span, TokenStream};
use quote::quote; use quote::quote;
use syn::Field;
use super::{Request, RequestField, RequestFieldKind}; use super::{Request, RequestField, RequestFieldKind};
use crate::auth_scheme::AuthScheme; use crate::auth_scheme::AuthScheme;
@ -166,12 +167,10 @@ impl Request {
(TokenStream::new(), TokenStream::new()) (TokenStream::new(), TokenStream::new())
}; };
let extract_body = let extract_body = self.has_body_fields().then(|| {
(self.has_body_fields() || self.newtype_body_field().is_some()).then(|| {
let body_lifetimes = (!self.lifetimes.body.is_empty()).then(|| { let body_lifetimes = (!self.lifetimes.body.is_empty()).then(|| {
// duplicate the anonymous lifetime as many times as needed // duplicate the anonymous lifetime as many times as needed
let lifetimes = let lifetimes = std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len());
std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len());
quote! { < #( #lifetimes, )* > } quote! { < #( #lifetimes, )* > }
}); });
@ -195,14 +194,7 @@ impl Request {
} }
}); });
let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() { let (parse_body, body_vars) = if let Some(field) = self.raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let parse = quote! {
let #field_name = request_body.0;
};
(parse, quote! { #field_name, })
} else if let Some(field) = self.raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let parse = quote! { let parse = quote! {
let #field_name = let #field_name =
@ -211,7 +203,7 @@ impl Request {
(parse, quote! { #field_name, }) (parse, quote! { #field_name, })
} else { } else {
self.vars(RequestFieldKind::Body, quote! { request_body }) vars(self.body_fields(), quote! { request_body })
}; };
let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| {
@ -266,12 +258,18 @@ impl Request {
request_field_kind: RequestFieldKind, request_field_kind: RequestFieldKind,
src: TokenStream, src: TokenStream,
) -> (TokenStream, TokenStream) { ) -> (TokenStream, TokenStream) {
self.fields vars(self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)), src)
.iter() }
.filter_map(|f| f.field_of_kind(request_field_kind)) }
fn vars<'a>(
fields: impl IntoIterator<Item = &'a Field>,
src: TokenStream,
) -> (TokenStream, TokenStream) {
fields
.into_iter()
.map(|field| { .map(|field| {
let field_name = let field_name = field.ident.as_ref().expect("expected field to have an identifier");
field.ident.as_ref().expect("expected field to have an identifier");
let cfg_attrs = let cfg_attrs =
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>(); field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
@ -289,5 +287,4 @@ impl Request {
) )
}) })
.unzip() .unzip()
}
} }

View File

@ -1,5 +1,6 @@
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::{Ident, Span, TokenStream};
use quote::quote; use quote::quote;
use syn::Field;
use crate::auth_scheme::AuthScheme; use crate::auth_scheme::AuthScheme;
@ -156,18 +157,11 @@ impl Request {
let request_body = if let Some(field) = self.raw_body_field() { let request_body = if let Some(field) = self.raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { #ruma_serde::slice_to_buf(&self.#field_name) } quote! { #ruma_serde::slice_to_buf(&self.#field_name) }
} else if self.has_body_fields() || self.newtype_body_field().is_some() { } else if self.has_body_fields() {
let request_body_initializers = if let Some(field) = self.newtype_body_field() { let initializers = struct_init_fields(self.body_fields(), quote! { self });
let field_name =
field.ident.as_ref().expect("expected field to have an identifier");
quote! { (self.#field_name) }
} else {
let initializers = self.struct_init_fields(RequestFieldKind::Body, quote! { self });
quote! { { #initializers } }
};
quote! { quote! {
#ruma_serde::json_to_buf(&RequestBody #request_body_initializers)? #ruma_serde::json_to_buf(&RequestBody { #initializers })?
} }
} else { } else {
quote! { <T as ::std::default::Default>::default() } quote! { <T as ::std::default::Default>::default() }
@ -227,19 +221,28 @@ impl Request {
} }
} }
/// Produces code for a struct initializer for the given field kind to be accessed through the
/// given variable name.
fn struct_init_fields( fn struct_init_fields(
&self, &self,
request_field_kind: RequestFieldKind, request_field_kind: RequestFieldKind,
src: TokenStream, src: TokenStream,
) -> TokenStream { ) -> TokenStream {
self.fields struct_init_fields(
.iter() self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)),
.filter_map(|f| f.field_of_kind(request_field_kind)) src,
)
}
}
/// Produces code for a struct initializer for the given field kind to be accessed through the
/// given variable name.
fn struct_init_fields<'a>(
fields: impl IntoIterator<Item = &'a Field>,
src: TokenStream,
) -> TokenStream {
fields
.into_iter()
.map(|field| { .map(|field| {
let field_name = let field_name = field.ident.as_ref().expect("expected field to have an identifier");
field.ident.as_ref().expect("expected field to have an identifier");
let cfg_attrs = let cfg_attrs =
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>(); field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
@ -249,5 +252,4 @@ impl Request {
} }
}) })
.collect() .collect()
}
} }

View File

@ -65,17 +65,19 @@ struct Response {
impl Response { impl Response {
/// Whether or not this request has any data in the HTTP body. /// Whether or not this request has any data in the HTTP body.
fn has_body_fields(&self) -> bool { fn has_body_fields(&self) -> bool {
self.fields.iter().any(|f| matches!(f, ResponseField::Body(_))) self.fields
.iter()
.any(|f| matches!(f, ResponseField::Body(_) | &ResponseField::NewtypeBody(_)))
} }
/// Returns the body field. /// Whether or not this request has a single newtype body field.
fn newtype_body_field(&self) -> Option<&Field> { fn has_newtype_body(&self) -> bool {
self.fields.iter().find_map(ResponseField::as_newtype_body_field) self.fields.iter().any(|f| matches!(f, ResponseField::NewtypeBody(_)))
} }
/// Returns the body field. /// Whether or not this request has a single raw body field.
fn raw_body_field(&self) -> Option<&Field> { fn has_raw_body(&self) -> bool {
self.fields.iter().find_map(ResponseField::as_raw_body_field) self.fields.iter().any(|f| matches!(f, ResponseField::RawBody(_)))
} }
/// Whether or not this request has any data in the URL path. /// Whether or not this request has any data in the URL path.
@ -89,18 +91,9 @@ impl Response {
let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
let serde = quote! { #ruma_api::exports::serde }; let serde = quote! { #ruma_api::exports::serde };
let response_body_struct = let response_body_struct = (!self.has_raw_body()).then(|| {
self.fields.iter().all(|f| !matches!(f, ResponseField::RawBody(_))).then(|| { let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] });
let newtype_body_field = let fields = self.fields.iter().filter_map(ResponseField::as_body_field);
self.fields.iter().find(|f| matches!(f, ResponseField::NewtypeBody(_)));
let def = if let Some(body_field) = newtype_body_field {
let field =
Field { ident: None, colon_token: None, ..body_field.field().clone() };
quote! { (#field); }
} else {
let fields = self.fields.iter().filter_map(|f| f.as_body_field());
quote! { { #(#fields),* } }
};
quote! { quote! {
/// Data in the response body. /// Data in the response body.
@ -111,7 +104,8 @@ impl Response {
#serde::Deserialize, #serde::Deserialize,
#serde::Serialize, #serde::Serialize,
)] )]
struct ResponseBody #def #serde_attr
struct ResponseBody { #(#fields),* }
} }
}); });
@ -190,23 +184,7 @@ impl ResponseField {
/// Return the contained field if this response field is a body kind. /// Return the contained field if this response field is a body kind.
fn as_body_field(&self) -> Option<&Field> { fn as_body_field(&self) -> Option<&Field> {
match self { match self {
ResponseField::Body(field) => Some(field), ResponseField::Body(field) | ResponseField::NewtypeBody(field) => Some(field),
_ => None,
}
}
/// Return the contained field and HTTP header ident if this repsonse field is a header kind.
fn as_header_field(&self) -> Option<(&Field, &Ident)> {
match self {
ResponseField::Header(field, ident) => Some((field, ident)),
_ => None,
}
}
/// Return the contained field if this response field is a newtype body kind.
fn as_newtype_body_field(&self) -> Option<&Field> {
match self {
ResponseField::NewtypeBody(field) => Some(field),
_ => None, _ => None,
} }
} }
@ -218,6 +196,14 @@ impl ResponseField {
_ => None, _ => None,
} }
} }
/// Return the contained field and HTTP header ident if this repsonse field is a header kind.
fn as_header_field(&self) -> Option<(&Field, &Ident)> {
match self {
ResponseField::Header(field, ident) => Some((field, ident)),
_ => None,
}
}
} }
impl TryFrom<Field> for ResponseField { impl TryFrom<Field> for ResponseField {

View File

@ -16,8 +16,7 @@ impl Response {
} }
}); });
let typed_response_body_decl = let typed_response_body_decl = self.has_body_fields().then(|| {
(self.has_body_fields() || self.newtype_body_field().is_some()).then(|| {
quote! { quote! {
let response_body: < let response_body: <
ResponseBody ResponseBody
@ -50,7 +49,7 @@ impl Response {
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>(); field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
fields.push(match response_field { fields.push(match response_field {
ResponseField::Body(_) => { ResponseField::Body(_) | ResponseField::NewtypeBody(_) => {
quote! { quote! {
#( #cfg_attrs )* #( #cfg_attrs )*
#field_name: response_body.#field_name #field_name: response_body.#field_name
@ -83,12 +82,6 @@ impl Response {
}; };
quote! { #optional_header } quote! { #optional_header }
} }
ResponseField::NewtypeBody(_) => {
quote! {
#( #cfg_attrs )*
#field_name: response_body.0
}
}
// This field must be instantiated last to avoid `use of move value` error. // This field must be instantiated last to avoid `use of move value` error.
// We are guaranteed only one new body field because of a check in // We are guaranteed only one new body field because of a check in
// `parse_response`. // `parse_response`.

View File

@ -1,7 +1,7 @@
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::quote; use quote::quote;
use super::Response; use super::{Response, ResponseField};
impl Response { impl Response {
pub fn expand_outgoing(&self, ruma_api: &TokenStream) -> TokenStream { pub fn expand_outgoing(&self, ruma_api: &TokenStream) -> TokenStream {
@ -37,15 +37,12 @@ impl Response {
}) })
}); });
let body = if let Some(field) = self.raw_body_field() { let body = if let Some(field) =
self.fields.iter().find_map(ResponseField::as_raw_body_field)
{
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { #ruma_serde::slice_to_buf(&self.#field_name) } quote! { #ruma_serde::slice_to_buf(&self.#field_name) }
} else if let Some(field) = self.newtype_body_field() { } else if self.has_body_fields() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#ruma_serde::json_to_buf(&self.#field_name)?
}
} else {
let fields = self.fields.iter().filter_map(|response_field| { let fields = self.fields.iter().filter_map(|response_field| {
response_field.as_body_field().map(|field| { response_field.as_body_field().map(|field| {
let field_name = let field_name =
@ -62,6 +59,8 @@ impl Response {
quote! { quote! {
#ruma_serde::json_to_buf(&ResponseBody { #(#fields)* })? #ruma_serde::json_to_buf(&ResponseBody { #(#fields)* })?
} }
} else {
quote! { <T as ::std::default::Default>::default() }
}; };
quote! { quote! {