From a68b854734c1af93d9fcba31800171eb5a796fb8 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 10 Apr 2021 14:50:01 +0200 Subject: [PATCH] api: Refactor macro code and improve error handling * Inline lots of methods that were only used once * Create a separate error case for missing headers --- ruma-api-macros/src/api/request.rs | 309 ++++++++++++----------------- ruma-api-macros/src/util.rs | 4 +- ruma-api/src/error.rs | 21 +- 3 files changed, 150 insertions(+), 184 deletions(-) diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 75b80b0a..3f293b72 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -31,97 +31,6 @@ pub(crate) struct Request { } impl Request { - /// Produces code to add necessary HTTP headers to an `http::Request`. - fn append_header_kvs(&self, ruma_api: &TokenStream) -> TokenStream { - let http = quote! { #ruma_api::exports::http }; - - self.header_fields() - .map(|request_field| { - let (field, header_name) = match request_field { - RequestField::Header(field, header_name) => (field, header_name), - _ => unreachable!("expected request field to be header variant"), - }; - - let field_name = &field.ident; - - match &field.ty { - syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) - if segments.last().unwrap().ident == "Option" => - { - quote! { - if let Some(header_val) = self.#field_name.as_ref() { - req_headers.insert( - #http::header::#header_name, - #http::header::HeaderValue::from_str(header_val)?, - ); - } - } - } - _ => quote! { - req_headers.insert( - #http::header::#header_name, - #http::header::HeaderValue::from_str(self.#field_name.as_ref())?, - ); - }, - } - }) - .collect() - } - - /// Produces code to extract fields from the HTTP headers in an `http::Request`. - fn parse_headers_from_request(&self, ruma_api: &TokenStream) -> TokenStream { - let http = quote! { #ruma_api::exports::http }; - let serde = quote! { #ruma_api::exports::serde }; - let serde_json = quote! { #ruma_api::exports::serde_json }; - - let fields = self.header_fields().map(|request_field| { - let (field, header_name) = match request_field { - RequestField::Header(field, header_name) => (field, header_name), - _ => panic!("expected request field to be header variant"), - }; - - let field_name = &field.ident; - let header_name_string = header_name.to_string(); - - let (some_case, none_case) = match &field.ty { - syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) - if segments.last().unwrap().ident == "Option" => - { - (quote! { Some(header.to_owned()) }, quote! { None }) - } - _ => ( - quote! { header.to_owned() }, - quote! {{ - use #serde::de::Error as _; - - // FIXME: Not a missing json field, a missing header! - return Err(#ruma_api::error::RequestDeserializationError::new( - #serde_json::Error::missing_field( - #header_name_string - ), - request, - ) - .into()); - }}, - ), - }; - - quote! { - #field_name: match headers - .get(#http::header::#header_name) - .and_then(|v| v.to_str().ok()) // FIXME: Should have a distinct error message - { - Some(header) => #some_case, - None => #none_case, - } - } - }); - - quote! { - #(#fields,)* - } - } - /// Whether or not this request has any data in the HTTP body. pub fn has_body_fields(&self) -> bool { self.fields.iter().any(|field| field.is_body()) @@ -173,32 +82,27 @@ impl Request { /// The combination of every fields unique lifetime annotation. pub fn combine_lifetimes(&self) -> TokenStream { util::unique_lifetimes_to_tokens( - self.lifetimes - .body - .iter() - .chain(self.lifetimes.path.iter()) - .chain(self.lifetimes.query.iter()) - .chain(self.lifetimes.header.iter()) - .collect::>() - .into_iter(), + [ + &self.lifetimes.body, + &self.lifetimes.path, + &self.lifetimes.query, + &self.lifetimes.header, + ] + .iter() + .flat_map(|set| set.iter()), ) } /// The lifetimes on fields with the `query` attribute. pub fn query_lifetimes(&self) -> TokenStream { - util::unique_lifetimes_to_tokens(self.lifetimes.query.iter()) + util::unique_lifetimes_to_tokens(&self.lifetimes.query) } /// The lifetimes on fields with the `body` attribute. pub fn body_lifetimes(&self) -> TokenStream { - util::unique_lifetimes_to_tokens(self.lifetimes.body.iter()) + util::unique_lifetimes_to_tokens(&self.lifetimes.body) } - // /// The lifetimes on fields with the `header` attribute. - // pub fn header_lifetimes(&self) -> TokenStream { - // util::generics_to_tokens(self.lifetimes.header.iter()) - // } - /// Produces an iterator over all the header fields. pub fn header_fields(&self) -> impl Iterator { self.fields.iter().filter(|field| field.is_header()) @@ -224,28 +128,6 @@ impl Request { self.fields.iter().find_map(RequestField::as_query_map_field) } - /// Produces code for a struct initializer for body fields on a variable named `request`. - pub fn request_body_init_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Body, quote!(self)) - } - - /// Produces code for a struct initializer for query string fields on a variable named - /// `request`. - pub fn request_query_init_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Query, quote!(self)) - } - - /// Produces code for a struct initializer for body fields on a variable named `request_body`. - pub fn request_init_body_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Body, quote!(request_body)) - } - - /// Produces code for a struct initializer for query string fields on a variable named - /// `request_query`. - pub fn request_init_query_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Query, quote!(request_query)) - } - /// Produces code for a struct initializer for the given field kind to be accessed through the /// given variable name. fn struct_init_fields( @@ -336,10 +218,42 @@ impl Request { #field_name: request_query, } } else { - self.request_init_query_fields() + self.struct_init_fields(RequestFieldKind::Query, quote!(request_query)) }; - let mut header_kvs = self.append_header_kvs(&ruma_api); + let mut header_kvs: TokenStream = self + .header_fields() + .map(|request_field| { + let (field, header_name) = match request_field { + RequestField::Header(field, header_name) => (field, header_name), + _ => unreachable!("expected request field to be header variant"), + }; + + let field_name = &field.ident; + + match &field.ty { + syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) + if segments.last().unwrap().ident == "Option" => + { + quote! { + if let Some(header_val) = self.#field_name.as_ref() { + req_headers.insert( + #http::header::#header_name, + #http::header::HeaderValue::from_str(header_val)?, + ); + } + } + } + _ => quote! { + req_headers.insert( + #http::header::#header_name, + #http::header::HeaderValue::from_str(self.#field_name.as_ref())?, + ); + }, + } + }) + .collect(); + for auth in &metadata.authentication { if auth.value == "AccessToken" { let attrs = &auth.attrs; @@ -398,13 +312,91 @@ impl Request { }; let parse_request_headers = if self.has_header_fields() { - self.parse_headers_from_request(&ruma_api) + let fields = self.header_fields().map(|request_field| { + let (field, header_name) = match request_field { + RequestField::Header(field, header_name) => (field, header_name), + _ => panic!("expected request field to be header variant"), + }; + + let field_name = &field.ident; + let header_name_string = header_name.to_string(); + + let (some_case, none_case) = match &field.ty { + syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) + if segments.last().unwrap().ident == "Option" => + { + (quote! { Some(str_value.to_owned()) }, quote! { None }) + } + _ => ( + quote! { str_value.to_owned() }, + quote! { + // FIXME: Not a missing json field, a missing header! + return Err(#ruma_api::error::RequestDeserializationError::new( + #ruma_api::error::HeaderDeserializationError::MissingHeader( + #header_name_string.into() + ), + request, + ) + .into()) + }, + ), + }; + + quote! { + #field_name: match headers.get(#http::header::#header_name) { + Some(header_value) => { + let str_value = + #ruma_api::try_deserialize!(request, header_value.to_str()); + #some_case + } + None => #none_case, + } + } + }); + + quote! { + #(#fields,)* + } } else { TokenStream::new() }; - let request_body = self.build_request_body(&ruma_api); - let parse_request_body = self.parse_request_body(); + let request_body = if let Some(field) = self.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { self.#field_name } + } else if self.has_body_fields() || self.newtype_body_field().is_some() { + let request_body_initializers = if let Some(field) = self.newtype_body_field() { + let field_name = + field.ident.as_ref().expect("expected field to have an identifier"); + quote! { (self.#field_name) } + } else { + let initializers = self.struct_init_fields(RequestFieldKind::Body, quote!(self)); + quote! { { #initializers } } + }; + + quote! { + { + let request_body = RequestBody #request_body_initializers; + #serde_json::to_vec(&request_body)? + } + } + } else { + quote! { Vec::new() } + }; + + let parse_request_body = if let Some(field) = self.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request_body.0, + } + } else if let Some(field) = self.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request.into_body(), + } + } else { + self.struct_init_fields(RequestFieldKind::Body, quote!(request_body)) + }; let request_generics = self.combine_lifetimes(); @@ -634,52 +626,6 @@ impl Request { } } - /// Generates the code to initialize a `Request`. - /// - /// Used to construct an `http::Request`s body. - fn build_request_body(&self, ruma_api: &TokenStream) -> TokenStream { - let serde_json = quote! { #ruma_api::exports::serde_json }; - - if let Some(field) = self.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote!(self.#field_name) - } else if self.has_body_fields() || self.newtype_body_field().is_some() { - let request_body_initializers = if let Some(field) = self.newtype_body_field() { - let field_name = - field.ident.as_ref().expect("expected field to have an identifier"); - quote! { (self.#field_name) } - } else { - let initializers = self.request_body_init_fields(); - quote! { { #initializers } } - }; - - quote! { - { - let request_body = RequestBody #request_body_initializers; - #serde_json::to_vec(&request_body)? - } - } - } else { - quote!(Vec::new()) - } - } - - fn parse_request_body(&self) -> TokenStream { - if let Some(field) = self.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request_body.0, - } - } else if let Some(field) = self.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request.into_body(), - } - } else { - self.request_init_body_fields() - } - } - /// The function determines the type of query string that needs to be built /// and then builds it using `ruma_serde::urlencoded::to_string`. fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream { @@ -715,7 +661,8 @@ impl Request { ) }) } else if self.has_query_fields() { - let request_query_init_fields = self.request_query_init_fields(); + let request_query_init_fields = + self.struct_init_fields(RequestFieldKind::Query, quote!(self)); quote!({ let request_query = RequestQuery { @@ -739,7 +686,7 @@ impl Request { /// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is /// used for implementing `TryFrom>>`, from path strings deserialized to /// Ruma types. - pub(crate) fn path_string_and_parse( + fn path_string_and_parse( &self, metadata: &Metadata, ruma_api: &TokenStream, diff --git a/ruma-api-macros/src/util.rs b/ruma-api-macros/src/util.rs index a6beed62..8498edb3 100644 --- a/ruma-api-macros/src/util.rs +++ b/ruma-api-macros/src/util.rs @@ -8,10 +8,10 @@ use quote::quote; use syn::{AttrStyle, Attribute, Ident, Lifetime}; /// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`. -pub(crate) fn unique_lifetimes_to_tokens<'a, I: Iterator>( +pub(crate) fn unique_lifetimes_to_tokens<'a, I: IntoIterator>( lifetimes: I, ) -> TokenStream { - let lifetimes = lifetimes.collect::>(); + let lifetimes = lifetimes.into_iter().collect::>(); if lifetimes.is_empty() { TokenStream::new() } else { diff --git a/ruma-api/src/error.rs b/ruma-api/src/error.rs index 5f2aa95f..101151cd 100644 --- a/ruma-api/src/error.rs +++ b/ruma-api/src/error.rs @@ -209,7 +209,7 @@ pub enum DeserializationError { /// Header value deserialization failed. #[error("{0}")] - Header(#[from] http::header::ToStrError), + Header(#[from] HeaderDeserializationError), } impl From for DeserializationError { @@ -217,3 +217,22 @@ impl From for DeserializationError { match err {} } } + +impl From for DeserializationError { + fn from(err: http::header::ToStrError) -> Self { + Self::Header(HeaderDeserializationError::ToStrError(err)) + } +} + +/// An error with the http headers. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum HeaderDeserializationError { + /// Failed to convert `http::header::HeaderValue` to `str`. + #[error("{0}")] + ToStrError(http::header::ToStrError), + + /// The given required header is missing. + #[error("Missing header `{0}`")] + MissingHeader(String), +}