From a8ba82d5859b42293c78c58b04840b67f7e77ef4 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Wed, 9 Nov 2022 19:24:16 +0100 Subject: [PATCH] api: Make EndpointError construction infallible Simplifies error matching and preserves more information for non-spec-compliant server errors. --- crates/ruma-client-api/CHANGELOG.md | 2 + crates/ruma-client-api/src/error.rs | 137 ++++++++++++------ crates/ruma-client-api/src/uiaa.rs | 21 ++- crates/ruma-client-api/tests/uiaa.rs | 4 +- crates/ruma-common/CHANGELOG.md | 5 + crates/ruma-common/src/api.rs | 4 +- crates/ruma-common/src/api/error.rs | 124 +++++++--------- .../tests/api/manual_endpoint_impl.rs | 8 +- .../ruma-macros/src/api/response/incoming.rs | 15 +- 9 files changed, 169 insertions(+), 151 deletions(-) diff --git a/crates/ruma-client-api/CHANGELOG.md b/crates/ruma-client-api/CHANGELOG.md index da174e2a..a97481de 100644 --- a/crates/ruma-client-api/CHANGELOG.md +++ b/crates/ruma-client-api/CHANGELOG.md @@ -13,6 +13,8 @@ Breaking changes: * Make `push::PusherKind` contain the pusher's `data` * Use an enum for the `scope` of the `push` endpoints * Use `NewPushRule` to construct a `push::set_pushrule::v3::Request` +* `Error` is now an enum because endpoint error construction is infallible (see changelog for + `ruma-common`); the previous fields are in the `Standard` variant Improvements: diff --git a/crates/ruma-client-api/src/error.rs b/crates/ruma-client-api/src/error.rs index b26db667..c199a210 100644 --- a/crates/ruma-client-api/src/error.rs +++ b/crates/ruma-client-api/src/error.rs @@ -1,11 +1,11 @@ //! Errors that can be sent from the homeserver. -use std::{collections::BTreeMap, fmt, time::Duration}; +use std::{collections::BTreeMap, fmt, sync::Arc, time::Duration}; -use bytes::BufMut; +use bytes::{BufMut, Bytes}; use ruma_common::{ api::{ - error::{DeserializationError, IntoHttpError}, + error::{IntoHttpError, MatrixErrorBody}, EndpointError, OutgoingResponse, }, RoomVersionId, @@ -220,10 +220,37 @@ impl fmt::Display for ErrorKind { } } -/// A Matrix Error without a status code. -#[derive(Debug, Clone, Serialize, Deserialize)] +/// The body of a Matrix Client API error. +#[derive(Debug, Clone)] +#[allow(clippy::exhaustive_enums)] +pub enum ErrorBody { + /// A JSON body with the fields expected for Client API errors. + Standard { + /// A value which can be used to handle an error message. + kind: ErrorKind, + + /// A human-readable error message, usually a sentence explaining what went wrong. + message: String, + }, + + /// A JSON body with an unexpected structure. + Json(JsonValue), + + /// A response body that is not valid JSON. + #[non_exhaustive] + NotJson { + /// The raw bytes of the response body. + bytes: Bytes, + + /// The error from trying to deserialize the bytes as JSON. + deserialization_error: Arc, + }, +} + +/// A JSON body with the fields expected for Client API errors. +#[derive(Clone, Debug, Deserialize, Serialize)] #[allow(clippy::exhaustive_structs)] -pub struct ErrorBody { +pub struct StandardErrorBody { /// A value which can be used to handle an error message. #[serde(flatten)] pub kind: ErrorKind, @@ -237,73 +264,72 @@ pub struct ErrorBody { #[derive(Debug, Clone)] #[allow(clippy::exhaustive_structs)] pub struct Error { - /// A value which can be used to handle an error message. - pub kind: ErrorKind, - - /// A human-readable error message, usually a sentence explaining what went wrong. - pub message: String, - /// The http status code. pub status_code: http::StatusCode, /// The `WWW-Authenticate` header error message. #[cfg(feature = "unstable-msc2967")] pub authenticate: Option, + + /// The http response's body. + pub body: ErrorBody, } impl EndpointError for Error { - fn try_from_http_response>( - response: http::Response, - ) -> Result { + fn from_http_response>(response: http::Response) -> Self { let status = response.status(); - let error_body: ErrorBody = from_json_slice(response.body().as_ref())?; - - #[cfg(not(feature = "unstable-msc2967"))] - { - Ok(error_body.into_error(status)) - } #[cfg(feature = "unstable-msc2967")] - { - use ruma_common::api::error::HeaderDeserializationError; + let authenticate = response + .headers() + .get(http::header::WWW_AUTHENTICATE) + .and_then(|val| val.to_str().ok()) + .and_then(AuthenticateError::from_str); - let mut error = error_body.into_error(status); + let body_bytes = &response.body().as_ref(); + let error_body: ErrorBody = match from_json_slice(body_bytes) { + Ok(StandardErrorBody { kind, message }) => ErrorBody::Standard { kind, message }, + Err(_) => match MatrixErrorBody::from_bytes(body_bytes) { + MatrixErrorBody::Json(json) => ErrorBody::Json(json), + MatrixErrorBody::NotJson { bytes, deserialization_error, .. } => { + ErrorBody::NotJson { bytes, deserialization_error } + } + }, + }; - error.authenticate = response - .headers() - .get(http::header::WWW_AUTHENTICATE) - .map(|val| val.to_str().map_err(HeaderDeserializationError::ToStrError)) - .transpose()? - .and_then(AuthenticateError::from_str); + let error = error_body.into_error(status); - Ok(error) - } + #[cfg(not(feature = "unstable-msc2967"))] + return error; + + #[cfg(feature = "unstable-msc2967")] + Self { authenticate, ..error } } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "[{} / {}] {}", self.status_code.as_u16(), self.kind, self.message) + let status_code = self.status_code.as_u16(); + match &self.body { + ErrorBody::Standard { kind, message } => { + write!(f, "[{status_code} / {kind}] {message}") + } + ErrorBody::Json(json) => write!(f, "[{status_code}] {json}"), + ErrorBody::NotJson { .. } => write!(f, "[{status_code}] "), + } } } impl std::error::Error for Error {} -impl From for ErrorBody { - fn from(error: Error) -> Self { - Self { kind: error.kind, message: error.message } - } -} - impl ErrorBody { /// Convert the ErrorBody into an Error by adding the http status code. pub fn into_error(self, status_code: http::StatusCode) -> Error { Error { - kind: self.kind, - message: self.message, status_code, #[cfg(feature = "unstable-msc2967")] authenticate: None, + body: self, } } } @@ -323,7 +349,19 @@ impl OutgoingResponse for Error { builder }; - builder.body(ruma_common::serde::json_to_buf(&ErrorBody::from(self))?).map_err(Into::into) + builder + .body(match self.body { + ErrorBody::Standard { kind, message } => { + ruma_common::serde::json_to_buf(&StandardErrorBody { kind, message })? + } + ErrorBody::Json(json) => ruma_common::serde::json_to_buf(&json)?, + ErrorBody::NotJson { .. } => { + return Err(IntoHttpError::Json(serde::ser::Error::custom( + "attempted to serialize ErrorBody::NotJson", + ))); + } + }) + .map_err(Into::into) } } @@ -426,11 +464,11 @@ impl TryFrom<&AuthenticateError> for http::HeaderValue { mod tests { use serde_json::{from_value as from_json_value, json}; - use super::{ErrorBody, ErrorKind}; + use super::{ErrorKind, StandardErrorBody}; #[test] fn deserialize_forbidden() { - let deserialized: ErrorBody = from_json_value(json!({ + let deserialized: StandardErrorBody = from_json_value(json!({ "errcode": "M_FORBIDDEN", "error": "You are not authorized to ban users in this room.", })) @@ -471,9 +509,10 @@ mod tests { #[cfg(feature = "unstable-msc2967")] #[test] fn deserialize_insufficient_scope() { + use assert_matches::assert_matches; use ruma_common::api::EndpointError; - use super::{AuthenticateError, Error}; + use super::{AuthenticateError, Error, ErrorBody}; let response = http::Response::builder() .header( @@ -489,11 +528,13 @@ mod tests { .unwrap(), ) .unwrap(); - let error = Error::try_from_http_response(response).unwrap(); + let error = Error::from_http_response(response); assert_eq!(error.status_code, http::StatusCode::UNAUTHORIZED); - assert_eq!(error.kind, ErrorKind::Forbidden); - assert_eq!(error.message, "Insufficient privilege"); + let (kind, message) = + assert_matches!(error.body, ErrorBody::Standard { kind, message } => (kind, message)); + assert_eq!(kind, ErrorKind::Forbidden); + assert_eq!(message, "Insufficient privilege"); let scope = assert_matches::assert_matches!( error.authenticate, Some(AuthenticateError::InsufficientScope { scope }) => scope diff --git a/crates/ruma-client-api/src/uiaa.rs b/crates/ruma-client-api/src/uiaa.rs index 9aced0e5..b4061eeb 100644 --- a/crates/ruma-client-api/src/uiaa.rs +++ b/crates/ruma-client-api/src/uiaa.rs @@ -6,10 +6,7 @@ use std::{borrow::Cow, fmt}; use bytes::BufMut; use ruma_common::{ - api::{ - error::{DeserializationError, IntoHttpError}, - EndpointError, OutgoingResponse, - }, + api::{error::IntoHttpError, EndpointError, OutgoingResponse}, serde::{from_raw_json_value, Incoming, JsonObject, StringEnum}, thirdparty::Medium, ClientSecret, OwnedSessionId, OwnedUserId, UserId, @@ -23,7 +20,7 @@ use serde_json::{ }; use crate::{ - error::{Error as MatrixError, ErrorBody}, + error::{Error as MatrixError, StandardErrorBody}, PrivOwnedStr, }; @@ -838,7 +835,7 @@ pub struct UiaaInfo { /// Authentication-related errors for previous request returned by homeserver. #[serde(flatten, skip_serializing_if = "Option::is_none")] - pub auth_error: Option, + pub auth_error: Option, } impl UiaaInfo { @@ -893,14 +890,14 @@ impl From for UiaaResponse { } impl EndpointError for UiaaResponse { - fn try_from_http_response>( - response: http::Response, - ) -> Result { + fn from_http_response>(response: http::Response) -> Self { if response.status() == http::StatusCode::UNAUTHORIZED { - Ok(UiaaResponse::AuthResponse(from_json_slice(response.body().as_ref())?)) - } else { - MatrixError::try_from_http_response(response).map(From::from) + if let Ok(uiaa_info) = from_json_slice(response.body().as_ref()) { + return Self::AuthResponse(uiaa_info); + } } + + Self::MatrixError(MatrixError::from_http_response(response)) } } diff --git a/crates/ruma-client-api/tests/uiaa.rs b/crates/ruma-client-api/tests/uiaa.rs index 40aef15c..c34a2e20 100644 --- a/crates/ruma-client-api/tests/uiaa.rs +++ b/crates/ruma-client-api/tests/uiaa.rs @@ -205,8 +205,8 @@ fn try_uiaa_response_from_http_response() { .unwrap(); let info = assert_matches!( - UiaaResponse::try_from_http_response(http_response), - Ok(UiaaResponse::AuthResponse(info)) => info + UiaaResponse::from_http_response(http_response), + UiaaResponse::AuthResponse(info) => info ); assert_eq!(info.completed, vec![AuthType::ReCaptcha]); assert_eq!(info.flows.len(), 2); diff --git a/crates/ruma-common/CHANGELOG.md b/crates/ruma-common/CHANGELOG.md index 32a29344..b1c6270b 100644 --- a/crates/ruma-common/CHANGELOG.md +++ b/crates/ruma-common/CHANGELOG.md @@ -21,6 +21,11 @@ Breaking changes: adjusted as well to not require this field. * Rename `push::PusherData` to `HttpPusherData` and make the `url` field required * Remove `Ruleset::add` and the implementation of `Extend` for `Ruleset` +* Make `EndpointError` construction infallible + * `EndpointError::try_from_http_request` has been replaced by `EndpointError::from_http_request` + * `FromHttpResponseError::Server` now contains `E` instead of `ServerError` + * `ServerError` has been removed + * `MatrixError` is now an enum with the `Json` variant containing the previous fields Improvements: diff --git a/crates/ruma-common/src/api.rs b/crates/ruma-common/src/api.rs index 01d622a0..5e32864d 100644 --- a/crates/ruma-common/src/api.rs +++ b/crates/ruma-common/src/api.rs @@ -584,9 +584,7 @@ pub trait EndpointError: OutgoingResponse + StdError + Sized + Send + 'static { /// /// This will always return `Err` variant when no `error` field is defined in /// the `ruma_api` macro. - fn try_from_http_response>( - response: http::Response, - ) -> Result; + fn from_http_response>(response: http::Response) -> Self; } /// Authentication scheme used by the endpoint. diff --git a/crates/ruma-common/src/api/error.rs b/crates/ruma-common/src/api/error.rs index 3f3ae0b8..1e38f439 100644 --- a/crates/ruma-common/src/api/error.rs +++ b/crates/ruma-common/src/api/error.rs @@ -2,9 +2,9 @@ //! converting between http requests / responses and ruma's representation of //! matrix API requests / responses. -use std::{error::Error as StdError, fmt}; +use std::{error::Error as StdError, fmt, sync::Arc}; -use bytes::BufMut; +use bytes::{BufMut, Bytes}; use serde_json::{from_slice as from_json_slice, Value as JsonValue}; use thiserror::Error; @@ -20,13 +20,16 @@ pub struct MatrixError { pub status_code: http::StatusCode, /// The http response's body. - pub body: JsonValue, + pub body: MatrixErrorBody, } impl fmt::Display for MatrixError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "[{}] ", self.status_code.as_u16())?; - fmt::Display::fmt(&self.body, f) + let status_code = self.status_code.as_u16(); + match &self.body { + MatrixErrorBody::Json(json) => write!(f, "[{status_code}] {json}"), + MatrixErrorBody::NotJson { .. } => write!(f, "[{status_code}] "), + } } } @@ -39,19 +42,54 @@ impl OutgoingResponse for MatrixError { http::Response::builder() .header(http::header::CONTENT_TYPE, "application/json") .status(self.status_code) - .body(crate::serde::json_to_buf(&self.body)?) + .body(match self.body { + MatrixErrorBody::Json(json) => crate::serde::json_to_buf(&json)?, + MatrixErrorBody::NotJson { .. } => { + return Err(IntoHttpError::Json(serde::ser::Error::custom( + "attempted to serialize MatrixErrorBody::NotJson", + ))); + } + }) .map_err(Into::into) } } impl EndpointError for MatrixError { - fn try_from_http_response>( - response: http::Response, - ) -> Result { - Ok(Self { - status_code: response.status(), - body: from_json_slice(response.body().as_ref())?, - }) + fn from_http_response>(response: http::Response) -> Self { + let status_code = response.status(); + let body = MatrixErrorBody::from_bytes(response.body().as_ref()); + Self { status_code, body } + } +} + +/// The body of an error response. +#[derive(Clone, Debug)] +#[allow(clippy::exhaustive_enums)] +pub enum MatrixErrorBody { + /// A JSON body, as intended. + Json(JsonValue), + + /// A response body that is not valid JSON. + #[non_exhaustive] + NotJson { + /// The raw bytes of the response body. + bytes: Bytes, + + /// The error from trying to deserialize the bytes as JSON. + deserialization_error: Arc, + }, +} + +impl MatrixErrorBody { + /// Create a `MatrixErrorBody` from the given HTTP body bytes. + pub fn from_bytes(body_bytes: &[u8]) -> Self { + match from_json_slice(body_bytes) { + Ok(json) => MatrixErrorBody::Json(json), + Err(e) => MatrixErrorBody::NotJson { + bytes: Bytes::copy_from_slice(body_bytes), + deserialization_error: Arc::new(e), + }, + } } } @@ -131,16 +169,13 @@ pub enum FromHttpResponseError { Deserialization(DeserializationError), /// The server returned a non-success status - Server(ServerError), + Server(E), } impl FromHttpResponseError { /// Map `FromHttpResponseError` to `FromHttpResponseError` by applying a function to a /// contained `Server` value, leaving a `Deserialization` value untouched. - pub fn map( - self, - f: impl FnOnce(ServerError) -> ServerError, - ) -> FromHttpResponseError { + pub fn map(self, f: impl FnOnce(E) -> F) -> FromHttpResponseError { match self { Self::Deserialization(d) => FromHttpResponseError::Deserialization(d), Self::Server(s) => FromHttpResponseError::Server(f(s)), @@ -153,7 +188,7 @@ impl FromHttpResponseError> { pub fn transpose(self) -> Result, F> { match self { Self::Deserialization(d) => Ok(FromHttpResponseError::Deserialization(d)), - Self::Server(s) => s.transpose().map(FromHttpResponseError::Server), + Self::Server(s) => s.map(FromHttpResponseError::Server), } } } @@ -167,12 +202,6 @@ impl fmt::Display for FromHttpResponseError { } } -impl From> for FromHttpResponseError { - fn from(err: ServerError) -> Self { - Self::Server(err) - } -} - impl From for FromHttpResponseError where T: Into, @@ -184,51 +213,6 @@ where impl StdError for FromHttpResponseError {} -/// An error was reported by the server (HTTP status code 4xx or 5xx) -#[derive(Debug)] -#[allow(clippy::exhaustive_enums)] -pub enum ServerError { - /// An error that is expected to happen under certain circumstances and - /// that has a well-defined structure - Known(E), - - /// An error of unexpected type of structure - Unknown(DeserializationError), -} - -impl ServerError { - /// Map `ServerError` to `ServerError` by applying a function to a contained `Known` - /// value, leaving an `Unknown` value untouched. - pub fn map(self, f: impl FnOnce(E) -> F) -> ServerError { - match self { - Self::Known(k) => ServerError::Known(f(k)), - Self::Unknown(u) => ServerError::Unknown(u), - } - } -} - -impl ServerError> { - /// Transpose `ServerError>` to `Result, F>`. - pub fn transpose(self) -> Result, F> { - match self { - Self::Known(Ok(k)) => Ok(ServerError::Known(k)), - Self::Known(Err(e)) => Err(e), - Self::Unknown(u) => Ok(ServerError::Unknown(u)), - } - } -} - -impl fmt::Display for ServerError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ServerError::Known(e) => fmt::Display::fmt(e, f), - ServerError::Unknown(res_err) => fmt::Display::fmt(res_err, f), - } - } -} - -impl StdError for ServerError {} - /// An error when converting a http request / response to one of ruma's endpoint-specific request / /// response types. #[derive(Debug, Error)] diff --git a/crates/ruma-common/tests/api/manual_endpoint_impl.rs b/crates/ruma-common/tests/api/manual_endpoint_impl.rs index 34b4cbd9..b4ee3997 100644 --- a/crates/ruma-common/tests/api/manual_endpoint_impl.rs +++ b/crates/ruma-common/tests/api/manual_endpoint_impl.rs @@ -6,9 +6,7 @@ use bytes::BufMut; use http::{header::CONTENT_TYPE, method::Method}; use ruma_common::{ api::{ - error::{ - FromHttpRequestError, FromHttpResponseError, IntoHttpError, MatrixError, ServerError, - }, + error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError, MatrixError}, AuthScheme, EndpointError, IncomingRequest, IncomingResponse, MatrixVersion, Metadata, OutgoingRequest, OutgoingResponse, SendAccessToken, VersionHistory, }, @@ -119,9 +117,7 @@ impl IncomingResponse for Response { if http_response.status().as_u16() < 400 { Ok(Response) } else { - Err(FromHttpResponseError::Server(ServerError::Known( - ::try_from_http_response(http_response)?, - ))) + Err(FromHttpResponseError::Server(MatrixError::from_http_response(http_response))) } } } diff --git a/crates/ruma-macros/src/api/response/incoming.rs b/crates/ruma-macros/src/api/response/incoming.rs index 59260c53..1b0b8890 100644 --- a/crates/ruma-macros/src/api/response/incoming.rs +++ b/crates/ruma-macros/src/api/response/incoming.rs @@ -124,16 +124,11 @@ impl Response { #response_init_fields }) } else { - match <#error_ty as #ruma_common::api::EndpointError>::try_from_http_response( - response - ) { - ::std::result::Result::Ok(err) => { - Err(#ruma_common::api::error::ServerError::Known(err).into()) - } - ::std::result::Result::Err(response_err) => { - Err(#ruma_common::api::error::ServerError::Unknown(response_err).into()) - } - } + Err(#ruma_common::api::error::FromHttpResponseError::Server( + <#error_ty as #ruma_common::api::EndpointError>::from_http_response( + response, + ) + )) } } }