api: Allow types implementing ToString and FromStr for header values

Contrary to what the previous docs said, types implementing Display did not work,
only string types worked.
This commit is contained in:
Kévin Commaille 2024-06-27 18:44:20 +02:00 committed by Kévin Commaille
parent cc56e5277b
commit f73ba5556c
7 changed files with 75 additions and 32 deletions

View File

@ -13,6 +13,8 @@ Breaking changes:
This allows to use a struct or enum as well as a map to represent the list of
query parameters. Note that the (de)serialization of the type used must work
with `serde_html_form`.
- The `header` attribute for the `request` and `response` macros accepts any
type that implements `ToString` and `FromStr`.
Improvements:

View File

@ -113,10 +113,12 @@ macro_rules! metadata {
/// To declare which part of the request a field belongs to:
///
/// * `#[ruma_api(header = HEADER_NAME)]`: Fields with this attribute will be treated as HTTP
/// headers on the request. The value must implement `Display`. Generally this is a `String`.
/// The attribute value shown above as `HEADER_NAME` must be a `const` expression of the type
/// `http::header::HeaderName`, like one of the constants from `http::header`, e.g.
/// `CONTENT_TYPE`.
/// headers on the request. The value must implement `ToString` and `FromStr`. Generally this
/// is a `String`. The attribute value shown above as `HEADER_NAME` must be a `const`
/// expression of the type `http::header::HeaderName`, like one of the constants from
/// `http::header`, e.g. `CONTENT_TYPE`. During deserialization of the request, if the field
/// is an `Option` and parsing the header fails, the error will be ignored and the value will
/// be `None`.
/// * `#[ruma_api(path)]`: Fields with this attribute will be inserted into the matching path
/// component of the request URL. If there are multiple of these fields, the order in which
/// they are declared must match the order in which they occur in the request path.
@ -230,9 +232,11 @@ pub use ruma_macros::request;
/// To declare which part of the response a field belongs to:
///
/// * `#[ruma_api(header = HEADER_NAME)]`: Fields with this attribute will be treated as HTTP
/// headers on the response. The value must implement `Display`. Generally this is a
/// `String`. The attribute value shown above as `HEADER_NAME` must be a header name constant
/// from `http::header`, e.g. `CONTENT_TYPE`.
/// headers on the response. The value must implement `ToString` and `FromStr`. Generally
/// this is a `String`. The attribute value shown above as `HEADER_NAME` must be a header
/// name constant from `http::header`, e.g. `CONTENT_TYPE`. During deserialization of the
/// response, if the field is an `Option` and parsing the header fails, the error will be
/// ignored and the value will be `None`.
/// * No attribute: Fields without an attribute are part of the body. They can use `#[serde]`
/// attributes to customize (de)serialization.
/// * `#[ruma_api(body)]`: Use this if multiple endpoints should share a response body type, or

View File

@ -277,6 +277,10 @@ pub enum HeaderDeserializationError {
#[error("missing header `{0}`")]
MissingHeader(String),
/// The given header failed to parse.
#[error("invalid header: {0}")]
InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
/// A header was received with a unexpected value.
#[error(
"The {header} header was received with an unexpected value, \

View File

@ -80,18 +80,38 @@ impl Request {
syn::Type::Path(syn::TypePath {
path: syn::Path { segments, .. }, ..
}) if segments.last().unwrap().ident == "Option" => {
(quote! { Some(str_value.to_owned()) }, quote! { None })
let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
args: option_args, ..
}) = &segments.last().unwrap().arguments else {
panic!("Option should use angle brackets");
};
let syn::GenericArgument::Type(field_type) = option_args.first().unwrap() else {
panic!("Option brackets should contain type");
};
(
quote! {
str_value.parse::<#field_type>().ok()
},
quote! { None }
)
}
_ => {
let field_type = &field.ty;
(
quote! {
str_value
.parse::<#field_type>()
.map_err(|e| #ruma_common::api::error::HeaderDeserializationError::InvalidHeader(e.into()))?
},
quote! {
return Err(
#ruma_common::api::error::HeaderDeserializationError::MissingHeader(
#header_name_string.into()
).into(),
)
},
)
}
_ => (
quote! { str_value.to_owned() },
quote! {
return Err(
#ruma_common::api::error::HeaderDeserializationError::MissingHeader(
#header_name_string.into()
).into(),
)
},
),
};
let decl = quote! {

View File

@ -66,7 +66,7 @@ impl Request {
if let Some(header_val) = self.#field_name.as_ref() {
req_headers.insert(
#header_name,
#http::header::HeaderValue::from_str(header_val)?,
#http::header::HeaderValue::from_str(&header_val.to_string())?,
);
}
}
@ -74,7 +74,7 @@ impl Request {
_ => quote! {
req_headers.insert(
#header_name,
#http::header::HeaderValue::from_str(self.#field_name.as_ref())?,
#http::header::HeaderValue::from_str(&self.#field_name.to_string())?,
);
},
}

View File

@ -56,24 +56,37 @@ impl Response {
Type::Path(syn::TypePath {
path: syn::Path { segments, .. }, ..
}) if segments.last().unwrap().ident == "Option" => {
let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
args: option_args, ..
}) = &segments.last().unwrap().arguments else {
panic!("Option should use angle brackets");
};
let syn::GenericArgument::Type(field_type) = option_args.first().unwrap() else {
panic!("Option brackets should contain type");
};
quote! {
#( #cfg_attrs )*
#field_name: {
headers.remove(#header_name)
.map(|h| h.to_str().map(|s| s.to_owned()))
.transpose()?
.and_then(|h| { h.to_str().ok()?.parse::<#field_type>().ok() })
}
}
}
_ => quote! {
#( #cfg_attrs )*
#field_name: {
headers.remove(#header_name)
.expect("response missing expected header")
.to_str()?
.to_owned()
_ => {
let field_type = &field.ty;
quote! {
#( #cfg_attrs )*
#field_name: {
headers.remove(#header_name)
.ok_or_else(|| #ruma_common::api::error::HeaderDeserializationError::MissingHeader(
#header_name.to_string()
))?
.to_str()?
.parse::<#field_type>()
.map_err(|e| #ruma_common::api::error::HeaderDeserializationError::InvalidHeader(e.into()))?
}
}
},
}
};
quote! { #optional_header }
}

View File

@ -22,7 +22,7 @@ impl Response {
if let Some(header) = self.#field_name {
headers.insert(
#header_name,
header.parse()?,
header.to_string().parse()?,
);
}
}
@ -30,7 +30,7 @@ impl Response {
_ => quote! {
headers.insert(
#header_name,
self.#field_name.parse()?,
self.#field_name.to_string().parse()?,
);
},
}