From 2ac020173b293acae32c0b37bd3a4670ff500a20 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 9 Apr 2021 15:29:39 +0200 Subject: [PATCH] api: Introduce IncomingResponse trait --- ruma-api-macros/src/api/response.rs | 41 ++++++++++++++++---------- ruma-api/Cargo.toml | 1 + ruma-api/src/error.rs | 30 +++++++++---------- ruma-api/src/lib.rs | 31 +++++++++++-------- ruma-api/tests/manual_endpoint_impl.rs | 12 +++++--- ruma-api/tests/ui/05-request-only.rs | 11 ++++--- ruma-client-api/Cargo.toml | 1 + ruma-client-api/src/error.rs | 14 ++++----- ruma-client-api/src/r0/uiaa.rs | 17 +++++------ ruma-client/src/lib.rs | 9 ++---- 10 files changed, 94 insertions(+), 73 deletions(-) diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index 9eed7f89..9d0fd065 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -28,6 +28,7 @@ impl Response { /// Produces code for a response struct initializer. fn init_fields(&self, ruma_api: &TokenStream) -> TokenStream { + let bytes = quote! { #ruma_api::exports::bytes }; let http = quote! { #ruma_api::exports::http }; let mut fields = vec![]; @@ -79,7 +80,13 @@ impl Response { // We are guaranteed only one new body field because of a check in `try_from`. ResponseField::NewtypeRawBody(_) => { new_type_raw_body = Some(quote_spanned! {span=> - #field_name: response.into_body() + #field_name: { + let mut reader = #bytes::Buf::reader(response.into_body()); + let mut vec = ::std::vec::Vec::new(); + ::std::io::Read::read_to_end(&mut reader, &mut vec) + .expect("reading from a bytes::Buf never fails"); + vec + } }); // skip adding to the vec continue; @@ -194,6 +201,7 @@ impl Response { error_ty: &TokenStream, ruma_api: &TokenStream, ) -> TokenStream { + let bytes = quote! { #ruma_api::exports::bytes }; let http = quote! { #ruma_api::exports::http }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde = quote! { #ruma_api::exports::serde }; @@ -218,15 +226,15 @@ impl Response { ResponseBody as #ruma_serde::Outgoing >::Incoming = { - // If the reponse body is completely empty, pretend it is an empty JSON object - // instead. This allows reponses with only optional body parameters to be - // deserialized in that case. - let json = match response.body().as_slice() { - b"" => b"{}", - body => body, - }; - - #serde_json::from_slice(json)? + let body = response.into_body(); + if #bytes::Buf::has_remaining(&body) { + #serde_json::from_reader(#bytes::Buf::reader(body))? + } else { + // If the reponse body is completely empty, pretend it is an empty JSON + // object instead. This allows reponses with only optional body + // parameters to be deserialized in that case. + #serde_json::from_str("{}")? + } }; } } else { @@ -299,12 +307,15 @@ impl Response { #[automatically_derived] #[cfg(feature = "client")] - impl ::std::convert::TryFrom<#http::Response>> for Response { - type Error = #ruma_api::error::FromHttpResponseError<#error_ty>; + impl #ruma_api::IncomingResponse for Response { + type EndpointError = #error_ty; - fn try_from( - response: #http::Response>, - ) -> ::std::result::Result { + fn try_from_http_response( + response: #http::Response, + ) -> ::std::result::Result< + Self, + #ruma_api::error::FromHttpResponseError<#error_ty>, + > { if response.status().as_u16() < 400 { #extract_response_headers diff --git a/ruma-api/Cargo.toml b/ruma-api/Cargo.toml index 36b9464d..3ebbcad7 100644 --- a/ruma-api/Cargo.toml +++ b/ruma-api/Cargo.toml @@ -19,6 +19,7 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] +bytes = "1.0.1" http = "0.2.2" percent-encoding = "2.1.0" ruma-api-macros = { version = "=0.17.0-alpha.2", path = "../ruma-api-macros" } diff --git a/ruma-api/src/error.rs b/ruma-api/src/error.rs index bd76272b..f22b219a 100644 --- a/ruma-api/src/error.rs +++ b/ruma-api/src/error.rs @@ -4,6 +4,7 @@ use std::{error::Error as StdError, fmt}; +use bytes::Buf; use thiserror::Error; use crate::EndpointError; @@ -14,8 +15,8 @@ use crate::EndpointError; pub enum Void {} impl EndpointError for Void { - fn try_from_response( - _response: http::Response>, + fn try_from_response( + _response: http::Response, ) -> Result { Err(ResponseDeserializationError::none()) } @@ -121,18 +122,12 @@ impl From> for FromHttpResponseError { } } -impl From for FromHttpResponseError { - fn from(err: ResponseDeserializationError) -> Self { - Self::Deserialization(err) - } -} - impl From for FromHttpResponseError where - T: Into, + T: Into, { fn from(err: T) -> Self { - Self::Deserialization(ResponseDeserializationError::new(err)) + Self::Deserialization(err.into()) } } @@ -145,17 +140,20 @@ pub struct ResponseDeserializationError { } impl ResponseDeserializationError { - /// Creates a new `ResponseDeserializationError` from the given deserialization error and http - /// response. - pub fn new(inner: impl Into) -> Self { - Self { inner: Some(inner.into()) } - } - fn none() -> Self { Self { inner: None } } } +impl From for ResponseDeserializationError +where + T: Into, +{ + fn from(err: T) -> Self { + Self { inner: Some(err.into()) } + } +} + impl fmt::Display for ResponseDeserializationError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(ref inner) = self.inner { diff --git a/ruma-api/src/lib.rs b/ruma-api/src/lib.rs index 223ab666..21b3688b 100644 --- a/ruma-api/src/lib.rs +++ b/ruma-api/src/lib.rs @@ -20,12 +20,10 @@ #[cfg(not(all(feature = "client", feature = "server")))] compile_error!("ruma_api's Cargo features only exist as a workaround are not meant to be disabled"); -use std::{ - convert::{TryFrom, TryInto}, - error::Error as StdError, -}; +use std::{convert::TryInto, error::Error as StdError}; -use http::{uri::PathAndQuery, Method}; +use bytes::Buf; +use http::Method; use ruma_identifiers::UserId; /// Generates a `ruma_api::Endpoint` from a concise definition. @@ -206,6 +204,7 @@ pub mod error; /// It is not considered part of ruma-api's public API. #[doc(hidden)] pub mod exports { + pub use bytes; pub use http; pub use percent_encoding; pub use ruma_serde; @@ -221,8 +220,8 @@ pub trait EndpointError: StdError + Sized + 'static { /// /// This will always return `Err` variant when no `error` field is defined in /// the `ruma_api` macro. - fn try_from_response( - response: http::Response>, + fn try_from_response( + response: http::Response, ) -> Result; } @@ -232,10 +231,7 @@ pub trait OutgoingRequest: Sized { type EndpointError: EndpointError; /// Response type returned when the request is successful. - type IncomingResponse: TryFrom< - http::Response>, - Error = FromHttpResponseError, - >; + type IncomingResponse: IncomingResponse; /// Metadata about the endpoint. const METADATA: Metadata; @@ -255,6 +251,17 @@ pub trait OutgoingRequest: Sized { ) -> Result>, IntoHttpError>; } +/// A response type for a Matrix API endpoint, used for receiving responses. +pub trait IncomingResponse: Sized { + /// A type capturing the expected error conditions the server can return. + type EndpointError: EndpointError; + + /// Tries to convert the given `http::Response` into this response type. + fn try_from_http_response( + response: http::Response, + ) -> Result>; +} + /// An extension to `OutgoingRequest` which provides Appservice specific methods pub trait OutgoingRequestAppserviceExt: OutgoingRequest { /// Tries to convert this request into an `http::Request` and appends a virtual `user_id` to @@ -283,7 +290,7 @@ pub trait OutgoingRequestAppserviceExt: OutgoingRequest { }; parts.path_and_query = - Some(PathAndQuery::try_from(path_and_query_with_user_id).map_err(http::Error::from)?); + Some(path_and_query_with_user_id.try_into().map_err(http::Error::from)?); *http_request.uri_mut() = parts.try_into().map_err(http::Error::from)?; diff --git a/ruma-api/tests/manual_endpoint_impl.rs b/ruma-api/tests/manual_endpoint_impl.rs index c22633d6..65e2190c 100644 --- a/ruma-api/tests/manual_endpoint_impl.rs +++ b/ruma-api/tests/manual_endpoint_impl.rs @@ -2,10 +2,12 @@ use std::convert::TryFrom; +use bytes::Buf; use http::{header::CONTENT_TYPE, method::Method}; use ruma_api::{ error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError, ServerError, Void}, - try_deserialize, AuthScheme, EndpointError, IncomingRequest, Metadata, OutgoingRequest, + try_deserialize, AuthScheme, EndpointError, IncomingRequest, IncomingResponse, Metadata, + OutgoingRequest, }; use ruma_identifiers::{RoomAliasId, RoomId}; use ruma_serde::Outgoing; @@ -99,10 +101,12 @@ impl Outgoing for Response { type Incoming = Self; } -impl TryFrom>> for Response { - type Error = FromHttpResponseError; +impl IncomingResponse for Response { + type EndpointError = Void; - fn try_from(http_response: http::Response>) -> Result { + fn try_from_http_response( + http_response: http::Response, + ) -> Result> { if http_response.status().as_u16() < 400 { Ok(Response) } else { diff --git a/ruma-api/tests/ui/05-request-only.rs b/ruma-api/tests/ui/05-request-only.rs index 083d41b7..4035df54 100644 --- a/ruma-api/tests/ui/05-request-only.rs +++ b/ruma-api/tests/ui/05-request-only.rs @@ -1,8 +1,9 @@ use std::convert::TryFrom; +use bytes::Buf; use ruma_api::{ error::{FromHttpResponseError, IntoHttpError, Void}, - ruma_api, + ruma_api, IncomingResponse, }; use ruma_serde::Outgoing; @@ -26,10 +27,12 @@ ruma_api! { #[derive(Outgoing)] pub struct Response; -impl TryFrom>> for Response { - type Error = FromHttpResponseError; +impl IncomingResponse for Response { + type EndpointError = Void; - fn try_from(_: http::Response>) -> Result { + fn try_from_http_response( + _: http::Response, + ) -> Result> { todo!() } } diff --git a/ruma-client-api/Cargo.toml b/ruma-client-api/Cargo.toml index b52e76bc..95ba5172 100644 --- a/ruma-client-api/Cargo.toml +++ b/ruma-client-api/Cargo.toml @@ -17,6 +17,7 @@ edition = "2018" [dependencies] assign = "1.1.1" +bytes = "1.0.1" http = "0.2.2" js_int = { version = "0.2.0", features = ["serde"] } maplit = "1.0.2" diff --git a/ruma-client-api/src/error.rs b/ruma-client-api/src/error.rs index 301de0dd..0fe74627 100644 --- a/ruma-client-api/src/error.rs +++ b/ruma-client-api/src/error.rs @@ -2,10 +2,11 @@ use std::{collections::BTreeMap, fmt, time::Duration}; +use bytes::Buf; use ruma_api::{error::ResponseDeserializationError, EndpointError}; use ruma_identifiers::RoomVersionId; use serde::{Deserialize, Serialize}; -use serde_json::{from_slice as from_json_slice, to_vec as to_json_vec, Value as JsonValue}; +use serde_json::{from_reader as from_json_reader, to_vec as to_json_vec, Value as JsonValue}; /// Deserialize and Serialize implementations for ErrorKind. /// Separate module because it's a lot of code. @@ -202,13 +203,12 @@ pub struct Error { } impl EndpointError for Error { - fn try_from_response( - response: http::Response>, + fn try_from_response( + response: http::Response, ) -> Result { - match from_json_slice::(response.body()) { - Ok(error_body) => Ok(error_body.into_error(response.status())), - Err(de_error) => Err(ResponseDeserializationError::new(de_error)), - } + let status = response.status(); + let error_body: ErrorBody = from_json_reader(response.into_body().reader())?; + Ok(error_body.into_error(status)) } } diff --git a/ruma-client-api/src/r0/uiaa.rs b/ruma-client-api/src/r0/uiaa.rs index e7891fbf..fc8ba0bd 100644 --- a/ruma-client-api/src/r0/uiaa.rs +++ b/ruma-client-api/src/r0/uiaa.rs @@ -2,11 +2,12 @@ use std::{collections::BTreeMap, fmt}; +use bytes::Buf; use ruma_api::{error::ResponseDeserializationError, EndpointError}; use ruma_serde::Outgoing; use serde::{Deserialize, Serialize}; use serde_json::{ - from_slice as from_json_slice, to_vec as to_json_vec, value::RawValue as RawJsonValue, + from_reader as from_json_reader, to_vec as to_json_vec, value::RawValue as RawJsonValue, Value as JsonValue, }; @@ -133,16 +134,14 @@ impl From for UiaaResponse { } impl EndpointError for UiaaResponse { - fn try_from_response( - response: http::Response>, + fn try_from_response( + response: http::Response, ) -> Result { if response.status() == http::StatusCode::UNAUTHORIZED { - if let Ok(authentication_info) = from_json_slice::(response.body()) { - return Ok(UiaaResponse::AuthResponse(authentication_info)); - } + Ok(UiaaResponse::AuthResponse(from_json_reader(response.into_body().reader())?)) + } else { + MatrixError::try_from_response(response).map(From::from) } - - MatrixError::try_from_response(response).map(From::from) } } @@ -383,7 +382,7 @@ mod tests { let http_response = http::Response::builder() .status(http::StatusCode::UNAUTHORIZED) - .body(json.into()) + .body(json.as_bytes()) .unwrap(); let parsed_uiaa_info = match UiaaResponse::try_from_response(http_response).unwrap() { diff --git a/ruma-client/src/lib.rs b/ruma-client/src/lib.rs index 45b25d4c..abb85e01 100644 --- a/ruma-client/src/lib.rs +++ b/ruma-client/src/lib.rs @@ -104,7 +104,6 @@ use std::{ collections::BTreeMap, - convert::TryFrom, sync::{Arc, Mutex}, time::Duration, }; @@ -374,11 +373,9 @@ impl Client { let hyper_response = client.hyper.request(http_request.map(hyper::Body::from)).await?; let (head, body) = hyper_response.into_parts(); - // FIXME: We read the response into a contiguous buffer here (not actually required for - // deserialization) and then copy the whole thing to convert from Bytes to Vec. - let full_body = hyper::body::to_bytes(body).await?; - let full_response = HttpResponse::from_parts(head, full_body.as_ref().to_owned()); + let full_body = hyper::body::aggregate(body).await?; + let full_response = HttpResponse::from_parts(head, full_body); - Ok(Request::IncomingResponse::try_from(full_response)?) + Ok(ruma_api::IncomingResponse::try_from_http_response(full_response)?) } }