254 lines
10 KiB
Rust
254 lines
10 KiB
Rust
use proc_macro2::{Ident, Span, TokenStream};
|
|
use quote::quote;
|
|
|
|
use crate::auth_scheme::AuthScheme;
|
|
|
|
use super::{Request, RequestField, RequestFieldKind};
|
|
|
|
impl Request {
|
|
pub fn expand_outgoing(&self, ruma_api: &TokenStream) -> TokenStream {
|
|
let bytes = quote! { #ruma_api::exports::bytes };
|
|
let http = quote! { #ruma_api::exports::http };
|
|
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
|
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
|
|
|
let method = &self.method;
|
|
let error_ty = &self.error_ty;
|
|
let request_path_string = if self.has_path_fields() {
|
|
let mut format_string = self.path.value();
|
|
let mut format_args = Vec::new();
|
|
|
|
while let Some(start_of_segment) = format_string.find(':') {
|
|
// ':' should only ever appear at the start of a segment
|
|
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
|
|
|
|
let end_of_segment = match format_string[start_of_segment..].find('/') {
|
|
Some(rel_pos) => start_of_segment + rel_pos,
|
|
None => format_string.len(),
|
|
};
|
|
|
|
let path_var = Ident::new(
|
|
&format_string[start_of_segment + 1..end_of_segment],
|
|
Span::call_site(),
|
|
);
|
|
format_args.push(quote! {
|
|
#percent_encoding::utf8_percent_encode(
|
|
&::std::string::ToString::to_string(&self.#path_var),
|
|
#percent_encoding::NON_ALPHANUMERIC,
|
|
)
|
|
});
|
|
format_string.replace_range(start_of_segment..end_of_segment, "{}");
|
|
}
|
|
|
|
quote! {
|
|
format_args!(#format_string, #(#format_args),*)
|
|
}
|
|
} else {
|
|
quote! { metadata.path.to_owned() }
|
|
};
|
|
|
|
let request_query_string = if let Some(field) = self.query_map_field() {
|
|
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
|
|
|
quote! {{
|
|
// This function exists so that the compiler will throw an error when the type of
|
|
// the field with the query_map attribute doesn't implement
|
|
// `IntoIterator<Item = (String, String)>`.
|
|
//
|
|
// This is necessary because the `ruma_serde::urlencoded::to_string` call will
|
|
// result in a runtime error when the type cannot be encoded as a list key-value
|
|
// pairs (?key1=value1&key2=value2).
|
|
//
|
|
// By asserting that it implements the iterator trait, we can ensure that it won't
|
|
// fail.
|
|
fn assert_trait_impl<T>(_: &T)
|
|
where
|
|
T: ::std::iter::IntoIterator<
|
|
Item = (::std::string::String, ::std::string::String),
|
|
>,
|
|
{}
|
|
|
|
let request_query = RequestQuery(self.#field_name);
|
|
assert_trait_impl(&request_query.0);
|
|
|
|
format_args!(
|
|
"?{}",
|
|
#ruma_serde::urlencoded::to_string(request_query)?
|
|
)
|
|
}}
|
|
} else if self.has_query_fields() {
|
|
let request_query_init_fields =
|
|
self.struct_init_fields(RequestFieldKind::Query, quote! { self });
|
|
|
|
quote! {{
|
|
let request_query = RequestQuery {
|
|
#request_query_init_fields
|
|
};
|
|
|
|
format_args!(
|
|
"?{}",
|
|
#ruma_serde::urlencoded::to_string(request_query)?
|
|
)
|
|
}}
|
|
} else {
|
|
quote! { "" }
|
|
};
|
|
|
|
let mut header_kvs: TokenStream = self
|
|
.header_fields()
|
|
.map(|request_field| {
|
|
let (field, header_name) = match request_field {
|
|
RequestField::Header(field, header_name) => (field, header_name),
|
|
_ => unreachable!("expected request field to be header variant"),
|
|
};
|
|
|
|
let field_name = &field.ident;
|
|
|
|
match &field.ty {
|
|
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
|
|
if segments.last().unwrap().ident == "Option" =>
|
|
{
|
|
quote! {
|
|
if let Some(header_val) = self.#field_name.as_ref() {
|
|
req_headers.insert(
|
|
#http::header::#header_name,
|
|
#http::header::HeaderValue::from_str(header_val)?,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
_ => quote! {
|
|
req_headers.insert(
|
|
#http::header::#header_name,
|
|
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
|
|
);
|
|
},
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let hdr_kv = match self.authentication {
|
|
AuthScheme::AccessToken(_) => quote! {
|
|
req_headers.insert(
|
|
#http::header::AUTHORIZATION,
|
|
::std::convert::TryFrom::<_>::try_from(::std::format!(
|
|
"Bearer {}",
|
|
access_token
|
|
.get_required_for_endpoint()
|
|
.ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?,
|
|
))?,
|
|
);
|
|
},
|
|
AuthScheme::None(_) => quote! {
|
|
if let Some(access_token) = access_token.get_not_required_for_endpoint() {
|
|
req_headers.insert(
|
|
#http::header::AUTHORIZATION,
|
|
::std::convert::TryFrom::<_>::try_from(
|
|
::std::format!("Bearer {}", access_token),
|
|
)?
|
|
);
|
|
}
|
|
},
|
|
AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {},
|
|
};
|
|
header_kvs.extend(hdr_kv);
|
|
|
|
let request_body = if let Some(field) = self.newtype_raw_body_field() {
|
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
quote! { #ruma_serde::slice_to_buf(&self.#field_name) }
|
|
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
|
|
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
|
|
let field_name =
|
|
field.ident.as_ref().expect("expected field to have an identifier");
|
|
quote! { (self.#field_name) }
|
|
} else {
|
|
let initializers = self.struct_init_fields(RequestFieldKind::Body, quote! { self });
|
|
quote! { { #initializers } }
|
|
};
|
|
|
|
quote! {
|
|
#ruma_serde::json_to_buf(&RequestBody #request_body_initializers)?
|
|
}
|
|
} else {
|
|
quote! { <T as ::std::default::Default>::default() }
|
|
};
|
|
|
|
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
|
|
|
|
let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| {
|
|
quote! {
|
|
#[automatically_derived]
|
|
#[cfg(feature = "client")]
|
|
impl #impl_generics #ruma_api::OutgoingNonAuthRequest
|
|
for Request #ty_generics #where_clause {}
|
|
}
|
|
});
|
|
|
|
quote! {
|
|
#[automatically_derived]
|
|
#[cfg(feature = "client")]
|
|
impl #impl_generics #ruma_api::OutgoingRequest for Request #ty_generics #where_clause {
|
|
type EndpointError = #error_ty;
|
|
type IncomingResponse = <Response as #ruma_serde::Outgoing>::Incoming;
|
|
|
|
const METADATA: #ruma_api::Metadata = self::METADATA;
|
|
|
|
fn try_into_http_request<T: ::std::default::Default + #bytes::BufMut>(
|
|
self,
|
|
base_url: &::std::primitive::str,
|
|
access_token: #ruma_api::SendAccessToken<'_>,
|
|
) -> ::std::result::Result<#http::Request<T>, #ruma_api::error::IntoHttpError> {
|
|
let metadata = self::METADATA;
|
|
|
|
let mut req_builder = #http::Request::builder()
|
|
.method(#http::Method::#method)
|
|
.uri(::std::format!(
|
|
"{}{}{}",
|
|
base_url.strip_suffix('/').unwrap_or(base_url),
|
|
#request_path_string,
|
|
#request_query_string,
|
|
))
|
|
.header(
|
|
#ruma_api::exports::http::header::CONTENT_TYPE,
|
|
"application/json",
|
|
);
|
|
|
|
if let Some(mut req_headers) = req_builder.headers_mut() {
|
|
#header_kvs
|
|
}
|
|
|
|
let http_request = req_builder.body(#request_body)?;
|
|
|
|
Ok(http_request)
|
|
}
|
|
}
|
|
|
|
#non_auth_impl
|
|
}
|
|
}
|
|
|
|
/// Produces code for a struct initializer for the given field kind to be accessed through the
|
|
/// given variable name.
|
|
fn struct_init_fields(
|
|
&self,
|
|
request_field_kind: RequestFieldKind,
|
|
src: TokenStream,
|
|
) -> TokenStream {
|
|
self.fields
|
|
.iter()
|
|
.filter_map(|f| f.field_of_kind(request_field_kind))
|
|
.map(|field| {
|
|
let field_name =
|
|
field.ident.as_ref().expect("expected field to have an identifier");
|
|
let cfg_attrs =
|
|
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
|
|
|
|
quote! {
|
|
#( #cfg_attrs )*
|
|
#field_name: #src.#field_name,
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
}
|