diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 05bc57cd..155a2c6f 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -152,8 +152,6 @@ impl Request { quote! { #(#fields,)* } } - /// Produces code for a struct initializer for the given field kind to be accessed through the - /// given variable name. fn vars( &self, request_field_kind: RequestFieldKind, @@ -190,6 +188,7 @@ impl Request { error_ty: &TokenStream, ruma_api: &TokenStream, ) -> TokenStream { + let bytes = quote! { #ruma_api::exports::bytes }; let http = quote! { #ruma_api::exports::http }; let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; @@ -484,15 +483,15 @@ impl Request { RequestBody #body_lifetimes as #ruma_serde::Outgoing >::Incoming = { - // If the request body is completely empty, pretend it is an empty JSON object - // instead. This allows requests with only optional body parameters to be - // deserialized in that case. - let json = match request.body().as_slice() { - b"" => b"{}", - body => body, - }; - - #serde_json::from_slice(json)? + let body = request.into_body(); + if #bytes::Buf::has_remaining(&body) { + #serde_json::from_reader(#bytes::Buf::reader(body))? + } else { + // If the request body is completely empty, pretend it is an empty JSON + // object instead. This allows requests with only optional body parameters + // to be deserialized in that case. + #serde_json::from_str("{}")? + } }; } } else { @@ -532,7 +531,13 @@ impl Request { } else if let Some(field) = self.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let parse = quote! { - let #field_name = request.into_body(); + let #field_name = { + let mut reader = #bytes::Buf::reader(request.into_body()); + let mut vec = ::std::vec::Vec::new(); + ::std::io::Read::read_to_end(&mut reader, &mut vec) + .expect("reading from a bytes::Buf never fails"); + vec + }; }; (parse, quote! { #field_name, }) @@ -714,8 +719,8 @@ impl Request { const METADATA: #ruma_api::Metadata = self::METADATA; - fn try_from_http_request( - request: #http::Request> + fn try_from_http_request( + request: #http::Request ) -> ::std::result::Result { if request.method() != #http::Method::#method { return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch { diff --git a/ruma-api/src/lib.rs b/ruma-api/src/lib.rs index 1f97268d..a5e2547d 100644 --- a/ruma-api/src/lib.rs +++ b/ruma-api/src/lib.rs @@ -312,7 +312,7 @@ pub trait IncomingRequest: Sized { const METADATA: Metadata; /// Tries to turn the given `http::Request` into this request type. - fn try_from_http_request(req: http::Request>) -> Result; + fn try_from_http_request(req: http::Request) -> Result; } /// A request type for a Matrix API endpoint, used for sending responses. diff --git a/ruma-api/tests/conversions.rs b/ruma-api/tests/conversions.rs index 839bda63..f32064fc 100644 --- a/ruma-api/tests/conversions.rs +++ b/ruma-api/tests/conversions.rs @@ -48,7 +48,7 @@ fn request_serde() -> Result<(), Box> { }; let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?; - let req2 = Request::try_from_http_request(http_req)?; + let req2 = Request::try_from_http_request(http_req.map(std::io::Cursor::new))?; assert_eq!(req.hello, req2.hello); assert_eq!(req.world, req2.world); diff --git a/ruma-api/tests/manual_endpoint_impl.rs b/ruma-api/tests/manual_endpoint_impl.rs index 6b5623eb..9de917b6 100644 --- a/ruma-api/tests/manual_endpoint_impl.rs +++ b/ruma-api/tests/manual_endpoint_impl.rs @@ -67,21 +67,20 @@ impl IncomingRequest for Request { const METADATA: Metadata = METADATA; - fn try_from_http_request( - request: http::Request>, + fn try_from_http_request( + request: http::Request, ) -> Result { - let request_body: RequestBody = serde_json::from_slice(request.body().as_slice())?; let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); + let room_alias = { + let decoded = + percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?; - Ok(Request { - room_id: request_body.room_id, - room_alias: { - let decoded = - percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?; + TryFrom::try_from(&*decoded)? + }; - TryFrom::try_from(&*decoded)? - }, - }) + let request_body: RequestBody = serde_json::from_reader(request.into_body().reader())?; + + Ok(Request { room_id: request_body.room_id, room_alias }) } } diff --git a/ruma-client-api/src/r0/directory/get_public_rooms.rs b/ruma-client-api/src/r0/directory/get_public_rooms.rs index 17d563b5..442fcf89 100644 --- a/ruma-client-api/src/r0/directory/get_public_rooms.rs +++ b/ruma-client-api/src/r0/directory/get_public_rooms.rs @@ -71,19 +71,18 @@ impl Response { #[cfg(all(test, any(feature = "client", feature = "server")))] mod tests { - use std::convert::TryInto; - use js_int::uint; #[cfg(feature = "client")] #[test] fn construct_request_from_refs() { - use ruma_api::OutgoingRequest; + use ruma_api::OutgoingRequest as _; + use ruma_identifiers::server_name; - let req: http::Request> = super::Request { + let req = super::Request { limit: Some(uint!(10)), since: Some("hello"), - server: Some("address".try_into().unwrap()), + server: Some(&server_name!("test.tld")), } .try_into_http_request("https://homeserver.tld", Some("auth_tok")) .unwrap(); @@ -94,19 +93,21 @@ mod tests { assert_eq!(uri.path(), "/_matrix/client/r0/publicRooms"); assert!(query.contains("since=hello")); assert!(query.contains("limit=10")); - assert!(query.contains("server=address")); + assert!(query.contains("server=test.tld")); } #[cfg(feature = "server")] #[test] fn construct_response_from_refs() { - let res: http::Response> = super::Response { + use ruma_api::OutgoingResponse as _; + + let res = super::Response { chunk: vec![], next_batch: Some("next_batch_token".into()), prev_batch: Some("prev_batch_token".into()), total_room_count_estimate: Some(uint!(10)), } - .try_into() + .try_into_http_response() .unwrap(); assert_eq!( diff --git a/ruma-client-api/src/r0/membership/get_member_events.rs b/ruma-client-api/src/r0/membership/get_member_events.rs index 37a049a4..a9911b42 100644 --- a/ruma-client-api/src/r0/membership/get_member_events.rs +++ b/ruma-client-api/src/r0/membership/get_member_events.rs @@ -103,7 +103,7 @@ mod tests { .unwrap(); let req = IncomingRequest::try_from_http_request( - http::Request::builder().uri(uri).body(Vec::::new()).unwrap(), + http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), ); assert_matches!( diff --git a/ruma-client-api/src/r0/message/send_message_event.rs b/ruma-client-api/src/r0/message/send_message_event.rs index 75ad5300..832143db 100644 --- a/ruma-client-api/src/r0/message/send_message_event.rs +++ b/ruma-client-api/src/r0/message/send_message_event.rs @@ -104,15 +104,16 @@ impl ruma_api::IncomingRequest for IncomingRequest { const METADATA: ruma_api::Metadata = METADATA; - fn try_from_http_request( - request: http::Request>, + fn try_from_http_request( + request: http::Request, ) -> Result { use std::convert::TryFrom; use ruma_events::EventContent as _; use serde_json::value::RawValue as RawJsonValue; - let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); + let (parts, body) = request.into_parts(); + let path_segments: Vec<&str> = parts.uri.path()[1..].split('/').collect(); let room_id = { let decoded = @@ -126,13 +127,11 @@ impl ruma_api::IncomingRequest for IncomingRequest { .into_owned(); let content = { - let request_body: Box = - serde_json::from_slice(request.body().as_slice())?; - let event_type = percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?; + let body: Box = serde_json::from_reader(body.reader())?; - AnyMessageEventContent::from_parts(&event_type, request_body)? + AnyMessageEventContent::from_parts(&event_type, body)? }; Ok(Self { room_id, txn_id, content }) diff --git a/ruma-client-api/src/r0/profile/set_avatar_url.rs b/ruma-client-api/src/r0/profile/set_avatar_url.rs index 5e9ee3b9..83c08540 100644 --- a/ruma-client-api/src/r0/profile/set_avatar_url.rs +++ b/ruma-client-api/src/r0/profile/set_avatar_url.rs @@ -66,7 +66,7 @@ mod tests { http::Request::builder() .method("PUT") .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") - .body(Vec::::new())?, + .body(&[] as &[u8])?, )?, IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org" ); @@ -77,7 +77,9 @@ mod tests { http::Request::builder() .method("PUT") .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") - .body(serde_json::to_vec(&serde_json::json!({ "avatar_url": "" }))?)?, + .body(std::io::Cursor::new( + serde_json::to_vec(&serde_json::json!({ "avatar_url": "" }))?, + ))?, )?, IncomingRequest { user_id, avatar_url: None } if user_id == "@foo:bar.org" ); diff --git a/ruma-client-api/src/r0/state/get_state_events_for_key.rs b/ruma-client-api/src/r0/state/get_state_events_for_key.rs index cf668e82..6c631755 100644 --- a/ruma-client-api/src/r0/state/get_state_events_for_key.rs +++ b/ruma-client-api/src/r0/state/get_state_events_for_key.rs @@ -108,8 +108,8 @@ impl ruma_api::IncomingRequest for IncomingRequest { const METADATA: ruma_api::Metadata = METADATA; - fn try_from_http_request( - request: http::Request>, + fn try_from_http_request( + request: http::Request, ) -> Result { use std::convert::TryFrom; diff --git a/ruma-client-api/src/r0/state/send_state_event.rs b/ruma-client-api/src/r0/state/send_state_event.rs index e21f89ef..9eb7bc86 100644 --- a/ruma-client-api/src/r0/state/send_state_event.rs +++ b/ruma-client-api/src/r0/state/send_state_event.rs @@ -108,15 +108,16 @@ impl ruma_api::IncomingRequest for IncomingRequest { const METADATA: ruma_api::Metadata = METADATA; - fn try_from_http_request( - request: http::Request>, + fn try_from_http_request( + request: http::Request, ) -> Result { use std::{borrow::Cow, convert::TryFrom}; - use ruma_events::EventContent; + use ruma_events::EventContent as _; use serde_json::value::RawValue as RawJsonValue; - let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); + let (parts, body) = request.into_parts(); + let path_segments: Vec<&str> = parts.uri.path()[1..].split('/').collect(); let room_id = { let decoded = @@ -133,13 +134,11 @@ impl ruma_api::IncomingRequest for IncomingRequest { .into_owned(); let content = { - let request_body: Box = - serde_json::from_slice(request.body().as_slice())?; - let event_type = percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?; + let body: Box = serde_json::from_reader(body.reader())?; - AnyStateEventContent::from_parts(&event_type, request_body)? + AnyStateEventContent::from_parts(&event_type, body)? }; Ok(Self { room_id, state_key, content }) diff --git a/ruma-client-api/src/r0/sync/sync_events.rs b/ruma-client-api/src/r0/sync/sync_events.rs index da078523..dff257b9 100644 --- a/ruma-client-api/src/r0/sync/sync_events.rs +++ b/ruma-client-api/src/r0/sync/sync_events.rs @@ -618,7 +618,7 @@ mod server_tests { .unwrap(); let req = IncomingRequest::try_from_http_request( - http::Request::builder().uri(uri).body(Vec::::new()).unwrap(), + http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), ) .unwrap(); @@ -639,7 +639,7 @@ mod server_tests { .unwrap(); let req = IncomingRequest::try_from_http_request( - http::Request::builder().uri(uri).body(Vec::::new()).unwrap(), + http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), ) .unwrap(); @@ -664,7 +664,7 @@ mod server_tests { .unwrap(); let req = IncomingRequest::try_from_http_request( - http::Request::builder().uri(uri).body(Vec::::new()).unwrap(), + http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), ) .unwrap(); diff --git a/ruma-client-api/src/r0/tag/get_tags.rs b/ruma-client-api/src/r0/tag/get_tags.rs index 964361c6..283c36cc 100644 --- a/ruma-client-api/src/r0/tag/get_tags.rs +++ b/ruma-client-api/src/r0/tag/get_tags.rs @@ -48,9 +48,8 @@ impl Response { #[cfg(all(test, feature = "server"))] mod server_tests { - use std::convert::TryFrom; - use assign::assign; + use ruma_api::OutgoingResponse; use ruma_events::tag::{TagInfo, Tags}; use serde_json::json; @@ -63,7 +62,7 @@ mod server_tests { tags.insert("u.user_tag".into(), assign!(TagInfo::new(), { order: Some(0.11) })); let response = Response { tags }; - let http_response = http::Response::>::try_from(response).unwrap(); + let http_response = response.try_into_http_response().unwrap(); let json_response: serde_json::Value = serde_json::from_slice(http_response.body()).unwrap();