api-macros: Split request code generation into more files
This commit is contained in:
parent
0e197aae0b
commit
59d47227a6
@ -2,14 +2,17 @@
|
|||||||
|
|
||||||
use std::collections::BTreeSet;
|
use std::collections::BTreeSet;
|
||||||
|
|
||||||
use proc_macro2::{Span, TokenStream};
|
use proc_macro2::TokenStream;
|
||||||
use quote::{quote, quote_spanned};
|
use quote::quote;
|
||||||
use syn::{spanned::Spanned, Attribute, Field, Ident, Lifetime};
|
use syn::{Attribute, Field, Ident, Lifetime};
|
||||||
|
|
||||||
use crate::util;
|
use crate::util;
|
||||||
|
|
||||||
use super::metadata::Metadata;
|
use super::metadata::Metadata;
|
||||||
|
|
||||||
|
mod incoming;
|
||||||
|
mod outgoing;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub(super) struct RequestLifetimes {
|
pub(super) struct RequestLifetimes {
|
||||||
pub body: BTreeSet<Lifetime>,
|
pub body: BTreeSet<Lifetime>,
|
||||||
@ -56,11 +59,6 @@ impl Request {
|
|||||||
self.fields.iter().filter_map(|field| field.as_body_field())
|
self.fields.iter().filter_map(|field| field.as_body_field())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of unique lifetime annotations for `body` fields.
|
|
||||||
fn body_lifetime_count(&self) -> usize {
|
|
||||||
self.lifetimes.body.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Whether any `body` field has a lifetime annotation.
|
/// Whether any `body` field has a lifetime annotation.
|
||||||
fn has_body_lifetimes(&self) -> bool {
|
fn has_body_lifetimes(&self) -> bool {
|
||||||
!self.lifetimes.body.is_empty()
|
!self.lifetimes.body.is_empty()
|
||||||
@ -128,74 +126,14 @@ impl Request {
|
|||||||
self.fields.iter().find_map(RequestField::as_query_map_field)
|
self.fields.iter().find_map(RequestField::as_query_map_field)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Produces code for a struct initializer for the given field kind to be accessed through the
|
|
||||||
/// given variable name.
|
|
||||||
fn struct_init_fields(
|
|
||||||
&self,
|
|
||||||
request_field_kind: RequestFieldKind,
|
|
||||||
src: TokenStream,
|
|
||||||
) -> TokenStream {
|
|
||||||
let fields =
|
|
||||||
self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)).map(|field| {
|
|
||||||
let field_name =
|
|
||||||
field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
let span = field.span();
|
|
||||||
let cfg_attrs =
|
|
||||||
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
quote_spanned! {span=>
|
|
||||||
#( #cfg_attrs )*
|
|
||||||
#field_name: #src.#field_name
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
quote! { #(#fields,)* }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vars(
|
|
||||||
&self,
|
|
||||||
request_field_kind: RequestFieldKind,
|
|
||||||
src: TokenStream,
|
|
||||||
) -> (TokenStream, TokenStream) {
|
|
||||||
let (decls, names): (TokenStream, Vec<_>) = self
|
|
||||||
.fields
|
|
||||||
.iter()
|
|
||||||
.filter_map(|f| f.field_of_kind(request_field_kind))
|
|
||||||
.map(|field| {
|
|
||||||
let field_name =
|
|
||||||
field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
let span = field.span();
|
|
||||||
let cfg_attrs =
|
|
||||||
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let decl = quote_spanned! {span=>
|
|
||||||
#( #cfg_attrs )*
|
|
||||||
let #field_name = #src.#field_name;
|
|
||||||
};
|
|
||||||
|
|
||||||
(decl, field_name)
|
|
||||||
})
|
|
||||||
.unzip();
|
|
||||||
|
|
||||||
let names = quote! { #(#names,)* };
|
|
||||||
|
|
||||||
(decls, names)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn expand(
|
pub(super) fn expand(
|
||||||
&self,
|
&self,
|
||||||
metadata: &Metadata,
|
metadata: &Metadata,
|
||||||
error_ty: &TokenStream,
|
error_ty: &TokenStream,
|
||||||
ruma_api: &TokenStream,
|
ruma_api: &TokenStream,
|
||||||
) -> TokenStream {
|
) -> TokenStream {
|
||||||
let bytes = quote! { #ruma_api::exports::bytes };
|
|
||||||
let http = quote! { #ruma_api::exports::http };
|
|
||||||
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
|
||||||
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||||
let serde = quote! { #ruma_api::exports::serde };
|
let serde = quote! { #ruma_api::exports::serde };
|
||||||
let serde_json = quote! { #ruma_api::exports::serde_json };
|
|
||||||
|
|
||||||
let method = &metadata.method;
|
|
||||||
|
|
||||||
let docs = format!(
|
let docs = format!(
|
||||||
"Data for a request to the `{}` API endpoint.\n\n{}",
|
"Data for a request to the `{}` API endpoint.\n\n{}",
|
||||||
@ -211,342 +149,6 @@ impl Request {
|
|||||||
quote! { { #(#fields),* } }
|
quote! { { #(#fields),* } }
|
||||||
};
|
};
|
||||||
|
|
||||||
let incoming_request_type =
|
|
||||||
if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) };
|
|
||||||
|
|
||||||
let (request_path_string, parse_request_path, path_vars) = if self.has_path_fields() {
|
|
||||||
let path_string = metadata.path.value();
|
|
||||||
|
|
||||||
assert!(path_string.starts_with('/'), "path needs to start with '/'");
|
|
||||||
assert!(
|
|
||||||
path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(),
|
|
||||||
"number of declared path parameters needs to match amount of placeholders in path"
|
|
||||||
);
|
|
||||||
|
|
||||||
let format_call = {
|
|
||||||
let mut format_string = path_string.clone();
|
|
||||||
let mut format_args = Vec::new();
|
|
||||||
|
|
||||||
while let Some(start_of_segment) = format_string.find(':') {
|
|
||||||
// ':' should only ever appear at the start of a segment
|
|
||||||
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
|
|
||||||
|
|
||||||
let end_of_segment = match format_string[start_of_segment..].find('/') {
|
|
||||||
Some(rel_pos) => start_of_segment + rel_pos,
|
|
||||||
None => format_string.len(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let path_var = Ident::new(
|
|
||||||
&format_string[start_of_segment + 1..end_of_segment],
|
|
||||||
Span::call_site(),
|
|
||||||
);
|
|
||||||
format_args.push(quote! {
|
|
||||||
#percent_encoding::utf8_percent_encode(
|
|
||||||
&self.#path_var.to_string(),
|
|
||||||
#percent_encoding::NON_ALPHANUMERIC,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
format_string.replace_range(start_of_segment..end_of_segment, "{}");
|
|
||||||
}
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
format_args!(#format_string, #(#format_args),*)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let path_var_decls = path_string[1..]
|
|
||||||
.split('/')
|
|
||||||
.enumerate()
|
|
||||||
.filter(|(_, seg)| seg.starts_with(':'))
|
|
||||||
.map(|(i, seg)| {
|
|
||||||
let path_var = Ident::new(&seg[1..], Span::call_site());
|
|
||||||
quote! {
|
|
||||||
let #path_var = {
|
|
||||||
let segment = path_segments[#i].as_bytes();
|
|
||||||
let decoded =
|
|
||||||
#percent_encoding::percent_decode(segment).decode_utf8()?;
|
|
||||||
|
|
||||||
::std::convert::TryFrom::try_from(&*decoded)?
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let parse_request_path = quote! {
|
|
||||||
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
|
|
||||||
request.uri().path()[1..].split('/').collect();
|
|
||||||
|
|
||||||
#(#path_var_decls)*
|
|
||||||
};
|
|
||||||
|
|
||||||
let path_vars = path_string[1..]
|
|
||||||
.split('/')
|
|
||||||
.filter(|seg| seg.starts_with(':'))
|
|
||||||
.map(|seg| Ident::new(&seg[1..], Span::call_site()));
|
|
||||||
|
|
||||||
(format_call, parse_request_path, quote! { #(#path_vars,)* })
|
|
||||||
} else {
|
|
||||||
(quote! { metadata.path.to_owned() }, TokenStream::new(), TokenStream::new())
|
|
||||||
};
|
|
||||||
|
|
||||||
let request_query_string = if let Some(field) = self.query_map_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
|
||||||
|
|
||||||
quote!({
|
|
||||||
// This function exists so that the compiler will throw an error when the type of
|
|
||||||
// the field with the query_map attribute doesn't implement
|
|
||||||
// `IntoIterator<Item = (String, String)>`.
|
|
||||||
//
|
|
||||||
// This is necessary because the `ruma_serde::urlencoded::to_string` call will
|
|
||||||
// result in a runtime error when the type cannot be encoded as a list key-value
|
|
||||||
// pairs (?key1=value1&key2=value2).
|
|
||||||
//
|
|
||||||
// By asserting that it implements the iterator trait, we can ensure that it won't
|
|
||||||
// fail.
|
|
||||||
fn assert_trait_impl<T>(_: &T)
|
|
||||||
where
|
|
||||||
T: ::std::iter::IntoIterator<
|
|
||||||
Item = (::std::string::String, ::std::string::String),
|
|
||||||
>,
|
|
||||||
{}
|
|
||||||
|
|
||||||
let request_query = RequestQuery(self.#field_name);
|
|
||||||
assert_trait_impl(&request_query.0);
|
|
||||||
|
|
||||||
format_args!(
|
|
||||||
"?{}",
|
|
||||||
#ruma_serde::urlencoded::to_string(request_query)?
|
|
||||||
)
|
|
||||||
})
|
|
||||||
} else if self.has_query_fields() {
|
|
||||||
let request_query_init_fields =
|
|
||||||
self.struct_init_fields(RequestFieldKind::Query, quote!(self));
|
|
||||||
|
|
||||||
quote!({
|
|
||||||
let request_query = RequestQuery {
|
|
||||||
#request_query_init_fields
|
|
||||||
};
|
|
||||||
|
|
||||||
format_args!(
|
|
||||||
"?{}",
|
|
||||||
#ruma_serde::urlencoded::to_string(request_query)?
|
|
||||||
)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
quote! { "" }
|
|
||||||
};
|
|
||||||
|
|
||||||
let (parse_query, query_vars) = if let Some(field) = self.query_map_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
let parse = quote! {
|
|
||||||
let #field_name = #ruma_serde::urlencoded::from_str(
|
|
||||||
&request.uri().query().unwrap_or(""),
|
|
||||||
)?;
|
|
||||||
};
|
|
||||||
|
|
||||||
(parse, quote! { #field_name, })
|
|
||||||
} else if self.has_query_fields() {
|
|
||||||
let (decls, names) = self.vars(RequestFieldKind::Query, quote!(request_query));
|
|
||||||
|
|
||||||
let parse = quote! {
|
|
||||||
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
|
|
||||||
#ruma_serde::urlencoded::from_str(
|
|
||||||
&request.uri().query().unwrap_or("")
|
|
||||||
)?;
|
|
||||||
|
|
||||||
#decls
|
|
||||||
};
|
|
||||||
|
|
||||||
(parse, names)
|
|
||||||
} else {
|
|
||||||
(TokenStream::new(), TokenStream::new())
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut header_kvs: TokenStream = self
|
|
||||||
.header_fields()
|
|
||||||
.map(|request_field| {
|
|
||||||
let (field, header_name) = match request_field {
|
|
||||||
RequestField::Header(field, header_name) => (field, header_name),
|
|
||||||
_ => unreachable!("expected request field to be header variant"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let field_name = &field.ident;
|
|
||||||
|
|
||||||
match &field.ty {
|
|
||||||
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
|
|
||||||
if segments.last().unwrap().ident == "Option" =>
|
|
||||||
{
|
|
||||||
quote! {
|
|
||||||
if let Some(header_val) = self.#field_name.as_ref() {
|
|
||||||
req_headers.insert(
|
|
||||||
#http::header::#header_name,
|
|
||||||
#http::header::HeaderValue::from_str(header_val)?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => quote! {
|
|
||||||
req_headers.insert(
|
|
||||||
#http::header::#header_name,
|
|
||||||
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
|
|
||||||
);
|
|
||||||
},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
for auth in &metadata.authentication {
|
|
||||||
if auth.value == "AccessToken" {
|
|
||||||
let attrs = &auth.attrs;
|
|
||||||
header_kvs.extend(quote! {
|
|
||||||
#( #attrs )*
|
|
||||||
req_headers.insert(
|
|
||||||
#http::header::AUTHORIZATION,
|
|
||||||
#http::header::HeaderValue::from_str(
|
|
||||||
&::std::format!(
|
|
||||||
"Bearer {}",
|
|
||||||
access_token.ok_or(
|
|
||||||
#ruma_api::error::IntoHttpError::NeedsAuthentication
|
|
||||||
)?
|
|
||||||
)
|
|
||||||
)?
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let (parse_headers, header_vars) = if self.has_header_fields() {
|
|
||||||
let (decls, names): (TokenStream, Vec<_>) = self
|
|
||||||
.header_fields()
|
|
||||||
.map(|request_field| {
|
|
||||||
let (field, header_name) = match request_field {
|
|
||||||
RequestField::Header(field, header_name) => (field, header_name),
|
|
||||||
_ => panic!("expected request field to be header variant"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let field_name = &field.ident;
|
|
||||||
let header_name_string = header_name.to_string();
|
|
||||||
|
|
||||||
let (some_case, none_case) = match &field.ty {
|
|
||||||
syn::Type::Path(syn::TypePath {
|
|
||||||
path: syn::Path { segments, .. }, ..
|
|
||||||
}) if segments.last().unwrap().ident == "Option" => {
|
|
||||||
(quote! { Some(str_value.to_owned()) }, quote! { None })
|
|
||||||
}
|
|
||||||
_ => (
|
|
||||||
quote! { str_value.to_owned() },
|
|
||||||
quote! {
|
|
||||||
return Err(
|
|
||||||
#ruma_api::error::HeaderDeserializationError::MissingHeader(
|
|
||||||
#header_name_string.into()
|
|
||||||
).into(),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
let decl = quote! {
|
|
||||||
let #field_name = match headers.get(#http::header::#header_name) {
|
|
||||||
Some(header_value) => {
|
|
||||||
let str_value = header_value.to_str()?;
|
|
||||||
#some_case
|
|
||||||
}
|
|
||||||
None => #none_case,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
(decl, field_name)
|
|
||||||
})
|
|
||||||
.unzip();
|
|
||||||
|
|
||||||
let parse = quote! {
|
|
||||||
let headers = request.headers();
|
|
||||||
|
|
||||||
#decls
|
|
||||||
};
|
|
||||||
|
|
||||||
(parse, quote! { #(#names,)* })
|
|
||||||
} else {
|
|
||||||
(TokenStream::new(), TokenStream::new())
|
|
||||||
};
|
|
||||||
|
|
||||||
let extract_body = if self.has_body_fields() || self.newtype_body_field().is_some() {
|
|
||||||
let body_lifetimes = if self.has_body_lifetimes() {
|
|
||||||
// duplicate the anonymous lifetime as many times as needed
|
|
||||||
let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count());
|
|
||||||
quote! { < #( #lifetimes, )* >}
|
|
||||||
} else {
|
|
||||||
TokenStream::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
let request_body: <
|
|
||||||
RequestBody #body_lifetimes
|
|
||||||
as #ruma_serde::Outgoing
|
|
||||||
>::Incoming = {
|
|
||||||
let body = request.into_body();
|
|
||||||
if #bytes::Buf::has_remaining(&body) {
|
|
||||||
#serde_json::from_reader(#bytes::Buf::reader(body))?
|
|
||||||
} else {
|
|
||||||
// If the request body is completely empty, pretend it is an empty JSON
|
|
||||||
// object instead. This allows requests with only optional body parameters
|
|
||||||
// to be deserialized in that case.
|
|
||||||
#serde_json::from_str("{}")?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TokenStream::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
let request_body = if let Some(field) = self.newtype_raw_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! { self.#field_name }
|
|
||||||
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
|
|
||||||
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
|
|
||||||
let field_name =
|
|
||||||
field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! { (self.#field_name) }
|
|
||||||
} else {
|
|
||||||
let initializers = self.struct_init_fields(RequestFieldKind::Body, quote!(self));
|
|
||||||
quote! { { #initializers } }
|
|
||||||
};
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
{
|
|
||||||
let request_body = RequestBody #request_body_initializers;
|
|
||||||
#serde_json::to_vec(&request_body)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
quote! { Vec::new() }
|
|
||||||
};
|
|
||||||
|
|
||||||
let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
let parse = quote! {
|
|
||||||
let #field_name = request_body.0;
|
|
||||||
};
|
|
||||||
|
|
||||||
(parse, quote! { #field_name, })
|
|
||||||
} else if let Some(field) = self.newtype_raw_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
let parse = quote! {
|
|
||||||
let #field_name = {
|
|
||||||
let mut reader = #bytes::Buf::reader(request.into_body());
|
|
||||||
let mut vec = ::std::vec::Vec::new();
|
|
||||||
::std::io::Read::read_to_end(&mut reader, &mut vec)
|
|
||||||
.expect("reading from a bytes::Buf never fails");
|
|
||||||
vec
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
(parse, quote! { #field_name, })
|
|
||||||
} else {
|
|
||||||
self.vars(RequestFieldKind::Body, quote!(request_body))
|
|
||||||
};
|
|
||||||
|
|
||||||
let request_generics = self.combine_lifetimes();
|
|
||||||
|
|
||||||
let request_body_struct =
|
let request_body_struct =
|
||||||
if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) {
|
if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) {
|
||||||
let field = Field { ident: None, colon_token: None, ..body_field.field().clone() };
|
let field = Field { ident: None, colon_token: None, ..body_field.field().clone() };
|
||||||
@ -629,31 +231,9 @@ impl Request {
|
|||||||
TokenStream::new()
|
TokenStream::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
let request_lifetimes = self.combine_lifetimes();
|
let lifetimes = self.combine_lifetimes();
|
||||||
let non_auth_endpoint_impls: TokenStream = metadata
|
let outgoing_request_impl = self.expand_outgoing(metadata, error_ty, &lifetimes, ruma_api);
|
||||||
.authentication
|
let incoming_request_impl = self.expand_incoming(metadata, error_ty, ruma_api);
|
||||||
.iter()
|
|
||||||
.map(|auth| {
|
|
||||||
if auth.value != "None" {
|
|
||||||
TokenStream::new()
|
|
||||||
} else {
|
|
||||||
let attrs = &auth.attrs;
|
|
||||||
quote! {
|
|
||||||
#( #attrs )*
|
|
||||||
#[automatically_derived]
|
|
||||||
#[cfg(feature = "client")]
|
|
||||||
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest
|
|
||||||
for Request #request_lifetimes
|
|
||||||
{}
|
|
||||||
|
|
||||||
#( #attrs )*
|
|
||||||
#[automatically_derived]
|
|
||||||
#[cfg(feature = "server")]
|
|
||||||
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
#[doc = #docs]
|
#[doc = #docs]
|
||||||
@ -661,89 +241,13 @@ impl Request {
|
|||||||
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
|
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
|
||||||
#[incoming_derive(!Deserialize)]
|
#[incoming_derive(!Deserialize)]
|
||||||
#( #struct_attributes )*
|
#( #struct_attributes )*
|
||||||
pub struct Request #request_generics #request_def
|
pub struct Request #lifetimes #request_def
|
||||||
|
|
||||||
#non_auth_endpoint_impls
|
|
||||||
|
|
||||||
#request_body_struct
|
#request_body_struct
|
||||||
#request_query_struct
|
#request_query_struct
|
||||||
|
|
||||||
#[automatically_derived]
|
#outgoing_request_impl
|
||||||
#[cfg(feature = "client")]
|
#incoming_request_impl
|
||||||
impl #request_lifetimes #ruma_api::OutgoingRequest for Request #request_lifetimes {
|
|
||||||
type EndpointError = #error_ty;
|
|
||||||
type IncomingResponse = <Response as #ruma_serde::Outgoing>::Incoming;
|
|
||||||
|
|
||||||
const METADATA: #ruma_api::Metadata = self::METADATA;
|
|
||||||
|
|
||||||
fn try_into_http_request(
|
|
||||||
self,
|
|
||||||
base_url: &::std::primitive::str,
|
|
||||||
access_token: ::std::option::Option<&str>,
|
|
||||||
) -> ::std::result::Result<
|
|
||||||
#http::Request<Vec<u8>>,
|
|
||||||
#ruma_api::error::IntoHttpError,
|
|
||||||
> {
|
|
||||||
let metadata = self::METADATA;
|
|
||||||
|
|
||||||
let mut req_builder = #http::Request::builder()
|
|
||||||
.method(#http::Method::#method)
|
|
||||||
.uri(::std::format!(
|
|
||||||
"{}{}{}",
|
|
||||||
base_url.strip_suffix('/').unwrap_or(base_url),
|
|
||||||
#request_path_string,
|
|
||||||
#request_query_string,
|
|
||||||
))
|
|
||||||
.header(
|
|
||||||
#ruma_api::exports::http::header::CONTENT_TYPE,
|
|
||||||
"application/json",
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut req_headers = req_builder
|
|
||||||
.headers_mut()
|
|
||||||
.expect("`http::RequestBuilder` is in unusable state");
|
|
||||||
|
|
||||||
#header_kvs
|
|
||||||
|
|
||||||
let http_request = req_builder.body(#request_body)?;
|
|
||||||
|
|
||||||
Ok(http_request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[automatically_derived]
|
|
||||||
#[cfg(feature = "server")]
|
|
||||||
impl #ruma_api::IncomingRequest for #incoming_request_type {
|
|
||||||
type EndpointError = #error_ty;
|
|
||||||
type OutgoingResponse = Response;
|
|
||||||
|
|
||||||
const METADATA: #ruma_api::Metadata = self::METADATA;
|
|
||||||
|
|
||||||
fn try_from_http_request<T: #bytes::Buf>(
|
|
||||||
request: #http::Request<T>
|
|
||||||
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> {
|
|
||||||
if request.method() != #http::Method::#method {
|
|
||||||
return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch {
|
|
||||||
expected: #http::Method::#method,
|
|
||||||
received: request.method().clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#parse_request_path
|
|
||||||
#parse_query
|
|
||||||
#parse_headers
|
|
||||||
|
|
||||||
#extract_body
|
|
||||||
#parse_body
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
#path_vars
|
|
||||||
#query_vars
|
|
||||||
#header_vars
|
|
||||||
#body_vars
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -774,7 +278,7 @@ pub(crate) enum RequestField {
|
|||||||
|
|
||||||
impl RequestField {
|
impl RequestField {
|
||||||
/// Creates a new `RequestField`.
|
/// Creates a new `RequestField`.
|
||||||
pub fn new(kind: RequestFieldKind, field: Field, header: Option<Ident>) -> Self {
|
pub(super) fn new(kind: RequestFieldKind, field: Field, header: Option<Ident>) -> Self {
|
||||||
match kind {
|
match kind {
|
||||||
RequestFieldKind::Body => RequestField::Body(field),
|
RequestFieldKind::Body => RequestField::Body(field),
|
||||||
RequestFieldKind::Header => {
|
RequestFieldKind::Header => {
|
||||||
@ -788,71 +292,58 @@ impl RequestField {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the kind of the request field.
|
|
||||||
pub fn kind(&self) -> RequestFieldKind {
|
|
||||||
match self {
|
|
||||||
RequestField::Body(..) => RequestFieldKind::Body,
|
|
||||||
RequestField::Header(..) => RequestFieldKind::Header,
|
|
||||||
RequestField::NewtypeBody(..) => RequestFieldKind::NewtypeBody,
|
|
||||||
RequestField::NewtypeRawBody(..) => RequestFieldKind::NewtypeRawBody,
|
|
||||||
RequestField::Path(..) => RequestFieldKind::Path,
|
|
||||||
RequestField::Query(..) => RequestFieldKind::Query,
|
|
||||||
RequestField::QueryMap(..) => RequestFieldKind::QueryMap,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Whether or not this request field is a body kind.
|
/// Whether or not this request field is a body kind.
|
||||||
pub fn is_body(&self) -> bool {
|
pub(super) fn is_body(&self) -> bool {
|
||||||
self.kind() == RequestFieldKind::Body
|
matches!(self, RequestField::Body(..))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether or not this request field is a header kind.
|
/// Whether or not this request field is a header kind.
|
||||||
pub fn is_header(&self) -> bool {
|
fn is_header(&self) -> bool {
|
||||||
self.kind() == RequestFieldKind::Header
|
matches!(self, RequestField::Header(..))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether or not this request field is a newtype body kind.
|
/// Whether or not this request field is a newtype body kind.
|
||||||
pub fn is_newtype_body(&self) -> bool {
|
fn is_newtype_body(&self) -> bool {
|
||||||
self.kind() == RequestFieldKind::NewtypeBody
|
matches!(self, RequestField::NewtypeBody(..))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether or not this request field is a path kind.
|
/// Whether or not this request field is a path kind.
|
||||||
pub fn is_path(&self) -> bool {
|
fn is_path(&self) -> bool {
|
||||||
self.kind() == RequestFieldKind::Path
|
matches!(self, RequestField::Path(..))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether or not this request field is a query string kind.
|
/// Whether or not this request field is a query string kind.
|
||||||
pub fn is_query(&self) -> bool {
|
pub(super) fn is_query(&self) -> bool {
|
||||||
self.kind() == RequestFieldKind::Query
|
matches!(self, RequestField::Query(..))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the contained field if this request field is a body kind.
|
/// Return the contained field if this request field is a body kind.
|
||||||
pub fn as_body_field(&self) -> Option<&Field> {
|
fn as_body_field(&self) -> Option<&Field> {
|
||||||
self.field_of_kind(RequestFieldKind::Body)
|
self.field_of_kind(RequestFieldKind::Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the contained field if this request field is a body kind.
|
/// Return the contained field if this request field is a body kind.
|
||||||
pub fn as_newtype_body_field(&self) -> Option<&Field> {
|
fn as_newtype_body_field(&self) -> Option<&Field> {
|
||||||
self.field_of_kind(RequestFieldKind::NewtypeBody)
|
self.field_of_kind(RequestFieldKind::NewtypeBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the contained field if this request field is a raw body kind.
|
/// Return the contained field if this request field is a raw body kind.
|
||||||
pub fn as_newtype_raw_body_field(&self) -> Option<&Field> {
|
fn as_newtype_raw_body_field(&self) -> Option<&Field> {
|
||||||
self.field_of_kind(RequestFieldKind::NewtypeRawBody)
|
self.field_of_kind(RequestFieldKind::NewtypeRawBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the contained field if this request field is a query kind.
|
/// Return the contained field if this request field is a query kind.
|
||||||
pub fn as_query_field(&self) -> Option<&Field> {
|
fn as_query_field(&self) -> Option<&Field> {
|
||||||
self.field_of_kind(RequestFieldKind::Query)
|
self.field_of_kind(RequestFieldKind::Query)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the contained field if this request field is a query map kind.
|
/// Return the contained field if this request field is a query map kind.
|
||||||
pub fn as_query_map_field(&self) -> Option<&Field> {
|
fn as_query_map_field(&self) -> Option<&Field> {
|
||||||
self.field_of_kind(RequestFieldKind::QueryMap)
|
self.field_of_kind(RequestFieldKind::QueryMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the inner `Field` value.
|
/// Gets the inner `Field` value.
|
||||||
pub fn field(&self) -> &Field {
|
fn field(&self) -> &Field {
|
||||||
match self {
|
match self {
|
||||||
RequestField::Body(field)
|
RequestField::Body(field)
|
||||||
| RequestField::Header(field, _)
|
| RequestField::Header(field, _)
|
||||||
@ -865,11 +356,16 @@ impl RequestField {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the inner `Field` value if it's of the provided kind.
|
/// Gets the inner `Field` value if it's of the provided kind.
|
||||||
pub fn field_of_kind(&self, kind: RequestFieldKind) -> Option<&Field> {
|
fn field_of_kind(&self, kind: RequestFieldKind) -> Option<&Field> {
|
||||||
if self.kind() == kind {
|
match (self, kind) {
|
||||||
Some(self.field())
|
(RequestField::Body(field), RequestFieldKind::Body)
|
||||||
} else {
|
| (RequestField::Header(field, _), RequestFieldKind::Header)
|
||||||
None
|
| (RequestField::NewtypeBody(field), RequestFieldKind::NewtypeBody)
|
||||||
|
| (RequestField::NewtypeRawBody(field), RequestFieldKind::NewtypeRawBody)
|
||||||
|
| (RequestField::Path(field), RequestFieldKind::Path)
|
||||||
|
| (RequestField::Query(field), RequestFieldKind::Query)
|
||||||
|
| (RequestField::QueryMap(field), RequestFieldKind::QueryMap) => Some(field),
|
||||||
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -877,24 +373,11 @@ impl RequestField {
|
|||||||
/// The types of fields that a request can have, without their values.
|
/// The types of fields that a request can have, without their values.
|
||||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||||
pub(crate) enum RequestFieldKind {
|
pub(crate) enum RequestFieldKind {
|
||||||
/// See the similarly named variant of `RequestField`.
|
|
||||||
Body,
|
Body,
|
||||||
|
|
||||||
/// See the similarly named variant of `RequestField`.
|
|
||||||
Header,
|
Header,
|
||||||
|
|
||||||
/// See the similarly named variant of `RequestField`.
|
|
||||||
NewtypeBody,
|
NewtypeBody,
|
||||||
|
|
||||||
/// See the similarly named variant of `RequestField`.
|
|
||||||
NewtypeRawBody,
|
NewtypeRawBody,
|
||||||
|
|
||||||
/// See the similarly named variant of `RequestField`.
|
|
||||||
Path,
|
Path,
|
||||||
|
|
||||||
/// See the similarly named variant of `RequestField`.
|
|
||||||
Query,
|
Query,
|
||||||
|
|
||||||
/// See the similarly named variant of `RequestField`.
|
|
||||||
QueryMap,
|
QueryMap,
|
||||||
}
|
}
|
||||||
|
277
ruma-api-macros/src/api/request/incoming.rs
Normal file
277
ruma-api-macros/src/api/request/incoming.rs
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
use proc_macro2::{Ident, Span, TokenStream};
|
||||||
|
use quote::quote;
|
||||||
|
|
||||||
|
use super::{Metadata, Request, RequestField, RequestFieldKind};
|
||||||
|
|
||||||
|
impl Request {
|
||||||
|
pub fn expand_incoming(
|
||||||
|
&self,
|
||||||
|
metadata: &Metadata,
|
||||||
|
error_ty: &TokenStream,
|
||||||
|
ruma_api: &TokenStream,
|
||||||
|
) -> TokenStream {
|
||||||
|
let bytes = quote! { #ruma_api::exports::bytes };
|
||||||
|
let http = quote! { #ruma_api::exports::http };
|
||||||
|
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
||||||
|
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||||
|
let serde_json = quote! { #ruma_api::exports::serde_json };
|
||||||
|
|
||||||
|
let method = &metadata.method;
|
||||||
|
|
||||||
|
let incoming_request_type =
|
||||||
|
if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) };
|
||||||
|
|
||||||
|
let (parse_request_path, path_vars) = if self.has_path_fields() {
|
||||||
|
let path_string = metadata.path.value();
|
||||||
|
|
||||||
|
assert!(path_string.starts_with('/'), "path needs to start with '/'");
|
||||||
|
assert!(
|
||||||
|
path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(),
|
||||||
|
"number of declared path parameters needs to match amount of placeholders in path"
|
||||||
|
);
|
||||||
|
|
||||||
|
let path_var_decls = path_string[1..]
|
||||||
|
.split('/')
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, seg)| seg.starts_with(':'))
|
||||||
|
.map(|(i, seg)| {
|
||||||
|
let path_var = Ident::new(&seg[1..], Span::call_site());
|
||||||
|
quote! {
|
||||||
|
let #path_var = {
|
||||||
|
let segment = path_segments[#i].as_bytes();
|
||||||
|
let decoded =
|
||||||
|
#percent_encoding::percent_decode(segment).decode_utf8()?;
|
||||||
|
|
||||||
|
::std::convert::TryFrom::try_from(&*decoded)?
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let parse_request_path = quote! {
|
||||||
|
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
|
||||||
|
request.uri().path()[1..].split('/').collect();
|
||||||
|
|
||||||
|
#(#path_var_decls)*
|
||||||
|
};
|
||||||
|
|
||||||
|
let path_vars = path_string[1..]
|
||||||
|
.split('/')
|
||||||
|
.filter(|seg| seg.starts_with(':'))
|
||||||
|
.map(|seg| Ident::new(&seg[1..], Span::call_site()));
|
||||||
|
|
||||||
|
(parse_request_path, quote! { #(#path_vars,)* })
|
||||||
|
} else {
|
||||||
|
(TokenStream::new(), TokenStream::new())
|
||||||
|
};
|
||||||
|
|
||||||
|
let (parse_query, query_vars) = if let Some(field) = self.query_map_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
let parse = quote! {
|
||||||
|
let #field_name = #ruma_serde::urlencoded::from_str(
|
||||||
|
&request.uri().query().unwrap_or(""),
|
||||||
|
)?;
|
||||||
|
};
|
||||||
|
|
||||||
|
(parse, quote! { #field_name, })
|
||||||
|
} else if self.has_query_fields() {
|
||||||
|
let (decls, names) = self.vars(RequestFieldKind::Query, quote!(request_query));
|
||||||
|
|
||||||
|
let parse = quote! {
|
||||||
|
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
|
||||||
|
#ruma_serde::urlencoded::from_str(
|
||||||
|
&request.uri().query().unwrap_or("")
|
||||||
|
)?;
|
||||||
|
|
||||||
|
#decls
|
||||||
|
};
|
||||||
|
|
||||||
|
(parse, names)
|
||||||
|
} else {
|
||||||
|
(TokenStream::new(), TokenStream::new())
|
||||||
|
};
|
||||||
|
|
||||||
|
let (parse_headers, header_vars) = if self.has_header_fields() {
|
||||||
|
let (decls, names): (TokenStream, Vec<_>) = self
|
||||||
|
.header_fields()
|
||||||
|
.map(|request_field| {
|
||||||
|
let (field, header_name) = match request_field {
|
||||||
|
RequestField::Header(field, header_name) => (field, header_name),
|
||||||
|
_ => panic!("expected request field to be header variant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let field_name = &field.ident;
|
||||||
|
let header_name_string = header_name.to_string();
|
||||||
|
|
||||||
|
let (some_case, none_case) = match &field.ty {
|
||||||
|
syn::Type::Path(syn::TypePath {
|
||||||
|
path: syn::Path { segments, .. }, ..
|
||||||
|
}) if segments.last().unwrap().ident == "Option" => {
|
||||||
|
(quote! { Some(str_value.to_owned()) }, quote! { None })
|
||||||
|
}
|
||||||
|
_ => (
|
||||||
|
quote! { str_value.to_owned() },
|
||||||
|
quote! {
|
||||||
|
return Err(
|
||||||
|
#ruma_api::error::HeaderDeserializationError::MissingHeader(
|
||||||
|
#header_name_string.into()
|
||||||
|
).into(),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let decl = quote! {
|
||||||
|
let #field_name = match headers.get(#http::header::#header_name) {
|
||||||
|
Some(header_value) => {
|
||||||
|
let str_value = header_value.to_str()?;
|
||||||
|
#some_case
|
||||||
|
}
|
||||||
|
None => #none_case,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
(decl, field_name)
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
|
||||||
|
let parse = quote! {
|
||||||
|
let headers = request.headers();
|
||||||
|
|
||||||
|
#decls
|
||||||
|
};
|
||||||
|
|
||||||
|
(parse, quote! { #(#names,)* })
|
||||||
|
} else {
|
||||||
|
(TokenStream::new(), TokenStream::new())
|
||||||
|
};
|
||||||
|
|
||||||
|
let extract_body = if self.has_body_fields() || self.newtype_body_field().is_some() {
|
||||||
|
let body_lifetimes = if self.has_body_lifetimes() {
|
||||||
|
// duplicate the anonymous lifetime as many times as needed
|
||||||
|
let lifetimes = std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len());
|
||||||
|
quote! { < #( #lifetimes, )* >}
|
||||||
|
} else {
|
||||||
|
TokenStream::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
let request_body: <
|
||||||
|
RequestBody #body_lifetimes
|
||||||
|
as #ruma_serde::Outgoing
|
||||||
|
>::Incoming = {
|
||||||
|
let body = request.into_body();
|
||||||
|
if #bytes::Buf::has_remaining(&body) {
|
||||||
|
#serde_json::from_reader(#bytes::Buf::reader(body))?
|
||||||
|
} else {
|
||||||
|
// If the request body is completely empty, pretend it is an empty JSON
|
||||||
|
// object instead. This allows requests with only optional body parameters
|
||||||
|
// to be deserialized in that case.
|
||||||
|
#serde_json::from_str("{}")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TokenStream::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
let parse = quote! {
|
||||||
|
let #field_name = request_body.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
(parse, quote! { #field_name, })
|
||||||
|
} else if let Some(field) = self.newtype_raw_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
let parse = quote! {
|
||||||
|
let #field_name = {
|
||||||
|
let mut reader = #bytes::Buf::reader(request.into_body());
|
||||||
|
let mut vec = ::std::vec::Vec::new();
|
||||||
|
::std::io::Read::read_to_end(&mut reader, &mut vec)
|
||||||
|
.expect("reading from a bytes::Buf never fails");
|
||||||
|
vec
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
(parse, quote! { #field_name, })
|
||||||
|
} else {
|
||||||
|
self.vars(RequestFieldKind::Body, quote!(request_body))
|
||||||
|
};
|
||||||
|
|
||||||
|
let non_auth_impls = metadata.authentication.iter().map(|auth| {
|
||||||
|
if auth.value == "None" {
|
||||||
|
let attrs = &auth.attrs;
|
||||||
|
quote! {
|
||||||
|
#( #attrs )*
|
||||||
|
#[automatically_derived]
|
||||||
|
#[cfg(feature = "server")]
|
||||||
|
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TokenStream::new()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
#[automatically_derived]
|
||||||
|
#[cfg(feature = "server")]
|
||||||
|
impl #ruma_api::IncomingRequest for #incoming_request_type {
|
||||||
|
type EndpointError = #error_ty;
|
||||||
|
type OutgoingResponse = Response;
|
||||||
|
|
||||||
|
const METADATA: #ruma_api::Metadata = self::METADATA;
|
||||||
|
|
||||||
|
fn try_from_http_request<T: #bytes::Buf>(
|
||||||
|
request: #http::Request<T>
|
||||||
|
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> {
|
||||||
|
if request.method() != #http::Method::#method {
|
||||||
|
return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch {
|
||||||
|
expected: #http::Method::#method,
|
||||||
|
received: request.method().clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#parse_request_path
|
||||||
|
#parse_query
|
||||||
|
#parse_headers
|
||||||
|
|
||||||
|
#extract_body
|
||||||
|
#parse_body
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
#path_vars
|
||||||
|
#query_vars
|
||||||
|
#header_vars
|
||||||
|
#body_vars
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#(#non_auth_impls)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vars(
|
||||||
|
&self,
|
||||||
|
request_field_kind: RequestFieldKind,
|
||||||
|
src: TokenStream,
|
||||||
|
) -> (TokenStream, TokenStream) {
|
||||||
|
self.fields
|
||||||
|
.iter()
|
||||||
|
.filter_map(|f| f.field_of_kind(request_field_kind))
|
||||||
|
.map(|field| {
|
||||||
|
let field_name =
|
||||||
|
field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
let cfg_attrs =
|
||||||
|
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let decl = quote! {
|
||||||
|
#( #cfg_attrs )*
|
||||||
|
let #field_name = #src.#field_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
(decl, quote! { #field_name, })
|
||||||
|
})
|
||||||
|
.unzip()
|
||||||
|
}
|
||||||
|
}
|
261
ruma-api-macros/src/api/request/outgoing.rs
Normal file
261
ruma-api-macros/src/api/request/outgoing.rs
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
use proc_macro2::{Ident, Span, TokenStream};
|
||||||
|
use quote::quote;
|
||||||
|
|
||||||
|
use super::{Metadata, Request, RequestField, RequestFieldKind};
|
||||||
|
|
||||||
|
impl Request {
|
||||||
|
pub fn expand_outgoing(
|
||||||
|
&self,
|
||||||
|
metadata: &Metadata,
|
||||||
|
error_ty: &TokenStream,
|
||||||
|
lifetimes: &TokenStream,
|
||||||
|
ruma_api: &TokenStream,
|
||||||
|
) -> TokenStream {
|
||||||
|
let http = quote! { #ruma_api::exports::http };
|
||||||
|
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
||||||
|
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||||
|
let serde_json = quote! { #ruma_api::exports::serde_json };
|
||||||
|
|
||||||
|
let method = &metadata.method;
|
||||||
|
let request_path_string = if self.has_path_fields() {
|
||||||
|
let mut format_string = metadata.path.value();
|
||||||
|
let mut format_args = Vec::new();
|
||||||
|
|
||||||
|
while let Some(start_of_segment) = format_string.find(':') {
|
||||||
|
// ':' should only ever appear at the start of a segment
|
||||||
|
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
|
||||||
|
|
||||||
|
let end_of_segment = match format_string[start_of_segment..].find('/') {
|
||||||
|
Some(rel_pos) => start_of_segment + rel_pos,
|
||||||
|
None => format_string.len(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let path_var = Ident::new(
|
||||||
|
&format_string[start_of_segment + 1..end_of_segment],
|
||||||
|
Span::call_site(),
|
||||||
|
);
|
||||||
|
format_args.push(quote! {
|
||||||
|
#percent_encoding::utf8_percent_encode(
|
||||||
|
&self.#path_var.to_string(),
|
||||||
|
#percent_encoding::NON_ALPHANUMERIC,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
format_string.replace_range(start_of_segment..end_of_segment, "{}");
|
||||||
|
}
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
format_args!(#format_string, #(#format_args),*)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! { metadata.path.to_owned() }
|
||||||
|
};
|
||||||
|
|
||||||
|
let request_query_string = if let Some(field) = self.query_map_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
||||||
|
|
||||||
|
quote!({
|
||||||
|
// This function exists so that the compiler will throw an error when the type of
|
||||||
|
// the field with the query_map attribute doesn't implement
|
||||||
|
// `IntoIterator<Item = (String, String)>`.
|
||||||
|
//
|
||||||
|
// This is necessary because the `ruma_serde::urlencoded::to_string` call will
|
||||||
|
// result in a runtime error when the type cannot be encoded as a list key-value
|
||||||
|
// pairs (?key1=value1&key2=value2).
|
||||||
|
//
|
||||||
|
// By asserting that it implements the iterator trait, we can ensure that it won't
|
||||||
|
// fail.
|
||||||
|
fn assert_trait_impl<T>(_: &T)
|
||||||
|
where
|
||||||
|
T: ::std::iter::IntoIterator<
|
||||||
|
Item = (::std::string::String, ::std::string::String),
|
||||||
|
>,
|
||||||
|
{}
|
||||||
|
|
||||||
|
let request_query = RequestQuery(self.#field_name);
|
||||||
|
assert_trait_impl(&request_query.0);
|
||||||
|
|
||||||
|
format_args!(
|
||||||
|
"?{}",
|
||||||
|
#ruma_serde::urlencoded::to_string(request_query)?
|
||||||
|
)
|
||||||
|
})
|
||||||
|
} else if self.has_query_fields() {
|
||||||
|
let request_query_init_fields =
|
||||||
|
self.struct_init_fields(RequestFieldKind::Query, quote!(self));
|
||||||
|
|
||||||
|
quote!({
|
||||||
|
let request_query = RequestQuery {
|
||||||
|
#request_query_init_fields
|
||||||
|
};
|
||||||
|
|
||||||
|
format_args!(
|
||||||
|
"?{}",
|
||||||
|
#ruma_serde::urlencoded::to_string(request_query)?
|
||||||
|
)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
quote! { "" }
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut header_kvs: TokenStream = self
|
||||||
|
.header_fields()
|
||||||
|
.map(|request_field| {
|
||||||
|
let (field, header_name) = match request_field {
|
||||||
|
RequestField::Header(field, header_name) => (field, header_name),
|
||||||
|
_ => unreachable!("expected request field to be header variant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let field_name = &field.ident;
|
||||||
|
|
||||||
|
match &field.ty {
|
||||||
|
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
|
||||||
|
if segments.last().unwrap().ident == "Option" =>
|
||||||
|
{
|
||||||
|
quote! {
|
||||||
|
if let Some(header_val) = self.#field_name.as_ref() {
|
||||||
|
req_headers.insert(
|
||||||
|
#http::header::#header_name,
|
||||||
|
#http::header::HeaderValue::from_str(header_val)?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => quote! {
|
||||||
|
req_headers.insert(
|
||||||
|
#http::header::#header_name,
|
||||||
|
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for auth in &metadata.authentication {
|
||||||
|
if auth.value == "AccessToken" {
|
||||||
|
let attrs = &auth.attrs;
|
||||||
|
header_kvs.extend(quote! {
|
||||||
|
#( #attrs )*
|
||||||
|
req_headers.insert(
|
||||||
|
#http::header::AUTHORIZATION,
|
||||||
|
#http::header::HeaderValue::from_str(
|
||||||
|
&::std::format!(
|
||||||
|
"Bearer {}",
|
||||||
|
access_token.ok_or(
|
||||||
|
#ruma_api::error::IntoHttpError::NeedsAuthentication
|
||||||
|
)?
|
||||||
|
)
|
||||||
|
)?
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let request_body = if let Some(field) = self.newtype_raw_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! { self.#field_name }
|
||||||
|
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
|
||||||
|
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
|
||||||
|
let field_name =
|
||||||
|
field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! { (self.#field_name) }
|
||||||
|
} else {
|
||||||
|
let initializers = self.struct_init_fields(RequestFieldKind::Body, quote!(self));
|
||||||
|
quote! { { #initializers } }
|
||||||
|
};
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
{
|
||||||
|
let request_body = RequestBody #request_body_initializers;
|
||||||
|
#serde_json::to_vec(&request_body)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! { Vec::new() }
|
||||||
|
};
|
||||||
|
|
||||||
|
let non_auth_impls = metadata.authentication.iter().map(|auth| {
|
||||||
|
if auth.value == "None" {
|
||||||
|
let attrs = &auth.attrs;
|
||||||
|
quote! {
|
||||||
|
#( #attrs )*
|
||||||
|
#[automatically_derived]
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
impl #lifetimes #ruma_api::OutgoingNonAuthRequest for Request #lifetimes {}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TokenStream::new()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
#[automatically_derived]
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
impl #lifetimes #ruma_api::OutgoingRequest for Request #lifetimes {
|
||||||
|
type EndpointError = #error_ty;
|
||||||
|
type IncomingResponse = <Response as #ruma_serde::Outgoing>::Incoming;
|
||||||
|
|
||||||
|
const METADATA: #ruma_api::Metadata = self::METADATA;
|
||||||
|
|
||||||
|
fn try_into_http_request(
|
||||||
|
self,
|
||||||
|
base_url: &::std::primitive::str,
|
||||||
|
access_token: ::std::option::Option<&str>,
|
||||||
|
) -> ::std::result::Result<
|
||||||
|
#http::Request<Vec<u8>>,
|
||||||
|
#ruma_api::error::IntoHttpError,
|
||||||
|
> {
|
||||||
|
let metadata = self::METADATA;
|
||||||
|
|
||||||
|
let mut req_builder = #http::Request::builder()
|
||||||
|
.method(#http::Method::#method)
|
||||||
|
.uri(::std::format!(
|
||||||
|
"{}{}{}",
|
||||||
|
base_url.strip_suffix('/').unwrap_or(base_url),
|
||||||
|
#request_path_string,
|
||||||
|
#request_query_string,
|
||||||
|
))
|
||||||
|
.header(
|
||||||
|
#ruma_api::exports::http::header::CONTENT_TYPE,
|
||||||
|
"application/json",
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut req_headers = req_builder
|
||||||
|
.headers_mut()
|
||||||
|
.expect("`http::RequestBuilder` is in unusable state");
|
||||||
|
|
||||||
|
#header_kvs
|
||||||
|
|
||||||
|
let http_request = req_builder.body(#request_body)?;
|
||||||
|
|
||||||
|
Ok(http_request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#(#non_auth_impls)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Produces code for a struct initializer for the given field kind to be accessed through the
|
||||||
|
/// given variable name.
|
||||||
|
fn struct_init_fields(
|
||||||
|
&self,
|
||||||
|
request_field_kind: RequestFieldKind,
|
||||||
|
src: TokenStream,
|
||||||
|
) -> TokenStream {
|
||||||
|
self.fields
|
||||||
|
.iter()
|
||||||
|
.filter_map(|f| f.field_of_kind(request_field_kind))
|
||||||
|
.map(|field| {
|
||||||
|
let field_name =
|
||||||
|
field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
let cfg_attrs =
|
||||||
|
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
#( #cfg_attrs )*
|
||||||
|
#field_name: #src.#field_name,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user