api: Make EndpointError construction infallible

Simplifies error matching and preserves more information for
non-spec-compliant server errors.
This commit is contained in:
Jonas Platte 2022-11-09 19:24:16 +01:00 committed by Jonas Platte
parent 7d018897b0
commit a8ba82d585
9 changed files with 169 additions and 151 deletions

View File

@ -13,6 +13,8 @@ Breaking changes:
* Make `push::PusherKind` contain the pusher's `data` * Make `push::PusherKind` contain the pusher's `data`
* Use an enum for the `scope` of the `push` endpoints * Use an enum for the `scope` of the `push` endpoints
* Use `NewPushRule` to construct a `push::set_pushrule::v3::Request` * Use `NewPushRule` to construct a `push::set_pushrule::v3::Request`
* `Error` is now an enum because endpoint error construction is infallible (see changelog for
`ruma-common`); the previous fields are in the `Standard` variant
Improvements: Improvements:

View File

@ -1,11 +1,11 @@
//! Errors that can be sent from the homeserver. //! Errors that can be sent from the homeserver.
use std::{collections::BTreeMap, fmt, time::Duration}; use std::{collections::BTreeMap, fmt, sync::Arc, time::Duration};
use bytes::BufMut; use bytes::{BufMut, Bytes};
use ruma_common::{ use ruma_common::{
api::{ api::{
error::{DeserializationError, IntoHttpError}, error::{IntoHttpError, MatrixErrorBody},
EndpointError, OutgoingResponse, EndpointError, OutgoingResponse,
}, },
RoomVersionId, RoomVersionId,
@ -220,10 +220,37 @@ impl fmt::Display for ErrorKind {
} }
} }
/// A Matrix Error without a status code. /// The body of a Matrix Client API error.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone)]
#[allow(clippy::exhaustive_enums)]
pub enum ErrorBody {
/// A JSON body with the fields expected for Client API errors.
Standard {
/// A value which can be used to handle an error message.
kind: ErrorKind,
/// A human-readable error message, usually a sentence explaining what went wrong.
message: String,
},
/// A JSON body with an unexpected structure.
Json(JsonValue),
/// A response body that is not valid JSON.
#[non_exhaustive]
NotJson {
/// The raw bytes of the response body.
bytes: Bytes,
/// The error from trying to deserialize the bytes as JSON.
deserialization_error: Arc<serde_json::Error>,
},
}
/// A JSON body with the fields expected for Client API errors.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[allow(clippy::exhaustive_structs)] #[allow(clippy::exhaustive_structs)]
pub struct ErrorBody { pub struct StandardErrorBody {
/// A value which can be used to handle an error message. /// A value which can be used to handle an error message.
#[serde(flatten)] #[serde(flatten)]
pub kind: ErrorKind, pub kind: ErrorKind,
@ -237,73 +264,72 @@ pub struct ErrorBody {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[allow(clippy::exhaustive_structs)] #[allow(clippy::exhaustive_structs)]
pub struct Error { pub struct Error {
/// A value which can be used to handle an error message.
pub kind: ErrorKind,
/// A human-readable error message, usually a sentence explaining what went wrong.
pub message: String,
/// The http status code. /// The http status code.
pub status_code: http::StatusCode, pub status_code: http::StatusCode,
/// The `WWW-Authenticate` header error message. /// The `WWW-Authenticate` header error message.
#[cfg(feature = "unstable-msc2967")] #[cfg(feature = "unstable-msc2967")]
pub authenticate: Option<AuthenticateError>, pub authenticate: Option<AuthenticateError>,
/// The http response's body.
pub body: ErrorBody,
} }
impl EndpointError for Error { impl EndpointError for Error {
fn try_from_http_response<T: AsRef<[u8]>>( fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
response: http::Response<T>,
) -> Result<Self, DeserializationError> {
let status = response.status(); let status = response.status();
let error_body: ErrorBody = from_json_slice(response.body().as_ref())?;
#[cfg(not(feature = "unstable-msc2967"))]
{
Ok(error_body.into_error(status))
}
#[cfg(feature = "unstable-msc2967")] #[cfg(feature = "unstable-msc2967")]
{ let authenticate = response
use ruma_common::api::error::HeaderDeserializationError; .headers()
.get(http::header::WWW_AUTHENTICATE)
.and_then(|val| val.to_str().ok())
.and_then(AuthenticateError::from_str);
let mut error = error_body.into_error(status); let body_bytes = &response.body().as_ref();
let error_body: ErrorBody = match from_json_slice(body_bytes) {
Ok(StandardErrorBody { kind, message }) => ErrorBody::Standard { kind, message },
Err(_) => match MatrixErrorBody::from_bytes(body_bytes) {
MatrixErrorBody::Json(json) => ErrorBody::Json(json),
MatrixErrorBody::NotJson { bytes, deserialization_error, .. } => {
ErrorBody::NotJson { bytes, deserialization_error }
}
},
};
error.authenticate = response let error = error_body.into_error(status);
.headers()
.get(http::header::WWW_AUTHENTICATE)
.map(|val| val.to_str().map_err(HeaderDeserializationError::ToStrError))
.transpose()?
.and_then(AuthenticateError::from_str);
Ok(error) #[cfg(not(feature = "unstable-msc2967"))]
} return error;
#[cfg(feature = "unstable-msc2967")]
Self { authenticate, ..error }
} }
} }
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{} / {}] {}", self.status_code.as_u16(), self.kind, self.message) let status_code = self.status_code.as_u16();
match &self.body {
ErrorBody::Standard { kind, message } => {
write!(f, "[{status_code} / {kind}] {message}")
}
ErrorBody::Json(json) => write!(f, "[{status_code}] {json}"),
ErrorBody::NotJson { .. } => write!(f, "[{status_code}] <non-json bytes>"),
}
} }
} }
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl From<Error> for ErrorBody {
fn from(error: Error) -> Self {
Self { kind: error.kind, message: error.message }
}
}
impl ErrorBody { impl ErrorBody {
/// Convert the ErrorBody into an Error by adding the http status code. /// Convert the ErrorBody into an Error by adding the http status code.
pub fn into_error(self, status_code: http::StatusCode) -> Error { pub fn into_error(self, status_code: http::StatusCode) -> Error {
Error { Error {
kind: self.kind,
message: self.message,
status_code, status_code,
#[cfg(feature = "unstable-msc2967")] #[cfg(feature = "unstable-msc2967")]
authenticate: None, authenticate: None,
body: self,
} }
} }
} }
@ -323,7 +349,19 @@ impl OutgoingResponse for Error {
builder builder
}; };
builder.body(ruma_common::serde::json_to_buf(&ErrorBody::from(self))?).map_err(Into::into) builder
.body(match self.body {
ErrorBody::Standard { kind, message } => {
ruma_common::serde::json_to_buf(&StandardErrorBody { kind, message })?
}
ErrorBody::Json(json) => ruma_common::serde::json_to_buf(&json)?,
ErrorBody::NotJson { .. } => {
return Err(IntoHttpError::Json(serde::ser::Error::custom(
"attempted to serialize ErrorBody::NotJson",
)));
}
})
.map_err(Into::into)
} }
} }
@ -426,11 +464,11 @@ impl TryFrom<&AuthenticateError> for http::HeaderValue {
mod tests { mod tests {
use serde_json::{from_value as from_json_value, json}; use serde_json::{from_value as from_json_value, json};
use super::{ErrorBody, ErrorKind}; use super::{ErrorKind, StandardErrorBody};
#[test] #[test]
fn deserialize_forbidden() { fn deserialize_forbidden() {
let deserialized: ErrorBody = from_json_value(json!({ let deserialized: StandardErrorBody = from_json_value(json!({
"errcode": "M_FORBIDDEN", "errcode": "M_FORBIDDEN",
"error": "You are not authorized to ban users in this room.", "error": "You are not authorized to ban users in this room.",
})) }))
@ -471,9 +509,10 @@ mod tests {
#[cfg(feature = "unstable-msc2967")] #[cfg(feature = "unstable-msc2967")]
#[test] #[test]
fn deserialize_insufficient_scope() { fn deserialize_insufficient_scope() {
use assert_matches::assert_matches;
use ruma_common::api::EndpointError; use ruma_common::api::EndpointError;
use super::{AuthenticateError, Error}; use super::{AuthenticateError, Error, ErrorBody};
let response = http::Response::builder() let response = http::Response::builder()
.header( .header(
@ -489,11 +528,13 @@ mod tests {
.unwrap(), .unwrap(),
) )
.unwrap(); .unwrap();
let error = Error::try_from_http_response(response).unwrap(); let error = Error::from_http_response(response);
assert_eq!(error.status_code, http::StatusCode::UNAUTHORIZED); assert_eq!(error.status_code, http::StatusCode::UNAUTHORIZED);
assert_eq!(error.kind, ErrorKind::Forbidden); let (kind, message) =
assert_eq!(error.message, "Insufficient privilege"); assert_matches!(error.body, ErrorBody::Standard { kind, message } => (kind, message));
assert_eq!(kind, ErrorKind::Forbidden);
assert_eq!(message, "Insufficient privilege");
let scope = assert_matches::assert_matches!( let scope = assert_matches::assert_matches!(
error.authenticate, error.authenticate,
Some(AuthenticateError::InsufficientScope { scope }) => scope Some(AuthenticateError::InsufficientScope { scope }) => scope

View File

@ -6,10 +6,7 @@ use std::{borrow::Cow, fmt};
use bytes::BufMut; use bytes::BufMut;
use ruma_common::{ use ruma_common::{
api::{ api::{error::IntoHttpError, EndpointError, OutgoingResponse},
error::{DeserializationError, IntoHttpError},
EndpointError, OutgoingResponse,
},
serde::{from_raw_json_value, Incoming, JsonObject, StringEnum}, serde::{from_raw_json_value, Incoming, JsonObject, StringEnum},
thirdparty::Medium, thirdparty::Medium,
ClientSecret, OwnedSessionId, OwnedUserId, UserId, ClientSecret, OwnedSessionId, OwnedUserId, UserId,
@ -23,7 +20,7 @@ use serde_json::{
}; };
use crate::{ use crate::{
error::{Error as MatrixError, ErrorBody}, error::{Error as MatrixError, StandardErrorBody},
PrivOwnedStr, PrivOwnedStr,
}; };
@ -838,7 +835,7 @@ pub struct UiaaInfo {
/// Authentication-related errors for previous request returned by homeserver. /// Authentication-related errors for previous request returned by homeserver.
#[serde(flatten, skip_serializing_if = "Option::is_none")] #[serde(flatten, skip_serializing_if = "Option::is_none")]
pub auth_error: Option<ErrorBody>, pub auth_error: Option<StandardErrorBody>,
} }
impl UiaaInfo { impl UiaaInfo {
@ -893,14 +890,14 @@ impl From<MatrixError> for UiaaResponse {
} }
impl EndpointError for UiaaResponse { impl EndpointError for UiaaResponse {
fn try_from_http_response<T: AsRef<[u8]>>( fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
response: http::Response<T>,
) -> Result<Self, DeserializationError> {
if response.status() == http::StatusCode::UNAUTHORIZED { if response.status() == http::StatusCode::UNAUTHORIZED {
Ok(UiaaResponse::AuthResponse(from_json_slice(response.body().as_ref())?)) if let Ok(uiaa_info) = from_json_slice(response.body().as_ref()) {
} else { return Self::AuthResponse(uiaa_info);
MatrixError::try_from_http_response(response).map(From::from) }
} }
Self::MatrixError(MatrixError::from_http_response(response))
} }
} }

View File

@ -205,8 +205,8 @@ fn try_uiaa_response_from_http_response() {
.unwrap(); .unwrap();
let info = assert_matches!( let info = assert_matches!(
UiaaResponse::try_from_http_response(http_response), UiaaResponse::from_http_response(http_response),
Ok(UiaaResponse::AuthResponse(info)) => info UiaaResponse::AuthResponse(info) => info
); );
assert_eq!(info.completed, vec![AuthType::ReCaptcha]); assert_eq!(info.completed, vec![AuthType::ReCaptcha]);
assert_eq!(info.flows.len(), 2); assert_eq!(info.flows.len(), 2);

View File

@ -21,6 +21,11 @@ Breaking changes:
adjusted as well to not require this field. adjusted as well to not require this field.
* Rename `push::PusherData` to `HttpPusherData` and make the `url` field required * Rename `push::PusherData` to `HttpPusherData` and make the `url` field required
* Remove `Ruleset::add` and the implementation of `Extend<AnyPushRule>` for `Ruleset` * Remove `Ruleset::add` and the implementation of `Extend<AnyPushRule>` for `Ruleset`
* Make `EndpointError` construction infallible
* `EndpointError::try_from_http_request` has been replaced by `EndpointError::from_http_request`
* `FromHttpResponseError<E>::Server` now contains `E` instead of `ServerError<E>`
* `ServerError<E>` has been removed
* `MatrixError` is now an enum with the `Json` variant containing the previous fields
Improvements: Improvements:

View File

@ -584,9 +584,7 @@ pub trait EndpointError: OutgoingResponse + StdError + Sized + Send + 'static {
/// ///
/// This will always return `Err` variant when no `error` field is defined in /// This will always return `Err` variant when no `error` field is defined in
/// the `ruma_api` macro. /// the `ruma_api` macro.
fn try_from_http_response<T: AsRef<[u8]>>( fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self;
response: http::Response<T>,
) -> Result<Self, error::DeserializationError>;
} }
/// Authentication scheme used by the endpoint. /// Authentication scheme used by the endpoint.

View File

@ -2,9 +2,9 @@
//! converting between http requests / responses and ruma's representation of //! converting between http requests / responses and ruma's representation of
//! matrix API requests / responses. //! matrix API requests / responses.
use std::{error::Error as StdError, fmt}; use std::{error::Error as StdError, fmt, sync::Arc};
use bytes::BufMut; use bytes::{BufMut, Bytes};
use serde_json::{from_slice as from_json_slice, Value as JsonValue}; use serde_json::{from_slice as from_json_slice, Value as JsonValue};
use thiserror::Error; use thiserror::Error;
@ -20,13 +20,16 @@ pub struct MatrixError {
pub status_code: http::StatusCode, pub status_code: http::StatusCode,
/// The http response's body. /// The http response's body.
pub body: JsonValue, pub body: MatrixErrorBody,
} }
impl fmt::Display for MatrixError { impl fmt::Display for MatrixError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{}] ", self.status_code.as_u16())?; let status_code = self.status_code.as_u16();
fmt::Display::fmt(&self.body, f) match &self.body {
MatrixErrorBody::Json(json) => write!(f, "[{status_code}] {json}"),
MatrixErrorBody::NotJson { .. } => write!(f, "[{status_code}] <non-json bytes>"),
}
} }
} }
@ -39,19 +42,54 @@ impl OutgoingResponse for MatrixError {
http::Response::builder() http::Response::builder()
.header(http::header::CONTENT_TYPE, "application/json") .header(http::header::CONTENT_TYPE, "application/json")
.status(self.status_code) .status(self.status_code)
.body(crate::serde::json_to_buf(&self.body)?) .body(match self.body {
MatrixErrorBody::Json(json) => crate::serde::json_to_buf(&json)?,
MatrixErrorBody::NotJson { .. } => {
return Err(IntoHttpError::Json(serde::ser::Error::custom(
"attempted to serialize MatrixErrorBody::NotJson",
)));
}
})
.map_err(Into::into) .map_err(Into::into)
} }
} }
impl EndpointError for MatrixError { impl EndpointError for MatrixError {
fn try_from_http_response<T: AsRef<[u8]>>( fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
response: http::Response<T>, let status_code = response.status();
) -> Result<Self, DeserializationError> { let body = MatrixErrorBody::from_bytes(response.body().as_ref());
Ok(Self { Self { status_code, body }
status_code: response.status(), }
body: from_json_slice(response.body().as_ref())?, }
})
/// The body of an error response.
#[derive(Clone, Debug)]
#[allow(clippy::exhaustive_enums)]
pub enum MatrixErrorBody {
/// A JSON body, as intended.
Json(JsonValue),
/// A response body that is not valid JSON.
#[non_exhaustive]
NotJson {
/// The raw bytes of the response body.
bytes: Bytes,
/// The error from trying to deserialize the bytes as JSON.
deserialization_error: Arc<serde_json::Error>,
},
}
impl MatrixErrorBody {
/// Create a `MatrixErrorBody` from the given HTTP body bytes.
pub fn from_bytes(body_bytes: &[u8]) -> Self {
match from_json_slice(body_bytes) {
Ok(json) => MatrixErrorBody::Json(json),
Err(e) => MatrixErrorBody::NotJson {
bytes: Bytes::copy_from_slice(body_bytes),
deserialization_error: Arc::new(e),
},
}
} }
} }
@ -131,16 +169,13 @@ pub enum FromHttpResponseError<E> {
Deserialization(DeserializationError), Deserialization(DeserializationError),
/// The server returned a non-success status /// The server returned a non-success status
Server(ServerError<E>), Server(E),
} }
impl<E> FromHttpResponseError<E> { impl<E> FromHttpResponseError<E> {
/// Map `FromHttpResponseError<E>` to `FromHttpResponseError<F>` by applying a function to a /// Map `FromHttpResponseError<E>` to `FromHttpResponseError<F>` by applying a function to a
/// contained `Server` value, leaving a `Deserialization` value untouched. /// contained `Server` value, leaving a `Deserialization` value untouched.
pub fn map<F>( pub fn map<F>(self, f: impl FnOnce(E) -> F) -> FromHttpResponseError<F> {
self,
f: impl FnOnce(ServerError<E>) -> ServerError<F>,
) -> FromHttpResponseError<F> {
match self { match self {
Self::Deserialization(d) => FromHttpResponseError::Deserialization(d), Self::Deserialization(d) => FromHttpResponseError::Deserialization(d),
Self::Server(s) => FromHttpResponseError::Server(f(s)), Self::Server(s) => FromHttpResponseError::Server(f(s)),
@ -153,7 +188,7 @@ impl<E, F> FromHttpResponseError<Result<E, F>> {
pub fn transpose(self) -> Result<FromHttpResponseError<E>, F> { pub fn transpose(self) -> Result<FromHttpResponseError<E>, F> {
match self { match self {
Self::Deserialization(d) => Ok(FromHttpResponseError::Deserialization(d)), Self::Deserialization(d) => Ok(FromHttpResponseError::Deserialization(d)),
Self::Server(s) => s.transpose().map(FromHttpResponseError::Server), Self::Server(s) => s.map(FromHttpResponseError::Server),
} }
} }
} }
@ -167,12 +202,6 @@ impl<E: fmt::Display> fmt::Display for FromHttpResponseError<E> {
} }
} }
impl<E> From<ServerError<E>> for FromHttpResponseError<E> {
fn from(err: ServerError<E>) -> Self {
Self::Server(err)
}
}
impl<E, T> From<T> for FromHttpResponseError<E> impl<E, T> From<T> for FromHttpResponseError<E>
where where
T: Into<DeserializationError>, T: Into<DeserializationError>,
@ -184,51 +213,6 @@ where
impl<E: StdError> StdError for FromHttpResponseError<E> {} impl<E: StdError> StdError for FromHttpResponseError<E> {}
/// An error was reported by the server (HTTP status code 4xx or 5xx)
#[derive(Debug)]
#[allow(clippy::exhaustive_enums)]
pub enum ServerError<E> {
/// An error that is expected to happen under certain circumstances and
/// that has a well-defined structure
Known(E),
/// An error of unexpected type of structure
Unknown(DeserializationError),
}
impl<E> ServerError<E> {
/// Map `ServerError<E>` to `ServerError<F>` by applying a function to a contained `Known`
/// value, leaving an `Unknown` value untouched.
pub fn map<F>(self, f: impl FnOnce(E) -> F) -> ServerError<F> {
match self {
Self::Known(k) => ServerError::Known(f(k)),
Self::Unknown(u) => ServerError::Unknown(u),
}
}
}
impl<E, F> ServerError<Result<E, F>> {
/// Transpose `ServerError<Result<E, F>>` to `Result<ServerError<E>, F>`.
pub fn transpose(self) -> Result<ServerError<E>, F> {
match self {
Self::Known(Ok(k)) => Ok(ServerError::Known(k)),
Self::Known(Err(e)) => Err(e),
Self::Unknown(u) => Ok(ServerError::Unknown(u)),
}
}
}
impl<E: fmt::Display> fmt::Display for ServerError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ServerError::Known(e) => fmt::Display::fmt(e, f),
ServerError::Unknown(res_err) => fmt::Display::fmt(res_err, f),
}
}
}
impl<E: StdError> StdError for ServerError<E> {}
/// An error when converting a http request / response to one of ruma's endpoint-specific request / /// An error when converting a http request / response to one of ruma's endpoint-specific request /
/// response types. /// response types.
#[derive(Debug, Error)] #[derive(Debug, Error)]

