api: Introduce IncomingResponse trait

This commit is contained in:
Jonas Platte 2021-04-09 15:29:39 +02:00
parent effb53444d
commit 2ac020173b
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
10 changed files with 94 additions and 73 deletions

View File

@ -28,6 +28,7 @@ impl Response {
/// Produces code for a response struct initializer. /// Produces code for a response struct initializer.
fn init_fields(&self, ruma_api: &TokenStream) -> TokenStream { fn init_fields(&self, ruma_api: &TokenStream) -> TokenStream {
let bytes = quote! { #ruma_api::exports::bytes };
let http = quote! { #ruma_api::exports::http }; let http = quote! { #ruma_api::exports::http };
let mut fields = vec![]; let mut fields = vec![];
@ -79,7 +80,13 @@ impl Response {
// We are guaranteed only one new body field because of a check in `try_from`. // We are guaranteed only one new body field because of a check in `try_from`.
ResponseField::NewtypeRawBody(_) => { ResponseField::NewtypeRawBody(_) => {
new_type_raw_body = Some(quote_spanned! {span=> new_type_raw_body = Some(quote_spanned! {span=>
#field_name: response.into_body() #field_name: {
let mut reader = #bytes::Buf::reader(response.into_body());
let mut vec = ::std::vec::Vec::new();
::std::io::Read::read_to_end(&mut reader, &mut vec)
.expect("reading from a bytes::Buf never fails");
vec
}
}); });
// skip adding to the vec // skip adding to the vec
continue; continue;
@ -194,6 +201,7 @@ impl Response {
error_ty: &TokenStream, error_ty: &TokenStream,
ruma_api: &TokenStream, ruma_api: &TokenStream,
) -> TokenStream { ) -> TokenStream {
let bytes = quote! { #ruma_api::exports::bytes };
let http = quote! { #ruma_api::exports::http }; let http = quote! { #ruma_api::exports::http };
let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
let serde = quote! { #ruma_api::exports::serde }; let serde = quote! { #ruma_api::exports::serde };
@ -218,15 +226,15 @@ impl Response {
ResponseBody ResponseBody
as #ruma_serde::Outgoing as #ruma_serde::Outgoing
>::Incoming = { >::Incoming = {
// If the reponse body is completely empty, pretend it is an empty JSON object let body = response.into_body();
// instead. This allows reponses with only optional body parameters to be if #bytes::Buf::has_remaining(&body) {
// deserialized in that case. #serde_json::from_reader(#bytes::Buf::reader(body))?
let json = match response.body().as_slice() { } else {
b"" => b"{}", // If the reponse body is completely empty, pretend it is an empty JSON
body => body, // object instead. This allows reponses with only optional body
}; // parameters to be deserialized in that case.
#serde_json::from_str("{}")?
#serde_json::from_slice(json)? }
}; };
} }
} else { } else {
@ -299,12 +307,15 @@ impl Response {
#[automatically_derived] #[automatically_derived]
#[cfg(feature = "client")] #[cfg(feature = "client")]
impl ::std::convert::TryFrom<#http::Response<Vec<u8>>> for Response { impl #ruma_api::IncomingResponse for Response {
type Error = #ruma_api::error::FromHttpResponseError<#error_ty>; type EndpointError = #error_ty;
fn try_from( fn try_from_http_response<T: #bytes::Buf>(
response: #http::Response<Vec<u8>>, response: #http::Response<T>,
) -> ::std::result::Result<Self, Self::Error> { ) -> ::std::result::Result<
Self,
#ruma_api::error::FromHttpResponseError<#error_ty>,
> {
if response.status().as_u16() < 400 { if response.status().as_u16() < 400 {
#extract_response_headers #extract_response_headers

View File

@ -19,6 +19,7 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"] rustdoc-args = ["--cfg", "docsrs"]
[dependencies] [dependencies]
bytes = "1.0.1"
http = "0.2.2" http = "0.2.2"
percent-encoding = "2.1.0" percent-encoding = "2.1.0"
ruma-api-macros = { version = "=0.17.0-alpha.2", path = "../ruma-api-macros" } ruma-api-macros = { version = "=0.17.0-alpha.2", path = "../ruma-api-macros" }

View File

@ -4,6 +4,7 @@
use std::{error::Error as StdError, fmt}; use std::{error::Error as StdError, fmt};
use bytes::Buf;
use thiserror::Error; use thiserror::Error;
use crate::EndpointError; use crate::EndpointError;
@ -14,8 +15,8 @@ use crate::EndpointError;
pub enum Void {} pub enum Void {}
impl EndpointError for Void { impl EndpointError for Void {
fn try_from_response( fn try_from_response<T: Buf>(
_response: http::Response<Vec<u8>>, _response: http::Response<T>,
) -> Result<Self, ResponseDeserializationError> { ) -> Result<Self, ResponseDeserializationError> {
Err(ResponseDeserializationError::none()) Err(ResponseDeserializationError::none())
} }
@ -121,18 +122,12 @@ impl<E> From<ServerError<E>> for FromHttpResponseError<E> {
} }
} }
impl<E> From<ResponseDeserializationError> for FromHttpResponseError<E> {
fn from(err: ResponseDeserializationError) -> Self {
Self::Deserialization(err)
}
}
impl<E, T> From<T> for FromHttpResponseError<E> impl<E, T> From<T> for FromHttpResponseError<E>
where where
T: Into<DeserializationError>, T: Into<ResponseDeserializationError>,
{ {
fn from(err: T) -> Self { fn from(err: T) -> Self {
Self::Deserialization(ResponseDeserializationError::new(err)) Self::Deserialization(err.into())
} }
} }
@ -145,17 +140,20 @@ pub struct ResponseDeserializationError {
} }
impl ResponseDeserializationError { impl ResponseDeserializationError {
/// Creates a new `ResponseDeserializationError` from the given deserialization error and http
/// response.
pub fn new(inner: impl Into<DeserializationError>) -> Self {
Self { inner: Some(inner.into()) }
}
fn none() -> Self { fn none() -> Self {
Self { inner: None } Self { inner: None }
} }
} }
impl<T> From<T> for ResponseDeserializationError
where
T: Into<DeserializationError>,
{
fn from(err: T) -> Self {
Self { inner: Some(err.into()) }
}
}
impl fmt::Display for ResponseDeserializationError { impl fmt::Display for ResponseDeserializationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref inner) = self.inner { if let Some(ref inner) = self.inner {

View File

@ -20,12 +20,10 @@
#[cfg(not(all(feature = "client", feature = "server")))] #[cfg(not(all(feature = "client", feature = "server")))]
compile_error!("ruma_api's Cargo features only exist as a workaround are not meant to be disabled"); compile_error!("ruma_api's Cargo features only exist as a workaround are not meant to be disabled");
use std::{ use std::{convert::TryInto, error::Error as StdError};
convert::{TryFrom, TryInto},
error::Error as StdError,
};
use http::{uri::PathAndQuery, Method}; use bytes::Buf;
use http::Method;
use ruma_identifiers::UserId; use ruma_identifiers::UserId;
/// Generates a `ruma_api::Endpoint` from a concise definition. /// Generates a `ruma_api::Endpoint` from a concise definition.
@ -206,6 +204,7 @@ pub mod error;
/// It is not considered part of ruma-api's public API. /// It is not considered part of ruma-api's public API.
#[doc(hidden)] #[doc(hidden)]
pub mod exports { pub mod exports {
pub use bytes;
pub use http; pub use http;
pub use percent_encoding; pub use percent_encoding;
pub use ruma_serde; pub use ruma_serde;
@ -221,8 +220,8 @@ pub trait EndpointError: StdError + Sized + 'static {
/// ///
/// This will always return `Err` variant when no `error` field is defined in /// This will always return `Err` variant when no `error` field is defined in
/// the `ruma_api` macro. /// the `ruma_api` macro.
fn try_from_response( fn try_from_response<T: Buf>(
response: http::Response<Vec<u8>>, response: http::Response<T>,
) -> Result<Self, error::ResponseDeserializationError>; ) -> Result<Self, error::ResponseDeserializationError>;
} }
@ -232,10 +231,7 @@ pub trait OutgoingRequest: Sized {
type EndpointError: EndpointError; type EndpointError: EndpointError;
/// Response type returned when the request is successful. /// Response type returned when the request is successful.
type IncomingResponse: TryFrom< type IncomingResponse: IncomingResponse<EndpointError = Self::EndpointError>;
http::Response<Vec<u8>>,
Error = FromHttpResponseError<Self::EndpointError>,
>;
/// Metadata about the endpoint. /// Metadata about the endpoint.
const METADATA: Metadata; const METADATA: Metadata;
@ -255,6 +251,17 @@ pub trait OutgoingRequest: Sized {
) -> Result<http::Request<Vec<u8>>, IntoHttpError>; ) -> Result<http::Request<Vec<u8>>, IntoHttpError>;
} }
/// A response type for a Matrix API endpoint, used for receiving responses.
pub trait IncomingResponse: Sized {
/// A type capturing the expected error conditions the server can return.
type EndpointError: EndpointError;
/// Tries to convert the given `http::Response` into this response type.
fn try_from_http_response<T: Buf>(
response: http::Response<T>,
) -> Result<Self, FromHttpResponseError<Self::EndpointError>>;
}
/// An extension to `OutgoingRequest` which provides Appservice specific methods /// An extension to `OutgoingRequest` which provides Appservice specific methods
pub trait OutgoingRequestAppserviceExt: OutgoingRequest { pub trait OutgoingRequestAppserviceExt: OutgoingRequest {
/// Tries to convert this request into an `http::Request` and appends a virtual `user_id` to /// Tries to convert this request into an `http::Request` and appends a virtual `user_id` to
@ -283,7 +290,7 @@ pub trait OutgoingRequestAppserviceExt: OutgoingRequest {
}; };
parts.path_and_query = parts.path_and_query =
Some(PathAndQuery::try_from(path_and_query_with_user_id).map_err(http::Error::from)?); Some(path_and_query_with_user_id.try_into().map_err(http::Error::from)?);
*http_request.uri_mut() = parts.try_into().map_err(http::Error::from)?; *http_request.uri_mut() = parts.try_into().map_err(http::Error::from)?;

View File

@ -2,10 +2,12 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use bytes::Buf;
use http::{header::CONTENT_TYPE, method::Method}; use http::{header::CONTENT_TYPE, method::Method};
use ruma_api::{ use ruma_api::{
error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError, ServerError, Void}, error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError, ServerError, Void},
try_deserialize, AuthScheme, EndpointError, IncomingRequest, Metadata, OutgoingRequest, try_deserialize, AuthScheme, EndpointError, IncomingRequest, IncomingResponse, Metadata,
OutgoingRequest,
}; };
use ruma_identifiers::{RoomAliasId, RoomId}; use ruma_identifiers::{RoomAliasId, RoomId};
use ruma_serde::Outgoing; use ruma_serde::Outgoing;
@ -99,10 +101,12 @@ impl Outgoing for Response {
type Incoming = Self; type Incoming = Self;
} }
impl TryFrom<http::Response<Vec<u8>>> for Response { impl IncomingResponse for Response {
type Error = FromHttpResponseError<Void>; type EndpointError = Void;
fn try_from(http_response: http::Response<Vec<u8>>) -> Result<Response, Self::Error> { fn try_from_http_response<T: Buf>(
http_response: http::Response<T>,
) -> Result<Self, FromHttpResponseError<Void>> {
if http_response.status().as_u16() < 400 { if http_response.status().as_u16() < 400 {
Ok(Response) Ok(Response)
} else { } else {

View File

@ -1,8 +1,9 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use bytes::Buf;
use ruma_api::{ use ruma_api::{
error::{FromHttpResponseError, IntoHttpError, Void}, error::{FromHttpResponseError, IntoHttpError, Void},
ruma_api, ruma_api, IncomingResponse,
}; };
use ruma_serde::Outgoing; use ruma_serde::Outgoing;
@ -26,10 +27,12 @@ ruma_api! {
#[derive(Outgoing)] #[derive(Outgoing)]
pub struct Response; pub struct Response;
impl TryFrom<http::Response<Vec<u8>>> for Response { impl IncomingResponse for Response {
type Error = FromHttpResponseError<Void>; type EndpointError = Void;
fn try_from(_: http::Response<Vec<u8>>) -> Result<Self, Self::Error> { fn try_from_http_response<T: Buf>(
_: http::Response<T>,
) -> Result<Self, FromHttpResponseError<Void>> {
todo!() todo!()
} }
} }

View File

@ -17,6 +17,7 @@ edition = "2018"
[dependencies] [dependencies]
assign = "1.1.1" assign = "1.1.1"
bytes = "1.0.1"
http = "0.2.2" http = "0.2.2"
js_int = { version = "0.2.0", features = ["serde"] } js_int = { version = "0.2.0", features = ["serde"] }
maplit = "1.0.2" maplit = "1.0.2"

View File

@ -2,10 +2,11 @@
use std::{collections::BTreeMap, fmt, time::Duration}; use std::{collections::BTreeMap, fmt, time::Duration};
use bytes::Buf;
use ruma_api::{error::ResponseDeserializationError, EndpointError}; use ruma_api::{error::ResponseDeserializationError, EndpointError};
use ruma_identifiers::RoomVersionId; use ruma_identifiers::RoomVersionId;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{from_slice as from_json_slice, to_vec as to_json_vec, Value as JsonValue}; use serde_json::{from_reader as from_json_reader, to_vec as to_json_vec, Value as JsonValue};
/// Deserialize and Serialize implementations for ErrorKind. /// Deserialize and Serialize implementations for ErrorKind.
/// Separate module because it's a lot of code. /// Separate module because it's a lot of code.
@ -202,13 +203,12 @@ pub struct Error {
} }
impl EndpointError for Error { impl EndpointError for Error {
fn try_from_response( fn try_from_response<T: Buf>(
response: http::Response<Vec<u8>>, response: http::Response<T>,
) -> Result<Self, ResponseDeserializationError> { ) -> Result<Self, ResponseDeserializationError> {
match from_json_slice::<ErrorBody>(response.body()) { let status = response.status();
Ok(error_body) => Ok(error_body.into_error(response.status())), let error_body: ErrorBody = from_json_reader(response.into_body().reader())?;
Err(de_error) => Err(ResponseDeserializationError::new(de_error)), Ok(error_body.into_error(status))
}
} }
} }

View File

@ -2,11 +2,12 @@
use std::{collections::BTreeMap, fmt}; use std::{collections::BTreeMap, fmt};
use bytes::Buf;
use ruma_api::{error::ResponseDeserializationError, EndpointError}; use ruma_api::{error::ResponseDeserializationError, EndpointError};
use ruma_serde::Outgoing; use ruma_serde::Outgoing;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{ use serde_json::{
from_slice as from_json_slice, to_vec as to_json_vec, value::RawValue as RawJsonValue, from_reader as from_json_reader, to_vec as to_json_vec, value::RawValue as RawJsonValue,
Value as JsonValue, Value as JsonValue,
}; };
@ -133,17 +134,15 @@ impl From<MatrixError> for UiaaResponse {
} }
impl EndpointError for UiaaResponse { impl EndpointError for UiaaResponse {
fn try_from_response( fn try_from_response<T: Buf>(
response: http::Response<Vec<u8>>, response: http::Response<T>,
) -> Result<Self, ResponseDeserializationError> { ) -> Result<Self, ResponseDeserializationError> {
if response.status() == http::StatusCode::UNAUTHORIZED { if response.status() == http::StatusCode::UNAUTHORIZED {
if let Ok(authentication_info) = from_json_slice::<UiaaInfo>(response.body()) { Ok(UiaaResponse::AuthResponse(from_json_reader(response.into_body().reader())?))
return Ok(UiaaResponse::AuthResponse(authentication_info)); } else {
}
}
MatrixError::try_from_response(response).map(From::from) MatrixError::try_from_response(response).map(From::from)
} }
}
} }
impl std::error::Error for UiaaResponse {} impl std::error::Error for UiaaResponse {}
@ -383,7 +382,7 @@ mod tests {
let http_response = http::Response::builder() let http_response = http::Response::builder()
.status(http::StatusCode::UNAUTHORIZED) .status(http::StatusCode::UNAUTHORIZED)
.body(json.into()) .body(json.as_bytes())
.unwrap(); .unwrap();
let parsed_uiaa_info = match UiaaResponse::try_from_response(http_response).unwrap() { let parsed_uiaa_info = match UiaaResponse::try_from_response(http_response).unwrap() {

View File

@ -104,7 +104,6 @@
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::TryFrom,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
time::Duration, time::Duration,
}; };
@ -374,11 +373,9 @@ impl Client {
let hyper_response = client.hyper.request(http_request.map(hyper::Body::from)).await?; let hyper_response = client.hyper.request(http_request.map(hyper::Body::from)).await?;
let (head, body) = hyper_response.into_parts(); let (head, body) = hyper_response.into_parts();
// FIXME: We read the response into a contiguous buffer here (not actually required for let full_body = hyper::body::aggregate(body).await?;
// deserialization) and then copy the whole thing to convert from Bytes to Vec<u8>. let full_response = HttpResponse::from_parts(head, full_body);
let full_body = hyper::body::to_bytes(body).await?;
let full_response = HttpResponse::from_parts(head, full_body.as_ref().to_owned());
Ok(Request::IncomingResponse::try_from(full_response)?) Ok(ruma_api::IncomingResponse::try_from_http_response(full_response)?)
} }
} }