client-api: Use an enum for user-interactive auth stage type

This commit is contained in:
Jonas Platte 2021-09-23 18:04:30 +02:00
parent 36462694e6
commit a7be60d9eb
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 86 additions and 53 deletions

View File

@ -11,7 +11,7 @@ use ruma_api::{
};
use ruma_common::thirdparty::Medium;
use ruma_identifiers::{ClientSecret, SessionId};
use ruma_serde::Outgoing;
use ruma_serde::{Outgoing, StringEnum};
use serde::{
de::{self, DeserializeOwned},
Deserialize, Deserializer, Serialize,
@ -25,7 +25,7 @@ use crate::error::{Error as MatrixError, ErrorBody};
pub mod authorize_fallback;
mod user_serde;
/// Additional authentication information for the user-interactive authentication API.
/// Information for one authentication stage.
#[derive(Clone, Debug, Outgoing, Serialize)]
#[non_exhaustive]
#[incoming_derive(!Deserialize)]
@ -52,7 +52,7 @@ pub enum AuthData<'a> {
/// Dummy authentication (`m.login.dummy`).
Dummy(Dummy<'a>),
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`)
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`).
#[cfg(feature = "unstable-pre-spec")]
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))]
RegistrationToken(RegistrationToken<'a>),
@ -71,19 +71,19 @@ impl<'a> AuthData<'a> {
}
/// Returns the value of the `type` field, if it exists.
pub fn auth_type(&self) -> Option<&'a str> {
pub fn auth_type(&self) -> Option<AuthType> {
match self {
Self::Password(_) => Some("m.login.password"),
Self::ReCaptcha(_) => Some("m.login.recaptcha"),
Self::Token(_) => Some("m.login.token"),
Self::OAuth2(_) => Some("m.login.oauth2"),
Self::EmailIdentity(_) => Some("m.login.email.identity"),
Self::Msisdn(_) => Some("m.login.msisdn"),
Self::Dummy(_) => Some("m.login.dummy"),
Self::Password(_) => Some(AuthType::Password),
Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
Self::Token(_) => Some(AuthType::Token),
Self::OAuth2(_) => Some(AuthType::OAuth2),
Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
Self::Msisdn(_) => Some(AuthType::Msisdn),
Self::Dummy(_) => Some(AuthType::Dummy),
#[cfg(feature = "unstable-pre-spec")]
Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"),
Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
Self::FallbackAcknowledgement(_) => None,
Self::_Custom(c) => Some(c.auth_type),
Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.to_owned())),
}
}
@ -107,19 +107,19 @@ impl<'a> AuthData<'a> {
impl IncomingAuthData {
/// Returns the value of the `type` field, if it exists.
pub fn auth_type(&self) -> Option<&str> {
pub fn auth_type(&self) -> Option<AuthType> {
match self {
Self::Password(_) => Some("m.login.password"),
Self::ReCaptcha(_) => Some("m.login.recaptcha"),
Self::Token(_) => Some("m.login.token"),
Self::OAuth2(_) => Some("m.login.oauth2"),
Self::EmailIdentity(_) => Some("m.login.email.identity"),
Self::Msisdn(_) => Some("m.login.msisdn"),
Self::Dummy(_) => Some("m.login.dummy"),
Self::Password(_) => Some(AuthType::Password),
Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
Self::Token(_) => Some(AuthType::Token),
Self::OAuth2(_) => Some(AuthType::OAuth2),
Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
Self::Msisdn(_) => Some(AuthType::Msisdn),
Self::Dummy(_) => Some(AuthType::Dummy),
#[cfg(feature = "unstable-pre-spec")]
Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"),
Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
Self::FallbackAcknowledgement(_) => None,
Self::_Custom(c) => Some(&c.auth_type),
Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.clone())),
}
}
@ -182,6 +182,48 @@ impl<'de> Deserialize<'de> for IncomingAuthData {
}
}
/// The type of an authentication stage.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[non_exhaustive]
pub enum AuthType {
/// Password-based authentication (`m.login.password`).
#[ruma_enum(rename = "m.login.password")]
Password,
/// Google ReCaptcha 2.0 authentication (`m.login.recaptcha`).
#[ruma_enum(rename = "m.login.recaptcha")]
ReCaptcha,
/// Token-based authentication (`m.login.token`).
#[ruma_enum(rename = "m.login.token")]
Token,
/// OAuth2-based authentication (`m.login.oauth2`).
#[ruma_enum(rename = "m.login.oauth2")]
OAuth2,
/// Email-based authentication (`m.login.email.identity`).
#[ruma_enum(rename = "m.login.email.identity")]
EmailIdentity,
/// Phone number-based authentication (`m.login.msisdn`).
#[ruma_enum(rename = "m.login.msisdn")]
Msisdn,
/// Dummy authentication (`m.login.dummy`).
#[ruma_enum(rename = "m.login.dummy")]
Dummy,
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`).
#[cfg(feature = "unstable-pre-spec")]
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))]
#[ruma_enum(rename = "org.matrix.msc3231.login.registration_token")]
RegistrationToken,
#[doc(hidden)]
_Custom(String),
}
/// Data for password-based UIAA flow.
///
/// See [the spec] for how to use this.
@ -475,7 +517,7 @@ pub struct UiaaInfo {
/// List of stages in the current flow completed by the client.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub completed: Vec<String>,
pub completed: Vec<AuthType>,
/// Authentication parameters required for the client to complete authentication.
///
@ -504,14 +546,14 @@ impl UiaaInfo {
pub struct AuthFlow {
/// Ordered list of stages required to complete authentication.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stages: Vec<String>,
pub stages: Vec<AuthType>,
}
impl AuthFlow {
/// Creates a new `AuthFlow` with the given stages.
///
/// To create an empty `AuthFlow`, use `AuthFlow::default()`.
pub fn new(stages: Vec<String>) -> Self {
pub fn new(stages: Vec<AuthType>) -> Self {
Self { stages }
}
}

View File

@ -9,7 +9,8 @@ use serde_json::{
use ruma_client_api::{
error::{ErrorBody, ErrorKind},
r0::uiaa::{
self, AuthData, AuthFlow, IncomingAuthData, IncomingUserIdentifier, UiaaInfo, UiaaResponse,
self, AuthData, AuthFlow, AuthType, IncomingAuthData, IncomingUserIdentifier, UiaaInfo,
UiaaResponse,
},
};
@ -148,13 +149,13 @@ fn deserialize_uiaa_info() {
let json = json!({
"errcode": "M_FORBIDDEN",
"error": "Invalid password",
"completed": ["example.type.foo"],
"completed": ["m.login.recaptcha"],
"flows": [
{
"stages": ["example.type.foo", "example.type.bar"]
"stages": ["m.login.password"]
},
{
"stages": ["example.type.foo", "example.type.baz"]
"stages": ["m.login.email.identity", "m.login.msisdn"]
}
],
"params": {
@ -178,17 +179,12 @@ fn deserialize_uiaa_info() {
session: Some(session),
..
} if error_message == "Invalid password"
&& completed == vec!["example.type.foo".to_owned()]
&& completed == vec![AuthType::ReCaptcha]
&& matches!(
flows.as_slice(),
[f1, f2]
if f1.stages == vec![
"example.type.foo".to_owned(),
"example.type.bar".to_owned()
] && f2.stages == vec![
"example.type.foo".to_owned(),
"example.type.baz".to_owned()
]
if f1.stages == vec![AuthType::Password]
&& f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn]
)
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
"example.type.baz": {
@ -201,7 +197,7 @@ fn deserialize_uiaa_info() {
#[test]
fn try_uiaa_response_into_http_response() {
let flows = vec![AuthFlow::new(vec!["m.login.password".into(), "m.login.dummy".into()])];
let flows = vec![AuthFlow::new(vec![AuthType::Password, AuthType::Dummy])];
let params = to_raw_json_value(&json!({
"example.type.baz": {
"example_key": "foobar"
@ -209,7 +205,7 @@ fn try_uiaa_response_into_http_response() {
}))
.unwrap();
let uiaa_info = assign!(UiaaInfo::new(flows, params), {
completed: vec!["m.login.password".into()],
completed: vec![AuthType::ReCaptcha],
});
let uiaa_response =
UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::<Vec<u8>>().unwrap();
@ -225,8 +221,8 @@ fn try_uiaa_response_into_http_response() {
..
} if matches!(
flows.as_slice(),
[flow] if flow.stages == vec!["m.login.password".to_owned(), "m.login.dummy".to_owned()]
) && completed == vec!["m.login.password".to_owned()]
[flow] if flow.stages == vec![AuthType::Password, AuthType::Dummy]
) && completed == vec![AuthType::ReCaptcha]
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
"example.type.baz": {
"example_key": "foobar"
@ -241,13 +237,13 @@ fn try_uiaa_response_from_http_response() {
let json = serde_json::to_string(&json!({
"errcode": "M_FORBIDDEN",
"error": "Invalid password",
"completed": [ "example.type.foo" ],
"completed": ["m.login.recaptcha"],
"flows": [
{
"stages": [ "example.type.foo", "example.type.bar" ]
"stages": ["m.login.password"]
},
{
"stages": [ "example.type.foo", "example.type.baz" ]
"stages": ["m.login.email.identity", "m.login.msisdn"]
}
],
"params": {
@ -282,17 +278,12 @@ fn try_uiaa_response_from_http_response() {
session: Some(session),
..
} if error_message == "Invalid password"
&& completed == vec!["example.type.foo".to_owned()]
&& completed == vec![AuthType::ReCaptcha]
&& matches!(
flows.as_slice(),
[f1, f2]
if f1.stages == vec![
"example.type.foo".to_owned(),
"example.type.bar".to_owned()
] && f2.stages == vec![
"example.type.foo".to_owned(),
"example.type.baz".to_owned()
]
if f1.stages == vec![AuthType::Password]
&& f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn]
)
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
"example.type.baz": {