diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index 49196e73..80eb34c1 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -274,26 +274,55 @@ impl ToTokens for Api { TokenStream::new() }; + let extract_request_body = + if self.request.has_body_fields() || self.request.newtype_body_field().is_some() { + quote! { + let request_body: ::Incoming = + ruma_api::exports::serde_json::from_slice(request.body().as_slice())?; + } + } else { + TokenStream::new() + }; + let parse_request_headers = if self.request.has_header_fields() { self.request.parse_headers_from_request() } else { TokenStream::new() }; - let request_body_initializers = if let Some(field) = self.request.newtype_body_field() { + let request_body = if let Some(field) = self.request.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { (request.#field_name) } + quote!(request.#field_name) + } else if self.request.has_body_fields() || self.request.newtype_body_field().is_some() { + let request_body_initializers = if let Some(field) = self.request.newtype_body_field() { + let field_name = + field.ident.as_ref().expect("expected field to have an identifier"); + quote! { (request.#field_name) } + } else { + let initializers = self.request.request_body_init_fields(); + quote! { { #initializers } } + }; + + quote! { + { + let request_body = RequestBody #request_body_initializers; + ruma_api::exports::serde_json::to_vec(&request_body)? + } + } } else { - let initializers = self.request.request_body_init_fields(); - quote! { { #initializers } } + quote!(Vec::new()) }; let parse_request_body = if let Some(field) = self.request.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { #field_name: request_body.0, } + } else if let Some(field) = self.request.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request.into_body(), + } } else { self.request.request_init_body_fields() }; @@ -306,6 +335,16 @@ impl ToTokens for Api { TokenStream::new() }; + let typed_response_body_decl = + if self.response.has_body_fields() || self.response.newtype_body_field().is_some() { + quote! { + let response_body: ::Incoming = + ruma_api::exports::serde_json::from_slice(response_body.as_slice())?; + } + } else { + TokenStream::new() + }; + let response_init_fields = self.response.init_fields(); let serialize_response_headers = self.response.apply_header_fields(); @@ -337,9 +376,7 @@ impl ToTokens for Api { #extract_request_path #extract_request_query #extract_request_headers - - let request_body: ::Incoming = - ruma_api::exports::serde_json::from_slice(request.body().as_slice())?; + #extract_request_body Ok(Self { #parse_request_path @@ -367,11 +404,7 @@ impl ToTokens for Api { { #url_set_path } { #url_set_querystring } - let request_body = RequestBody #request_body_initializers; - - let mut http_request = ruma_api::exports::http::Request::new( - ruma_api::exports::serde_json::to_vec(&request_body)?, - ); + let mut http_request = ruma_api::exports::http::Request::new(#request_body); *http_request.method_mut() = ruma_api::exports::http::Method::#method; *http_request.uri_mut() = url.into_string().parse().unwrap(); @@ -393,7 +426,7 @@ impl ToTokens for Api { let response = ruma_api::exports::http::Response::builder() .header(ruma_api::exports::http::header::CONTENT_TYPE, "application/json") #serialize_response_headers - .body(ruma_api::exports::serde_json::to_vec(&#body)?) + .body(#body) .unwrap(); Ok(response) } @@ -409,10 +442,8 @@ impl ToTokens for Api { if http_response.status().is_success() { #extract_response_headers - let response_body: ::Incoming = - ruma_api::exports::serde_json::from_slice( - http_response.into_body().as_slice(), - )?; + let response_body = http_response.into_body(); + #typed_response_body_decl Ok(Self { #response_init_fields diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index ad741f89..d4c02d06 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -118,6 +118,11 @@ impl Request { self.fields.iter().find_map(RequestField::as_newtype_body_field) } + /// Returns the body field. + pub fn newtype_raw_body_field(&self) -> Option<&Field> { + self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) + } + /// Returns the query map field. pub fn query_map_field(&self) -> Option<&Field> { self.fields.iter().find_map(RequestField::as_query_map_field) @@ -205,7 +210,7 @@ impl TryFrom for Request { field_kind = Some(match meta { Meta::Word(ident) => { match &ident.to_string()[..] { - "body" => { + s @ "body" | s @ "raw_body" => { if let Some(f) = &newtype_body_field { let mut error = syn::Error::new_spanned( field, @@ -219,7 +224,11 @@ impl TryFrom for Request { } newtype_body_field = Some(field.clone()); - RequestFieldKind::NewtypeBody + match s { + "body" => RequestFieldKind::NewtypeBody, + "raw_body" => RequestFieldKind::NewtypeRawBody, + _ => unreachable!(), + } } "path" => RequestFieldKind::Path, "query" => RequestFieldKind::Query, @@ -299,7 +308,7 @@ impl ToTokens for Request { quote! { { #(#fields),* } } }; - let (derive_deserialize, request_body_def) = + let request_body_struct = if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; let derive_deserialize = if body_field.has_wrap_incoming_attr() { @@ -308,7 +317,7 @@ impl ToTokens for Request { quote!(ruma_api::exports::serde::Deserialize) }; - (derive_deserialize, quote! { (#field); }) + Some((derive_deserialize, quote! { (#field); })) } else if self.has_body_fields() { let fields = self.fields.iter().filter(|f| f.is_body()); let derive_deserialize = if fields.clone().any(|f| f.has_wrap_incoming_attr()) { @@ -318,10 +327,22 @@ impl ToTokens for Request { }; let fields = fields.map(RequestField::field); - (derive_deserialize, quote! { { #(#fields),* } }) + Some((derive_deserialize, quote! { { #(#fields),* } })) } else { - (quote!(ruma_api::exports::serde::Deserialize), quote!(;)) - }; + None + } + .map(|(derive_deserialize, def)| { + quote! { + /// Data in the request body. + #[derive( + Debug, + ruma_api::Outgoing, + ruma_api::exports::serde::Serialize, + #derive_deserialize + )] + struct RequestBody #def + } + }); let request_path_struct = if self.has_path_fields() { let fields = self.fields.iter().filter_map(RequestField::as_path_field); @@ -376,15 +397,7 @@ impl ToTokens for Request { #[incoming_no_deserialize] pub struct Request #request_def - /// Data in the request body. - #[derive( - Debug, - ruma_api::Outgoing, - ruma_api::exports::serde::Serialize, - #derive_deserialize - )] - struct RequestBody #request_body_def - + #request_body_struct #request_path_struct #request_query_struct }; @@ -401,6 +414,8 @@ pub enum RequestField { Header(Field, Ident), /// A specific data type in the body of the request. NewtypeBody(Field), + /// Arbitrary bytes in the body of the request. + NewtypeRawBody(Field), /// Data that appears in the URL path. Path(Field), /// Data that appears in the query string. @@ -418,6 +433,7 @@ impl RequestField { RequestField::Header(field, header.expect("missing header name")) } RequestFieldKind::NewtypeBody => RequestField::NewtypeBody(field), + RequestFieldKind::NewtypeRawBody => RequestField::NewtypeRawBody(field), RequestFieldKind::Path => RequestField::Path(field), RequestFieldKind::Query => RequestField::Query(field), RequestFieldKind::QueryMap => RequestField::QueryMap(field), @@ -430,6 +446,7 @@ impl RequestField { RequestField::Body(..) => RequestFieldKind::Body, RequestField::Header(..) => RequestFieldKind::Header, RequestField::NewtypeBody(..) => RequestFieldKind::NewtypeBody, + RequestField::NewtypeRawBody(..) => RequestFieldKind::NewtypeRawBody, RequestField::Path(..) => RequestFieldKind::Path, RequestField::Query(..) => RequestFieldKind::Query, RequestField::QueryMap(..) => RequestFieldKind::QueryMap, @@ -471,6 +488,11 @@ impl RequestField { self.field_of_kind(RequestFieldKind::NewtypeBody) } + /// Return the contained field if this request field is a raw body kind. + fn as_newtype_raw_body_field(&self) -> Option<&Field> { + self.field_of_kind(RequestFieldKind::NewtypeRawBody) + } + /// Return the contained field if this request field is a path kind. fn as_path_field(&self) -> Option<&Field> { self.field_of_kind(RequestFieldKind::Path) @@ -492,6 +514,7 @@ impl RequestField { RequestField::Body(field) | RequestField::Header(field, _) | RequestField::NewtypeBody(field) + | RequestField::NewtypeRawBody(field) | RequestField::Path(field) | RequestField::Query(field) | RequestField::QueryMap(field) => field, @@ -525,6 +548,8 @@ enum RequestFieldKind { /// See the similarly named variant of `RequestField`. NewtypeBody, /// See the similarly named variant of `RequestField`. + NewtypeRawBody, + /// See the similarly named variant of `RequestField`. Path, /// See the similarly named variant of `RequestField`. Query, diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index 7753594a..c289b2a9 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -60,6 +60,11 @@ impl Response { #field_name: response_body.0 } } + ResponseField::NewtypeRawBody(_) => { + quote_spanned! {span=> + #field_name: response_body + } + } } }); @@ -89,7 +94,17 @@ impl Response { /// Produces code to initialize the struct that will be used to create the response body. pub fn to_body(&self) -> TokenStream { - if let Some(field) = self.newtype_body_field() { + if let Some(field) = self.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + let span = field.span(); + return quote_spanned!(span=> response.#field_name); + } + + if !self.has_body_fields() && self.newtype_body_field().is_none() { + return quote!(Vec::new()); + } + + let body = if let Some(field) = self.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let span = field.span(); quote_spanned!(span=> response.#field_name) @@ -111,13 +126,20 @@ impl Response { quote! { ResponseBody { #(#fields),* } } - } + }; + + quote!(ruma_api::exports::serde_json::to_vec(&#body)?) } /// Gets the newtype body field, if this response has one. pub fn newtype_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(ResponseField::as_newtype_body_field) } + + /// Gets the newtype raw body field, if this response has one. + pub fn newtype_raw_body_field(&self) -> Option<&Field> { + self.fields.iter().find_map(ResponseField::as_newtype_raw_body_field) + } } impl TryFrom for Response { @@ -150,29 +172,34 @@ impl TryFrom for Response { } field_kind = Some(match meta { - Meta::Word(ident) => { - if ident != "body" { + Meta::Word(ident) => match &ident.to_string()[..] { + s @ "body" | s @ "raw_body" => { + if let Some(f) = &newtype_body_field { + let mut error = syn::Error::new_spanned( + field, + "There can only be one newtype body field", + ); + error.combine(syn::Error::new_spanned( + f, + "Previous newtype body field", + )); + return Err(error); + } + + newtype_body_field = Some(field.clone()); + match s { + "body" => ResponseFieldKind::NewtypeBody, + "raw_body" => ResponseFieldKind::NewtypeRawBody, + _ => unreachable!(), + } + } + _ => { return Err(syn::Error::new_spanned( ident, "Invalid #[ruma_api] argument with value, expected `body`", )); } - - if let Some(f) = &newtype_body_field { - let mut error = syn::Error::new_spanned( - field, - "There can only be one newtype body field", - ); - error.combine(syn::Error::new_spanned( - f, - "Previous newtype body field", - )); - return Err(error); - } - - newtype_body_field = Some(field.clone()); - ResponseFieldKind::NewtypeBody - } + }, Meta::NameValue(MetaNameValue { name, value }) => { if name != "header" { return Err(syn::Error::new_spanned( @@ -193,6 +220,7 @@ impl TryFrom for Response { ResponseField::Header(field, header.expect("missing header name")) } ResponseFieldKind::NewtypeBody => ResponseField::NewtypeBody(field), + ResponseFieldKind::NewtypeRawBody => ResponseField::NewtypeRawBody(field), }) }) .collect::>>()?; @@ -220,7 +248,7 @@ impl ToTokens for Response { quote! { { #(#fields),* } } }; - let (derive_deserialize, response_body_def) = + let response_body_struct = if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; let derive_deserialize = if body_field.has_wrap_incoming_attr() { @@ -229,7 +257,7 @@ impl ToTokens for Response { quote!(ruma_api::exports::serde::Deserialize) }; - (derive_deserialize, quote! { (#field); }) + Some((derive_deserialize, quote! { (#field); })) } else if self.has_body_fields() { let fields = self.fields.iter().filter(|f| f.is_body()); let derive_deserialize = if fields.clone().any(|f| f.has_wrap_incoming_attr()) { @@ -239,24 +267,29 @@ impl ToTokens for Response { }; let fields = fields.map(ResponseField::field); - (derive_deserialize, quote!({ #(#fields),* })) + Some((derive_deserialize, quote!({ #(#fields),* }))) } else { - (quote!(ruma_api::exports::serde::Deserialize), quote!(;)) - }; + None + } + .map(|(derive_deserialize, def)| { + quote! { + /// Data in the response body. + #[derive( + Debug, + ruma_api::Outgoing, + ruma_api::exports::serde::Serialize, + #derive_deserialize + )] + struct ResponseBody #def + } + }); let response = quote! { #[derive(Debug, Clone, ruma_api::Outgoing)] #[incoming_no_deserialize] pub struct Response #response_def - /// Data in the response body. - #[derive( - Debug, - ruma_api::Outgoing, - ruma_api::exports::serde::Serialize, - #derive_deserialize - )] - struct ResponseBody #response_body_def + #response_body_struct }; response.to_tokens(tokens); @@ -271,6 +304,8 @@ pub enum ResponseField { Header(Field, Ident), /// A specific data type in the body of the response. NewtypeBody(Field), + /// Arbitrary bytes in the body of the response. + NewtypeRawBody(Field), } impl ResponseField { @@ -279,7 +314,8 @@ impl ResponseField { match self { ResponseField::Body(field) | ResponseField::Header(field, _) - | ResponseField::NewtypeBody(field) => field, + | ResponseField::NewtypeBody(field) + | ResponseField::NewtypeRawBody(field) => field, } } @@ -317,6 +353,14 @@ impl ResponseField { } } + /// Return the contained field if this response field is a newtype raw body kind. + fn as_newtype_raw_body_field(&self) -> Option<&Field> { + match self { + ResponseField::NewtypeRawBody(field) => Some(field), + _ => None, + } + } + /// Whether or not the reponse field has a #[wrap_incoming] attribute. fn has_wrap_incoming_attr(&self) -> bool { self.field().attrs.iter().any(|attr| { @@ -333,4 +377,6 @@ enum ResponseFieldKind { Header, /// See the similarly named variant of `ResponseField`. NewtypeBody, + /// See the similarly named variant of `ResponseField`. + NewtypeRawBody, } diff --git a/src/lib.rs b/src/lib.rs index a669e55c..b387b826 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,11 +198,9 @@ use serde_urlencoded; /// /// ## Fallible deserialization /// -/// All request and response types also derive [`Outgoing`][]. As such, to allow fallible +/// All request and response types also derive [`Outgoing`][Outgoing]. As such, to allow fallible /// deserialization, you can use the `#[wrap_incoming]` attribute. For details, see the -/// documentation for [`Outgoing`][]. -/// -/// [`Outgoing`]: derive.Outgoing.html +/// documentation for [the derive macro](derive.Outgoing.html). // TODO: Explain the concept of fallible deserialization before jumping to `ruma_api::Outgoing` #[cfg(feature = "with-ruma-api-macros")] pub use ruma_api_macros::ruma_api; diff --git a/tests/ruma_api_macros.rs b/tests/ruma_api_macros.rs index e18c33ff..96d983f3 100644 --- a/tests/ruma_api_macros.rs +++ b/tests/ruma_api_macros.rs @@ -86,12 +86,42 @@ pub mod newtype_body_endpoint { request { #[ruma_api(body)] - pub file: Vec, + pub list_of_custom_things: Vec, } response { #[ruma_api(body)] - pub my_custom_type: MyCustomType, + pub my_custom_thing: MyCustomType, + } + } +} + +pub mod newtype_raw_body_endpoint { + use ruma_api_macros::ruma_api; + + #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] + pub struct MyCustomType { + pub foo: String, + } + + ruma_api! { + metadata { + description: "Does something.", + method: PUT, + name: "newtype_body_endpoint", + path: "/_matrix/some/newtype/body/endpoint", + rate_limited: false, + requires_authentication: false, + } + + request { + #[ruma_api(raw_body)] + pub file: Vec, + } + + response { + #[ruma_api(raw_body)] + pub file: Vec, } } }