From a7be60d9ebeadff934c62945c5ba251d523ecca1 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 23 Sep 2021 18:04:30 +0200 Subject: [PATCH] client-api: Use an enum for user-interactive auth stage type --- crates/ruma-client-api/src/r0/uiaa.rs | 94 +++++++++++++++++++-------- crates/ruma-client-api/tests/uiaa.rs | 45 +++++-------- 2 files changed, 86 insertions(+), 53 deletions(-) diff --git a/crates/ruma-client-api/src/r0/uiaa.rs b/crates/ruma-client-api/src/r0/uiaa.rs index 215d5f56..bb0783fa 100644 --- a/crates/ruma-client-api/src/r0/uiaa.rs +++ b/crates/ruma-client-api/src/r0/uiaa.rs @@ -11,7 +11,7 @@ use ruma_api::{ }; use ruma_common::thirdparty::Medium; use ruma_identifiers::{ClientSecret, SessionId}; -use ruma_serde::Outgoing; +use ruma_serde::{Outgoing, StringEnum}; use serde::{ de::{self, DeserializeOwned}, Deserialize, Deserializer, Serialize, @@ -25,7 +25,7 @@ use crate::error::{Error as MatrixError, ErrorBody}; pub mod authorize_fallback; mod user_serde; -/// Additional authentication information for the user-interactive authentication API. +/// Information for one authentication stage. #[derive(Clone, Debug, Outgoing, Serialize)] #[non_exhaustive] #[incoming_derive(!Deserialize)] @@ -52,7 +52,7 @@ pub enum AuthData<'a> { /// Dummy authentication (`m.login.dummy`). Dummy(Dummy<'a>), - /// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`) + /// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`). #[cfg(feature = "unstable-pre-spec")] #[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))] RegistrationToken(RegistrationToken<'a>), @@ -71,19 +71,19 @@ impl<'a> AuthData<'a> { } /// Returns the value of the `type` field, if it exists. - pub fn auth_type(&self) -> Option<&'a str> { + pub fn auth_type(&self) -> Option { match self { - Self::Password(_) => Some("m.login.password"), - Self::ReCaptcha(_) => Some("m.login.recaptcha"), - Self::Token(_) => Some("m.login.token"), - Self::OAuth2(_) => Some("m.login.oauth2"), - Self::EmailIdentity(_) => Some("m.login.email.identity"), - Self::Msisdn(_) => Some("m.login.msisdn"), - Self::Dummy(_) => Some("m.login.dummy"), + Self::Password(_) => Some(AuthType::Password), + Self::ReCaptcha(_) => Some(AuthType::ReCaptcha), + Self::Token(_) => Some(AuthType::Token), + Self::OAuth2(_) => Some(AuthType::OAuth2), + Self::EmailIdentity(_) => Some(AuthType::EmailIdentity), + Self::Msisdn(_) => Some(AuthType::Msisdn), + Self::Dummy(_) => Some(AuthType::Dummy), #[cfg(feature = "unstable-pre-spec")] - Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"), + Self::RegistrationToken(_) => Some(AuthType::RegistrationToken), Self::FallbackAcknowledgement(_) => None, - Self::_Custom(c) => Some(c.auth_type), + Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.to_owned())), } } @@ -107,19 +107,19 @@ impl<'a> AuthData<'a> { impl IncomingAuthData { /// Returns the value of the `type` field, if it exists. - pub fn auth_type(&self) -> Option<&str> { + pub fn auth_type(&self) -> Option { match self { - Self::Password(_) => Some("m.login.password"), - Self::ReCaptcha(_) => Some("m.login.recaptcha"), - Self::Token(_) => Some("m.login.token"), - Self::OAuth2(_) => Some("m.login.oauth2"), - Self::EmailIdentity(_) => Some("m.login.email.identity"), - Self::Msisdn(_) => Some("m.login.msisdn"), - Self::Dummy(_) => Some("m.login.dummy"), + Self::Password(_) => Some(AuthType::Password), + Self::ReCaptcha(_) => Some(AuthType::ReCaptcha), + Self::Token(_) => Some(AuthType::Token), + Self::OAuth2(_) => Some(AuthType::OAuth2), + Self::EmailIdentity(_) => Some(AuthType::EmailIdentity), + Self::Msisdn(_) => Some(AuthType::Msisdn), + Self::Dummy(_) => Some(AuthType::Dummy), #[cfg(feature = "unstable-pre-spec")] - Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"), + Self::RegistrationToken(_) => Some(AuthType::RegistrationToken), Self::FallbackAcknowledgement(_) => None, - Self::_Custom(c) => Some(&c.auth_type), + Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.clone())), } } @@ -182,6 +182,48 @@ impl<'de> Deserialize<'de> for IncomingAuthData { } } +/// The type of an authentication stage. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, StringEnum)] +#[non_exhaustive] +pub enum AuthType { + /// Password-based authentication (`m.login.password`). + #[ruma_enum(rename = "m.login.password")] + Password, + + /// Google ReCaptcha 2.0 authentication (`m.login.recaptcha`). + #[ruma_enum(rename = "m.login.recaptcha")] + ReCaptcha, + + /// Token-based authentication (`m.login.token`). + #[ruma_enum(rename = "m.login.token")] + Token, + + /// OAuth2-based authentication (`m.login.oauth2`). + #[ruma_enum(rename = "m.login.oauth2")] + OAuth2, + + /// Email-based authentication (`m.login.email.identity`). + #[ruma_enum(rename = "m.login.email.identity")] + EmailIdentity, + + /// Phone number-based authentication (`m.login.msisdn`). + #[ruma_enum(rename = "m.login.msisdn")] + Msisdn, + + /// Dummy authentication (`m.login.dummy`). + #[ruma_enum(rename = "m.login.dummy")] + Dummy, + + /// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`). + #[cfg(feature = "unstable-pre-spec")] + #[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))] + #[ruma_enum(rename = "org.matrix.msc3231.login.registration_token")] + RegistrationToken, + + #[doc(hidden)] + _Custom(String), +} + /// Data for password-based UIAA flow. /// /// See [the spec] for how to use this. @@ -475,7 +517,7 @@ pub struct UiaaInfo { /// List of stages in the current flow completed by the client. #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub completed: Vec, + pub completed: Vec, /// Authentication parameters required for the client to complete authentication. /// @@ -504,14 +546,14 @@ impl UiaaInfo { pub struct AuthFlow { /// Ordered list of stages required to complete authentication. #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub stages: Vec, + pub stages: Vec, } impl AuthFlow { /// Creates a new `AuthFlow` with the given stages. /// /// To create an empty `AuthFlow`, use `AuthFlow::default()`. - pub fn new(stages: Vec) -> Self { + pub fn new(stages: Vec) -> Self { Self { stages } } } diff --git a/crates/ruma-client-api/tests/uiaa.rs b/crates/ruma-client-api/tests/uiaa.rs index 147795eb..c77bb8ee 100644 --- a/crates/ruma-client-api/tests/uiaa.rs +++ b/crates/ruma-client-api/tests/uiaa.rs @@ -9,7 +9,8 @@ use serde_json::{ use ruma_client_api::{ error::{ErrorBody, ErrorKind}, r0::uiaa::{ - self, AuthData, AuthFlow, IncomingAuthData, IncomingUserIdentifier, UiaaInfo, UiaaResponse, + self, AuthData, AuthFlow, AuthType, IncomingAuthData, IncomingUserIdentifier, UiaaInfo, + UiaaResponse, }, }; @@ -148,13 +149,13 @@ fn deserialize_uiaa_info() { let json = json!({ "errcode": "M_FORBIDDEN", "error": "Invalid password", - "completed": ["example.type.foo"], + "completed": ["m.login.recaptcha"], "flows": [ { - "stages": ["example.type.foo", "example.type.bar"] + "stages": ["m.login.password"] }, { - "stages": ["example.type.foo", "example.type.baz"] + "stages": ["m.login.email.identity", "m.login.msisdn"] } ], "params": { @@ -178,17 +179,12 @@ fn deserialize_uiaa_info() { session: Some(session), .. } if error_message == "Invalid password" - && completed == vec!["example.type.foo".to_owned()] + && completed == vec![AuthType::ReCaptcha] && matches!( flows.as_slice(), [f1, f2] - if f1.stages == vec![ - "example.type.foo".to_owned(), - "example.type.bar".to_owned() - ] && f2.stages == vec![ - "example.type.foo".to_owned(), - "example.type.baz".to_owned() - ] + if f1.stages == vec![AuthType::Password] + && f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn] ) && from_json_str::(params.get()).unwrap() == json!({ "example.type.baz": { @@ -201,7 +197,7 @@ fn deserialize_uiaa_info() { #[test] fn try_uiaa_response_into_http_response() { - let flows = vec![AuthFlow::new(vec!["m.login.password".into(), "m.login.dummy".into()])]; + let flows = vec![AuthFlow::new(vec![AuthType::Password, AuthType::Dummy])]; let params = to_raw_json_value(&json!({ "example.type.baz": { "example_key": "foobar" @@ -209,7 +205,7 @@ fn try_uiaa_response_into_http_response() { })) .unwrap(); let uiaa_info = assign!(UiaaInfo::new(flows, params), { - completed: vec!["m.login.password".into()], + completed: vec![AuthType::ReCaptcha], }); let uiaa_response = UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::>().unwrap(); @@ -225,8 +221,8 @@ fn try_uiaa_response_into_http_response() { .. } if matches!( flows.as_slice(), - [flow] if flow.stages == vec!["m.login.password".to_owned(), "m.login.dummy".to_owned()] - ) && completed == vec!["m.login.password".to_owned()] + [flow] if flow.stages == vec![AuthType::Password, AuthType::Dummy] + ) && completed == vec![AuthType::ReCaptcha] && from_json_str::(params.get()).unwrap() == json!({ "example.type.baz": { "example_key": "foobar" @@ -241,13 +237,13 @@ fn try_uiaa_response_from_http_response() { let json = serde_json::to_string(&json!({ "errcode": "M_FORBIDDEN", "error": "Invalid password", - "completed": [ "example.type.foo" ], + "completed": ["m.login.recaptcha"], "flows": [ { - "stages": [ "example.type.foo", "example.type.bar" ] + "stages": ["m.login.password"] }, { - "stages": [ "example.type.foo", "example.type.baz" ] + "stages": ["m.login.email.identity", "m.login.msisdn"] } ], "params": { @@ -282,17 +278,12 @@ fn try_uiaa_response_from_http_response() { session: Some(session), .. } if error_message == "Invalid password" - && completed == vec!["example.type.foo".to_owned()] + && completed == vec![AuthType::ReCaptcha] && matches!( flows.as_slice(), [f1, f2] - if f1.stages == vec![ - "example.type.foo".to_owned(), - "example.type.bar".to_owned() - ] && f2.stages == vec![ - "example.type.foo".to_owned(), - "example.type.baz".to_owned() - ] + if f1.stages == vec![AuthType::Password] + && f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn] ) && from_json_str::(params.get()).unwrap() == json!({ "example.type.baz": {