use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use crate::auth_scheme::AuthScheme; use super::{Request, RequestField, RequestFieldKind}; impl Request { pub fn expand_outgoing(&self, ruma_api: &TokenStream) -> TokenStream { let bytes = quote! { #ruma_api::exports::bytes }; let http = quote! { #ruma_api::exports::http }; let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let method = &self.method; let error_ty = &self.error_ty; let request_path_string = if self.has_path_fields() { let mut format_string = self.path.value(); let mut format_args = Vec::new(); while let Some(start_of_segment) = format_string.find(':') { // ':' should only ever appear at the start of a segment assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); let end_of_segment = match format_string[start_of_segment..].find('/') { Some(rel_pos) => start_of_segment + rel_pos, None => format_string.len(), }; let path_var = Ident::new( &format_string[start_of_segment + 1..end_of_segment], Span::call_site(), ); format_args.push(quote! { #percent_encoding::utf8_percent_encode( &::std::string::ToString::to_string(&self.#path_var), #percent_encoding::NON_ALPHANUMERIC, ) }); format_string.replace_range(start_of_segment..end_of_segment, "{}"); } quote! { format_args!(#format_string, #(#format_args),*) } } else { quote! { metadata.path.to_owned() } }; let request_query_string = if let Some(field) = self.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have identifier"); quote! {{ // This function exists so that the compiler will throw an error when the type of // the field with the query_map attribute doesn't implement // `IntoIterator`. // // This is necessary because the `ruma_serde::urlencoded::to_string` call will // result in a runtime error when the type cannot be encoded as a list key-value // pairs (?key1=value1&key2=value2). // // By asserting that it implements the iterator trait, we can ensure that it won't // fail. fn assert_trait_impl(_: &T) where T: ::std::iter::IntoIterator< Item = (::std::string::String, ::std::string::String), >, {} let request_query = RequestQuery(self.#field_name); assert_trait_impl(&request_query.0); format_args!( "?{}", #ruma_serde::urlencoded::to_string(request_query)? ) }} } else if self.has_query_fields() { let request_query_init_fields = self.struct_init_fields(RequestFieldKind::Query, quote! { self }); quote! {{ let request_query = RequestQuery { #request_query_init_fields }; format_args!( "?{}", #ruma_serde::urlencoded::to_string(request_query)? ) }} } else { quote! { "" } }; let mut header_kvs: TokenStream = self .header_fields() .map(|request_field| { let (field, header_name) = match request_field { RequestField::Header(field, header_name) => (field, header_name), _ => unreachable!("expected request field to be header variant"), }; let field_name = &field.ident; match &field.ty { syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.last().unwrap().ident == "Option" => { quote! { if let Some(header_val) = self.#field_name.as_ref() { req_headers.insert( #http::header::#header_name, #http::header::HeaderValue::from_str(header_val)?, ); } } } _ => quote! { req_headers.insert( #http::header::#header_name, #http::header::HeaderValue::from_str(self.#field_name.as_ref())?, ); }, } }) .collect(); let hdr_kv = match self.authentication { AuthScheme::AccessToken(_) => quote! { req_headers.insert( #http::header::AUTHORIZATION, ::std::convert::TryFrom::<_>::try_from(::std::format!( "Bearer {}", access_token .get_required_for_endpoint() .ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?, ))?, ); }, AuthScheme::None(_) => quote! { if let Some(access_token) = access_token.get_not_required_for_endpoint() { req_headers.insert( #http::header::AUTHORIZATION, ::std::convert::TryFrom::<_>::try_from( ::std::format!("Bearer {}", access_token), )? ); } }, AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {}, }; header_kvs.extend(hdr_kv); let request_body = if let Some(field) = self.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { #ruma_serde::slice_to_buf(&self.#field_name) } } else if self.has_body_fields() || self.newtype_body_field().is_some() { let request_body_initializers = if let Some(field) = self.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { (self.#field_name) } } else { let initializers = self.struct_init_fields(RequestFieldKind::Body, quote! { self }); quote! { { #initializers } } }; quote! { #ruma_serde::json_to_buf(&RequestBody #request_body_initializers)? } } else { quote! { ::default() } }; let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { quote! { #[automatically_derived] #[cfg(feature = "client")] impl #impl_generics #ruma_api::OutgoingNonAuthRequest for Request #ty_generics #where_clause {} } }); quote! { #[automatically_derived] #[cfg(feature = "client")] impl #impl_generics #ruma_api::OutgoingRequest for Request #ty_generics #where_clause { type EndpointError = #error_ty; type IncomingResponse = ::Incoming; const METADATA: #ruma_api::Metadata = self::METADATA; fn try_into_http_request( self, base_url: &::std::primitive::str, access_token: #ruma_api::SendAccessToken<'_>, ) -> ::std::result::Result<#http::Request, #ruma_api::error::IntoHttpError> { let metadata = self::METADATA; let mut req_builder = #http::Request::builder() .method(#http::Method::#method) .uri(::std::format!( "{}{}{}", base_url.strip_suffix('/').unwrap_or(base_url), #request_path_string, #request_query_string, )) .header( #ruma_api::exports::http::header::CONTENT_TYPE, "application/json", ); if let Some(mut req_headers) = req_builder.headers_mut() { #header_kvs } let http_request = req_builder.body(#request_body)?; Ok(http_request) } } #non_auth_impl } } /// Produces code for a struct initializer for the given field kind to be accessed through the /// given variable name. fn struct_init_fields( &self, request_field_kind: RequestFieldKind, src: TokenStream, ) -> TokenStream { self.fields .iter() .filter_map(|f| f.field_of_kind(request_field_kind)) .map(|field| { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); quote! { #( #cfg_attrs )* #field_name: #src.#field_name, } }) .collect() } }