use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use super::{Request, RequestField, RequestFieldKind}; use crate::auth_scheme::AuthScheme; impl Request { pub fn expand_incoming(&self, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde_json = quote! { #ruma_api::exports::serde_json }; let method = &self.method; let error_ty = &self.error_ty; let incoming_request_type = if self.has_lifetimes() { quote! { IncomingRequest } } else { quote! { Request } }; // FIXME: the rest of the field initializer expansions are gated `cfg(...)` // except this one. If we get errors about missing fields in IncomingRequest for // a path field look here. let (parse_request_path, path_vars) = if self.has_path_fields() { let path_string = self.path.value(); assert!(path_string.starts_with('/'), "path needs to start with '/'"); assert!( path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(), "number of declared path parameters needs to match amount of placeholders in path" ); let path_var_decls = path_string[1..] .split('/') .enumerate() .filter(|(_, seg)| seg.starts_with(':')) .map(|(i, seg)| { let path_var = Ident::new(&seg[1..], Span::call_site()); quote! { let #path_var = { let segment = path_segments[#i].as_bytes(); let decoded = #percent_encoding::percent_decode(segment).decode_utf8()?; ::std::convert::TryFrom::try_from(&*decoded)? }; } }); let parse_request_path = quote! { let path_segments: ::std::vec::Vec<&::std::primitive::str> = request.uri().path()[1..].split('/').collect(); #(#path_var_decls)* }; let path_vars = path_string[1..] .split('/') .filter(|seg| seg.starts_with(':')) .map(|seg| Ident::new(&seg[1..], Span::call_site())); (parse_request_path, quote! { #(#path_vars,)* }) } else { (TokenStream::new(), TokenStream::new()) }; let (parse_query, query_vars) = if let Some(field) = self.query_map_field() { let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let parse = quote! { #( #cfg_attrs )* let #field_name = #ruma_serde::urlencoded::from_str( &request.uri().query().unwrap_or(""), )?; }; ( parse, quote! { #( #cfg_attrs )* #field_name, }, ) } else if self.has_query_fields() { let (decls, names) = self.vars(RequestFieldKind::Query, quote! { request_query }); let parse = quote! { let request_query: ::Incoming = #ruma_serde::urlencoded::from_str( &request.uri().query().unwrap_or("") )?; #decls }; (parse, names) } else { (TokenStream::new(), TokenStream::new()) }; let (parse_headers, header_vars) = if self.has_header_fields() { let (decls, names): (TokenStream, Vec<_>) = self .header_fields() .map(|request_field| { let (field, header_name) = match request_field { RequestField::Header(field, header_name) => (field, header_name), _ => panic!("expected request field to be header variant"), }; let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); let field_name = &field.ident; let header_name_string = header_name.to_string(); let (some_case, none_case) = match &field.ty { syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.last().unwrap().ident == "Option" => { (quote! { Some(str_value.to_owned()) }, quote! { None }) } _ => ( quote! { str_value.to_owned() }, quote! { return Err( #ruma_api::error::HeaderDeserializationError::MissingHeader( #header_name_string.into() ).into(), ) }, ), }; let decl = quote! { #( #cfg_attrs )* let #field_name = match headers.get(#http::header::#header_name) { Some(header_value) => { let str_value = header_value.to_str()?; #some_case } None => #none_case, }; }; ( decl, quote! { #( #cfg_attrs )* #field_name }, ) }) .unzip(); let parse = quote! { let headers = request.headers(); #decls }; (parse, quote! { #(#names,)* }) } else { (TokenStream::new(), TokenStream::new()) }; let extract_body = (self.has_body_fields() || self.newtype_body_field().is_some()).then(|| { let body_lifetimes = (!self.lifetimes.body.is_empty()).then(|| { // duplicate the anonymous lifetime as many times as needed let lifetimes = std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len()); quote! { < #( #lifetimes, )* > } }); quote! { let request_body: < RequestBody #body_lifetimes as #ruma_serde::Outgoing >::Incoming = { let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( request.body(), ); #serde_json::from_slice(match body { // If the request body is completely empty, pretend it is an empty JSON // object instead. This allows requests with only optional body parameters // to be deserialized in that case. [] => b"{}", b => b, })? }; } }); let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let parse = quote! { let #field_name = request_body.0; }; (parse, quote! { #field_name, }) } else if let Some(field) = self.raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let parse = quote! { let #field_name = ::std::convert::AsRef::<[u8]>::as_ref(request.body()).to_vec(); }; (parse, quote! { #field_name, }) } else { self.vars(RequestFieldKind::Body, quote! { request_body }) }; let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { quote! { #[automatically_derived] #[cfg(feature = "server")] impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} } }); quote! { #[automatically_derived] #[cfg(feature = "server")] impl #ruma_api::IncomingRequest for #incoming_request_type { type EndpointError = #error_ty; type OutgoingResponse = Response; const METADATA: #ruma_api::Metadata = self::METADATA; fn try_from_http_request>( request: #http::Request ) -> ::std::result::Result { if request.method() != #http::Method::#method { return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch { expected: #http::Method::#method, received: request.method().clone(), }); } #parse_request_path #parse_query #parse_headers #extract_body #parse_body ::std::result::Result::Ok(Self { #path_vars #query_vars #header_vars #body_vars }) } } #non_auth_impl } } fn vars( &self, request_field_kind: RequestFieldKind, src: TokenStream, ) -> (TokenStream, TokenStream) { self.fields .iter() .filter_map(|f| f.field_of_kind(request_field_kind)) .map(|field| { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); let decl = quote! { #( #cfg_attrs )* let #field_name = #src.#field_name; }; ( decl, quote! { #( #cfg_attrs )* #field_name, }, ) }) .unzip() } }