api-macros: Refactor request code generation
This commit is contained in:
		
							parent
							
								
									a20f03894e
								
							
						
					
					
						commit
						23ba0bc164
					
				| @ -32,47 +32,47 @@ pub(crate) struct Request { | |||||||
| 
 | 
 | ||||||
| impl Request { | impl Request { | ||||||
|     /// Whether or not this request has any data in the HTTP body.
 |     /// Whether or not this request has any data in the HTTP body.
 | ||||||
|     pub fn has_body_fields(&self) -> bool { |     pub(super) fn has_body_fields(&self) -> bool { | ||||||
|         self.fields.iter().any(|field| field.is_body()) |         self.fields.iter().any(|field| field.is_body()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Whether or not this request has any data in HTTP headers.
 |     /// Whether or not this request has any data in HTTP headers.
 | ||||||
|     pub fn has_header_fields(&self) -> bool { |     fn has_header_fields(&self) -> bool { | ||||||
|         self.fields.iter().any(|field| field.is_header()) |         self.fields.iter().any(|field| field.is_header()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Whether or not this request has any data in the URL path.
 |     /// Whether or not this request has any data in the URL path.
 | ||||||
|     pub fn has_path_fields(&self) -> bool { |     fn has_path_fields(&self) -> bool { | ||||||
|         self.fields.iter().any(|field| field.is_path()) |         self.fields.iter().any(|field| field.is_path()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Whether or not this request has any data in the query string.
 |     /// Whether or not this request has any data in the query string.
 | ||||||
|     pub fn has_query_fields(&self) -> bool { |     fn has_query_fields(&self) -> bool { | ||||||
|         self.fields.iter().any(|field| field.is_query()) |         self.fields.iter().any(|field| field.is_query()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Produces an iterator over all the body fields.
 |     /// Produces an iterator over all the body fields.
 | ||||||
|     pub fn body_fields(&self) -> impl Iterator<Item = &Field> { |     pub(super) fn body_fields(&self) -> impl Iterator<Item = &Field> { | ||||||
|         self.fields.iter().filter_map(|field| field.as_body_field()) |         self.fields.iter().filter_map(|field| field.as_body_field()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// The number of unique lifetime annotations for `body` fields.
 |     /// The number of unique lifetime annotations for `body` fields.
 | ||||||
|     pub fn body_lifetime_count(&self) -> usize { |     fn body_lifetime_count(&self) -> usize { | ||||||
|         self.lifetimes.body.len() |         self.lifetimes.body.len() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Whether any `body` field has a lifetime annotation.
 |     /// Whether any `body` field has a lifetime annotation.
 | ||||||
|     pub fn has_body_lifetimes(&self) -> bool { |     fn has_body_lifetimes(&self) -> bool { | ||||||
|         !self.lifetimes.body.is_empty() |         !self.lifetimes.body.is_empty() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Whether any `query` field has a lifetime annotation.
 |     /// Whether any `query` field has a lifetime annotation.
 | ||||||
|     pub fn has_query_lifetimes(&self) -> bool { |     fn has_query_lifetimes(&self) -> bool { | ||||||
|         !self.lifetimes.query.is_empty() |         !self.lifetimes.query.is_empty() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Whether any field has a lifetime.
 |     /// Whether any field has a lifetime.
 | ||||||
|     pub fn contains_lifetimes(&self) -> bool { |     fn contains_lifetimes(&self) -> bool { | ||||||
|         !(self.lifetimes.body.is_empty() |         !(self.lifetimes.body.is_empty() | ||||||
|             && self.lifetimes.path.is_empty() |             && self.lifetimes.path.is_empty() | ||||||
|             && self.lifetimes.query.is_empty() |             && self.lifetimes.query.is_empty() | ||||||
| @ -80,7 +80,7 @@ impl Request { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// The combination of every fields unique lifetime annotation.
 |     /// The combination of every fields unique lifetime annotation.
 | ||||||
|     pub fn combine_lifetimes(&self) -> TokenStream { |     fn combine_lifetimes(&self) -> TokenStream { | ||||||
|         util::unique_lifetimes_to_tokens( |         util::unique_lifetimes_to_tokens( | ||||||
|             [ |             [ | ||||||
|                 &self.lifetimes.body, |                 &self.lifetimes.body, | ||||||
| @ -94,22 +94,22 @@ impl Request { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// The lifetimes on fields with the `query` attribute.
 |     /// The lifetimes on fields with the `query` attribute.
 | ||||||
|     pub fn query_lifetimes(&self) -> TokenStream { |     fn query_lifetimes(&self) -> TokenStream { | ||||||
|         util::unique_lifetimes_to_tokens(&self.lifetimes.query) |         util::unique_lifetimes_to_tokens(&self.lifetimes.query) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// The lifetimes on fields with the `body` attribute.
 |     /// The lifetimes on fields with the `body` attribute.
 | ||||||
|     pub fn body_lifetimes(&self) -> TokenStream { |     fn body_lifetimes(&self) -> TokenStream { | ||||||
|         util::unique_lifetimes_to_tokens(&self.lifetimes.body) |         util::unique_lifetimes_to_tokens(&self.lifetimes.body) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Produces an iterator over all the header fields.
 |     /// Produces an iterator over all the header fields.
 | ||||||
|     pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> { |     fn header_fields(&self) -> impl Iterator<Item = &RequestField> { | ||||||
|         self.fields.iter().filter(|field| field.is_header()) |         self.fields.iter().filter(|field| field.is_header()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Gets the number of path fields.
 |     /// Gets the number of path fields.
 | ||||||
|     pub fn path_field_count(&self) -> usize { |     fn path_field_count(&self) -> usize { | ||||||
|         self.fields.iter().filter(|field| field.is_path()).count() |         self.fields.iter().filter(|field| field.is_path()).count() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -119,12 +119,12 @@ impl Request { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns the body field.
 |     /// Returns the body field.
 | ||||||
|     pub fn newtype_raw_body_field(&self) -> Option<&Field> { |     fn newtype_raw_body_field(&self) -> Option<&Field> { | ||||||
|         self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) |         self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns the query map field.
 |     /// Returns the query map field.
 | ||||||
|     pub fn query_map_field(&self) -> Option<&Field> { |     fn query_map_field(&self) -> Option<&Field> { | ||||||
|         self.fields.iter().find_map(RequestField::as_query_map_field) |         self.fields.iter().find_map(RequestField::as_query_map_field) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -135,8 +135,8 @@ impl Request { | |||||||
|         request_field_kind: RequestFieldKind, |         request_field_kind: RequestFieldKind, | ||||||
|         src: TokenStream, |         src: TokenStream, | ||||||
|     ) -> TokenStream { |     ) -> TokenStream { | ||||||
|         let process_field = |f: &RequestField| { |         let fields = | ||||||
|             f.field_of_kind(request_field_kind).map(|field| { |             self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)).map(|field| { | ||||||
|                 let field_name = |                 let field_name = | ||||||
|                     field.ident.as_ref().expect("expected field to have an identifier"); |                     field.ident.as_ref().expect("expected field to have an identifier"); | ||||||
|                 let span = field.span(); |                 let span = field.span(); | ||||||
| @ -147,25 +147,43 @@ impl Request { | |||||||
|                     #( #cfg_attrs )* |                     #( #cfg_attrs )* | ||||||
|                     #field_name: #src.#field_name |                     #field_name: #src.#field_name | ||||||
|                 } |                 } | ||||||
|             }) |             }); | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         let mut fields = vec![]; |  | ||||||
|         let mut new_type_body = None; |  | ||||||
|         for field in &self.fields { |  | ||||||
|             if let RequestField::NewtypeRawBody(_) = field { |  | ||||||
|                 new_type_body = process_field(field); |  | ||||||
|             } else { |  | ||||||
|                 fields.extend(process_field(field)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // Move field that consumes `request` to the end of the init list.
 |  | ||||||
|         fields.extend(new_type_body); |  | ||||||
| 
 | 
 | ||||||
|         quote! { #(#fields,)* } |         quote! { #(#fields,)* } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     /// Produces code for a struct initializer for the given field kind to be accessed through the
 | ||||||
|  |     /// given variable name.
 | ||||||
|  |     fn vars( | ||||||
|  |         &self, | ||||||
|  |         request_field_kind: RequestFieldKind, | ||||||
|  |         src: TokenStream, | ||||||
|  |     ) -> (TokenStream, TokenStream) { | ||||||
|  |         let (decls, names): (TokenStream, Vec<_>) = self | ||||||
|  |             .fields | ||||||
|  |             .iter() | ||||||
|  |             .filter_map(|f| f.field_of_kind(request_field_kind)) | ||||||
|  |             .map(|field| { | ||||||
|  |                 let field_name = | ||||||
|  |                     field.ident.as_ref().expect("expected field to have an identifier"); | ||||||
|  |                 let span = field.span(); | ||||||
|  |                 let cfg_attrs = | ||||||
|  |                     field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>(); | ||||||
|  | 
 | ||||||
|  |                 let decl = quote_spanned! {span=> | ||||||
|  |                     #( #cfg_attrs )* | ||||||
|  |                     let #field_name = #src.#field_name; | ||||||
|  |                 }; | ||||||
|  | 
 | ||||||
|  |                 (decl, field_name) | ||||||
|  |             }) | ||||||
|  |             .unzip(); | ||||||
|  | 
 | ||||||
|  |         let names = quote! { #(#names,)* }; | ||||||
|  | 
 | ||||||
|  |         (decls, names) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     pub(super) fn expand( |     pub(super) fn expand( | ||||||
|         &self, |         &self, | ||||||
|         metadata: &Metadata, |         metadata: &Metadata, | ||||||
| @ -197,16 +215,7 @@ impl Request { | |||||||
|         let incoming_request_type = |         let incoming_request_type = | ||||||
|             if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) }; |             if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) }; | ||||||
| 
 | 
 | ||||||
|         let extract_request_path = if self.has_path_fields() { |         let (request_path_string, parse_request_path, path_vars) = if self.has_path_fields() { | ||||||
|             quote! { |  | ||||||
|                 let path_segments: ::std::vec::Vec<&::std::primitive::str> = |  | ||||||
|                     request.uri().path()[1..].split('/').collect(); |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             TokenStream::new() |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         let (request_path_string, parse_request_path) = if self.has_path_fields() { |  | ||||||
|             let path_string = metadata.path.value(); |             let path_string = metadata.path.value(); | ||||||
| 
 | 
 | ||||||
|             assert!(path_string.starts_with('/'), "path needs to start with '/'"); |             assert!(path_string.starts_with('/'), "path needs to start with '/'"); | ||||||
| @ -246,26 +255,38 @@ impl Request { | |||||||
|                 } |                 } | ||||||
|             }; |             }; | ||||||
| 
 | 
 | ||||||
|             let path_fields = |             let path_var_decls = path_string[1..] | ||||||
|                 path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( |                 .split('/') | ||||||
|                     |(i, segment)| { |                 .enumerate() | ||||||
|                         let path_var = &segment[1..]; |                 .filter(|(_, seg)| seg.starts_with(':')) | ||||||
|                         let path_var_ident = Ident::new(path_var, Span::call_site()); |                 .map(|(i, seg)| { | ||||||
|  |                     let path_var = Ident::new(&seg[1..], Span::call_site()); | ||||||
|                     quote! { |                     quote! { | ||||||
|                             #path_var_ident: { |                         let #path_var = { | ||||||
|                             let segment = path_segments[#i].as_bytes(); |                             let segment = path_segments[#i].as_bytes(); | ||||||
|                             let decoded = |                             let decoded = | ||||||
|                                 #percent_encoding::percent_decode(segment).decode_utf8()?; |                                 #percent_encoding::percent_decode(segment).decode_utf8()?; | ||||||
| 
 | 
 | ||||||
|                             ::std::convert::TryFrom::try_from(&*decoded)? |                             ::std::convert::TryFrom::try_from(&*decoded)? | ||||||
|  |                         }; | ||||||
|                     } |                     } | ||||||
|                         } |                 }); | ||||||
|                     }, |  | ||||||
|                 ); |  | ||||||
| 
 | 
 | ||||||
|             (format_call, quote! { #(#path_fields,)* }) |             let parse_request_path = quote! { | ||||||
|  |                 let path_segments: ::std::vec::Vec<&::std::primitive::str> = | ||||||
|  |                     request.uri().path()[1..].split('/').collect(); | ||||||
|  | 
 | ||||||
|  |                 #(#path_var_decls)* | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             let path_vars = path_string[1..] | ||||||
|  |                 .split('/') | ||||||
|  |                 .filter(|seg| seg.starts_with(':')) | ||||||
|  |                 .map(|seg| Ident::new(&seg[1..], Span::call_site())); | ||||||
|  | 
 | ||||||
|  |             (format_call, parse_request_path, quote! { #(#path_vars,)* }) | ||||||
|         } else { |         } else { | ||||||
|             (quote! { metadata.path.to_owned() }, TokenStream::new()) |             (quote! { metadata.path.to_owned() }, TokenStream::new(), TokenStream::new()) | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let request_query_string = if let Some(field) = self.query_map_field() { |         let request_query_string = if let Some(field) = self.query_map_field() { | ||||||
| @ -315,31 +336,30 @@ impl Request { | |||||||
|             quote! { "" } |             quote! { "" } | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let extract_request_query = if self.query_map_field().is_some() { |         let (parse_query, query_vars) = if let Some(field) = self.query_map_field() { | ||||||
|             quote! { |             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); | ||||||
|                 let request_query = #ruma_serde::urlencoded::from_str( |             let parse = quote! { | ||||||
|  |                 let #field_name = #ruma_serde::urlencoded::from_str( | ||||||
|                     &request.uri().query().unwrap_or(""), |                     &request.uri().query().unwrap_or(""), | ||||||
|                 )?; |                 )?; | ||||||
|             } |             }; | ||||||
|  | 
 | ||||||
|  |             (parse, quote! { #field_name, }) | ||||||
|         } else if self.has_query_fields() { |         } else if self.has_query_fields() { | ||||||
|             quote! { |             let (decls, names) = self.vars(RequestFieldKind::Query, quote!(request_query)); | ||||||
|  | 
 | ||||||
|  |             let parse = quote! { | ||||||
|                 let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming = |                 let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming = | ||||||
|                     #ruma_serde::urlencoded::from_str( |                     #ruma_serde::urlencoded::from_str( | ||||||
|                         &request.uri().query().unwrap_or("") |                         &request.uri().query().unwrap_or("") | ||||||
|                     )?; |                     )?; | ||||||
|             } | 
 | ||||||
|         } else { |                 #decls | ||||||
|             TokenStream::new() |  | ||||||
|             }; |             }; | ||||||
| 
 | 
 | ||||||
|         let parse_request_query = if let Some(field) = self.query_map_field() { |             (parse, names) | ||||||
|             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); |  | ||||||
| 
 |  | ||||||
|             quote! { |  | ||||||
|                 #field_name: request_query, |  | ||||||
|             } |  | ||||||
|         } else { |         } else { | ||||||
|             self.struct_init_fields(RequestFieldKind::Query, quote!(request_query)) |             (TokenStream::new(), TokenStream::new()) | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let mut header_kvs: TokenStream = self |         let mut header_kvs: TokenStream = self | ||||||
| @ -395,16 +415,62 @@ impl Request { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let extract_request_headers = if self.has_header_fields() { |         let (parse_headers, header_vars) = if self.has_header_fields() { | ||||||
|             quote! { |             let (decls, names): (TokenStream, Vec<_>) = self | ||||||
|                 let headers = request.headers(); |                 .header_fields() | ||||||
|             } |                 .map(|request_field| { | ||||||
|         } else { |                     let (field, header_name) = match request_field { | ||||||
|             TokenStream::new() |                         RequestField::Header(field, header_name) => (field, header_name), | ||||||
|  |                         _ => panic!("expected request field to be header variant"), | ||||||
|                     }; |                     }; | ||||||
| 
 | 
 | ||||||
|         let extract_request_body = if self.has_body_fields() || self.newtype_body_field().is_some() |                     let field_name = &field.ident; | ||||||
|         { |                     let header_name_string = header_name.to_string(); | ||||||
|  | 
 | ||||||
|  |                     let (some_case, none_case) = match &field.ty { | ||||||
|  |                         syn::Type::Path(syn::TypePath { | ||||||
|  |                             path: syn::Path { segments, .. }, .. | ||||||
|  |                         }) if segments.last().unwrap().ident == "Option" => { | ||||||
|  |                             (quote! { Some(str_value.to_owned()) }, quote! { None }) | ||||||
|  |                         } | ||||||
|  |                         _ => ( | ||||||
|  |                             quote! { str_value.to_owned() }, | ||||||
|  |                             quote! { | ||||||
|  |                                 return Err( | ||||||
|  |                                     #ruma_api::error::HeaderDeserializationError::MissingHeader( | ||||||
|  |                                         #header_name_string.into() | ||||||
|  |                                     ).into(), | ||||||
|  |                                 ) | ||||||
|  |                             }, | ||||||
|  |                         ), | ||||||
|  |                     }; | ||||||
|  | 
 | ||||||
|  |                     let decl = quote! { | ||||||
|  |                         let #field_name = match headers.get(#http::header::#header_name) { | ||||||
|  |                             Some(header_value) => { | ||||||
|  |                                 let str_value = header_value.to_str()?; | ||||||
|  |                                 #some_case | ||||||
|  |                             } | ||||||
|  |                             None => #none_case, | ||||||
|  |                         }; | ||||||
|  |                     }; | ||||||
|  | 
 | ||||||
|  |                     (decl, field_name) | ||||||
|  |                 }) | ||||||
|  |                 .unzip(); | ||||||
|  | 
 | ||||||
|  |             let parse = quote! { | ||||||
|  |                 let headers = request.headers(); | ||||||
|  | 
 | ||||||
|  |                 #decls | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             (parse, quote! { #(#names,)* }) | ||||||
|  |         } else { | ||||||
|  |             (TokenStream::new(), TokenStream::new()) | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let extract_body = if self.has_body_fields() || self.newtype_body_field().is_some() { | ||||||
|             let body_lifetimes = if self.has_body_lifetimes() { |             let body_lifetimes = if self.has_body_lifetimes() { | ||||||
|                 // duplicate the anonymous lifetime as many times as needed
 |                 // duplicate the anonymous lifetime as many times as needed
 | ||||||
|                 let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count()); |                 let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count()); | ||||||
| @ -412,6 +478,7 @@ impl Request { | |||||||
|             } else { |             } else { | ||||||
|                 TokenStream::new() |                 TokenStream::new() | ||||||
|             }; |             }; | ||||||
|  | 
 | ||||||
|             quote! { |             quote! { | ||||||
|                 let request_body: < |                 let request_body: < | ||||||
|                     RequestBody #body_lifetimes |                     RequestBody #body_lifetimes | ||||||
| @ -432,52 +499,6 @@ impl Request { | |||||||
|             TokenStream::new() |             TokenStream::new() | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let parse_request_headers = if self.has_header_fields() { |  | ||||||
|             let fields = self.header_fields().map(|request_field| { |  | ||||||
|                 let (field, header_name) = match request_field { |  | ||||||
|                     RequestField::Header(field, header_name) => (field, header_name), |  | ||||||
|                     _ => panic!("expected request field to be header variant"), |  | ||||||
|                 }; |  | ||||||
| 
 |  | ||||||
|                 let field_name = &field.ident; |  | ||||||
|                 let header_name_string = header_name.to_string(); |  | ||||||
| 
 |  | ||||||
|                 let (some_case, none_case) = match &field.ty { |  | ||||||
|                     syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) |  | ||||||
|                         if segments.last().unwrap().ident == "Option" => |  | ||||||
|                     { |  | ||||||
|                         (quote! { Some(str_value.to_owned()) }, quote! { None }) |  | ||||||
|                     } |  | ||||||
|                     _ => ( |  | ||||||
|                         quote! { str_value.to_owned() }, |  | ||||||
|                         quote! { |  | ||||||
|                             return Err( |  | ||||||
|                                 #ruma_api::error::HeaderDeserializationError::MissingHeader( |  | ||||||
|                                     #header_name_string.into() |  | ||||||
|                                 ).into(), |  | ||||||
|                             ) |  | ||||||
|                         }, |  | ||||||
|                     ), |  | ||||||
|                 }; |  | ||||||
| 
 |  | ||||||
|                 quote! { |  | ||||||
|                     #field_name: match headers.get(#http::header::#header_name) { |  | ||||||
|                         Some(header_value) => { |  | ||||||
|                             let str_value = header_value.to_str()?; |  | ||||||
|                             #some_case |  | ||||||
|                         } |  | ||||||
|                         None => #none_case, |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
| 
 |  | ||||||
|             quote! { |  | ||||||
|                 #(#fields,)* |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             TokenStream::new() |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         let request_body = if let Some(field) = self.newtype_raw_body_field() { |         let request_body = if let Some(field) = self.newtype_raw_body_field() { | ||||||
|             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); |             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); | ||||||
|             quote! { self.#field_name } |             quote! { self.#field_name } | ||||||
| @ -501,18 +522,22 @@ impl Request { | |||||||
|             quote! { Vec::new() } |             quote! { Vec::new() } | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let parse_request_body = if let Some(field) = self.newtype_body_field() { |         let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() { | ||||||
|             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); |             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); | ||||||
|             quote! { |             let parse = quote! { | ||||||
|                 #field_name: request_body.0, |                 let #field_name = request_body.0; | ||||||
|             } |             }; | ||||||
|  | 
 | ||||||
|  |             (parse, quote! { #field_name, }) | ||||||
|         } else if let Some(field) = self.newtype_raw_body_field() { |         } else if let Some(field) = self.newtype_raw_body_field() { | ||||||
|             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); |             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); | ||||||
|             quote! { |             let parse = quote! { | ||||||
|                 #field_name: request.into_body(), |                 let #field_name = request.into_body(); | ||||||
|             } |             }; | ||||||
|  | 
 | ||||||
|  |             (parse, quote! { #field_name, }) | ||||||
|         } else { |         } else { | ||||||
|             self.struct_init_fields(RequestFieldKind::Body, quote!(request_body)) |             self.vars(RequestFieldKind::Body, quote!(request_body)) | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let request_generics = self.combine_lifetimes(); |         let request_generics = self.combine_lifetimes(); | ||||||
| @ -699,16 +724,18 @@ impl Request { | |||||||
|                         }); |                         }); | ||||||
|                     } |                     } | ||||||
| 
 | 
 | ||||||
|                     #extract_request_path |                     #parse_request_path | ||||||
|                     #extract_request_query |                     #parse_query | ||||||
|                     #extract_request_headers |                     #parse_headers | ||||||
|                     #extract_request_body | 
 | ||||||
|  |                     #extract_body | ||||||
|  |                     #parse_body | ||||||
| 
 | 
 | ||||||
|                     Ok(Self { |                     Ok(Self { | ||||||
|                         #parse_request_path |                         #path_vars | ||||||
|                         #parse_request_query |                         #query_vars | ||||||
|                         #parse_request_headers |                         #header_vars | ||||||
|                         #parse_request_body |                         #body_vars | ||||||
|                     }) |                     }) | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ ruma_api! { | |||||||
|         description: "Does something.", |         description: "Does something.", | ||||||
|         method: POST, |         method: POST, | ||||||
|         name: "my_endpoint", |         name: "my_endpoint", | ||||||
|         path: "/_matrix/foo/:bar/:baz", |         path: "/_matrix/foo/:bar/:user", | ||||||
|         rate_limited: false, |         rate_limited: false, | ||||||
|         authentication: None, |         authentication: None, | ||||||
|     } |     } | ||||||
| @ -24,7 +24,7 @@ ruma_api! { | |||||||
|         #[ruma_api(path)] |         #[ruma_api(path)] | ||||||
|         pub bar: String, |         pub bar: String, | ||||||
|         #[ruma_api(path)] |         #[ruma_api(path)] | ||||||
|         pub baz: UserId, |         pub user: UserId, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     response: { |     response: { | ||||||
| @ -44,7 +44,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> { | |||||||
|         q1: "query_param_special_chars %/&@!".to_owned(), |         q1: "query_param_special_chars %/&@!".to_owned(), | ||||||
|         q2: 55, |         q2: 55, | ||||||
|         bar: "barVal".to_owned(), |         bar: "barVal".to_owned(), | ||||||
|         baz: user_id!("@bazme:ruma.io"), |         user: user_id!("@bazme:ruma.io"), | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?; |     let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?; | ||||||
| @ -55,7 +55,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> { | |||||||
|     assert_eq!(req.q1, req2.q1); |     assert_eq!(req.q1, req2.q1); | ||||||
|     assert_eq!(req.q2, req2.q2); |     assert_eq!(req.q2, req2.q2); | ||||||
|     assert_eq!(req.bar, req2.bar); |     assert_eq!(req.bar, req2.bar); | ||||||
|     assert_eq!(req.baz, req2.baz); |     assert_eq!(req.user, req2.user); | ||||||
| 
 | 
 | ||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
| @ -68,12 +68,12 @@ fn request_with_user_id_serde() -> Result<(), Box<dyn std::error::Error + 'stati | |||||||
|         q1: "query_param_special_chars %/&@!".to_owned(), |         q1: "query_param_special_chars %/&@!".to_owned(), | ||||||
|         q2: 55, |         q2: 55, | ||||||
|         bar: "barVal".to_owned(), |         bar: "barVal".to_owned(), | ||||||
|         baz: user_id!("@bazme:ruma.io"), |         user: user_id!("@bazme:ruma.io"), | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     let user_id = user_id!("@_virtual_:ruma.io"); |     let user_id = user_id!("@_virtual_:ruma.io"); | ||||||
|     let http_req = |     let http_req = | ||||||
|         req.clone().try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?; |         req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?; | ||||||
| 
 | 
 | ||||||
|     let query = http_req.uri().query().unwrap(); |     let query = http_req.uri().query().unwrap(); | ||||||
| 
 | 
 | ||||||
| @ -93,7 +93,7 @@ mod without_query { | |||||||
|             description: "Does something without query.", |             description: "Does something without query.", | ||||||
|             method: POST, |             method: POST, | ||||||
|             name: "my_endpoint", |             name: "my_endpoint", | ||||||
|             path: "/_matrix/foo/:bar/:baz", |             path: "/_matrix/foo/:bar/:user", | ||||||
|             rate_limited: false, |             rate_limited: false, | ||||||
|             authentication: None, |             authentication: None, | ||||||
|         } |         } | ||||||
| @ -105,7 +105,7 @@ mod without_query { | |||||||
|             #[ruma_api(path)] |             #[ruma_api(path)] | ||||||
|             pub bar: String, |             pub bar: String, | ||||||
|             #[ruma_api(path)] |             #[ruma_api(path)] | ||||||
|             pub baz: UserId, |             pub user: UserId, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         response: { |         response: { | ||||||
| @ -124,15 +124,12 @@ mod without_query { | |||||||
|             hello: "hi".to_owned(), |             hello: "hi".to_owned(), | ||||||
|             world: "test".to_owned(), |             world: "test".to_owned(), | ||||||
|             bar: "barVal".to_owned(), |             bar: "barVal".to_owned(), | ||||||
|             baz: user_id!("@bazme:ruma.io"), |             user: user_id!("@bazme:ruma.io"), | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let user_id = user_id!("@_virtual_:ruma.io"); |         let user_id = user_id!("@_virtual_:ruma.io"); | ||||||
|         let http_req = req.clone().try_into_http_request_with_user_id( |         let http_req = | ||||||
|             "https://homeserver.tld", |             req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?; | ||||||
|             None, |  | ||||||
|             user_id, |  | ||||||
|         )?; |  | ||||||
| 
 | 
 | ||||||
|         let query = http_req.uri().query().unwrap(); |         let query = http_req.uri().query().unwrap(); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| pub mod some_endpoint { | pub mod some_endpoint { | ||||||
|     use ruma_api::ruma_api; |     use ruma_api::ruma_api; | ||||||
|     use ruma_events::{tag::TagEvent, AnyRoomEvent}; |     use ruma_events::{tag::TagEvent, AnyRoomEvent}; | ||||||
|  |     use ruma_identifiers::UserId; | ||||||
|     use ruma_serde::Raw; |     use ruma_serde::Raw; | ||||||
| 
 | 
 | ||||||
|     ruma_api! { |     ruma_api! { | ||||||
| @ -8,7 +9,7 @@ pub mod some_endpoint { | |||||||
|             description: "Does something.", |             description: "Does something.", | ||||||
|             method: POST, // An `http::Method` constant. No imports required.
 |             method: POST, // An `http::Method` constant. No imports required.
 | ||||||
|             name: "some_endpoint", |             name: "some_endpoint", | ||||||
|             path: "/_matrix/some/endpoint/:baz", |             path: "/_matrix/some/endpoint/:user", | ||||||
| 
 | 
 | ||||||
|             #[cfg(all())] |             #[cfg(all())] | ||||||
|             rate_limited: true, |             rate_limited: true, | ||||||
| @ -23,7 +24,7 @@ pub mod some_endpoint { | |||||||
| 
 | 
 | ||||||
|         request: { |         request: { | ||||||
|             // With no attribute on the field, it will be put into the body of the request.
 |             // With no attribute on the field, it will be put into the body of the request.
 | ||||||
|             pub foo: String, |             pub a_field: String, | ||||||
| 
 | 
 | ||||||
|             // This value will be put into the "Content-Type" HTTP header.
 |             // This value will be put into the "Content-Type" HTTP header.
 | ||||||
|             #[ruma_api(header = CONTENT_TYPE)] |             #[ruma_api(header = CONTENT_TYPE)] | ||||||
| @ -34,9 +35,9 @@ pub mod some_endpoint { | |||||||
|             pub bar: String, |             pub bar: String, | ||||||
| 
 | 
 | ||||||
|             // This value will be inserted into the request's URL in place of the
 |             // This value will be inserted into the request's URL in place of the
 | ||||||
|             // ":baz" path component.
 |             // ":user" path component.
 | ||||||
|             #[ruma_api(path)] |             #[ruma_api(path)] | ||||||
|             pub baz: String, |             pub user: UserId, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         response: { |         response: { | ||||||
| @ -65,7 +66,7 @@ pub mod newtype_body_endpoint { | |||||||
| 
 | 
 | ||||||
|     #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] |     #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] | ||||||
|     pub struct MyCustomType { |     pub struct MyCustomType { | ||||||
|         pub foo: String, |         pub a_field: String, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ruma_api! { |     ruma_api! { | ||||||
| @ -95,7 +96,7 @@ pub mod newtype_raw_body_endpoint { | |||||||
| 
 | 
 | ||||||
|     #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] |     #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] | ||||||
|     pub struct MyCustomType { |     pub struct MyCustomType { | ||||||
|         pub foo: String, |         pub a_field: String, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ruma_api! { |     ruma_api! { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user