client-api: Move Error authenticate field to ErrorKind::Forbidden

This commit is contained in:
Kévin Commaille 2024-03-26 11:46:49 +01:00 committed by Kévin Commaille
parent 4f4085a013
commit 917584e0ca
4 changed files with 76 additions and 42 deletions

View File

@ -14,6 +14,8 @@ Breaking changes:
- The query parameter of `check_registration_token_validity` endpoint - The query parameter of `check_registration_token_validity` endpoint
has been renamed from `registration_token` to `token` has been renamed from `registration_token` to `token`
- `Error` is now non-exhaustive. - `Error` is now non-exhaustive.
- `ErrorKind::Forbidden` is now a non-exhaustive struct variant that can be
constructed with `ErrorKind::forbidden()`.
Improvements: Improvements:

View File

@ -27,7 +27,12 @@ mod kind_serde;
#[non_exhaustive] #[non_exhaustive]
pub enum ErrorKind { pub enum ErrorKind {
/// M_FORBIDDEN /// M_FORBIDDEN
Forbidden, #[non_exhaustive]
Forbidden {
/// The `WWW-Authenticate` header error message.
#[cfg(feature = "unstable-msc2967")]
authenticate: Option<AuthenticateError>,
},
/// M_UNKNOWN_TOKEN /// M_UNKNOWN_TOKEN
UnknownToken { UnknownToken {
@ -192,6 +197,23 @@ pub enum ErrorKind {
_Custom { errcode: PrivOwnedStr, extra: Extra }, _Custom { errcode: PrivOwnedStr, extra: Extra },
} }
impl ErrorKind {
/// Constructs an empty [`ErrorKind::Forbidden`] variant.
pub fn forbidden() -> Self {
Self::Forbidden {
#[cfg(feature = "unstable-msc2967")]
authenticate: None,
}
}
/// Constructs an [`ErrorKind::Forbidden`] variant with the given `WWW-Authenticate` header
/// error message.
#[cfg(feature = "unstable-msc2967")]
pub fn forbidden_with_authenticate(authenticate: AuthenticateError) -> Self {
Self::Forbidden { authenticate: Some(authenticate) }
}
}
#[doc(hidden)] #[doc(hidden)]
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Extra(BTreeMap<String, JsonValue>); pub struct Extra(BTreeMap<String, JsonValue>);
@ -199,7 +221,7 @@ pub struct Extra(BTreeMap<String, JsonValue>);
impl AsRef<str> for ErrorKind { impl AsRef<str> for ErrorKind {
fn as_ref(&self) -> &str { fn as_ref(&self) -> &str {
match self { match self {
Self::Forbidden => "M_FORBIDDEN", Self::Forbidden { .. } => "M_FORBIDDEN",
Self::UnknownToken { .. } => "M_UNKNOWN_TOKEN", Self::UnknownToken { .. } => "M_UNKNOWN_TOKEN",
Self::MissingToken => "M_MISSING_TOKEN", Self::MissingToken => "M_MISSING_TOKEN",
Self::BadJson => "M_BAD_JSON", Self::BadJson => "M_BAD_JSON",
@ -303,10 +325,6 @@ pub struct Error {
/// The http status code. /// The http status code.
pub status_code: http::StatusCode, pub status_code: http::StatusCode,
/// The `WWW-Authenticate` header error message.
#[cfg(feature = "unstable-msc2967")]
pub authenticate: Option<AuthenticateError>,
/// The http response's body. /// The http response's body.
pub body: ErrorBody, pub body: ErrorBody,
} }
@ -316,12 +334,7 @@ impl Error {
/// ///
/// This is equivalent to calling `body.into_error(status_code)`. /// This is equivalent to calling `body.into_error(status_code)`.
pub fn new(status_code: http::StatusCode, body: ErrorBody) -> Self { pub fn new(status_code: http::StatusCode, body: ErrorBody) -> Self {
Self { Self { status_code, body }
status_code,
#[cfg(feature = "unstable-msc2967")]
authenticate: None,
body,
}
} }
/// If `self` is a server error in the `errcode` + `error` format expected /// If `self` is a server error in the `errcode` + `error` format expected
@ -335,16 +348,24 @@ impl EndpointError for Error {
fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self { fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
let status = response.status(); let status = response.status();
#[cfg(feature = "unstable-msc2967")]
let authenticate = response
.headers()
.get(http::header::WWW_AUTHENTICATE)
.and_then(|val| val.to_str().ok())
.and_then(AuthenticateError::from_str);
let body_bytes = &response.body().as_ref(); let body_bytes = &response.body().as_ref();
let error_body: ErrorBody = match from_json_slice(body_bytes) { let error_body: ErrorBody = match from_json_slice(body_bytes) {
Ok(StandardErrorBody { kind, message }) => ErrorBody::Standard { kind, message }, Ok(StandardErrorBody { kind, message }) => {
#[cfg(feature = "unstable-msc2967")]
let kind = if let ErrorKind::Forbidden { .. } = kind {
let authenticate = response
.headers()
.get(http::header::WWW_AUTHENTICATE)
.and_then(|val| val.to_str().ok())
.and_then(AuthenticateError::from_str);
ErrorKind::Forbidden { authenticate }
} else {
kind
};
ErrorBody::Standard { kind, message }
}
Err(_) => match MatrixErrorBody::from_bytes(body_bytes) { Err(_) => match MatrixErrorBody::from_bytes(body_bytes) {
MatrixErrorBody::Json(json) => ErrorBody::Json(json), MatrixErrorBody::Json(json) => ErrorBody::Json(json),
MatrixErrorBody::NotJson { bytes, deserialization_error, .. } => { MatrixErrorBody::NotJson { bytes, deserialization_error, .. } => {
@ -353,13 +374,7 @@ impl EndpointError for Error {
}, },
}; };
let error = error_body.into_error(status); error_body.into_error(status)
#[cfg(not(feature = "unstable-msc2967"))]
return error;
#[cfg(feature = "unstable-msc2967")]
Self { authenticate, ..error }
} }
} }
@ -383,12 +398,7 @@ impl ErrorBody {
/// ///
/// This is equivalent to calling `Error::new(status_code, self)`. /// This is equivalent to calling `Error::new(status_code, self)`.
pub fn into_error(self, status_code: http::StatusCode) -> Error { pub fn into_error(self, status_code: http::StatusCode) -> Error {
Error { Error { status_code, body: self }
status_code,
#[cfg(feature = "unstable-msc2967")]
authenticate: None,
body: self,
}
} }
} }
@ -401,7 +411,11 @@ impl OutgoingResponse for Error {
.status(self.status_code); .status(self.status_code);
#[cfg(feature = "unstable-msc2967")] #[cfg(feature = "unstable-msc2967")]
let builder = if let Some(auth_error) = &self.authenticate { let builder = if let ErrorBody::Standard {
kind: ErrorKind::Forbidden { authenticate: Some(auth_error) },
..
} = &self.body
{
builder.header(http::header::WWW_AUTHENTICATE, auth_error) builder.header(http::header::WWW_AUTHENTICATE, auth_error)
} else { } else {
builder builder
@ -546,7 +560,13 @@ mod tests {
})) }))
.unwrap(); .unwrap();
assert_eq!(deserialized.kind, ErrorKind::Forbidden); assert_eq!(
deserialized.kind,
ErrorKind::Forbidden {
#[cfg(feature = "unstable-msc2967")]
authenticate: None
}
);
assert_eq!(deserialized.message, "You are not authorized to ban users in this room."); assert_eq!(deserialized.message, "You are not authorized to ban users in this room.");
} }
@ -617,9 +637,9 @@ mod tests {
assert_eq!(error.status_code, http::StatusCode::UNAUTHORIZED); assert_eq!(error.status_code, http::StatusCode::UNAUTHORIZED);
assert_matches!(error.body, ErrorBody::Standard { kind, message }); assert_matches!(error.body, ErrorBody::Standard { kind, message });
assert_eq!(kind, ErrorKind::Forbidden); assert_matches!(kind, ErrorKind::Forbidden { authenticate });
assert_eq!(message, "Insufficient privilege"); assert_eq!(message, "Insufficient privilege");
assert_matches!(error.authenticate, Some(AuthenticateError::InsufficientScope { scope })); assert_matches!(authenticate, Some(AuthenticateError::InsufficientScope { scope }));
assert_eq!(scope, "something_privileged"); assert_eq!(scope, "something_privileged");
} }
} }

View File

@ -165,7 +165,7 @@ impl<'de> Visitor<'de> for ErrorKindVisitor {
let extra = Extra(extra); let extra = Extra(extra);
Ok(match errcode { Ok(match errcode {
ErrCode::Forbidden => ErrorKind::Forbidden, ErrCode::Forbidden => ErrorKind::forbidden(),
ErrCode::UnknownToken => ErrorKind::UnknownToken { ErrCode::UnknownToken => ErrorKind::UnknownToken {
soft_logout: soft_logout soft_logout: soft_logout
.map(from_json_value) .map(from_json_value)
@ -361,7 +361,13 @@ mod tests {
#[test] #[test]
fn deserialize_forbidden() { fn deserialize_forbidden() {
let deserialized: ErrorKind = from_json_value(json!({ "errcode": "M_FORBIDDEN" })).unwrap(); let deserialized: ErrorKind = from_json_value(json!({ "errcode": "M_FORBIDDEN" })).unwrap();
assert_eq!(deserialized, ErrorKind::Forbidden); assert_eq!(
deserialized,
ErrorKind::Forbidden {
#[cfg(feature = "unstable-msc2967")]
authenticate: None
}
);
} }
#[test] #[test]
@ -372,7 +378,13 @@ mod tests {
})) }))
.unwrap(); .unwrap();
assert_eq!(deserialized, ErrorKind::Forbidden); assert_eq!(
deserialized,
ErrorKind::Forbidden {
#[cfg(feature = "unstable-msc2967")]
authenticate: None
}
);
} }
#[test] #[test]

View File

@ -125,7 +125,7 @@ fn deserialize_uiaa_info() {
assert_eq!(info.flows[1].stages, vec![AuthType::EmailIdentity, AuthType::Msisdn]); assert_eq!(info.flows[1].stages, vec![AuthType::EmailIdentity, AuthType::Msisdn]);
assert_eq!(info.session.as_deref(), Some("xxxxxx")); assert_eq!(info.session.as_deref(), Some("xxxxxx"));
let auth_error = info.auth_error.unwrap(); let auth_error = info.auth_error.unwrap();
assert_eq!(auth_error.kind, ErrorKind::Forbidden); assert_matches!(auth_error.kind, ErrorKind::Forbidden { .. });
assert_eq!(auth_error.message, "Invalid password"); assert_eq!(auth_error.message, "Invalid password");
assert_eq!( assert_eq!(
from_json_str::<JsonValue>(info.params.get()).unwrap(), from_json_str::<JsonValue>(info.params.get()).unwrap(),
@ -207,7 +207,7 @@ fn try_uiaa_response_from_http_response() {
assert_eq!(info.flows[1].stages, vec![AuthType::EmailIdentity, AuthType::Msisdn]); assert_eq!(info.flows[1].stages, vec![AuthType::EmailIdentity, AuthType::Msisdn]);
assert_eq!(info.session.as_deref(), Some("xxxxxx")); assert_eq!(info.session.as_deref(), Some("xxxxxx"));
let auth_error = info.auth_error.unwrap(); let auth_error = info.auth_error.unwrap();
assert_eq!(auth_error.kind, ErrorKind::Forbidden); assert_matches!(auth_error.kind, ErrorKind::Forbidden { .. });
assert_eq!(auth_error.message, "Invalid password"); assert_eq!(auth_error.message, "Invalid password");
assert_eq!( assert_eq!(
from_json_str::<JsonValue>(info.params.get()).unwrap(), from_json_str::<JsonValue>(info.params.get()).unwrap(),