client-api: Use an enum for user-interactive auth stage type

This commit is contained in:
Jonas Platte 2021-09-23 18:04:30 +02:00
parent 36462694e6
commit a7be60d9eb
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 86 additions and 53 deletions

View File

@ -11,7 +11,7 @@ use ruma_api::{
}; };
use ruma_common::thirdparty::Medium; use ruma_common::thirdparty::Medium;
use ruma_identifiers::{ClientSecret, SessionId}; use ruma_identifiers::{ClientSecret, SessionId};
use ruma_serde::Outgoing; use ruma_serde::{Outgoing, StringEnum};
use serde::{ use serde::{
de::{self, DeserializeOwned}, de::{self, DeserializeOwned},
Deserialize, Deserializer, Serialize, Deserialize, Deserializer, Serialize,
@ -25,7 +25,7 @@ use crate::error::{Error as MatrixError, ErrorBody};
pub mod authorize_fallback; pub mod authorize_fallback;
mod user_serde; mod user_serde;
/// Additional authentication information for the user-interactive authentication API. /// Information for one authentication stage.
#[derive(Clone, Debug, Outgoing, Serialize)] #[derive(Clone, Debug, Outgoing, Serialize)]
#[non_exhaustive] #[non_exhaustive]
#[incoming_derive(!Deserialize)] #[incoming_derive(!Deserialize)]
@ -52,7 +52,7 @@ pub enum AuthData<'a> {
/// Dummy authentication (`m.login.dummy`). /// Dummy authentication (`m.login.dummy`).
Dummy(Dummy<'a>), Dummy(Dummy<'a>),
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`) /// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`).
#[cfg(feature = "unstable-pre-spec")] #[cfg(feature = "unstable-pre-spec")]
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))] #[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))]
RegistrationToken(RegistrationToken<'a>), RegistrationToken(RegistrationToken<'a>),
@ -71,19 +71,19 @@ impl<'a> AuthData<'a> {
} }
/// Returns the value of the `type` field, if it exists. /// Returns the value of the `type` field, if it exists.
pub fn auth_type(&self) -> Option<&'a str> { pub fn auth_type(&self) -> Option<AuthType> {
match self { match self {
Self::Password(_) => Some("m.login.password"), Self::Password(_) => Some(AuthType::Password),
Self::ReCaptcha(_) => Some("m.login.recaptcha"), Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
Self::Token(_) => Some("m.login.token"), Self::Token(_) => Some(AuthType::Token),
Self::OAuth2(_) => Some("m.login.oauth2"), Self::OAuth2(_) => Some(AuthType::OAuth2),
Self::EmailIdentity(_) => Some("m.login.email.identity"), Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
Self::Msisdn(_) => Some("m.login.msisdn"), Self::Msisdn(_) => Some(AuthType::Msisdn),
Self::Dummy(_) => Some("m.login.dummy"), Self::Dummy(_) => Some(AuthType::Dummy),
#[cfg(feature = "unstable-pre-spec")] #[cfg(feature = "unstable-pre-spec")]
Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"), Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
Self::FallbackAcknowledgement(_) => None, Self::FallbackAcknowledgement(_) => None,
Self::_Custom(c) => Some(c.auth_type), Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.to_owned())),
} }
} }
@ -107,19 +107,19 @@ impl<'a> AuthData<'a> {
impl IncomingAuthData { impl IncomingAuthData {
/// Returns the value of the `type` field, if it exists. /// Returns the value of the `type` field, if it exists.
pub fn auth_type(&self) -> Option<&str> { pub fn auth_type(&self) -> Option<AuthType> {
match self { match self {
Self::Password(_) => Some("m.login.password"), Self::Password(_) => Some(AuthType::Password),
Self::ReCaptcha(_) => Some("m.login.recaptcha"), Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
Self::Token(_) => Some("m.login.token"), Self::Token(_) => Some(AuthType::Token),
Self::OAuth2(_) => Some("m.login.oauth2"), Self::OAuth2(_) => Some(AuthType::OAuth2),
Self::EmailIdentity(_) => Some("m.login.email.identity"), Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
Self::Msisdn(_) => Some("m.login.msisdn"), Self::Msisdn(_) => Some(AuthType::Msisdn),
Self::Dummy(_) => Some("m.login.dummy"), Self::Dummy(_) => Some(AuthType::Dummy),
#[cfg(feature = "unstable-pre-spec")] #[cfg(feature = "unstable-pre-spec")]
Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"), Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
Self::FallbackAcknowledgement(_) => None, Self::FallbackAcknowledgement(_) => None,
Self::_Custom(c) => Some(&c.auth_type), Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.clone())),
} }
} }
@ -182,6 +182,48 @@ impl<'de> Deserialize<'de> for IncomingAuthData {
} }
} }
/// The type of an authentication stage.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[non_exhaustive]
pub enum AuthType {
/// Password-based authentication (`m.login.password`).
#[ruma_enum(rename = "m.login.password")]
Password,
/// Google ReCaptcha 2.0 authentication (`m.login.recaptcha`).
#[ruma_enum(rename = "m.login.recaptcha")]
ReCaptcha,
/// Token-based authentication (`m.login.token`).
#[ruma_enum(rename = "m.login.token")]
Token,
/// OAuth2-based authentication (`m.login.oauth2`).
#[ruma_enum(rename = "m.login.oauth2")]
OAuth2,
/// Email-based authentication (`m.login.email.identity`).
#[ruma_enum(rename = "m.login.email.identity")]
EmailIdentity,
/// Phone number-based authentication (`m.login.msisdn`).
#[ruma_enum(rename = "m.login.msisdn")]
Msisdn,
/// Dummy authentication (`m.login.dummy`).
#[ruma_enum(rename = "m.login.dummy")]
Dummy,
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`).
#[cfg(feature = "unstable-pre-spec")]
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))]
#[ruma_enum(rename = "org.matrix.msc3231.login.registration_token")]
RegistrationToken,
#[doc(hidden)]
_Custom(String),
}
/// Data for password-based UIAA flow. /// Data for password-based UIAA flow.
/// ///
/// See [the spec] for how to use this. /// See [the spec] for how to use this.
@ -475,7 +517,7 @@ pub struct UiaaInfo {
/// List of stages in the current flow completed by the client. /// List of stages in the current flow completed by the client.
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub completed: Vec<String>, pub completed: Vec<AuthType>,
/// Authentication parameters required for the client to complete authentication. /// Authentication parameters required for the client to complete authentication.
/// ///
@ -504,14 +546,14 @@ impl UiaaInfo {
pub struct AuthFlow { pub struct AuthFlow {
/// Ordered list of stages required to complete authentication. /// Ordered list of stages required to complete authentication.
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stages: Vec<String>, pub stages: Vec<AuthType>,
} }
impl AuthFlow { impl AuthFlow {
/// Creates a new `AuthFlow` with the given stages. /// Creates a new `AuthFlow` with the given stages.
/// ///
/// To create an empty `AuthFlow`, use `AuthFlow::default()`. /// To create an empty `AuthFlow`, use `AuthFlow::default()`.
pub fn new(stages: Vec<String>) -> Self { pub fn new(stages: Vec<AuthType>) -> Self {
Self { stages } Self { stages }
} }
} }

View File

@ -9,7 +9,8 @@ use serde_json::{
use ruma_client_api::{ use ruma_client_api::{
error::{ErrorBody, ErrorKind}, error::{ErrorBody, ErrorKind},
r0::uiaa::{ r0::uiaa::{
self, AuthData, AuthFlow, IncomingAuthData, IncomingUserIdentifier, UiaaInfo, UiaaResponse, self, AuthData, AuthFlow, AuthType, IncomingAuthData, IncomingUserIdentifier, UiaaInfo,
UiaaResponse,
}, },
}; };
@ -148,13 +149,13 @@ fn deserialize_uiaa_info() {
let json = json!({ let json = json!({
"errcode": "M_FORBIDDEN", "errcode": "M_FORBIDDEN",
"error": "Invalid password", "error": "Invalid password",
"completed": ["example.type.foo"], "completed": ["m.login.recaptcha"],
"flows": [ "flows": [
{ {
"stages": ["example.type.foo", "example.type.bar"] "stages": ["m.login.password"]
}, },
{ {
"stages": ["example.type.foo", "example.type.baz"] "stages": ["m.login.email.identity", "m.login.msisdn"]
} }
], ],
"params": { "params": {
@ -178,17 +179,12 @@ fn deserialize_uiaa_info() {
session: Some(session), session: Some(session),
.. ..
} if error_message == "Invalid password" } if error_message == "Invalid password"
&& completed == vec!["example.type.foo".to_owned()] && completed == vec![AuthType::ReCaptcha]
&& matches!( && matches!(
flows.as_slice(), flows.as_slice(),
[f1, f2] [f1, f2]
if f1.stages == vec![ if f1.stages == vec![AuthType::Password]
"example.type.foo".to_owned(), && f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn]
"example.type.bar".to_owned()
] && f2.stages == vec![
"example.type.foo".to_owned(),
"example.type.baz".to_owned()
]
) )
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({ && from_json_str::<JsonValue>(params.get()).unwrap() == json!({
"example.type.baz": { "example.type.baz": {
@ -201,7 +197,7 @@ fn deserialize_uiaa_info() {
#[test] #[test]
fn try_uiaa_response_into_http_response() { fn try_uiaa_response_into_http_response() {
let flows = vec![AuthFlow::new(vec!["m.login.password".into(), "m.login.dummy".into()])]; let flows = vec![AuthFlow::new(vec![AuthType::Password, AuthType::Dummy])];
let params = to_raw_json_value(&json!({ let params = to_raw_json_value(&json!({
"example.type.baz": { "example.type.baz": {
"example_key": "foobar" "example_key": "foobar"
@ -209,7 +205,7 @@ fn try_uiaa_response_into_http_response() {
})) }))
.unwrap(); .unwrap();
let uiaa_info = assign!(UiaaInfo::new(flows, params), { let uiaa_info = assign!(UiaaInfo::new(flows, params), {
completed: vec!["m.login.password".into()], completed: vec![AuthType::ReCaptcha],
}); });
let uiaa_response = let uiaa_response =
UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::<Vec<u8>>().unwrap(); UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::<Vec<u8>>().unwrap();
@ -225,8 +221,8 @@ fn try_uiaa_response_into_http_response() {
.. ..
} if matches!( } if matches!(
flows.as_slice(), flows.as_slice(),
[flow] if flow.stages == vec!["m.login.password".to_owned(), "m.login.dummy".to_owned()] [flow] if flow.stages == vec![AuthType::Password, AuthType::Dummy]
) && completed == vec!["m.login.password".to_owned()] ) && completed == vec![AuthType::ReCaptcha]
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({ && from_json_str::<JsonValue>(params.get()).unwrap() == json!({
"example.type.baz": { "example.type.baz": {
"example_key": "foobar" "example_key": "foobar"
@ -241,13 +237,13 @@ fn try_uiaa_response_from_http_response() {
let json = serde_json::to_string(&json!({ let json = serde_json::to_string(&json!({
"errcode": "M_FORBIDDEN", "errcode": "M_FORBIDDEN",
"error": "Invalid password", "error": "Invalid password",
"completed": [ "example.type.foo" ], "completed": ["m.login.recaptcha"],
"flows": [ "flows": [
{ {
"stages": [ "example.type.foo", "example.type.bar" ] "stages": ["m.login.password"]
}, },
{ {
"stages": [ "example.type.foo", "example.type.baz" ] "stages": ["m.login.email.identity", "m.login.msisdn"]
} }
], ],
"params": { "params": {
@ -282,17 +278,12 @@ fn try_uiaa_response_from_http_response() {
session: Some(session), session: Some(session),
.. ..
} if error_message == "Invalid password" } if error_message == "Invalid password"
&& completed == vec!["example.type.foo".to_owned()] && completed == vec![AuthType::ReCaptcha]
&& matches!( && matches!(
flows.as_slice(), flows.as_slice(),
[f1, f2] [f1, f2]
if f1.stages == vec![ if f1.stages == vec![AuthType::Password]
"example.type.foo".to_owned(), && f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn]
"example.type.bar".to_owned()
] && f2.stages == vec![
"example.type.foo".to_owned(),
"example.type.baz".to_owned()
]
) )
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({ && from_json_str::<JsonValue>(params.get()).unwrap() == json!({
"example.type.baz": { "example.type.baz": {