api: Refactor macro code and improve error handling
* Inline lots of methods that were only used once * Create a separate error case for missing headers
This commit is contained in:
parent
110aa58300
commit
a68b854734
@ -31,97 +31,6 @@ pub(crate) struct Request {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Request {
|
impl Request {
|
||||||
/// Produces code to add necessary HTTP headers to an `http::Request`.
|
|
||||||
fn append_header_kvs(&self, ruma_api: &TokenStream) -> TokenStream {
|
|
||||||
let http = quote! { #ruma_api::exports::http };
|
|
||||||
|
|
||||||
self.header_fields()
|
|
||||||
.map(|request_field| {
|
|
||||||
let (field, header_name) = match request_field {
|
|
||||||
RequestField::Header(field, header_name) => (field, header_name),
|
|
||||||
_ => unreachable!("expected request field to be header variant"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let field_name = &field.ident;
|
|
||||||
|
|
||||||
match &field.ty {
|
|
||||||
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
|
|
||||||
if segments.last().unwrap().ident == "Option" =>
|
|
||||||
{
|
|
||||||
quote! {
|
|
||||||
if let Some(header_val) = self.#field_name.as_ref() {
|
|
||||||
req_headers.insert(
|
|
||||||
#http::header::#header_name,
|
|
||||||
#http::header::HeaderValue::from_str(header_val)?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => quote! {
|
|
||||||
req_headers.insert(
|
|
||||||
#http::header::#header_name,
|
|
||||||
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
|
|
||||||
);
|
|
||||||
},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Produces code to extract fields from the HTTP headers in an `http::Request`.
|
|
||||||
fn parse_headers_from_request(&self, ruma_api: &TokenStream) -> TokenStream {
|
|
||||||
let http = quote! { #ruma_api::exports::http };
|
|
||||||
let serde = quote! { #ruma_api::exports::serde };
|
|
||||||
let serde_json = quote! { #ruma_api::exports::serde_json };
|
|
||||||
|
|
||||||
let fields = self.header_fields().map(|request_field| {
|
|
||||||
let (field, header_name) = match request_field {
|
|
||||||
RequestField::Header(field, header_name) => (field, header_name),
|
|
||||||
_ => panic!("expected request field to be header variant"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let field_name = &field.ident;
|
|
||||||
let header_name_string = header_name.to_string();
|
|
||||||
|
|
||||||
let (some_case, none_case) = match &field.ty {
|
|
||||||
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
|
|
||||||
if segments.last().unwrap().ident == "Option" =>
|
|
||||||
{
|
|
||||||
(quote! { Some(header.to_owned()) }, quote! { None })
|
|
||||||
}
|
|
||||||
_ => (
|
|
||||||
quote! { header.to_owned() },
|
|
||||||
quote! {{
|
|
||||||
use #serde::de::Error as _;
|
|
||||||
|
|
||||||
// FIXME: Not a missing json field, a missing header!
|
|
||||||
return Err(#ruma_api::error::RequestDeserializationError::new(
|
|
||||||
#serde_json::Error::missing_field(
|
|
||||||
#header_name_string
|
|
||||||
),
|
|
||||||
request,
|
|
||||||
)
|
|
||||||
.into());
|
|
||||||
}},
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
#field_name: match headers
|
|
||||||
.get(#http::header::#header_name)
|
|
||||||
.and_then(|v| v.to_str().ok()) // FIXME: Should have a distinct error message
|
|
||||||
{
|
|
||||||
Some(header) => #some_case,
|
|
||||||
None => #none_case,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
#(#fields,)*
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Whether or not this request has any data in the HTTP body.
|
/// Whether or not this request has any data in the HTTP body.
|
||||||
pub fn has_body_fields(&self) -> bool {
|
pub fn has_body_fields(&self) -> bool {
|
||||||
self.fields.iter().any(|field| field.is_body())
|
self.fields.iter().any(|field| field.is_body())
|
||||||
@ -173,32 +82,27 @@ impl Request {
|
|||||||
/// The combination of every fields unique lifetime annotation.
|
/// The combination of every fields unique lifetime annotation.
|
||||||
pub fn combine_lifetimes(&self) -> TokenStream {
|
pub fn combine_lifetimes(&self) -> TokenStream {
|
||||||
util::unique_lifetimes_to_tokens(
|
util::unique_lifetimes_to_tokens(
|
||||||
self.lifetimes
|
[
|
||||||
.body
|
&self.lifetimes.body,
|
||||||
.iter()
|
&self.lifetimes.path,
|
||||||
.chain(self.lifetimes.path.iter())
|
&self.lifetimes.query,
|
||||||
.chain(self.lifetimes.query.iter())
|
&self.lifetimes.header,
|
||||||
.chain(self.lifetimes.header.iter())
|
]
|
||||||
.collect::<BTreeSet<_>>()
|
.iter()
|
||||||
.into_iter(),
|
.flat_map(|set| set.iter()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The lifetimes on fields with the `query` attribute.
|
/// The lifetimes on fields with the `query` attribute.
|
||||||
pub fn query_lifetimes(&self) -> TokenStream {
|
pub fn query_lifetimes(&self) -> TokenStream {
|
||||||
util::unique_lifetimes_to_tokens(self.lifetimes.query.iter())
|
util::unique_lifetimes_to_tokens(&self.lifetimes.query)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The lifetimes on fields with the `body` attribute.
|
/// The lifetimes on fields with the `body` attribute.
|
||||||
pub fn body_lifetimes(&self) -> TokenStream {
|
pub fn body_lifetimes(&self) -> TokenStream {
|
||||||
util::unique_lifetimes_to_tokens(self.lifetimes.body.iter())
|
util::unique_lifetimes_to_tokens(&self.lifetimes.body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// /// The lifetimes on fields with the `header` attribute.
|
|
||||||
// pub fn header_lifetimes(&self) -> TokenStream {
|
|
||||||
// util::generics_to_tokens(self.lifetimes.header.iter())
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Produces an iterator over all the header fields.
|
/// Produces an iterator over all the header fields.
|
||||||
pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> {
|
pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> {
|
||||||
self.fields.iter().filter(|field| field.is_header())
|
self.fields.iter().filter(|field| field.is_header())
|
||||||
@ -224,28 +128,6 @@ impl Request {
|
|||||||
self.fields.iter().find_map(RequestField::as_query_map_field)
|
self.fields.iter().find_map(RequestField::as_query_map_field)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Produces code for a struct initializer for body fields on a variable named `request`.
|
|
||||||
pub fn request_body_init_fields(&self) -> TokenStream {
|
|
||||||
self.struct_init_fields(RequestFieldKind::Body, quote!(self))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Produces code for a struct initializer for query string fields on a variable named
|
|
||||||
/// `request`.
|
|
||||||
pub fn request_query_init_fields(&self) -> TokenStream {
|
|
||||||
self.struct_init_fields(RequestFieldKind::Query, quote!(self))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Produces code for a struct initializer for body fields on a variable named `request_body`.
|
|
||||||
pub fn request_init_body_fields(&self) -> TokenStream {
|
|
||||||
self.struct_init_fields(RequestFieldKind::Body, quote!(request_body))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Produces code for a struct initializer for query string fields on a variable named
|
|
||||||
/// `request_query`.
|
|
||||||
pub fn request_init_query_fields(&self) -> TokenStream {
|
|
||||||
self.struct_init_fields(RequestFieldKind::Query, quote!(request_query))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Produces code for a struct initializer for the given field kind to be accessed through the
|
/// Produces code for a struct initializer for the given field kind to be accessed through the
|
||||||
/// given variable name.
|
/// given variable name.
|
||||||
fn struct_init_fields(
|
fn struct_init_fields(
|
||||||
@ -336,10 +218,42 @@ impl Request {
|
|||||||
#field_name: request_query,
|
#field_name: request_query,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.request_init_query_fields()
|
self.struct_init_fields(RequestFieldKind::Query, quote!(request_query))
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut header_kvs = self.append_header_kvs(&ruma_api);
|
let mut header_kvs: TokenStream = self
|
||||||
|
.header_fields()
|
||||||
|
.map(|request_field| {
|
||||||
|
let (field, header_name) = match request_field {
|
||||||
|
RequestField::Header(field, header_name) => (field, header_name),
|
||||||
|
_ => unreachable!("expected request field to be header variant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let field_name = &field.ident;
|
||||||
|
|
||||||
|
match &field.ty {
|
||||||
|
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
|
||||||
|
if segments.last().unwrap().ident == "Option" =>
|
||||||
|
{
|
||||||
|
quote! {
|
||||||
|
if let Some(header_val) = self.#field_name.as_ref() {
|
||||||
|
req_headers.insert(
|
||||||
|
#http::header::#header_name,
|
||||||
|
#http::header::HeaderValue::from_str(header_val)?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => quote! {
|
||||||
|
req_headers.insert(
|
||||||
|
#http::header::#header_name,
|
||||||
|
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
for auth in &metadata.authentication {
|
for auth in &metadata.authentication {
|
||||||
if auth.value == "AccessToken" {
|
if auth.value == "AccessToken" {
|
||||||
let attrs = &auth.attrs;
|
let attrs = &auth.attrs;
|
||||||
@ -398,13 +312,91 @@ impl Request {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let parse_request_headers = if self.has_header_fields() {
|
let parse_request_headers = if self.has_header_fields() {
|
||||||
self.parse_headers_from_request(&ruma_api)
|
let fields = self.header_fields().map(|request_field| {
|
||||||
|
let (field, header_name) = match request_field {
|
||||||
|
RequestField::Header(field, header_name) => (field, header_name),
|
||||||
|
_ => panic!("expected request field to be header variant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let field_name = &field.ident;
|
||||||
|
let header_name_string = header_name.to_string();
|
||||||
|
|
||||||
|
let (some_case, none_case) = match &field.ty {
|
||||||
|
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
|
||||||
|
if segments.last().unwrap().ident == "Option" =>
|
||||||
|
{
|
||||||
|
(quote! { Some(str_value.to_owned()) }, quote! { None })
|
||||||
|
}
|
||||||
|
_ => (
|
||||||
|
quote! { str_value.to_owned() },
|
||||||
|
quote! {
|
||||||
|
// FIXME: Not a missing json field, a missing header!
|
||||||
|
return Err(#ruma_api::error::RequestDeserializationError::new(
|
||||||
|
#ruma_api::error::HeaderDeserializationError::MissingHeader(
|
||||||
|
#header_name_string.into()
|
||||||
|
),
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
.into())
|
||||||
|
},
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
#field_name: match headers.get(#http::header::#header_name) {
|
||||||
|
Some(header_value) => {
|
||||||
|
let str_value =
|
||||||
|
#ruma_api::try_deserialize!(request, header_value.to_str());
|
||||||
|
#some_case
|
||||||
|
}
|
||||||
|
None => #none_case,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
#(#fields,)*
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TokenStream::new()
|
TokenStream::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
let request_body = self.build_request_body(&ruma_api);
|
let request_body = if let Some(field) = self.newtype_raw_body_field() {
|
||||||
let parse_request_body = self.parse_request_body();
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! { self.#field_name }
|
||||||
|
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
|
||||||
|
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
|
||||||
|
let field_name =
|
||||||
|
field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! { (self.#field_name) }
|
||||||
|
} else {
|
||||||
|
let initializers = self.struct_init_fields(RequestFieldKind::Body, quote!(self));
|
||||||
|
quote! { { #initializers } }
|
||||||
|
};
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
{
|
||||||
|
let request_body = RequestBody #request_body_initializers;
|
||||||
|
#serde_json::to_vec(&request_body)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! { Vec::new() }
|
||||||
|
};
|
||||||
|
|
||||||
|
let parse_request_body = if let Some(field) = self.newtype_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! {
|
||||||
|
#field_name: request_body.0,
|
||||||
|
}
|
||||||
|
} else if let Some(field) = self.newtype_raw_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! {
|
||||||
|
#field_name: request.into_body(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.struct_init_fields(RequestFieldKind::Body, quote!(request_body))
|
||||||
|
};
|
||||||
|
|
||||||
let request_generics = self.combine_lifetimes();
|
let request_generics = self.combine_lifetimes();
|
||||||
|
|
||||||
@ -634,52 +626,6 @@ impl Request {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates the code to initialize a `Request`.
|
|
||||||
///
|
|
||||||
/// Used to construct an `http::Request`s body.
|
|
||||||
fn build_request_body(&self, ruma_api: &TokenStream) -> TokenStream {
|
|
||||||
let serde_json = quote! { #ruma_api::exports::serde_json };
|
|
||||||
|
|
||||||
if let Some(field) = self.newtype_raw_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote!(self.#field_name)
|
|
||||||
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
|
|
||||||
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
|
|
||||||
let field_name =
|
|
||||||
field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! { (self.#field_name) }
|
|
||||||
} else {
|
|
||||||
let initializers = self.request_body_init_fields();
|
|
||||||
quote! { { #initializers } }
|
|
||||||
};
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
{
|
|
||||||
let request_body = RequestBody #request_body_initializers;
|
|
||||||
#serde_json::to_vec(&request_body)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
quote!(Vec::new())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_request_body(&self) -> TokenStream {
|
|
||||||
if let Some(field) = self.newtype_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! {
|
|
||||||
#field_name: request_body.0,
|
|
||||||
}
|
|
||||||
} else if let Some(field) = self.newtype_raw_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! {
|
|
||||||
#field_name: request.into_body(),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.request_init_body_fields()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The function determines the type of query string that needs to be built
|
/// The function determines the type of query string that needs to be built
|
||||||
/// and then builds it using `ruma_serde::urlencoded::to_string`.
|
/// and then builds it using `ruma_serde::urlencoded::to_string`.
|
||||||
fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream {
|
fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream {
|
||||||
@ -715,7 +661,8 @@ impl Request {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
} else if self.has_query_fields() {
|
} else if self.has_query_fields() {
|
||||||
let request_query_init_fields = self.request_query_init_fields();
|
let request_query_init_fields =
|
||||||
|
self.struct_init_fields(RequestFieldKind::Query, quote!(self));
|
||||||
|
|
||||||
quote!({
|
quote!({
|
||||||
let request_query = RequestQuery {
|
let request_query = RequestQuery {
|
||||||
@ -739,7 +686,7 @@ impl Request {
|
|||||||
/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is
|
/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is
|
||||||
/// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to
|
/// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to
|
||||||
/// Ruma types.
|
/// Ruma types.
|
||||||
pub(crate) fn path_string_and_parse(
|
fn path_string_and_parse(
|
||||||
&self,
|
&self,
|
||||||
metadata: &Metadata,
|
metadata: &Metadata,
|
||||||
ruma_api: &TokenStream,
|
ruma_api: &TokenStream,
|
||||||
|
@ -8,10 +8,10 @@ use quote::quote;
|
|||||||
use syn::{AttrStyle, Attribute, Ident, Lifetime};
|
use syn::{AttrStyle, Attribute, Ident, Lifetime};
|
||||||
|
|
||||||
/// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`.
|
/// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`.
|
||||||
pub(crate) fn unique_lifetimes_to_tokens<'a, I: Iterator<Item = &'a Lifetime>>(
|
pub(crate) fn unique_lifetimes_to_tokens<'a, I: IntoIterator<Item = &'a Lifetime>>(
|
||||||
lifetimes: I,
|
lifetimes: I,
|
||||||
) -> TokenStream {
|
) -> TokenStream {
|
||||||
let lifetimes = lifetimes.collect::<BTreeSet<_>>();
|
let lifetimes = lifetimes.into_iter().collect::<BTreeSet<_>>();
|
||||||
if lifetimes.is_empty() {
|
if lifetimes.is_empty() {
|
||||||
TokenStream::new()
|
TokenStream::new()
|
||||||
} else {
|
} else {
|
||||||
|
@ -209,7 +209,7 @@ pub enum DeserializationError {
|
|||||||
|
|
||||||
/// Header value deserialization failed.
|
/// Header value deserialization failed.
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
Header(#[from] http::header::ToStrError),
|
Header(#[from] HeaderDeserializationError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<std::convert::Infallible> for DeserializationError {
|
impl From<std::convert::Infallible> for DeserializationError {
|
||||||
@ -217,3 +217,22 @@ impl From<std::convert::Infallible> for DeserializationError {
|
|||||||
match err {}
|
match err {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<http::header::ToStrError> for DeserializationError {
|
||||||
|
fn from(err: http::header::ToStrError) -> Self {
|
||||||
|
Self::Header(HeaderDeserializationError::ToStrError(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error with the http headers.
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum HeaderDeserializationError {
|
||||||
|
/// Failed to convert `http::header::HeaderValue` to `str`.
|
||||||
|
#[error("{0}")]
|
||||||
|
ToStrError(http::header::ToStrError),
|
||||||
|
|
||||||
|
/// The given required header is missing.
|
||||||
|
#[error("Missing header `{0}`")]
|
||||||
|
MissingHeader(String),
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user