From 917584e0cae4ae8642625f234f22f049bc159fee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Tue, 26 Mar 2024 11:46:49 +0100 Subject: [PATCH] client-api: Move Error authenticate field to ErrorKind::Forbidden --- crates/ruma-client-api/CHANGELOG.md | 2 + crates/ruma-client-api/src/error.rs | 94 +++++++++++-------- .../ruma-client-api/src/error/kind_serde.rs | 18 +++- crates/ruma-client-api/tests/uiaa.rs | 4 +- 4 files changed, 76 insertions(+), 42 deletions(-) diff --git a/crates/ruma-client-api/CHANGELOG.md b/crates/ruma-client-api/CHANGELOG.md index 0f6f9005..317e9b85 100644 --- a/crates/ruma-client-api/CHANGELOG.md +++ b/crates/ruma-client-api/CHANGELOG.md @@ -14,6 +14,8 @@ Breaking changes: - The query parameter of `check_registration_token_validity` endpoint has been renamed from `registration_token` to `token` - `Error` is now non-exhaustive. +- `ErrorKind::Forbidden` is now a non-exhaustive struct variant that can be + constructed with `ErrorKind::forbidden()`. Improvements: diff --git a/crates/ruma-client-api/src/error.rs b/crates/ruma-client-api/src/error.rs index 1fa7bc28..99b94260 100644 --- a/crates/ruma-client-api/src/error.rs +++ b/crates/ruma-client-api/src/error.rs @@ -27,7 +27,12 @@ mod kind_serde; #[non_exhaustive] pub enum ErrorKind { /// M_FORBIDDEN - Forbidden, + #[non_exhaustive] + Forbidden { + /// The `WWW-Authenticate` header error message. + #[cfg(feature = "unstable-msc2967")] + authenticate: Option, + }, /// M_UNKNOWN_TOKEN UnknownToken { @@ -192,6 +197,23 @@ pub enum ErrorKind { _Custom { errcode: PrivOwnedStr, extra: Extra }, } +impl ErrorKind { + /// Constructs an empty [`ErrorKind::Forbidden`] variant. + pub fn forbidden() -> Self { + Self::Forbidden { + #[cfg(feature = "unstable-msc2967")] + authenticate: None, + } + } + + /// Constructs an [`ErrorKind::Forbidden`] variant with the given `WWW-Authenticate` header + /// error message. + #[cfg(feature = "unstable-msc2967")] + pub fn forbidden_with_authenticate(authenticate: AuthenticateError) -> Self { + Self::Forbidden { authenticate: Some(authenticate) } + } +} + #[doc(hidden)] #[derive(Clone, Debug, PartialEq, Eq)] pub struct Extra(BTreeMap); @@ -199,7 +221,7 @@ pub struct Extra(BTreeMap); impl AsRef for ErrorKind { fn as_ref(&self) -> &str { match self { - Self::Forbidden => "M_FORBIDDEN", + Self::Forbidden { .. } => "M_FORBIDDEN", Self::UnknownToken { .. } => "M_UNKNOWN_TOKEN", Self::MissingToken => "M_MISSING_TOKEN", Self::BadJson => "M_BAD_JSON", @@ -303,10 +325,6 @@ pub struct Error { /// The http status code. pub status_code: http::StatusCode, - /// The `WWW-Authenticate` header error message. - #[cfg(feature = "unstable-msc2967")] - pub authenticate: Option, - /// The http response's body. pub body: ErrorBody, } @@ -316,12 +334,7 @@ impl Error { /// /// This is equivalent to calling `body.into_error(status_code)`. pub fn new(status_code: http::StatusCode, body: ErrorBody) -> Self { - Self { - status_code, - #[cfg(feature = "unstable-msc2967")] - authenticate: None, - body, - } + Self { status_code, body } } /// If `self` is a server error in the `errcode` + `error` format expected @@ -335,16 +348,24 @@ impl EndpointError for Error { fn from_http_response>(response: http::Response) -> Self { let status = response.status(); - #[cfg(feature = "unstable-msc2967")] - let authenticate = response - .headers() - .get(http::header::WWW_AUTHENTICATE) - .and_then(|val| val.to_str().ok()) - .and_then(AuthenticateError::from_str); - let body_bytes = &response.body().as_ref(); let error_body: ErrorBody = match from_json_slice(body_bytes) { - Ok(StandardErrorBody { kind, message }) => ErrorBody::Standard { kind, message }, + Ok(StandardErrorBody { kind, message }) => { + #[cfg(feature = "unstable-msc2967")] + let kind = if let ErrorKind::Forbidden { .. } = kind { + let authenticate = response + .headers() + .get(http::header::WWW_AUTHENTICATE) + .and_then(|val| val.to_str().ok()) + .and_then(AuthenticateError::from_str); + + ErrorKind::Forbidden { authenticate } + } else { + kind + }; + + ErrorBody::Standard { kind, message } + } Err(_) => match MatrixErrorBody::from_bytes(body_bytes) { MatrixErrorBody::Json(json) => ErrorBody::Json(json), MatrixErrorBody::NotJson { bytes, deserialization_error, .. } => { @@ -353,13 +374,7 @@ impl EndpointError for Error { }, }; - let error = error_body.into_error(status); - - #[cfg(not(feature = "unstable-msc2967"))] - return error; - - #[cfg(feature = "unstable-msc2967")] - Self { authenticate, ..error } + error_body.into_error(status) } } @@ -383,12 +398,7 @@ impl ErrorBody { /// /// This is equivalent to calling `Error::new(status_code, self)`. pub fn into_error(self, status_code: http::StatusCode) -> Error { - Error { - status_code, - #[cfg(feature = "unstable-msc2967")] - authenticate: None, - body: self, - } + Error { status_code, body: self } } } @@ -401,7 +411,11 @@ impl OutgoingResponse for Error { .status(self.status_code); #[cfg(feature = "unstable-msc2967")] - let builder = if let Some(auth_error) = &self.authenticate { + let builder = if let ErrorBody::Standard { + kind: ErrorKind::Forbidden { authenticate: Some(auth_error) }, + .. + } = &self.body + { builder.header(http::header::WWW_AUTHENTICATE, auth_error) } else { builder @@ -546,7 +560,13 @@ mod tests { })) .unwrap(); - assert_eq!(deserialized.kind, ErrorKind::Forbidden); + assert_eq!( + deserialized.kind, + ErrorKind::Forbidden { + #[cfg(feature = "unstable-msc2967")] + authenticate: None + } + ); assert_eq!(deserialized.message, "You are not authorized to ban users in this room."); } @@ -617,9 +637,9 @@ mod tests { assert_eq!(error.status_code, http::StatusCode::UNAUTHORIZED); assert_matches!(error.body, ErrorBody::Standard { kind, message }); - assert_eq!(kind, ErrorKind::Forbidden); + assert_matches!(kind, ErrorKind::Forbidden { authenticate }); assert_eq!(message, "Insufficient privilege"); - assert_matches!(error.authenticate, Some(AuthenticateError::InsufficientScope { scope })); + assert_matches!(authenticate, Some(AuthenticateError::InsufficientScope { scope })); assert_eq!(scope, "something_privileged"); } } diff --git a/crates/ruma-client-api/src/error/kind_serde.rs b/crates/ruma-client-api/src/error/kind_serde.rs index b5a3d8c8..db27a5fd 100644 --- a/crates/ruma-client-api/src/error/kind_serde.rs +++ b/crates/ruma-client-api/src/error/kind_serde.rs @@ -165,7 +165,7 @@ impl<'de> Visitor<'de> for ErrorKindVisitor { let extra = Extra(extra); Ok(match errcode { - ErrCode::Forbidden => ErrorKind::Forbidden, + ErrCode::Forbidden => ErrorKind::forbidden(), ErrCode::UnknownToken => ErrorKind::UnknownToken { soft_logout: soft_logout .map(from_json_value) @@ -361,7 +361,13 @@ mod tests { #[test] fn deserialize_forbidden() { let deserialized: ErrorKind = from_json_value(json!({ "errcode": "M_FORBIDDEN" })).unwrap(); - assert_eq!(deserialized, ErrorKind::Forbidden); + assert_eq!( + deserialized, + ErrorKind::Forbidden { + #[cfg(feature = "unstable-msc2967")] + authenticate: None + } + ); } #[test] @@ -372,7 +378,13 @@ mod tests { })) .unwrap(); - assert_eq!(deserialized, ErrorKind::Forbidden); + assert_eq!( + deserialized, + ErrorKind::Forbidden { + #[cfg(feature = "unstable-msc2967")] + authenticate: None + } + ); } #[test] diff --git a/crates/ruma-client-api/tests/uiaa.rs b/crates/ruma-client-api/tests/uiaa.rs index efe521cc..1a2b0674 100644 --- a/crates/ruma-client-api/tests/uiaa.rs +++ b/crates/ruma-client-api/tests/uiaa.rs @@ -125,7 +125,7 @@ fn deserialize_uiaa_info() { assert_eq!(info.flows[1].stages, vec![AuthType::EmailIdentity, AuthType::Msisdn]); assert_eq!(info.session.as_deref(), Some("xxxxxx")); let auth_error = info.auth_error.unwrap(); - assert_eq!(auth_error.kind, ErrorKind::Forbidden); + assert_matches!(auth_error.kind, ErrorKind::Forbidden { .. }); assert_eq!(auth_error.message, "Invalid password"); assert_eq!( from_json_str::(info.params.get()).unwrap(), @@ -207,7 +207,7 @@ fn try_uiaa_response_from_http_response() { assert_eq!(info.flows[1].stages, vec![AuthType::EmailIdentity, AuthType::Msisdn]); assert_eq!(info.session.as_deref(), Some("xxxxxx")); let auth_error = info.auth_error.unwrap(); - assert_eq!(auth_error.kind, ErrorKind::Forbidden); + assert_matches!(auth_error.kind, ErrorKind::Forbidden { .. }); assert_eq!(auth_error.message, "Invalid password"); assert_eq!( from_json_str::(info.params.get()).unwrap(),