From 7cb7e6a211fce7b6cb136a9e4c8c294d5e6e0ed2 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Mon, 5 Apr 2021 14:05:43 +0200 Subject: [PATCH] api-macros: Move most parts of api and util into more specific modules --- ruma-api-macros/src/api.rs | 301 +------------------ ruma-api-macros/src/api/request.rs | 432 +++++++++++++++++++++++++++- ruma-api-macros/src/api/response.rs | 129 ++++++++- ruma-api-macros/src/util.rs | 215 +------------- 4 files changed, 543 insertions(+), 534 deletions(-) diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index 3248a0f4..2034743f 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -29,18 +29,12 @@ pub struct Api { } pub fn expand_all(api: Api) -> syn::Result { - // Guarantee `ruma_api` is available and named something we can refer to. let ruma_api = util::import_ruma_api(); let http = quote! { #ruma_api::exports::http }; - let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - let serde_json = quote! { #ruma_api::exports::serde_json }; let description = &api.metadata.description; let method = &api.metadata.method; - // We don't (currently) use this literal as a literal in the generated code. Instead we just - // put it into doc comments, for which the span information is irrelevant. So we can work - // with only the literal's value from here on. - let name = &api.metadata.name.value(); + let name = &api.metadata.name; let path = &api.metadata.path; let rate_limited: TokenStream = api .metadata @@ -69,235 +63,15 @@ pub fn expand_all(api: Api) -> syn::Result { }) .collect(); - let request_type = api.request.expand_type_def(&ruma_api); - let response_type = api.response.expand_type_def(&ruma_api); - - let incoming_request_type = - if api.request.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) }; - - let extract_request_path = if api.request.has_path_fields() { - quote! { - let path_segments: ::std::vec::Vec<&::std::primitive::str> = - request.uri().path()[1..].split('/').collect(); - } - } else { - TokenStream::new() - }; - - let (request_path_string, parse_request_path) = - util::request_path_string_and_parse(&api.request, &api.metadata, &ruma_api); - - let request_query_string = util::build_query_string(&api.request, &ruma_api); - let extract_request_query = util::extract_request_query(&api.request, &ruma_api); - - let parse_request_query = if let Some(field) = api.request.query_map_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - - quote! { - #field_name: request_query, - } - } else { - api.request.request_init_query_fields() - }; - - let mut header_kvs = api.request.append_header_kvs(&ruma_api); - for auth in &api.metadata.authentication { - if auth.value == "AccessToken" { - let attrs = &auth.attrs; - header_kvs.extend(quote! { - #( #attrs )* - req_headers.insert( - #http::header::AUTHORIZATION, - #http::header::HeaderValue::from_str( - &::std::format!( - "Bearer {}", - access_token.ok_or( - #ruma_api::error::IntoHttpError::NeedsAuthentication - )? - ) - )? - ); - }); - } - } - - let extract_request_headers = if api.request.has_header_fields() { - quote! { - let headers = request.headers(); - } - } else { - TokenStream::new() - }; - - let extract_request_body = - if api.request.has_body_fields() || api.request.newtype_body_field().is_some() { - let body_lifetimes = if api.request.has_body_lifetimes() { - // duplicate the anonymous lifetime as many times as needed - let lifetimes = - std::iter::repeat(quote! { '_ }).take(api.request.body_lifetime_count()); - quote! { < #( #lifetimes, )* >} - } else { - TokenStream::new() - }; - quote! { - let request_body: < - RequestBody #body_lifetimes - as #ruma_serde::Outgoing - >::Incoming = { - // If the request body is completely empty, pretend it is an empty JSON object - // instead. This allows requests with only optional body parameters to be - // deserialized in that case. - let json = match request.body().as_slice() { - b"" => b"{}", - body => body, - }; - - #ruma_api::try_deserialize!(request, #serde_json::from_slice(json)) - }; - } - } else { - TokenStream::new() - }; - - let parse_request_headers = if api.request.has_header_fields() { - api.request.parse_headers_from_request(&ruma_api) - } else { - TokenStream::new() - }; - - let request_body = util::build_request_body(&api.request, &ruma_api); - let parse_request_body = util::parse_request_body(&api.request); - - let extract_response_headers = if api.response.has_header_fields() { - quote! { - let mut headers = response.headers().clone(); - } - } else { - TokenStream::new() - }; - - let typed_response_body_decl = - if api.response.has_body_fields() || api.response.newtype_body_field().is_some() { - quote! { - let response_body: < - ResponseBody - as #ruma_serde::Outgoing - >::Incoming = { - // If the reponse body is completely empty, pretend it is an empty JSON object - // instead. This allows reponses with only optional body parameters to be - // deserialized in that case. - let json = match response.body().as_slice() { - b"" => b"{}", - body => body, - }; - - #ruma_api::try_deserialize!( - response, - #serde_json::from_slice(json), - ) - }; - } - } else { - TokenStream::new() - }; - - let response_init_fields = api.response.init_fields(&ruma_api); - let serialize_response_headers = api.response.apply_header_fields(&ruma_api); - - let body = api.response.to_body(&ruma_api); - - let metadata_doc = format!("Metadata for the `{}` API endpoint.", name); - let request_doc = - format!("Data for a request to the `{}` API endpoint.\n\n{}", name, description.value()); - let response_doc = format!("Data in the response from the `{}` API endpoint.", name); - let error_ty = api.error_ty.map_or_else(|| quote! { #ruma_api::error::Void }, |err_ty| quote! { #err_ty }); - let request_lifetimes = api.request.combine_lifetimes(); + let request = api.request.expand(&api.metadata, &error_ty, &ruma_api); + let response = api.response.expand(&api.metadata, &error_ty, &ruma_api); - let non_auth_endpoint_impls: TokenStream = api - .metadata - .authentication - .iter() - .map(|auth| { - if auth.value != "None" { - TokenStream::new() - } else { - let attrs = &auth.attrs; - quote! { - #( #attrs )* - #[automatically_derived] - #[cfg(feature = "client")] - impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest - for Request #request_lifetimes - {} - - #( #attrs )* - #[automatically_derived] - #[cfg(feature = "server")] - impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} - } - } - }) - .collect(); + let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); Ok(quote! { - #[doc = #request_doc] - #request_type - - #[doc = #response_doc] - #response_type - - #[automatically_derived] - #[cfg(feature = "server")] - impl ::std::convert::TryFrom for #http::Response> { - type Error = #ruma_api::error::IntoHttpError; - - fn try_from(response: Response) -> ::std::result::Result { - let mut resp_builder = #http::Response::builder() - .header(#http::header::CONTENT_TYPE, "application/json"); - - let mut headers = resp_builder - .headers_mut() - .expect("`http::ResponseBuilder` is in unusable state"); - #serialize_response_headers - - // This cannot fail because we parse each header value - // checking for errors as each value is inserted and - // we only allow keys from the `http::header` module. - let response = resp_builder.body(#body).unwrap(); - Ok(response) - } - } - - #[automatically_derived] - #[cfg(feature = "client")] - impl ::std::convert::TryFrom<#http::Response>> for Response { - type Error = #ruma_api::error::FromHttpResponseError<#error_ty>; - - fn try_from( - response: #http::Response>, - ) -> ::std::result::Result { - if response.status().as_u16() < 400 { - #extract_response_headers - - #typed_response_body_decl - - Ok(Self { - #response_init_fields - }) - } else { - match <#error_ty as #ruma_api::EndpointError>::try_from_response(response) { - Ok(err) => Err(#ruma_api::error::ServerError::Known(err).into()), - Err(response_err) => { - Err(#ruma_api::error::ServerError::Unknown(response_err).into()) - } - } - } - } - } - #[doc = #metadata_doc] pub const METADATA: #ruma_api::Metadata = #ruma_api::Metadata { description: #description, @@ -308,70 +82,7 @@ pub fn expand_all(api: Api) -> syn::Result { #authentication }; - #[automatically_derived] - #[cfg(feature = "client")] - impl #request_lifetimes #ruma_api::OutgoingRequest for Request #request_lifetimes { - type EndpointError = #error_ty; - type IncomingResponse = ::Incoming; - - #[doc = #metadata_doc] - const METADATA: #ruma_api::Metadata = self::METADATA; - - fn try_into_http_request( - self, - base_url: &::std::primitive::str, - access_token: ::std::option::Option<&str>, - ) -> ::std::result::Result<#http::Request>, #ruma_api::error::IntoHttpError> { - let metadata = self::METADATA; - - let mut req_builder = #http::Request::builder() - .method(#http::Method::#method) - .uri(::std::format!( - "{}{}{}", - base_url.strip_suffix('/').unwrap_or(base_url), - #request_path_string, - #request_query_string, - )) - .header(#ruma_api::exports::http::header::CONTENT_TYPE, "application/json"); - - let mut req_headers = req_builder - .headers_mut() - .expect("`http::RequestBuilder` is in unusable state"); - - #header_kvs - - let http_request = req_builder.body(#request_body)?; - - Ok(http_request) - } - } - - #[automatically_derived] - #[cfg(feature = "server")] - impl #ruma_api::IncomingRequest for #incoming_request_type { - type EndpointError = #error_ty; - type OutgoingResponse = Response; - - #[doc = #metadata_doc] - const METADATA: #ruma_api::Metadata = self::METADATA; - - fn try_from_http_request( - request: #http::Request> - ) -> ::std::result::Result { - #extract_request_path - #extract_request_query - #extract_request_headers - #extract_request_body - - Ok(Self { - #parse_request_path - #parse_request_query - #parse_request_headers - #parse_request_body - }) - } - } - - #non_auth_endpoint_impls + #request + #response }) } diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index ea8a634f..039a8c10 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -2,12 +2,14 @@ use std::collections::BTreeSet; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; use syn::{spanned::Spanned, Attribute, Field, Ident, Lifetime}; use crate::util; +use super::metadata::Metadata; + #[derive(Debug, Default)] pub(super) struct RequestLifetimes { pub body: BTreeSet, @@ -30,7 +32,7 @@ pub(crate) struct Request { impl Request { /// Produces code to add necessary HTTP headers to an `http::Request`. - pub fn append_header_kvs(&self, ruma_api: &TokenStream) -> TokenStream { + fn append_header_kvs(&self, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; self.header_fields() @@ -67,7 +69,7 @@ impl Request { } /// Produces code to extract fields from the HTTP headers in an `http::Request`. - pub fn parse_headers_from_request(&self, ruma_api: &TokenStream) -> TokenStream { + fn parse_headers_from_request(&self, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; let serde = quote! { #ruma_api::exports::serde }; let serde_json = quote! { #ruma_api::exports::serde_json }; @@ -282,10 +284,24 @@ impl Request { quote! { #(#fields,)* } } - pub(super) fn expand_type_def(&self, ruma_api: &TokenStream) -> TokenStream { + pub(super) fn expand( + &self, + metadata: &Metadata, + error_ty: &TokenStream, + ruma_api: &TokenStream, + ) -> TokenStream { + let http = quote! { #ruma_api::exports::http }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde = quote! { #ruma_api::exports::serde }; + let serde_json = quote! { #ruma_api::exports::serde_json }; + let method = &metadata.method; + + let docs = format!( + "Data for a request to the `{}` API endpoint.\n\n{}", + metadata.name.value(), + metadata.description.value(), + ); let struct_attributes = &self.attributes; let request_def = if self.fields.is_empty() { @@ -295,6 +311,101 @@ impl Request { quote! { { #(#fields),* } } }; + let incoming_request_type = + if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) }; + + let extract_request_path = if self.has_path_fields() { + quote! { + let path_segments: ::std::vec::Vec<&::std::primitive::str> = + request.uri().path()[1..].split('/').collect(); + } + } else { + TokenStream::new() + }; + + let (request_path_string, parse_request_path) = + path_string_and_parse(self, metadata, &ruma_api); + + let request_query_string = build_query_string(self, &ruma_api); + let extract_request_query = extract_request_query(self, &ruma_api); + + let parse_request_query = if let Some(field) = self.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + + quote! { + #field_name: request_query, + } + } else { + self.request_init_query_fields() + }; + + let mut header_kvs = self.append_header_kvs(&ruma_api); + for auth in &metadata.authentication { + if auth.value == "AccessToken" { + let attrs = &auth.attrs; + header_kvs.extend(quote! { + #( #attrs )* + req_headers.insert( + #http::header::AUTHORIZATION, + #http::header::HeaderValue::from_str( + &::std::format!( + "Bearer {}", + access_token.ok_or( + #ruma_api::error::IntoHttpError::NeedsAuthentication + )? + ) + )? + ); + }); + } + } + + let extract_request_headers = if self.has_header_fields() { + quote! { + let headers = request.headers(); + } + } else { + TokenStream::new() + }; + + let extract_request_body = if self.has_body_fields() || self.newtype_body_field().is_some() + { + let body_lifetimes = if self.has_body_lifetimes() { + // duplicate the anonymous lifetime as many times as needed + let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count()); + quote! { < #( #lifetimes, )* >} + } else { + TokenStream::new() + }; + quote! { + let request_body: < + RequestBody #body_lifetimes + as #ruma_serde::Outgoing + >::Incoming = { + // If the request body is completely empty, pretend it is an empty JSON object + // instead. This allows requests with only optional body parameters to be + // deserialized in that case. + let json = match request.body().as_slice() { + b"" => b"{}", + body => body, + }; + + #ruma_api::try_deserialize!(request, #serde_json::from_slice(json)) + }; + } + } else { + TokenStream::new() + }; + + let parse_request_headers = if self.has_header_fields() { + self.parse_headers_from_request(&ruma_api) + } else { + TokenStream::new() + }; + + let request_body = build_request_body(self, &ruma_api); + let parse_request_body = parse_request_body(self); + let request_generics = self.combine_lifetimes(); let request_body_struct = @@ -379,17 +490,107 @@ impl Request { TokenStream::new() }; + let request_lifetimes = self.combine_lifetimes(); + let non_auth_endpoint_impls: TokenStream = metadata + .authentication + .iter() + .map(|auth| { + if auth.value != "None" { + TokenStream::new() + } else { + let attrs = &auth.attrs; + quote! { + #( #attrs )* + #[automatically_derived] + #[cfg(feature = "client")] + impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest + for Request #request_lifetimes + {} + + #( #attrs )* + #[automatically_derived] + #[cfg(feature = "server")] + impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} + } + } + }) + .collect(); + quote! { - #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] - #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] - #[incoming_derive(!Deserialize)] - #( #struct_attributes )* - pub struct Request #request_generics #request_def + #[doc = #docs] + #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] + #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] + #[incoming_derive(!Deserialize)] + #( #struct_attributes )* + pub struct Request #request_generics #request_def - #request_body_struct + #non_auth_endpoint_impls - #request_query_struct + #request_body_struct + #request_query_struct + + #[automatically_derived] + #[cfg(feature = "client")] + impl #request_lifetimes #ruma_api::OutgoingRequest for Request #request_lifetimes { + type EndpointError = #error_ty; + type IncomingResponse = ::Incoming; + + const METADATA: #ruma_api::Metadata = self::METADATA; + + fn try_into_http_request( + self, + base_url: &::std::primitive::str, + access_token: ::std::option::Option<&str>, + ) -> ::std::result::Result<#http::Request>, #ruma_api::error::IntoHttpError> { + let metadata = self::METADATA; + + let mut req_builder = #http::Request::builder() + .method(#http::Method::#method) + .uri(::std::format!( + "{}{}{}", + base_url.strip_suffix('/').unwrap_or(base_url), + #request_path_string, + #request_query_string, + )) + .header(#ruma_api::exports::http::header::CONTENT_TYPE, "application/json"); + + let mut req_headers = req_builder + .headers_mut() + .expect("`http::RequestBuilder` is in unusable state"); + + #header_kvs + + let http_request = req_builder.body(#request_body)?; + + Ok(http_request) + } + } + + #[automatically_derived] + #[cfg(feature = "server")] + impl #ruma_api::IncomingRequest for #incoming_request_type { + type EndpointError = #error_ty; + type OutgoingResponse = Response; + + const METADATA: #ruma_api::Metadata = self::METADATA; + + fn try_from_http_request( + request: #http::Request> + ) -> ::std::result::Result { + #extract_request_path + #extract_request_query + #extract_request_headers + #extract_request_body + + Ok(Self { + #parse_request_path + #parse_request_query + #parse_request_headers + #parse_request_body + }) } + } + } } } @@ -543,3 +744,212 @@ pub(crate) enum RequestFieldKind { /// See the similarly named variant of `RequestField`. QueryMap, } + +/// The first item in the tuple generates code for the request path from +/// the `Metadata` and `Request` structs. The second item in the returned tuple +/// is the code to generate a Request struct field created from any segments +/// of the path that start with ":". +/// +/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is +/// used for implementing `TryFrom>>`, from path strings deserialized to Ruma +/// types. +pub(crate) fn path_string_and_parse( + request: &Request, + metadata: &Metadata, + ruma_api: &TokenStream, +) -> (TokenStream, TokenStream) { + let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; + + if request.has_path_fields() { + let path_string = metadata.path.value(); + + assert!(path_string.starts_with('/'), "path needs to start with '/'"); + assert!( + path_string.chars().filter(|c| *c == ':').count() == request.path_field_count(), + "number of declared path parameters needs to match amount of placeholders in path" + ); + + let format_call = { + let mut format_string = path_string.clone(); + let mut format_args = Vec::new(); + + while let Some(start_of_segment) = format_string.find(':') { + // ':' should only ever appear at the start of a segment + assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); + + let end_of_segment = match format_string[start_of_segment..].find('/') { + Some(rel_pos) => start_of_segment + rel_pos, + None => format_string.len(), + }; + + let path_var = Ident::new( + &format_string[start_of_segment + 1..end_of_segment], + Span::call_site(), + ); + format_args.push(quote! { + #percent_encoding::utf8_percent_encode( + &self.#path_var.to_string(), + #percent_encoding::NON_ALPHANUMERIC, + ) + }); + format_string.replace_range(start_of_segment..end_of_segment, "{}"); + } + + quote! { + format_args!(#format_string, #(#format_args),*) + } + }; + + let path_fields = + path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( + |(i, segment)| { + let path_var = &segment[1..]; + let path_var_ident = Ident::new(path_var, Span::call_site()); + quote! { + #path_var_ident: { + use #ruma_api::error::RequestDeserializationError; + + let segment = path_segments.get(#i).unwrap().as_bytes(); + let decoded = #ruma_api::try_deserialize!( + request, + #percent_encoding::percent_decode(segment) + .decode_utf8(), + ); + + #ruma_api::try_deserialize!( + request, + ::std::convert::TryFrom::try_from(&*decoded), + ) + } + } + }, + ); + + (format_call, quote! { #(#path_fields,)* }) + } else { + (quote! { metadata.path.to_owned() }, TokenStream::new()) + } +} + +/// The function determines the type of query string that needs to be built +/// and then builds it using `ruma_serde::urlencoded::to_string`. +pub(crate) fn build_query_string(request: &Request, ruma_api: &TokenStream) -> TokenStream { + let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; + + if let Some(field) = request.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have identifier"); + + quote!({ + // This function exists so that the compiler will throw an + // error when the type of the field with the query_map + // attribute doesn't implement IntoIterator + // + // This is necessary because the ruma_serde::urlencoded::to_string + // call will result in a runtime error when the type cannot be + // encoded as a list key-value pairs (?key1=value1&key2=value2) + // + // By asserting that it implements the iterator trait, we can + // ensure that it won't fail. + fn assert_trait_impl(_: &T) + where + T: ::std::iter::IntoIterator, + {} + + let request_query = RequestQuery(self.#field_name); + assert_trait_impl(&request_query.0); + + format_args!( + "?{}", + #ruma_serde::urlencoded::to_string(request_query)? + ) + }) + } else if request.has_query_fields() { + let request_query_init_fields = request.request_query_init_fields(); + + quote!({ + let request_query = RequestQuery { + #request_query_init_fields + }; + + format_args!( + "?{}", + #ruma_serde::urlencoded::to_string(request_query)? + ) + }) + } else { + quote! { "" } + } +} + +/// Deserialize the query string. +pub(crate) fn extract_request_query(request: &Request, ruma_api: &TokenStream) -> TokenStream { + let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; + + if request.query_map_field().is_some() { + quote! { + let request_query = #ruma_api::try_deserialize!( + request, + #ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or("") + ), + ); + } + } else if request.has_query_fields() { + quote! { + let request_query: ::Incoming = + #ruma_api::try_deserialize!( + request, + #ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or("") + ), + ); + } + } else { + TokenStream::new() + } +} + +/// Generates the code to initialize a `Request`. +/// +/// Used to construct an `http::Request`s body. +pub(crate) fn build_request_body(request: &Request, ruma_api: &TokenStream) -> TokenStream { + let serde_json = quote! { #ruma_api::exports::serde_json }; + + if let Some(field) = request.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote!(self.#field_name) + } else if request.has_body_fields() || request.newtype_body_field().is_some() { + let request_body_initializers = if let Some(field) = request.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { (self.#field_name) } + } else { + let initializers = request.request_body_init_fields(); + quote! { { #initializers } } + }; + + quote! { + { + let request_body = RequestBody #request_body_initializers; + #serde_json::to_vec(&request_body)? + } + } + } else { + quote!(Vec::new()) + } +} + +pub(crate) fn parse_request_body(request: &Request) -> TokenStream { + if let Some(field) = request.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request_body.0, + } + } else if let Some(field) = request.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request.into_body(), + } + } else { + request.request_init_body_fields() + } +} diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index b76cf510..5a795ba2 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -4,6 +4,8 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; use syn::{spanned::Spanned, Attribute, Field, Ident}; +use super::metadata::Metadata; + /// The result of processing the `response` section of the macro. pub(crate) struct Response { /// The attributes that will be applied to the struct definition. @@ -15,17 +17,17 @@ pub(crate) struct Response { impl Response { /// Whether or not this response has any data in the HTTP body. - pub fn has_body_fields(&self) -> bool { + fn has_body_fields(&self) -> bool { self.fields.iter().any(|field| field.is_body()) } /// Whether or not this response has any data in HTTP headers. - pub fn has_header_fields(&self) -> bool { + fn has_header_fields(&self) -> bool { self.fields.iter().any(|field| field.is_header()) } /// Produces code for a response struct initializer. - pub fn init_fields(&self, ruma_api: &TokenStream) -> TokenStream { + fn init_fields(&self, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; let mut fields = vec![]; @@ -95,7 +97,7 @@ impl Response { } /// Produces code to add necessary HTTP headers to an `http::Response`. - pub fn apply_header_fields(&self, ruma_api: &TokenStream) -> TokenStream { + fn apply_header_fields(&self, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; let header_calls = self.fields.iter().filter_map(|response_field| { @@ -139,7 +141,7 @@ impl Response { } /// Produces code to initialize the struct that will be used to create the response body. - pub fn to_body(&self, ruma_api: &TokenStream) -> TokenStream { + fn to_body(&self, ruma_api: &TokenStream) -> TokenStream { let serde_json = quote! { #ruma_api::exports::serde_json }; if let Some(field) = self.newtype_raw_body_field() { @@ -179,21 +181,68 @@ impl Response { } /// Gets the newtype body field, if this response has one. - pub fn newtype_body_field(&self) -> Option<&Field> { + fn newtype_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(ResponseField::as_newtype_body_field) } /// Gets the newtype raw body field, if this response has one. - pub fn newtype_raw_body_field(&self) -> Option<&Field> { + fn newtype_raw_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(ResponseField::as_newtype_raw_body_field) } - pub(super) fn expand_type_def(&self, ruma_api: &TokenStream) -> TokenStream { + pub(super) fn expand( + &self, + metadata: &Metadata, + error_ty: &TokenStream, + ruma_api: &TokenStream, + ) -> TokenStream { + let http = quote! { #ruma_api::exports::http }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde = quote! { #ruma_api::exports::serde }; + let serde_json = quote! { #ruma_api::exports::serde_json }; + let docs = + format!("Data in the response from the `{}` API endpoint.", metadata.name.value()); let struct_attributes = &self.attributes; + let extract_response_headers = if self.has_header_fields() { + quote! { + let mut headers = response.headers().clone(); + } + } else { + TokenStream::new() + }; + + let typed_response_body_decl = + if self.has_body_fields() || self.newtype_body_field().is_some() { + quote! { + let response_body: < + ResponseBody + as #ruma_serde::Outgoing + >::Incoming = { + // If the reponse body is completely empty, pretend it is an empty JSON object + // instead. This allows reponses with only optional body parameters to be + // deserialized in that case. + let json = match response.body().as_slice() { + b"" => b"{}", + body => body, + }; + + #ruma_api::try_deserialize!( + response, + #serde_json::from_slice(json), + ) + }; + } + } else { + TokenStream::new() + }; + + let response_init_fields = self.init_fields(&ruma_api); + let serialize_response_headers = self.apply_header_fields(&ruma_api); + + let body = self.to_body(&ruma_api); + let response_def = if self.fields.is_empty() { quote!(;) } else { @@ -222,6 +271,7 @@ impl Response { }; quote! { + #[doc = #docs] #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[incoming_derive(!Deserialize)] @@ -229,6 +279,55 @@ impl Response { pub struct Response #response_def #response_body_struct + + #[automatically_derived] + #[cfg(feature = "server")] + impl ::std::convert::TryFrom for #http::Response> { + type Error = #ruma_api::error::IntoHttpError; + + fn try_from(response: Response) -> ::std::result::Result { + let mut resp_builder = #http::Response::builder() + .header(#http::header::CONTENT_TYPE, "application/json"); + + let mut headers = resp_builder + .headers_mut() + .expect("`http::ResponseBuilder` is in unusable state"); + #serialize_response_headers + + // This cannot fail because we parse each header value + // checking for errors as each value is inserted and + // we only allow keys from the `http::header` module. + let response = resp_builder.body(#body).unwrap(); + Ok(response) + } + } + + #[automatically_derived] + #[cfg(feature = "client")] + impl ::std::convert::TryFrom<#http::Response>> for Response { + type Error = #ruma_api::error::FromHttpResponseError<#error_ty>; + + fn try_from( + response: #http::Response>, + ) -> ::std::result::Result { + if response.status().as_u16() < 400 { + #extract_response_headers + + #typed_response_body_decl + + Ok(Self { + #response_init_fields + }) + } else { + match <#error_ty as #ruma_api::EndpointError>::try_from_response(response) { + Ok(err) => Err(#ruma_api::error::ServerError::Known(err).into()), + Err(response_err) => { + Err(#ruma_api::error::ServerError::Unknown(response_err).into()) + } + } + } + } + } } } } @@ -250,7 +349,7 @@ pub(crate) enum ResponseField { impl ResponseField { /// Gets the inner `Field` value. - pub fn field(&self) -> &Field { + fn field(&self) -> &Field { match self { ResponseField::Body(field) | ResponseField::Header(field, _) @@ -260,22 +359,22 @@ impl ResponseField { } /// Whether or not this response field is a body kind. - pub fn is_body(&self) -> bool { + pub(super) fn is_body(&self) -> bool { self.as_body_field().is_some() } /// Whether or not this response field is a header kind. - pub fn is_header(&self) -> bool { + fn is_header(&self) -> bool { matches!(self, ResponseField::Header(..)) } /// Whether or not this response field is a newtype body kind. - pub fn is_newtype_body(&self) -> bool { + fn is_newtype_body(&self) -> bool { self.as_newtype_body_field().is_some() } /// Return the contained field if this response field is a body kind. - pub fn as_body_field(&self) -> Option<&Field> { + fn as_body_field(&self) -> Option<&Field> { match self { ResponseField::Body(field) => Some(field), _ => None, @@ -283,7 +382,7 @@ impl ResponseField { } /// Return the contained field if this response field is a newtype body kind. - pub fn as_newtype_body_field(&self) -> Option<&Field> { + fn as_newtype_body_field(&self) -> Option<&Field> { match self { ResponseField::NewtypeBody(field) => Some(field), _ => None, @@ -291,7 +390,7 @@ impl ResponseField { } /// Return the contained field if this response field is a newtype raw body kind. - pub fn as_newtype_raw_body_field(&self) -> Option<&Field> { + fn as_newtype_raw_body_field(&self) -> Option<&Field> { match self { ResponseField::NewtypeRawBody(field) => Some(field), _ => None, diff --git a/ruma-api-macros/src/util.rs b/ruma-api-macros/src/util.rs index 00389453..a6beed62 100644 --- a/ruma-api-macros/src/util.rs +++ b/ruma-api-macros/src/util.rs @@ -7,10 +7,8 @@ use proc_macro_crate::{crate_name, FoundCrate}; use quote::quote; use syn::{AttrStyle, Attribute, Ident, Lifetime}; -use crate::api::{metadata::Metadata, request::Request}; - /// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`. -pub fn unique_lifetimes_to_tokens<'a, I: Iterator>( +pub(crate) fn unique_lifetimes_to_tokens<'a, I: Iterator>( lifetimes: I, ) -> TokenStream { let lifetimes = lifetimes.collect::>(); @@ -22,220 +20,11 @@ pub fn unique_lifetimes_to_tokens<'a, I: Iterator>( } } -/// The first item in the tuple generates code for the request path from -/// the `Metadata` and `Request` structs. The second item in the returned tuple -/// is the code to generate a Request struct field created from any segments -/// of the path that start with ":". -/// -/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is -/// used for implementing `TryFrom>>`, from path strings deserialized to Ruma -/// types. -pub(crate) fn request_path_string_and_parse( - request: &Request, - metadata: &Metadata, - ruma_api: &TokenStream, -) -> (TokenStream, TokenStream) { - let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; - - if request.has_path_fields() { - let path_string = metadata.path.value(); - - assert!(path_string.starts_with('/'), "path needs to start with '/'"); - assert!( - path_string.chars().filter(|c| *c == ':').count() == request.path_field_count(), - "number of declared path parameters needs to match amount of placeholders in path" - ); - - let format_call = { - let mut format_string = path_string.clone(); - let mut format_args = Vec::new(); - - while let Some(start_of_segment) = format_string.find(':') { - // ':' should only ever appear at the start of a segment - assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); - - let end_of_segment = match format_string[start_of_segment..].find('/') { - Some(rel_pos) => start_of_segment + rel_pos, - None => format_string.len(), - }; - - let path_var = Ident::new( - &format_string[start_of_segment + 1..end_of_segment], - Span::call_site(), - ); - format_args.push(quote! { - #percent_encoding::utf8_percent_encode( - &self.#path_var.to_string(), - #percent_encoding::NON_ALPHANUMERIC, - ) - }); - format_string.replace_range(start_of_segment..end_of_segment, "{}"); - } - - quote! { - format_args!(#format_string, #(#format_args),*) - } - }; - - let path_fields = - path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( - |(i, segment)| { - let path_var = &segment[1..]; - let path_var_ident = Ident::new(path_var, Span::call_site()); - quote! { - #path_var_ident: { - use #ruma_api::error::RequestDeserializationError; - - let segment = path_segments.get(#i).unwrap().as_bytes(); - let decoded = #ruma_api::try_deserialize!( - request, - #percent_encoding::percent_decode(segment) - .decode_utf8(), - ); - - #ruma_api::try_deserialize!( - request, - ::std::convert::TryFrom::try_from(&*decoded), - ) - } - } - }, - ); - - (format_call, quote! { #(#path_fields,)* }) - } else { - (quote! { metadata.path.to_owned() }, TokenStream::new()) - } -} - -/// The function determines the type of query string that needs to be built -/// and then builds it using `ruma_serde::urlencoded::to_string`. -pub(crate) fn build_query_string(request: &Request, ruma_api: &TokenStream) -> TokenStream { - let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - - if let Some(field) = request.query_map_field() { - let field_name = field.ident.as_ref().expect("expected field to have identifier"); - - quote!({ - // This function exists so that the compiler will throw an - // error when the type of the field with the query_map - // attribute doesn't implement IntoIterator - // - // This is necessary because the ruma_serde::urlencoded::to_string - // call will result in a runtime error when the type cannot be - // encoded as a list key-value pairs (?key1=value1&key2=value2) - // - // By asserting that it implements the iterator trait, we can - // ensure that it won't fail. - fn assert_trait_impl(_: &T) - where - T: ::std::iter::IntoIterator, - {} - - let request_query = RequestQuery(self.#field_name); - assert_trait_impl(&request_query.0); - - format_args!( - "?{}", - #ruma_serde::urlencoded::to_string(request_query)? - ) - }) - } else if request.has_query_fields() { - let request_query_init_fields = request.request_query_init_fields(); - - quote!({ - let request_query = RequestQuery { - #request_query_init_fields - }; - - format_args!( - "?{}", - #ruma_serde::urlencoded::to_string(request_query)? - ) - }) - } else { - quote! { "" } - } -} - -/// Deserialize the query string. -pub(crate) fn extract_request_query(request: &Request, ruma_api: &TokenStream) -> TokenStream { - let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - - if request.query_map_field().is_some() { - quote! { - let request_query = #ruma_api::try_deserialize!( - request, - #ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or("") - ), - ); - } - } else if request.has_query_fields() { - quote! { - let request_query: ::Incoming = - #ruma_api::try_deserialize!( - request, - #ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or("") - ), - ); - } - } else { - TokenStream::new() - } -} - -/// Generates the code to initialize a `Request`. -/// -/// Used to construct an `http::Request`s body. -pub(crate) fn build_request_body(request: &Request, ruma_api: &TokenStream) -> TokenStream { - let serde_json = quote! { #ruma_api::exports::serde_json }; - - if let Some(field) = request.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote!(self.#field_name) - } else if request.has_body_fields() || request.newtype_body_field().is_some() { - let request_body_initializers = if let Some(field) = request.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { (self.#field_name) } - } else { - let initializers = request.request_body_init_fields(); - quote! { { #initializers } } - }; - - quote! { - { - let request_body = RequestBody #request_body_initializers; - #serde_json::to_vec(&request_body)? - } - } - } else { - quote!(Vec::new()) - } -} - -pub(crate) fn parse_request_body(request: &Request) -> TokenStream { - if let Some(field) = request.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request_body.0, - } - } else if let Some(field) = request.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request.into_body(), - } - } else { - request.request_init_body_fields() - } -} - pub(crate) fn is_valid_endpoint_path(string: &str) -> bool { string.as_bytes().iter().all(|b| (0x21..=0x7E).contains(b)) } -pub fn import_ruma_api() -> TokenStream { +pub(crate) fn import_ruma_api() -> TokenStream { if let Ok(FoundCrate::Name(possibly_renamed)) = crate_name("ruma-api") { let import = Ident::new(&possibly_renamed, Span::call_site()); quote! { ::#import }