From 23ba0bc1641ed5e1d1ecc01e56c9f149e6b33ee5 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 10 Apr 2021 16:47:34 +0200 Subject: [PATCH] api-macros: Refactor request code generation --- ruma-api-macros/src/api/request.rs | 313 ++++++++++++++++------------- ruma-api/tests/conversions.rs | 25 +-- ruma-api/tests/ruma_api_macros.rs | 13 +- 3 files changed, 188 insertions(+), 163 deletions(-) diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index f0d6f847..05bc57cd 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -32,47 +32,47 @@ pub(crate) struct Request { impl Request { /// Whether or not this request has any data in the HTTP body. - pub fn has_body_fields(&self) -> bool { + pub(super) fn has_body_fields(&self) -> bool { self.fields.iter().any(|field| field.is_body()) } /// Whether or not this request has any data in HTTP headers. - pub fn has_header_fields(&self) -> bool { + fn has_header_fields(&self) -> bool { self.fields.iter().any(|field| field.is_header()) } /// Whether or not this request has any data in the URL path. - pub fn has_path_fields(&self) -> bool { + fn has_path_fields(&self) -> bool { self.fields.iter().any(|field| field.is_path()) } /// Whether or not this request has any data in the query string. - pub fn has_query_fields(&self) -> bool { + fn has_query_fields(&self) -> bool { self.fields.iter().any(|field| field.is_query()) } /// Produces an iterator over all the body fields. - pub fn body_fields(&self) -> impl Iterator { + pub(super) fn body_fields(&self) -> impl Iterator { self.fields.iter().filter_map(|field| field.as_body_field()) } /// The number of unique lifetime annotations for `body` fields. - pub fn body_lifetime_count(&self) -> usize { + fn body_lifetime_count(&self) -> usize { self.lifetimes.body.len() } /// Whether any `body` field has a lifetime annotation. - pub fn has_body_lifetimes(&self) -> bool { + fn has_body_lifetimes(&self) -> bool { !self.lifetimes.body.is_empty() } /// Whether any `query` field has a lifetime annotation. - pub fn has_query_lifetimes(&self) -> bool { + fn has_query_lifetimes(&self) -> bool { !self.lifetimes.query.is_empty() } /// Whether any field has a lifetime. - pub fn contains_lifetimes(&self) -> bool { + fn contains_lifetimes(&self) -> bool { !(self.lifetimes.body.is_empty() && self.lifetimes.path.is_empty() && self.lifetimes.query.is_empty() @@ -80,7 +80,7 @@ impl Request { } /// The combination of every fields unique lifetime annotation. - pub fn combine_lifetimes(&self) -> TokenStream { + fn combine_lifetimes(&self) -> TokenStream { util::unique_lifetimes_to_tokens( [ &self.lifetimes.body, @@ -94,22 +94,22 @@ impl Request { } /// The lifetimes on fields with the `query` attribute. - pub fn query_lifetimes(&self) -> TokenStream { + fn query_lifetimes(&self) -> TokenStream { util::unique_lifetimes_to_tokens(&self.lifetimes.query) } /// The lifetimes on fields with the `body` attribute. - pub fn body_lifetimes(&self) -> TokenStream { + fn body_lifetimes(&self) -> TokenStream { util::unique_lifetimes_to_tokens(&self.lifetimes.body) } /// Produces an iterator over all the header fields. - pub fn header_fields(&self) -> impl Iterator { + fn header_fields(&self) -> impl Iterator { self.fields.iter().filter(|field| field.is_header()) } /// Gets the number of path fields. - pub fn path_field_count(&self) -> usize { + fn path_field_count(&self) -> usize { self.fields.iter().filter(|field| field.is_path()).count() } @@ -119,12 +119,12 @@ impl Request { } /// Returns the body field. - pub fn newtype_raw_body_field(&self) -> Option<&Field> { + fn newtype_raw_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) } /// Returns the query map field. - pub fn query_map_field(&self) -> Option<&Field> { + fn query_map_field(&self) -> Option<&Field> { self.fields.iter().find_map(RequestField::as_query_map_field) } @@ -135,8 +135,8 @@ impl Request { request_field_kind: RequestFieldKind, src: TokenStream, ) -> TokenStream { - let process_field = |f: &RequestField| { - f.field_of_kind(request_field_kind).map(|field| { + let fields = + self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)).map(|field| { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let span = field.span(); @@ -147,25 +147,43 @@ impl Request { #( #cfg_attrs )* #field_name: #src.#field_name } - }) - }; - - let mut fields = vec![]; - let mut new_type_body = None; - for field in &self.fields { - if let RequestField::NewtypeRawBody(_) = field { - new_type_body = process_field(field); - } else { - fields.extend(process_field(field)); - } - } - - // Move field that consumes `request` to the end of the init list. - fields.extend(new_type_body); + }); quote! { #(#fields,)* } } + /// Produces code for a struct initializer for the given field kind to be accessed through the + /// given variable name. + fn vars( + &self, + request_field_kind: RequestFieldKind, + src: TokenStream, + ) -> (TokenStream, TokenStream) { + let (decls, names): (TokenStream, Vec<_>) = self + .fields + .iter() + .filter_map(|f| f.field_of_kind(request_field_kind)) + .map(|field| { + let field_name = + field.ident.as_ref().expect("expected field to have an identifier"); + let span = field.span(); + let cfg_attrs = + field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); + + let decl = quote_spanned! {span=> + #( #cfg_attrs )* + let #field_name = #src.#field_name; + }; + + (decl, field_name) + }) + .unzip(); + + let names = quote! { #(#names,)* }; + + (decls, names) + } + pub(super) fn expand( &self, metadata: &Metadata, @@ -197,16 +215,7 @@ impl Request { let incoming_request_type = if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) }; - let extract_request_path = if self.has_path_fields() { - quote! { - let path_segments: ::std::vec::Vec<&::std::primitive::str> = - request.uri().path()[1..].split('/').collect(); - } - } else { - TokenStream::new() - }; - - let (request_path_string, parse_request_path) = if self.has_path_fields() { + let (request_path_string, parse_request_path, path_vars) = if self.has_path_fields() { let path_string = metadata.path.value(); assert!(path_string.starts_with('/'), "path needs to start with '/'"); @@ -246,26 +255,38 @@ impl Request { } }; - let path_fields = - path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( - |(i, segment)| { - let path_var = &segment[1..]; - let path_var_ident = Ident::new(path_var, Span::call_site()); - quote! { - #path_var_ident: { - let segment = path_segments[#i].as_bytes(); - let decoded = - #percent_encoding::percent_decode(segment).decode_utf8()?; + let path_var_decls = path_string[1..] + .split('/') + .enumerate() + .filter(|(_, seg)| seg.starts_with(':')) + .map(|(i, seg)| { + let path_var = Ident::new(&seg[1..], Span::call_site()); + quote! { + let #path_var = { + let segment = path_segments[#i].as_bytes(); + let decoded = + #percent_encoding::percent_decode(segment).decode_utf8()?; - ::std::convert::TryFrom::try_from(&*decoded)? - } - } - }, - ); + ::std::convert::TryFrom::try_from(&*decoded)? + }; + } + }); - (format_call, quote! { #(#path_fields,)* }) + let parse_request_path = quote! { + let path_segments: ::std::vec::Vec<&::std::primitive::str> = + request.uri().path()[1..].split('/').collect(); + + #(#path_var_decls)* + }; + + let path_vars = path_string[1..] + .split('/') + .filter(|seg| seg.starts_with(':')) + .map(|seg| Ident::new(&seg[1..], Span::call_site())); + + (format_call, parse_request_path, quote! { #(#path_vars,)* }) } else { - (quote! { metadata.path.to_owned() }, TokenStream::new()) + (quote! { metadata.path.to_owned() }, TokenStream::new(), TokenStream::new()) }; let request_query_string = if let Some(field) = self.query_map_field() { @@ -315,31 +336,30 @@ impl Request { quote! { "" } }; - let extract_request_query = if self.query_map_field().is_some() { - quote! { - let request_query = #ruma_serde::urlencoded::from_str( + let (parse_query, query_vars) = if let Some(field) = self.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + let parse = quote! { + let #field_name = #ruma_serde::urlencoded::from_str( &request.uri().query().unwrap_or(""), )?; - } + }; + + (parse, quote! { #field_name, }) } else if self.has_query_fields() { - quote! { + let (decls, names) = self.vars(RequestFieldKind::Query, quote!(request_query)); + + let parse = quote! { let request_query: ::Incoming = #ruma_serde::urlencoded::from_str( &request.uri().query().unwrap_or("") )?; - } - } else { - TokenStream::new() - }; - let parse_request_query = if let Some(field) = self.query_map_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + #decls + }; - quote! { - #field_name: request_query, - } + (parse, names) } else { - self.struct_init_fields(RequestFieldKind::Query, quote!(request_query)) + (TokenStream::new(), TokenStream::new()) }; let mut header_kvs: TokenStream = self @@ -395,16 +415,62 @@ impl Request { } } - let extract_request_headers = if self.has_header_fields() { - quote! { + let (parse_headers, header_vars) = if self.has_header_fields() { + let (decls, names): (TokenStream, Vec<_>) = self + .header_fields() + .map(|request_field| { + let (field, header_name) = match request_field { + RequestField::Header(field, header_name) => (field, header_name), + _ => panic!("expected request field to be header variant"), + }; + + let field_name = &field.ident; + let header_name_string = header_name.to_string(); + + let (some_case, none_case) = match &field.ty { + syn::Type::Path(syn::TypePath { + path: syn::Path { segments, .. }, .. + }) if segments.last().unwrap().ident == "Option" => { + (quote! { Some(str_value.to_owned()) }, quote! { None }) + } + _ => ( + quote! { str_value.to_owned() }, + quote! { + return Err( + #ruma_api::error::HeaderDeserializationError::MissingHeader( + #header_name_string.into() + ).into(), + ) + }, + ), + }; + + let decl = quote! { + let #field_name = match headers.get(#http::header::#header_name) { + Some(header_value) => { + let str_value = header_value.to_str()?; + #some_case + } + None => #none_case, + }; + }; + + (decl, field_name) + }) + .unzip(); + + let parse = quote! { let headers = request.headers(); - } + + #decls + }; + + (parse, quote! { #(#names,)* }) } else { - TokenStream::new() + (TokenStream::new(), TokenStream::new()) }; - let extract_request_body = if self.has_body_fields() || self.newtype_body_field().is_some() - { + let extract_body = if self.has_body_fields() || self.newtype_body_field().is_some() { let body_lifetimes = if self.has_body_lifetimes() { // duplicate the anonymous lifetime as many times as needed let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count()); @@ -412,6 +478,7 @@ impl Request { } else { TokenStream::new() }; + quote! { let request_body: < RequestBody #body_lifetimes @@ -432,52 +499,6 @@ impl Request { TokenStream::new() }; - let parse_request_headers = if self.has_header_fields() { - let fields = self.header_fields().map(|request_field| { - let (field, header_name) = match request_field { - RequestField::Header(field, header_name) => (field, header_name), - _ => panic!("expected request field to be header variant"), - }; - - let field_name = &field.ident; - let header_name_string = header_name.to_string(); - - let (some_case, none_case) = match &field.ty { - syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) - if segments.last().unwrap().ident == "Option" => - { - (quote! { Some(str_value.to_owned()) }, quote! { None }) - } - _ => ( - quote! { str_value.to_owned() }, - quote! { - return Err( - #ruma_api::error::HeaderDeserializationError::MissingHeader( - #header_name_string.into() - ).into(), - ) - }, - ), - }; - - quote! { - #field_name: match headers.get(#http::header::#header_name) { - Some(header_value) => { - let str_value = header_value.to_str()?; - #some_case - } - None => #none_case, - } - } - }); - - quote! { - #(#fields,)* - } - } else { - TokenStream::new() - }; - let request_body = if let Some(field) = self.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { self.#field_name } @@ -501,18 +522,22 @@ impl Request { quote! { Vec::new() } }; - let parse_request_body = if let Some(field) = self.newtype_body_field() { + let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request_body.0, - } + let parse = quote! { + let #field_name = request_body.0; + }; + + (parse, quote! { #field_name, }) } else if let Some(field) = self.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request.into_body(), - } + let parse = quote! { + let #field_name = request.into_body(); + }; + + (parse, quote! { #field_name, }) } else { - self.struct_init_fields(RequestFieldKind::Body, quote!(request_body)) + self.vars(RequestFieldKind::Body, quote!(request_body)) }; let request_generics = self.combine_lifetimes(); @@ -699,16 +724,18 @@ impl Request { }); } - #extract_request_path - #extract_request_query - #extract_request_headers - #extract_request_body + #parse_request_path + #parse_query + #parse_headers + + #extract_body + #parse_body Ok(Self { - #parse_request_path - #parse_request_query - #parse_request_headers - #parse_request_body + #path_vars + #query_vars + #header_vars + #body_vars }) } } diff --git a/ruma-api/tests/conversions.rs b/ruma-api/tests/conversions.rs index 2c6875eb..839bda63 100644 --- a/ruma-api/tests/conversions.rs +++ b/ruma-api/tests/conversions.rs @@ -8,7 +8,7 @@ ruma_api! { description: "Does something.", method: POST, name: "my_endpoint", - path: "/_matrix/foo/:bar/:baz", + path: "/_matrix/foo/:bar/:user", rate_limited: false, authentication: None, } @@ -24,7 +24,7 @@ ruma_api! { #[ruma_api(path)] pub bar: String, #[ruma_api(path)] - pub baz: UserId, + pub user: UserId, } response: { @@ -44,7 +44,7 @@ fn request_serde() -> Result<(), Box> { q1: "query_param_special_chars %/&@!".to_owned(), q2: 55, bar: "barVal".to_owned(), - baz: user_id!("@bazme:ruma.io"), + user: user_id!("@bazme:ruma.io"), }; let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?; @@ -55,7 +55,7 @@ fn request_serde() -> Result<(), Box> { assert_eq!(req.q1, req2.q1); assert_eq!(req.q2, req2.q2); assert_eq!(req.bar, req2.bar); - assert_eq!(req.baz, req2.baz); + assert_eq!(req.user, req2.user); Ok(()) } @@ -68,12 +68,12 @@ fn request_with_user_id_serde() -> Result<(), Box