api: Update try_from_http_request to be generic like try_from_http_response

This commit is contained in:
Jonas Platte 2021-04-10 16:16:47 +02:00
parent 23ba0bc164
commit 0e197aae0b
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
12 changed files with 65 additions and 61 deletions

View File

@ -152,8 +152,6 @@ impl Request {
quote! { #(#fields,)* } quote! { #(#fields,)* }
} }
/// Produces code for a struct initializer for the given field kind to be accessed through the
/// given variable name.
fn vars( fn vars(
&self, &self,
request_field_kind: RequestFieldKind, request_field_kind: RequestFieldKind,
@ -190,6 +188,7 @@ impl Request {
error_ty: &TokenStream, error_ty: &TokenStream,
ruma_api: &TokenStream, ruma_api: &TokenStream,
) -> TokenStream { ) -> TokenStream {
let bytes = quote! { #ruma_api::exports::bytes };
let http = quote! { #ruma_api::exports::http }; let http = quote! { #ruma_api::exports::http };
let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
@ -484,15 +483,15 @@ impl Request {
RequestBody #body_lifetimes RequestBody #body_lifetimes
as #ruma_serde::Outgoing as #ruma_serde::Outgoing
>::Incoming = { >::Incoming = {
// If the request body is completely empty, pretend it is an empty JSON object let body = request.into_body();
// instead. This allows requests with only optional body parameters to be if #bytes::Buf::has_remaining(&body) {
// deserialized in that case. #serde_json::from_reader(#bytes::Buf::reader(body))?
let json = match request.body().as_slice() { } else {
b"" => b"{}", // If the request body is completely empty, pretend it is an empty JSON
body => body, // object instead. This allows requests with only optional body parameters
}; // to be deserialized in that case.
#serde_json::from_str("{}")?
#serde_json::from_slice(json)? }
}; };
} }
} else { } else {
@ -532,7 +531,13 @@ impl Request {
} else if let Some(field) = self.newtype_raw_body_field() { } else if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let parse = quote! { let parse = quote! {
let #field_name = request.into_body(); let #field_name = {
let mut reader = #bytes::Buf::reader(request.into_body());
let mut vec = ::std::vec::Vec::new();
::std::io::Read::read_to_end(&mut reader, &mut vec)
.expect("reading from a bytes::Buf never fails");
vec
};
}; };
(parse, quote! { #field_name, }) (parse, quote! { #field_name, })
@ -714,8 +719,8 @@ impl Request {
const METADATA: #ruma_api::Metadata = self::METADATA; const METADATA: #ruma_api::Metadata = self::METADATA;
fn try_from_http_request( fn try_from_http_request<T: #bytes::Buf>(
request: #http::Request<Vec<u8>> request: #http::Request<T>
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> { ) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> {
if request.method() != #http::Method::#method { if request.method() != #http::Method::#method {
return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch { return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch {

View File

@ -312,7 +312,7 @@ pub trait IncomingRequest: Sized {
const METADATA: Metadata; const METADATA: Metadata;
/// Tries to turn the given `http::Request` into this request type. /// Tries to turn the given `http::Request` into this request type.
fn try_from_http_request(req: http::Request<Vec<u8>>) -> Result<Self, FromHttpRequestError>; fn try_from_http_request<T: Buf>(req: http::Request<T>) -> Result<Self, FromHttpRequestError>;
} }
/// A request type for a Matrix API endpoint, used for sending responses. /// A request type for a Matrix API endpoint, used for sending responses.

View File

@ -48,7 +48,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
}; };
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?; let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?;
let req2 = Request::try_from_http_request(http_req)?; let req2 = Request::try_from_http_request(http_req.map(std::io::Cursor::new))?;
assert_eq!(req.hello, req2.hello); assert_eq!(req.hello, req2.hello);
assert_eq!(req.world, req2.world); assert_eq!(req.world, req2.world);

View File

@ -67,21 +67,20 @@ impl IncomingRequest for Request {
const METADATA: Metadata = METADATA; const METADATA: Metadata = METADATA;
fn try_from_http_request( fn try_from_http_request<T: Buf>(
request: http::Request<Vec<u8>>, request: http::Request<T>,
) -> Result<Self, FromHttpRequestError> { ) -> Result<Self, FromHttpRequestError> {
let request_body: RequestBody = serde_json::from_slice(request.body().as_slice())?;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_alias = {
let decoded =
percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?;
Ok(Request { TryFrom::try_from(&*decoded)?
room_id: request_body.room_id, };
room_alias: {
let decoded =
percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?;
TryFrom::try_from(&*decoded)? let request_body: RequestBody = serde_json::from_reader(request.into_body().reader())?;
},
}) Ok(Request { room_id: request_body.room_id, room_alias })
} }
} }

View File

@ -71,19 +71,18 @@ impl Response {
#[cfg(all(test, any(feature = "client", feature = "server")))] #[cfg(all(test, any(feature = "client", feature = "server")))]
mod tests { mod tests {
use std::convert::TryInto;
use js_int::uint; use js_int::uint;
#[cfg(feature = "client")] #[cfg(feature = "client")]
#[test] #[test]
fn construct_request_from_refs() { fn construct_request_from_refs() {
use ruma_api::OutgoingRequest; use ruma_api::OutgoingRequest as _;
use ruma_identifiers::server_name;
let req: http::Request<Vec<u8>> = super::Request { let req = super::Request {
limit: Some(uint!(10)), limit: Some(uint!(10)),
since: Some("hello"), since: Some("hello"),
server: Some("address".try_into().unwrap()), server: Some(&server_name!("test.tld")),
} }
.try_into_http_request("https://homeserver.tld", Some("auth_tok")) .try_into_http_request("https://homeserver.tld", Some("auth_tok"))
.unwrap(); .unwrap();
@ -94,19 +93,21 @@ mod tests {
assert_eq!(uri.path(), "/_matrix/client/r0/publicRooms"); assert_eq!(uri.path(), "/_matrix/client/r0/publicRooms");
assert!(query.contains("since=hello")); assert!(query.contains("since=hello"));
assert!(query.contains("limit=10")); assert!(query.contains("limit=10"));
assert!(query.contains("server=address")); assert!(query.contains("server=test.tld"));
} }
#[cfg(feature = "server")] #[cfg(feature = "server")]
#[test] #[test]
fn construct_response_from_refs() { fn construct_response_from_refs() {
let res: http::Response<Vec<u8>> = super::Response { use ruma_api::OutgoingResponse as _;
let res = super::Response {
chunk: vec![], chunk: vec![],
next_batch: Some("next_batch_token".into()), next_batch: Some("next_batch_token".into()),
prev_batch: Some("prev_batch_token".into()), prev_batch: Some("prev_batch_token".into()),
total_room_count_estimate: Some(uint!(10)), total_room_count_estimate: Some(uint!(10)),
} }
.try_into() .try_into_http_response()
.unwrap(); .unwrap();
assert_eq!( assert_eq!(

View File

@ -103,7 +103,7 @@ mod tests {
.unwrap(); .unwrap();
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
); );
assert_matches!( assert_matches!(

View File

@ -104,15 +104,16 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA; const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request( fn try_from_http_request<T: bytes::Buf>(
request: http::Request<Vec<u8>>, request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> { ) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::convert::TryFrom; use std::convert::TryFrom;
use ruma_events::EventContent as _; use ruma_events::EventContent as _;
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); let (parts, body) = request.into_parts();
let path_segments: Vec<&str> = parts.uri.path()[1..].split('/').collect();
let room_id = { let room_id = {
let decoded = let decoded =
@ -126,13 +127,11 @@ impl ruma_api::IncomingRequest for IncomingRequest {
.into_owned(); .into_owned();
let content = { let content = {
let request_body: Box<RawJsonValue> =
serde_json::from_slice(request.body().as_slice())?;
let event_type = let event_type =
percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?; percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?;
let body: Box<RawJsonValue> = serde_json::from_reader(body.reader())?;
AnyMessageEventContent::from_parts(&event_type, request_body)? AnyMessageEventContent::from_parts(&event_type, body)?
}; };
Ok(Self { room_id, txn_id, content }) Ok(Self { room_id, txn_id, content })

View File

@ -66,7 +66,7 @@ mod tests {
http::Request::builder() http::Request::builder()
.method("PUT") .method("PUT")
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(Vec::<u8>::new())?, .body(&[] as &[u8])?,
)?, )?,
IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org" IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org"
); );
@ -77,7 +77,9 @@ mod tests {
http::Request::builder() http::Request::builder()
.method("PUT") .method("PUT")
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(serde_json::to_vec(&serde_json::json!({ "avatar_url": "" }))?)?, .body(std::io::Cursor::new(
serde_json::to_vec(&serde_json::json!({ "avatar_url": "" }))?,
))?,
)?, )?,
IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org" IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org"
); );

View File

@ -108,8 +108,8 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA; const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request( fn try_from_http_request<T: bytes::Buf>(
request: http::Request<Vec<u8>>, request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> { ) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::convert::TryFrom; use std::convert::TryFrom;

View File

@ -108,15 +108,16 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA; const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request( fn try_from_http_request<T: bytes::Buf>(
request: http::Request<Vec<u8>>, request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> { ) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::{borrow::Cow, convert::TryFrom}; use std::{borrow::Cow, convert::TryFrom};
use ruma_events::EventContent; use ruma_events::EventContent as _;
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); let (parts, body) = request.into_parts();
let path_segments: Vec<&str> = parts.uri.path()[1..].split('/').collect();
let room_id = { let room_id = {
let decoded = let decoded =
@ -133,13 +134,11 @@ impl ruma_api::IncomingRequest for IncomingRequest {
.into_owned(); .into_owned();
let content = { let content = {
let request_body: Box<RawJsonValue> =
serde_json::from_slice(request.body().as_slice())?;
let event_type = let event_type =
percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?; percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?;
let body: Box<RawJsonValue> = serde_json::from_reader(body.reader())?;
AnyStateEventContent::from_parts(&event_type, request_body)? AnyStateEventContent::from_parts(&event_type, body)?
}; };
Ok(Self { room_id, state_key, content }) Ok(Self { room_id, state_key, content })

View File

@ -618,7 +618,7 @@ mod server_tests {
.unwrap(); .unwrap();
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
) )
.unwrap(); .unwrap();
@ -639,7 +639,7 @@ mod server_tests {
.unwrap(); .unwrap();
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
) )
.unwrap(); .unwrap();
@ -664,7 +664,7 @@ mod server_tests {
.unwrap(); .unwrap();
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
) )
.unwrap(); .unwrap();

View File

@ -48,9 +48,8 @@ impl Response {
#[cfg(all(test, feature = "server"))] #[cfg(all(test, feature = "server"))]
mod server_tests { mod server_tests {
use std::convert::TryFrom;
use assign::assign; use assign::assign;
use ruma_api::OutgoingResponse;
use ruma_events::tag::{TagInfo, Tags}; use ruma_events::tag::{TagInfo, Tags};
use serde_json::json; use serde_json::json;
@ -63,7 +62,7 @@ mod server_tests {
tags.insert("u.user_tag".into(), assign!(TagInfo::new(), { order: Some(0.11) })); tags.insert("u.user_tag".into(), assign!(TagInfo::new(), { order: Some(0.11) }));
let response = Response { tags }; let response = Response { tags };
let http_response = http::Response::<Vec<u8>>::try_from(response).unwrap(); let http_response = response.try_into_http_response().unwrap();
let json_response: serde_json::Value = let json_response: serde_json::Value =
serde_json::from_slice(http_response.body()).unwrap(); serde_json::from_slice(http_response.body()).unwrap();