api: Update try_from_http_request to be generic like try_from_http_response

This commit is contained in:
Jonas Platte 2021-04-10 16:16:47 +02:00
parent 23ba0bc164
commit 0e197aae0b
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
12 changed files with 65 additions and 61 deletions

View File

@ -152,8 +152,6 @@ impl Request {
quote! { #(#fields,)* }
}
/// Produces code for a struct initializer for the given field kind to be accessed through the
/// given variable name.
fn vars(
&self,
request_field_kind: RequestFieldKind,
@ -190,6 +188,7 @@ impl Request {
error_ty: &TokenStream,
ruma_api: &TokenStream,
) -> TokenStream {
let bytes = quote! { #ruma_api::exports::bytes };
let http = quote! { #ruma_api::exports::http };
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
@ -484,15 +483,15 @@ impl Request {
RequestBody #body_lifetimes
as #ruma_serde::Outgoing
>::Incoming = {
// If the request body is completely empty, pretend it is an empty JSON object
// instead. This allows requests with only optional body parameters to be
// deserialized in that case.
let json = match request.body().as_slice() {
b"" => b"{}",
body => body,
};
#serde_json::from_slice(json)?
let body = request.into_body();
if #bytes::Buf::has_remaining(&body) {
#serde_json::from_reader(#bytes::Buf::reader(body))?
} else {
// If the request body is completely empty, pretend it is an empty JSON
// object instead. This allows requests with only optional body parameters
// to be deserialized in that case.
#serde_json::from_str("{}")?
}
};
}
} else {
@ -532,7 +531,13 @@ impl Request {
} else if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let parse = quote! {
let #field_name = request.into_body();
let #field_name = {
let mut reader = #bytes::Buf::reader(request.into_body());
let mut vec = ::std::vec::Vec::new();
::std::io::Read::read_to_end(&mut reader, &mut vec)
.expect("reading from a bytes::Buf never fails");
vec
};
};
(parse, quote! { #field_name, })
@ -714,8 +719,8 @@ impl Request {
const METADATA: #ruma_api::Metadata = self::METADATA;
fn try_from_http_request(
request: #http::Request<Vec<u8>>
fn try_from_http_request<T: #bytes::Buf>(
request: #http::Request<T>
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> {
if request.method() != #http::Method::#method {
return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch {

View File

@ -312,7 +312,7 @@ pub trait IncomingRequest: Sized {
const METADATA: Metadata;
/// Tries to turn the given `http::Request` into this request type.
fn try_from_http_request(req: http::Request<Vec<u8>>) -> Result<Self, FromHttpRequestError>;
fn try_from_http_request<T: Buf>(req: http::Request<T>) -> Result<Self, FromHttpRequestError>;
}
/// A request type for a Matrix API endpoint, used for sending responses.

View File

@ -48,7 +48,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
};
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?;
let req2 = Request::try_from_http_request(http_req)?;
let req2 = Request::try_from_http_request(http_req.map(std::io::Cursor::new))?;
assert_eq!(req.hello, req2.hello);
assert_eq!(req.world, req2.world);

View File

@ -67,21 +67,20 @@ impl IncomingRequest for Request {
const METADATA: Metadata = METADATA;
fn try_from_http_request(
request: http::Request<Vec<u8>>,
fn try_from_http_request<T: Buf>(
request: http::Request<T>,
) -> Result<Self, FromHttpRequestError> {
let request_body: RequestBody = serde_json::from_slice(request.body().as_slice())?;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_alias = {
let decoded =
percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?;
Ok(Request {
room_id: request_body.room_id,
room_alias: {
let decoded =
percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?;
TryFrom::try_from(&*decoded)?
};
TryFrom::try_from(&*decoded)?
},
})
let request_body: RequestBody = serde_json::from_reader(request.into_body().reader())?;
Ok(Request { room_id: request_body.room_id, room_alias })
}
}

View File

@ -71,19 +71,18 @@ impl Response {
#[cfg(all(test, any(feature = "client", feature = "server")))]
mod tests {
use std::convert::TryInto;
use js_int::uint;
#[cfg(feature = "client")]
#[test]
fn construct_request_from_refs() {
use ruma_api::OutgoingRequest;
use ruma_api::OutgoingRequest as _;
use ruma_identifiers::server_name;
let req: http::Request<Vec<u8>> = super::Request {
let req = super::Request {
limit: Some(uint!(10)),
since: Some("hello"),
server: Some("address".try_into().unwrap()),
server: Some(&server_name!("test.tld")),
}
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
.unwrap();
@ -94,19 +93,21 @@ mod tests {
assert_eq!(uri.path(), "/_matrix/client/r0/publicRooms");
assert!(query.contains("since=hello"));
assert!(query.contains("limit=10"));
assert!(query.contains("server=address"));
assert!(query.contains("server=test.tld"));
}
#[cfg(feature = "server")]
#[test]
fn construct_response_from_refs() {
let res: http::Response<Vec<u8>> = super::Response {
use ruma_api::OutgoingResponse as _;
let res = super::Response {
chunk: vec![],
next_batch: Some("next_batch_token".into()),
prev_batch: Some("prev_batch_token".into()),
total_room_count_estimate: Some(uint!(10)),
}
.try_into()
.try_into_http_response()
.unwrap();
assert_eq!(

View File

@ -103,7 +103,7 @@ mod tests {
.unwrap();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
);
assert_matches!(

View File

@ -104,15 +104,16 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request(
request: http::Request<Vec<u8>>,
fn try_from_http_request<T: bytes::Buf>(
request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::convert::TryFrom;
use ruma_events::EventContent as _;
use serde_json::value::RawValue as RawJsonValue;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let (parts, body) = request.into_parts();
let path_segments: Vec<&str> = parts.uri.path()[1..].split('/').collect();
let room_id = {
let decoded =
@ -126,13 +127,11 @@ impl ruma_api::IncomingRequest for IncomingRequest {
.into_owned();
let content = {
let request_body: Box<RawJsonValue> =
serde_json::from_slice(request.body().as_slice())?;
let event_type =
percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?;
let body: Box<RawJsonValue> = serde_json::from_reader(body.reader())?;
AnyMessageEventContent::from_parts(&event_type, request_body)?
AnyMessageEventContent::from_parts(&event_type, body)?
};
Ok(Self { room_id, txn_id, content })

View File

@ -66,7 +66,7 @@ mod tests {
http::Request::builder()
.method("PUT")
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(Vec::<u8>::new())?,
.body(&[] as &[u8])?,
)?,
IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org"
);
@ -77,7 +77,9 @@ mod tests {
http::Request::builder()
.method("PUT")
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(serde_json::to_vec(&serde_json::json!({ "avatar_url": "" }))?)?,
.body(std::io::Cursor::new(
serde_json::to_vec(&serde_json::json!({ "avatar_url": "" }))?,
))?,
)?,
IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org"
);

View File

@ -108,8 +108,8 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request(
request: http::Request<Vec<u8>>,
fn try_from_http_request<T: bytes::Buf>(
request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::convert::TryFrom;

View File

@ -108,15 +108,16 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request(
request: http::Request<Vec<u8>>,
fn try_from_http_request<T: bytes::Buf>(
request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::{borrow::Cow, convert::TryFrom};
use ruma_events::EventContent;
use ruma_events::EventContent as _;
use serde_json::value::RawValue as RawJsonValue;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let (parts, body) = request.into_parts();
let path_segments: Vec<&str> = parts.uri.path()[1..].split('/').collect();
let room_id = {
let decoded =
@ -133,13 +134,11 @@ impl ruma_api::IncomingRequest for IncomingRequest {
.into_owned();
let content = {
let request_body: Box<RawJsonValue> =
serde_json::from_slice(request.body().as_slice())?;
let event_type =
percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?;
let body: Box<RawJsonValue> = serde_json::from_reader(body.reader())?;
AnyStateEventContent::from_parts(&event_type, request_body)?
AnyStateEventContent::from_parts(&event_type, body)?
};
Ok(Self { room_id, state_key, content })

View File

@ -618,7 +618,7 @@ mod server_tests {
.unwrap();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
)
.unwrap();
@ -639,7 +639,7 @@ mod server_tests {
.unwrap();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
)
.unwrap();
@ -664,7 +664,7 @@ mod server_tests {
.unwrap();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
)
.unwrap();

View File

@ -48,9 +48,8 @@ impl Response {
#[cfg(all(test, feature = "server"))]
mod server_tests {
use std::convert::TryFrom;
use assign::assign;
use ruma_api::OutgoingResponse;
use ruma_events::tag::{TagInfo, Tags};
use serde_json::json;
@ -63,7 +62,7 @@ mod server_tests {
tags.insert("u.user_tag".into(), assign!(TagInfo::new(), { order: Some(0.11) }));
let response = Response { tags };
let http_response = http::Response::<Vec<u8>>::try_from(response).unwrap();
let http_response = response.try_into_http_response().unwrap();
let json_response: serde_json::Value =
serde_json::from_slice(http_response.body()).unwrap();