api: Refactor macro code and improve error handling

* Inline lots of methods that were only used once
* Create a separate error case for missing headers
This commit is contained in:
Jonas Platte 2021-04-10 14:50:01 +02:00
parent 110aa58300
commit a68b854734
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
3 changed files with 150 additions and 184 deletions

View File

@ -31,97 +31,6 @@ pub(crate) struct Request {
} }
impl Request { impl Request {
/// Produces code to add necessary HTTP headers to an `http::Request`.
fn append_header_kvs(&self, ruma_api: &TokenStream) -> TokenStream {
let http = quote! { #ruma_api::exports::http };
self.header_fields()
.map(|request_field| {
let (field, header_name) = match request_field {
RequestField::Header(field, header_name) => (field, header_name),
_ => unreachable!("expected request field to be header variant"),
};
let field_name = &field.ident;
match &field.ty {
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
if segments.last().unwrap().ident == "Option" =>
{
quote! {
if let Some(header_val) = self.#field_name.as_ref() {
req_headers.insert(
#http::header::#header_name,
#http::header::HeaderValue::from_str(header_val)?,
);
}
}
}
_ => quote! {
req_headers.insert(
#http::header::#header_name,
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
);
},
}
})
.collect()
}
/// Produces code to extract fields from the HTTP headers in an `http::Request`.
fn parse_headers_from_request(&self, ruma_api: &TokenStream) -> TokenStream {
let http = quote! { #ruma_api::exports::http };
let serde = quote! { #ruma_api::exports::serde };
let serde_json = quote! { #ruma_api::exports::serde_json };
let fields = self.header_fields().map(|request_field| {
let (field, header_name) = match request_field {
RequestField::Header(field, header_name) => (field, header_name),
_ => panic!("expected request field to be header variant"),
};
let field_name = &field.ident;
let header_name_string = header_name.to_string();
let (some_case, none_case) = match &field.ty {
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
if segments.last().unwrap().ident == "Option" =>
{
(quote! { Some(header.to_owned()) }, quote! { None })
}
_ => (
quote! { header.to_owned() },
quote! {{
use #serde::de::Error as _;
// FIXME: Not a missing json field, a missing header!
return Err(#ruma_api::error::RequestDeserializationError::new(
#serde_json::Error::missing_field(
#header_name_string
),
request,
)
.into());
}},
),
};
quote! {
#field_name: match headers
.get(#http::header::#header_name)
.and_then(|v| v.to_str().ok()) // FIXME: Should have a distinct error message
{
Some(header) => #some_case,
None => #none_case,
}
}
});
quote! {
#(#fields,)*
}
}
/// Whether or not this request has any data in the HTTP body. /// Whether or not this request has any data in the HTTP body.
pub fn has_body_fields(&self) -> bool { pub fn has_body_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_body()) self.fields.iter().any(|field| field.is_body())
@ -173,32 +82,27 @@ impl Request {
/// The combination of every fields unique lifetime annotation. /// The combination of every fields unique lifetime annotation.
pub fn combine_lifetimes(&self) -> TokenStream { pub fn combine_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens( util::unique_lifetimes_to_tokens(
self.lifetimes [
.body &self.lifetimes.body,
.iter() &self.lifetimes.path,
.chain(self.lifetimes.path.iter()) &self.lifetimes.query,
.chain(self.lifetimes.query.iter()) &self.lifetimes.header,
.chain(self.lifetimes.header.iter()) ]
.collect::<BTreeSet<_>>() .iter()
.into_iter(), .flat_map(|set| set.iter()),
) )
} }
/// The lifetimes on fields with the `query` attribute. /// The lifetimes on fields with the `query` attribute.
pub fn query_lifetimes(&self) -> TokenStream { pub fn query_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens(self.lifetimes.query.iter()) util::unique_lifetimes_to_tokens(&self.lifetimes.query)
} }
/// The lifetimes on fields with the `body` attribute. /// The lifetimes on fields with the `body` attribute.
pub fn body_lifetimes(&self) -> TokenStream { pub fn body_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens(self.lifetimes.body.iter()) util::unique_lifetimes_to_tokens(&self.lifetimes.body)
} }
// /// The lifetimes on fields with the `header` attribute.
// pub fn header_lifetimes(&self) -> TokenStream {
// util::generics_to_tokens(self.lifetimes.header.iter())
// }
/// Produces an iterator over all the header fields. /// Produces an iterator over all the header fields.
pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> { pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> {
self.fields.iter().filter(|field| field.is_header()) self.fields.iter().filter(|field| field.is_header())
@ -224,28 +128,6 @@ impl Request {
self.fields.iter().find_map(RequestField::as_query_map_field) self.fields.iter().find_map(RequestField::as_query_map_field)
} }
/// Produces code for a struct initializer for body fields on a variable named `request`.
pub fn request_body_init_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Body, quote!(self))
}
/// Produces code for a struct initializer for query string fields on a variable named
/// `request`.
pub fn request_query_init_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Query, quote!(self))
}
/// Produces code for a struct initializer for body fields on a variable named `request_body`.
pub fn request_init_body_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Body, quote!(request_body))
}
/// Produces code for a struct initializer for query string fields on a variable named
/// `request_query`.
pub fn request_init_query_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Query, quote!(request_query))
}
/// Produces code for a struct initializer for the given field kind to be accessed through the /// Produces code for a struct initializer for the given field kind to be accessed through the
/// given variable name. /// given variable name.
fn struct_init_fields( fn struct_init_fields(
@ -336,10 +218,42 @@ impl Request {
#field_name: request_query, #field_name: request_query,
} }
} else { } else {
self.request_init_query_fields() self.struct_init_fields(RequestFieldKind::Query, quote!(request_query))
}; };
let mut header_kvs = self.append_header_kvs(&ruma_api); let mut header_kvs: TokenStream = self
.header_fields()
.map(|request_field| {
let (field, header_name) = match request_field {
RequestField::Header(field, header_name) => (field, header_name),
_ => unreachable!("expected request field to be header variant"),
};
let field_name = &field.ident;
match &field.ty {
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
if segments.last().unwrap().ident == "Option" =>
{
quote! {
if let Some(header_val) = self.#field_name.as_ref() {
req_headers.insert(
#http::header::#header_name,
#http::header::HeaderValue::from_str(header_val)?,
);
}
}
}
_ => quote! {
req_headers.insert(
#http::header::#header_name,
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
);
},
}
})
.collect();
for auth in &metadata.authentication { for auth in &metadata.authentication {
if auth.value == "AccessToken" { if auth.value == "AccessToken" {
let attrs = &auth.attrs; let attrs = &auth.attrs;
@ -398,13 +312,91 @@ impl Request {
}; };
let parse_request_headers = if self.has_header_fields() { let parse_request_headers = if self.has_header_fields() {
self.parse_headers_from_request(&ruma_api) let fields = self.header_fields().map(|request_field| {
let (field, header_name) = match request_field {
RequestField::Header(field, header_name) => (field, header_name),
_ => panic!("expected request field to be header variant"),
};
let field_name = &field.ident;
let header_name_string = header_name.to_string();
let (some_case, none_case) = match &field.ty {
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
if segments.last().unwrap().ident == "Option" =>
{
(quote! { Some(str_value.to_owned()) }, quote! { None })
}
_ => (
quote! { str_value.to_owned() },
quote! {
// FIXME: Not a missing json field, a missing header!
return Err(#ruma_api::error::RequestDeserializationError::new(
#ruma_api::error::HeaderDeserializationError::MissingHeader(
#header_name_string.into()
),
request,
)
.into())
},
),
};
quote! {
#field_name: match headers.get(#http::header::#header_name) {
Some(header_value) => {
let str_value =
#ruma_api::try_deserialize!(request, header_value.to_str());
#some_case
}
None => #none_case,
}
}
});
quote! {
#(#fields,)*
}
} else { } else {
TokenStream::new() TokenStream::new()
}; };
let request_body = self.build_request_body(&ruma_api); let request_body = if let Some(field) = self.newtype_raw_body_field() {
let parse_request_body = self.parse_request_body(); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { self.#field_name }
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
let field_name =
field.ident.as_ref().expect("expected field to have an identifier");
quote! { (self.#field_name) }
} else {
let initializers = self.struct_init_fields(RequestFieldKind::Body, quote!(self));
quote! { { #initializers } }
};
quote! {
{
let request_body = RequestBody #request_body_initializers;
#serde_json::to_vec(&request_body)?
}
}
} else {
quote! { Vec::new() }
};
let parse_request_body = if let Some(field) = self.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_body.0,
}
} else if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request.into_body(),
}
} else {
self.struct_init_fields(RequestFieldKind::Body, quote!(request_body))
};
let request_generics = self.combine_lifetimes(); let request_generics = self.combine_lifetimes();
@ -634,52 +626,6 @@ impl Request {
} }
} }
/// Generates the code to initialize a `Request`.
///
/// Used to construct an `http::Request`s body.
fn build_request_body(&self, ruma_api: &TokenStream) -> TokenStream {
let serde_json = quote! { #ruma_api::exports::serde_json };
if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote!(self.#field_name)
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
let field_name =
field.ident.as_ref().expect("expected field to have an identifier");
quote! { (self.#field_name) }
} else {
let initializers = self.request_body_init_fields();
quote! { { #initializers } }
};
quote! {
{
let request_body = RequestBody #request_body_initializers;
#serde_json::to_vec(&request_body)?
}
}
} else {
quote!(Vec::new())
}
}
fn parse_request_body(&self) -> TokenStream {
if let Some(field) = self.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_body.0,
}
} else if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request.into_body(),
}
} else {
self.request_init_body_fields()
}
}
/// The function determines the type of query string that needs to be built /// The function determines the type of query string that needs to be built
/// and then builds it using `ruma_serde::urlencoded::to_string`. /// and then builds it using `ruma_serde::urlencoded::to_string`.
fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream { fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream {
@ -715,7 +661,8 @@ impl Request {
) )
}) })
} else if self.has_query_fields() { } else if self.has_query_fields() {
let request_query_init_fields = self.request_query_init_fields(); let request_query_init_fields =
self.struct_init_fields(RequestFieldKind::Query, quote!(self));
quote!({ quote!({
let request_query = RequestQuery { let request_query = RequestQuery {
@ -739,7 +686,7 @@ impl Request {
/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is /// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is
/// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to /// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to
/// Ruma types. /// Ruma types.
pub(crate) fn path_string_and_parse( fn path_string_and_parse(
&self, &self,
metadata: &Metadata, metadata: &Metadata,
ruma_api: &TokenStream, ruma_api: &TokenStream,

View File

@ -8,10 +8,10 @@ use quote::quote;
use syn::{AttrStyle, Attribute, Ident, Lifetime}; use syn::{AttrStyle, Attribute, Ident, Lifetime};
/// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`. /// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`.
pub(crate) fn unique_lifetimes_to_tokens<'a, I: Iterator<Item = &'a Lifetime>>( pub(crate) fn unique_lifetimes_to_tokens<'a, I: IntoIterator<Item = &'a Lifetime>>(
lifetimes: I, lifetimes: I,
) -> TokenStream { ) -> TokenStream {
let lifetimes = lifetimes.collect::<BTreeSet<_>>(); let lifetimes = lifetimes.into_iter().collect::<BTreeSet<_>>();
if lifetimes.is_empty() { if lifetimes.is_empty() {
TokenStream::new() TokenStream::new()
} else { } else {

View File

@ -209,7 +209,7 @@ pub enum DeserializationError {
/// Header value deserialization failed. /// Header value deserialization failed.
#[error("{0}")] #[error("{0}")]
Header(#[from] http::header::ToStrError), Header(#[from] HeaderDeserializationError),
} }
impl From<std::convert::Infallible> for DeserializationError { impl From<std::convert::Infallible> for DeserializationError {
@ -217,3 +217,22 @@ impl From<std::convert::Infallible> for DeserializationError {
match err {} match err {}
} }
} }
impl From<http::header::ToStrError> for DeserializationError {
fn from(err: http::header::ToStrError) -> Self {
Self::Header(HeaderDeserializationError::ToStrError(err))
}
}
/// An error with the http headers.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum HeaderDeserializationError {
/// Failed to convert `http::header::HeaderValue` to `str`.
#[error("{0}")]
ToStrError(http::header::ToStrError),
/// The given required header is missing.
#[error("Missing header `{0}`")]
MissingHeader(String),
}