View File

@ -6,9 +6,7 @@ use bytes::BufMut;
use http::{header::CONTENT_TYPE, method::Method}; use http::{header::CONTENT_TYPE, method::Method};
use ruma_common::{ use ruma_common::{
api::{ api::{
error::{ error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError, MatrixError},
FromHttpRequestError, FromHttpResponseError, IntoHttpError, MatrixError, ServerError,
},
AuthScheme, EndpointError, IncomingRequest, IncomingResponse, MatrixVersion, Metadata, AuthScheme, EndpointError, IncomingRequest, IncomingResponse, MatrixVersion, Metadata,
OutgoingRequest, OutgoingResponse, SendAccessToken, VersionHistory, OutgoingRequest, OutgoingResponse, SendAccessToken, VersionHistory,
}, },
@ -119,9 +117,7 @@ impl IncomingResponse for Response {
if http_response.status().as_u16() < 400 { if http_response.status().as_u16() < 400 {
Ok(Response) Ok(Response)
} else { } else {
Err(FromHttpResponseError::Server(ServerError::Known( Err(FromHttpResponseError::Server(MatrixError::from_http_response(http_response)))
<MatrixError as EndpointError>::try_from_http_response(http_response)?,
)))
} }
} }
} }

View File

@ -124,16 +124,11 @@ impl Response {
#response_init_fields #response_init_fields
}) })
} else { } else {
match <#error_ty as #ruma_common::api::EndpointError>::try_from_http_response( Err(#ruma_common::api::error::FromHttpResponseError::Server(
response <#error_ty as #ruma_common::api::EndpointError>::from_http_response(
) { response,
::std::result::Result::Ok(err) => { )
Err(#ruma_common::api::error::ServerError::Known(err).into()) ))
}
::std::result::Result::Err(response_err) => {
Err(#ruma_common::api::error::ServerError::Unknown(response_err).into())
}
}
} }
} }
} }