Use a dedicated method for conversion from Ruma request type to http::Request
This commit is contained in:
parent
67a2012b85
commit
10184cb9ca
@ -139,15 +139,20 @@ impl ToTokens for Api {
|
||||
self.request.request_init_query_fields()
|
||||
};
|
||||
|
||||
let add_headers_to_request = if self.request.has_header_fields() {
|
||||
let add_headers = self.request.add_headers_to_request();
|
||||
quote! {
|
||||
let headers = http_request.headers_mut();
|
||||
#add_headers
|
||||
}
|
||||
} else {
|
||||
TokenStream::new()
|
||||
};
|
||||
let mut header_kvs = self.request.append_header_kvs();
|
||||
if requires_authentication.value {
|
||||
header_kvs.push(quote! {
|
||||
ruma_api::exports::http::header::AUTHORIZATION,
|
||||
ruma_api::exports::http::header::HeaderValue::from_str(
|
||||
&format!(
|
||||
"Bearer {}",
|
||||
access_token.ok_or_else(
|
||||
ruma_api::error::IntoHttpError::needs_authentication
|
||||
)?
|
||||
)
|
||||
)?
|
||||
});
|
||||
}
|
||||
|
||||
let extract_request_headers = if self.request.has_header_fields() {
|
||||
quote! {
|
||||
@ -245,29 +250,6 @@ impl ToTokens for Api {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<Request> for ::ruma_api::exports::http::Request<Vec<u8>> {
|
||||
type Error = ::ruma_api::error::IntoHttpError;
|
||||
|
||||
#[allow(unused_mut, unused_variables)]
|
||||
fn try_from(request: Request) -> Result<Self, Self::Error> {
|
||||
let metadata = Request::METADATA;
|
||||
let path_and_query = #request_path_string + &#request_query_string;
|
||||
let mut http_request = ::ruma_api::exports::http::Request::new(#request_body);
|
||||
|
||||
*http_request.method_mut() = ::ruma_api::exports::http::Method::#method;
|
||||
*http_request.uri_mut() = ::ruma_api::exports::http::uri::Builder::new()
|
||||
.path_and_query(path_and_query.as_str())
|
||||
.build()
|
||||
// The ruma_api! macro guards against invalid path input, but if there are
|
||||
// invalid (non ASCII) bytes in the fields with the query attribute this will panic.
|
||||
.unwrap();
|
||||
|
||||
{ #add_headers_to_request }
|
||||
|
||||
Ok(http_request)
|
||||
}
|
||||
}
|
||||
|
||||
#[doc = #response_doc]
|
||||
#response_type
|
||||
|
||||
@ -324,6 +306,31 @@ impl ToTokens for Api {
|
||||
rate_limited: #rate_limited,
|
||||
requires_authentication: #requires_authentication,
|
||||
};
|
||||
|
||||
#[allow(unused_mut, unused_variables)]
|
||||
fn try_into_http_request(
|
||||
self,
|
||||
base_url: &str,
|
||||
access_token: ::std::option::Option<&str>,
|
||||
) -> ::std::result::Result<
|
||||
::ruma_api::exports::http::Request<Vec<u8>>,
|
||||
::ruma_api::error::IntoHttpError,
|
||||
> {
|
||||
let metadata = Request::METADATA;
|
||||
|
||||
let http_request = ::ruma_api::exports::http::Request::builder()
|
||||
.method(::ruma_api::exports::http::Method::#method)
|
||||
.uri(::std::format!(
|
||||
"{}{}{}",
|
||||
base_url.strip_suffix("/").unwrap_or(base_url),
|
||||
#request_path_string,
|
||||
#request_query_string,
|
||||
))
|
||||
#( .header(#header_kvs) )*
|
||||
.body(#request_body)?;
|
||||
|
||||
Ok(http_request)
|
||||
}
|
||||
}
|
||||
|
||||
#non_auth_endpoint_impl
|
||||
|
@ -22,8 +22,8 @@ pub struct Request {
|
||||
|
||||
impl Request {
|
||||
/// Produces code to add necessary HTTP headers to an `http::Request`.
|
||||
pub fn add_headers_to_request(&self) -> TokenStream {
|
||||
let append_stmts = self.header_fields().map(|request_field| {
|
||||
pub fn append_header_kvs(&self) -> Vec<TokenStream> {
|
||||
self.header_fields().map(|request_field| {
|
||||
let (field, header_name) = match request_field {
|
||||
RequestField::Header(field, header_name) => (field, header_name),
|
||||
_ => unreachable!("expected request field to be header variant"),
|
||||
@ -32,16 +32,10 @@ impl Request {
|
||||
let field_name = &field.ident;
|
||||
|
||||
quote! {
|
||||
headers.append(
|
||||
::ruma_api::exports::http::header::#header_name,
|
||||
::ruma_api::exports::http::header::HeaderValue::from_str(request.#field_name.as_ref())?,
|
||||
);
|
||||
::ruma_api::exports::http::header::#header_name,
|
||||
::ruma_api::exports::http::header::HeaderValue::from_str(self.#field_name.as_ref())?
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
#(#append_stmts)*
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
|
||||
/// Produces code to extract fields from the HTTP headers in an `http::Request`.
|
||||
@ -136,12 +130,12 @@ impl Request {
|
||||
|
||||
/// Produces code for a struct initializer for body fields on a variable named `request`.
|
||||
pub fn request_body_init_fields(&self) -> TokenStream {
|
||||
self.struct_init_fields(RequestFieldKind::Body, quote!(request))
|
||||
self.struct_init_fields(RequestFieldKind::Body, quote!(self))
|
||||
}
|
||||
|
||||
/// Produces code for a struct initializer for query string fields on a variable named `request`.
|
||||
pub fn request_query_init_fields(&self) -> TokenStream {
|
||||
self.struct_init_fields(RequestFieldKind::Query, quote!(request))
|
||||
self.struct_init_fields(RequestFieldKind::Query, quote!(self))
|
||||
}
|
||||
|
||||
/// Produces code for a struct initializer for body fields on a variable named `request_body`.
|
||||
|
@ -45,7 +45,7 @@ pub(crate) fn request_path_string_and_parse(
|
||||
);
|
||||
format_args.push(quote! {
|
||||
ruma_api::exports::percent_encoding::utf8_percent_encode(
|
||||
&request.#path_var.to_string(),
|
||||
&self.#path_var.to_string(),
|
||||
ruma_api::exports::percent_encoding::NON_ALPHANUMERIC,
|
||||
)
|
||||
});
|
||||
@ -53,7 +53,7 @@ pub(crate) fn request_path_string_and_parse(
|
||||
}
|
||||
|
||||
quote! {
|
||||
format!(#format_string, #(#format_args),*)
|
||||
format_args!(#format_string, #(#format_args),*)
|
||||
}
|
||||
};
|
||||
|
||||
@ -110,8 +110,11 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream {
|
||||
{}
|
||||
assert_trait_impl::<#field_type>();
|
||||
|
||||
let request_query = RequestQuery(request.#field_name);
|
||||
format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?)
|
||||
let request_query = RequestQuery(self.#field_name);
|
||||
format_args!(
|
||||
"?{}",
|
||||
ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?
|
||||
)
|
||||
})
|
||||
} else if request.has_query_fields() {
|
||||
let request_query_init_fields = request.request_query_init_fields();
|
||||
@ -121,12 +124,13 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream {
|
||||
#request_query_init_fields
|
||||
};
|
||||
|
||||
format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?)
|
||||
format_args!(
|
||||
"?{}",
|
||||
ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?
|
||||
)
|
||||
})
|
||||
} else {
|
||||
quote! {
|
||||
String::new()
|
||||
}
|
||||
quote! { "" }
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,11 +165,11 @@ pub(crate) fn extract_request_query(request: &Request) -> TokenStream {
|
||||
pub(crate) fn build_request_body(request: &Request) -> TokenStream {
|
||||
if let Some(field) = request.newtype_raw_body_field() {
|
||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||
quote!(request.#field_name)
|
||||
quote!(self.#field_name)
|
||||
} else if request.has_body_fields() || request.newtype_body_field().is_some() {
|
||||
let request_body_initializers = if let Some(field) = request.newtype_body_field() {
|
||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||
quote! { (request.#field_name) }
|
||||
quote! { (self.#field_name) }
|
||||
} else {
|
||||
let initializers = request.request_body_init_fields();
|
||||
quote! { { #initializers } }
|
||||
|
@ -16,40 +16,62 @@ impl crate::EndpointError for Void {
|
||||
Err(ResponseDeserializationError::from_response(response))
|
||||
}
|
||||
}
|
||||
|
||||
/// An error when converting one of ruma's endpoint-specific request or response
|
||||
/// types to the corresponding http type.
|
||||
#[derive(Debug)]
|
||||
pub struct IntoHttpError(SerializationError);
|
||||
pub struct IntoHttpError(InnerIntoHttpError);
|
||||
|
||||
impl IntoHttpError {
|
||||
// For usage in macros
|
||||
#[doc(hidden)]
|
||||
pub fn needs_authentication() -> Self {
|
||||
Self(InnerIntoHttpError::NeedsAuthentication)
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl From<serde_json::Error> for IntoHttpError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
Self(SerializationError::Json(err))
|
||||
Self(InnerIntoHttpError::Json(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl From<ruma_serde::urlencoded::ser::Error> for IntoHttpError {
|
||||
fn from(err: ruma_serde::urlencoded::ser::Error) -> Self {
|
||||
Self(SerializationError::Query(err))
|
||||
Self(InnerIntoHttpError::Query(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl From<http::header::InvalidHeaderValue> for IntoHttpError {
|
||||
fn from(err: http::header::InvalidHeaderValue) -> Self {
|
||||
Self(SerializationError::Header(err))
|
||||
Self(InnerIntoHttpError::Header(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl From<http::Error> for IntoHttpError {
|
||||
fn from(err: http::Error) -> Self {
|
||||
Self(InnerIntoHttpError::Http(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for IntoHttpError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
match &self.0 {
|
||||
SerializationError::Json(err) => write!(f, "JSON serialization failed: {}", err),
|
||||
SerializationError::Query(err) => {
|
||||
InnerIntoHttpError::NeedsAuthentication => write!(
|
||||
f,
|
||||
"This endpoint has to be converted to http::Request using \
|
||||
try_into_authenticated_http_request"
|
||||
),
|
||||
InnerIntoHttpError::Json(err) => write!(f, "JSON serialization failed: {}", err),
|
||||
InnerIntoHttpError::Query(err) => {
|
||||
write!(f, "Query parameter serialization failed: {}", err)
|
||||
}
|
||||
SerializationError::Header(err) => write!(f, "Header serialization failed: {}", err),
|
||||
InnerIntoHttpError::Header(err) => write!(f, "Header serialization failed: {}", err),
|
||||
InnerIntoHttpError::Http(err) => write!(f, "HTTP request construction failed: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -201,10 +223,12 @@ impl<E: Display> Display for ServerError<E> {
|
||||
impl<E: std::error::Error> std::error::Error for ServerError<E> {}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum SerializationError {
|
||||
enum InnerIntoHttpError {
|
||||
NeedsAuthentication,
|
||||
Json(serde_json::Error),
|
||||
Query(ruma_serde::urlencoded::ser::Error),
|
||||
Header(http::header::InvalidHeaderValue),
|
||||
Http(http::Error),
|
||||
}
|
||||
|
||||
/// This type is public so it is accessible from `ruma_api!` generated code.
|
||||
|
@ -245,7 +245,7 @@ pub trait EndpointError: Sized {
|
||||
/// A Matrix API endpoint.
|
||||
///
|
||||
/// The type implementing this trait contains any data needed to make a request to the endpoint.
|
||||
pub trait Endpoint: Outgoing + TryInto<http::Request<Vec<u8>>, Error = IntoHttpError>
|
||||
pub trait Endpoint: Outgoing
|
||||
where
|
||||
<Self as Outgoing>::Incoming: TryFrom<http::Request<Vec<u8>>, Error = FromHttpRequestError>,
|
||||
<Self::Response as Outgoing>::Incoming: TryFrom<
|
||||
@ -260,6 +260,20 @@ where
|
||||
|
||||
/// Metadata about the endpoint.
|
||||
const METADATA: Metadata;
|
||||
|
||||
/// Tries to convert this request into an `http::Request`.
|
||||
///
|
||||
/// This method should only fail when called on endpoints that require authentication. It may
|
||||
/// also fail with a serialization error in case of bugs in Ruma though.
|
||||
///
|
||||
/// The endpoints path will be appended to the given `base_url`, for example
|
||||
/// `https://matrix.org`. Since all paths begin with a slash, it is not necessary for the
|
||||
/// `base_url` to have a trailing slash. If it has one however, it will be ignored.
|
||||
fn try_into_http_request(
|
||||
self,
|
||||
base_url: &str,
|
||||
access_token: Option<&str>,
|
||||
) -> Result<http::Request<Vec<u8>>, IntoHttpError>;
|
||||
}
|
||||
|
||||
/// A Matrix API endpoint that doesn't require authentication.
|
||||
@ -357,26 +371,24 @@ mod tests {
|
||||
name: "create_alias",
|
||||
path: "/_matrix/client/r0/directory/room/:room_alias",
|
||||
rate_limited: false,
|
||||
requires_authentication: true,
|
||||
requires_authentication: false,
|
||||
};
|
||||
}
|
||||
|
||||
impl TryFrom<Request> for http::Request<Vec<u8>> {
|
||||
type Error = IntoHttpError;
|
||||
|
||||
fn try_from(request: Request) -> Result<http::Request<Vec<u8>>, Self::Error> {
|
||||
fn try_into_http_request(
|
||||
self,
|
||||
base_url: &str,
|
||||
_access_token: Option<&str>,
|
||||
) -> Result<http::Request<Vec<u8>>, IntoHttpError> {
|
||||
let metadata = Request::METADATA;
|
||||
|
||||
let path = metadata
|
||||
.path
|
||||
.to_string()
|
||||
.replace(":room_alias", &request.room_alias.to_string());
|
||||
let url = (base_url.to_owned() + metadata.path)
|
||||
.replace(":room_alias", &self.room_alias.to_string());
|
||||
|
||||
let request_body = RequestBody { room_id: request.room_id };
|
||||
let request_body = RequestBody { room_id: self.room_id };
|
||||
|
||||
let http_request = http::Request::builder()
|
||||
.method(metadata.method)
|
||||
.uri(path)
|
||||
.uri(url)
|
||||
.body(serde_json::to_vec(&request_body)?)
|
||||
// this cannot fail because we don't give user-supplied data to any of the
|
||||
// builder methods
|
||||
|
@ -47,7 +47,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
|
||||
baz: UserId::try_from("@bazme:ruma.io")?,
|
||||
};
|
||||
|
||||
let http_req = http::Request::<Vec<u8>>::try_from(req.clone())?;
|
||||
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?;
|
||||
let req2 = Request::try_from(http_req)?;
|
||||
|
||||
assert_eq!(req.hello, req2.hello);
|
||||
|
@ -1,6 +1,6 @@
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use ruma_api::ruma_api;
|
||||
use ruma_api::{ruma_api, Endpoint};
|
||||
|
||||
ruma_api! {
|
||||
metadata: {
|
||||
@ -19,7 +19,7 @@ ruma_api! {
|
||||
#[test]
|
||||
fn empty_request_http_repr() {
|
||||
let req = Request {};
|
||||
let http_req = http::Request::<Vec<u8>>::try_from(req).unwrap();
|
||||
let http_req = req.try_into_http_request("https://homeserver.tld", None).unwrap();
|
||||
|
||||
assert!(http_req.body().is_empty());
|
||||
}
|
||||
|
@ -117,9 +117,10 @@ pub enum Direction {
|
||||
mod tests {
|
||||
use super::{Direction, Request};
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use js_int::uint;
|
||||
use ruma_api::Endpoint;
|
||||
use ruma_identifiers::RoomId;
|
||||
|
||||
use crate::r0::filter::{LazyLoadOptions, RoomEventFilter};
|
||||
@ -143,7 +144,8 @@ mod tests {
|
||||
filter: Some(filter),
|
||||
};
|
||||
|
||||
let request: http::Request<Vec<u8>> = req.try_into().unwrap();
|
||||
let request: http::Request<Vec<u8>> =
|
||||
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
|
||||
assert_eq!(
|
||||
"from=token&to=token2&dir=b&limit=0&filter=%7B%22not_types%22%3A%5B%22type%22%5D%2C%22not_rooms%22%3A%5B%22room%22%2C%22room2%22%2C%22room3%22%5D%2C%22rooms%22%3A%5B%22%21roomid%3Aexample.org%22%5D%2C%22lazy_load_members%22%3Atrue%2C%22include_redundant_members%22%3Atrue%7D",
|
||||
request.uri().query().unwrap()
|
||||
@ -162,7 +164,8 @@ mod tests {
|
||||
filter: None,
|
||||
};
|
||||
|
||||
let request: http::Request<Vec<u8>> = req.try_into().unwrap();
|
||||
let request =
|
||||
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
|
||||
assert_eq!("from=token&to=token2&dir=b&limit=0", request.uri().query().unwrap(),);
|
||||
}
|
||||
|
||||
@ -178,7 +181,8 @@ mod tests {
|
||||
filter: Some(RoomEventFilter::default()),
|
||||
};
|
||||
|
||||
let request: http::Request<Vec<u8>> = req.try_into().unwrap();
|
||||
let request: http::Request<Vec<u8>> =
|
||||
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
|
||||
assert_eq!(
|
||||
"from=token&to=token2&dir=b&limit=0&filter=%7B%7D",
|
||||
request.uri().query().unwrap(),
|
||||
|
@ -141,8 +141,7 @@ mod user_serde;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::convert::TryInto;
|
||||
|
||||
use ruma_api::Endpoint;
|
||||
use serde_json::{from_value as from_json_value, json, Value as JsonValue};
|
||||
|
||||
use super::{LoginInfo, Medium, Request, UserInfo};
|
||||
@ -193,7 +192,7 @@ mod tests {
|
||||
device_id: None,
|
||||
initial_device_display_name: Some("test".into()),
|
||||
}
|
||||
.try_into()
|
||||
.try_into_http_request("https://homeserver.tld", None)
|
||||
.unwrap();
|
||||
|
||||
let req_body_value: JsonValue = serde_json::from_slice(req.body()).unwrap();
|
||||
|
@ -403,6 +403,7 @@ impl DeviceLists {
|
||||
mod tests {
|
||||
use std::{convert::TryInto, time::Duration};
|
||||
|
||||
use ruma_api::Endpoint;
|
||||
use serde_json::{from_value as from_json_value, json, to_value as to_json_value};
|
||||
|
||||
use matches::assert_matches;
|
||||
@ -418,7 +419,7 @@ mod tests {
|
||||
set_presence: PresenceState::Offline,
|
||||
timeout: Some(Duration::from_millis(30000)),
|
||||
}
|
||||
.try_into()
|
||||
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
|
||||
.unwrap();
|
||||
|
||||
let uri = req.uri();
|
||||
|
Loading…
x
Reference in New Issue
Block a user