client-api: Use an enum for user-interactive auth stage type
This commit is contained in:
parent
36462694e6
commit
a7be60d9eb
@ -11,7 +11,7 @@ use ruma_api::{
|
|||||||
};
|
};
|
||||||
use ruma_common::thirdparty::Medium;
|
use ruma_common::thirdparty::Medium;
|
||||||
use ruma_identifiers::{ClientSecret, SessionId};
|
use ruma_identifiers::{ClientSecret, SessionId};
|
||||||
use ruma_serde::Outgoing;
|
use ruma_serde::{Outgoing, StringEnum};
|
||||||
use serde::{
|
use serde::{
|
||||||
de::{self, DeserializeOwned},
|
de::{self, DeserializeOwned},
|
||||||
Deserialize, Deserializer, Serialize,
|
Deserialize, Deserializer, Serialize,
|
||||||
@ -25,7 +25,7 @@ use crate::error::{Error as MatrixError, ErrorBody};
|
|||||||
pub mod authorize_fallback;
|
pub mod authorize_fallback;
|
||||||
mod user_serde;
|
mod user_serde;
|
||||||
|
|
||||||
/// Additional authentication information for the user-interactive authentication API.
|
/// Information for one authentication stage.
|
||||||
#[derive(Clone, Debug, Outgoing, Serialize)]
|
#[derive(Clone, Debug, Outgoing, Serialize)]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
#[incoming_derive(!Deserialize)]
|
#[incoming_derive(!Deserialize)]
|
||||||
@ -52,7 +52,7 @@ pub enum AuthData<'a> {
|
|||||||
/// Dummy authentication (`m.login.dummy`).
|
/// Dummy authentication (`m.login.dummy`).
|
||||||
Dummy(Dummy<'a>),
|
Dummy(Dummy<'a>),
|
||||||
|
|
||||||
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`)
|
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`).
|
||||||
#[cfg(feature = "unstable-pre-spec")]
|
#[cfg(feature = "unstable-pre-spec")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))]
|
||||||
RegistrationToken(RegistrationToken<'a>),
|
RegistrationToken(RegistrationToken<'a>),
|
||||||
@ -71,19 +71,19 @@ impl<'a> AuthData<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the value of the `type` field, if it exists.
|
/// Returns the value of the `type` field, if it exists.
|
||||||
pub fn auth_type(&self) -> Option<&'a str> {
|
pub fn auth_type(&self) -> Option<AuthType> {
|
||||||
match self {
|
match self {
|
||||||
Self::Password(_) => Some("m.login.password"),
|
Self::Password(_) => Some(AuthType::Password),
|
||||||
Self::ReCaptcha(_) => Some("m.login.recaptcha"),
|
Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
|
||||||
Self::Token(_) => Some("m.login.token"),
|
Self::Token(_) => Some(AuthType::Token),
|
||||||
Self::OAuth2(_) => Some("m.login.oauth2"),
|
Self::OAuth2(_) => Some(AuthType::OAuth2),
|
||||||
Self::EmailIdentity(_) => Some("m.login.email.identity"),
|
Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
|
||||||
Self::Msisdn(_) => Some("m.login.msisdn"),
|
Self::Msisdn(_) => Some(AuthType::Msisdn),
|
||||||
Self::Dummy(_) => Some("m.login.dummy"),
|
Self::Dummy(_) => Some(AuthType::Dummy),
|
||||||
#[cfg(feature = "unstable-pre-spec")]
|
#[cfg(feature = "unstable-pre-spec")]
|
||||||
Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"),
|
Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
|
||||||
Self::FallbackAcknowledgement(_) => None,
|
Self::FallbackAcknowledgement(_) => None,
|
||||||
Self::_Custom(c) => Some(c.auth_type),
|
Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.to_owned())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,19 +107,19 @@ impl<'a> AuthData<'a> {
|
|||||||
|
|
||||||
impl IncomingAuthData {
|
impl IncomingAuthData {
|
||||||
/// Returns the value of the `type` field, if it exists.
|
/// Returns the value of the `type` field, if it exists.
|
||||||
pub fn auth_type(&self) -> Option<&str> {
|
pub fn auth_type(&self) -> Option<AuthType> {
|
||||||
match self {
|
match self {
|
||||||
Self::Password(_) => Some("m.login.password"),
|
Self::Password(_) => Some(AuthType::Password),
|
||||||
Self::ReCaptcha(_) => Some("m.login.recaptcha"),
|
Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
|
||||||
Self::Token(_) => Some("m.login.token"),
|
Self::Token(_) => Some(AuthType::Token),
|
||||||
Self::OAuth2(_) => Some("m.login.oauth2"),
|
Self::OAuth2(_) => Some(AuthType::OAuth2),
|
||||||
Self::EmailIdentity(_) => Some("m.login.email.identity"),
|
Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
|
||||||
Self::Msisdn(_) => Some("m.login.msisdn"),
|
Self::Msisdn(_) => Some(AuthType::Msisdn),
|
||||||
Self::Dummy(_) => Some("m.login.dummy"),
|
Self::Dummy(_) => Some(AuthType::Dummy),
|
||||||
#[cfg(feature = "unstable-pre-spec")]
|
#[cfg(feature = "unstable-pre-spec")]
|
||||||
Self::RegistrationToken(_) => Some("org.matrix.msc3231.login.registration_token"),
|
Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
|
||||||
Self::FallbackAcknowledgement(_) => None,
|
Self::FallbackAcknowledgement(_) => None,
|
||||||
Self::_Custom(c) => Some(&c.auth_type),
|
Self::_Custom(c) => Some(AuthType::_Custom(c.auth_type.clone())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,6 +182,48 @@ impl<'de> Deserialize<'de> for IncomingAuthData {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The type of an authentication stage.
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum AuthType {
|
||||||
|
/// Password-based authentication (`m.login.password`).
|
||||||
|
#[ruma_enum(rename = "m.login.password")]
|
||||||
|
Password,
|
||||||
|
|
||||||
|
/// Google ReCaptcha 2.0 authentication (`m.login.recaptcha`).
|
||||||
|
#[ruma_enum(rename = "m.login.recaptcha")]
|
||||||
|
ReCaptcha,
|
||||||
|
|
||||||
|
/// Token-based authentication (`m.login.token`).
|
||||||
|
#[ruma_enum(rename = "m.login.token")]
|
||||||
|
Token,
|
||||||
|
|
||||||
|
/// OAuth2-based authentication (`m.login.oauth2`).
|
||||||
|
#[ruma_enum(rename = "m.login.oauth2")]
|
||||||
|
OAuth2,
|
||||||
|
|
||||||
|
/// Email-based authentication (`m.login.email.identity`).
|
||||||
|
#[ruma_enum(rename = "m.login.email.identity")]
|
||||||
|
EmailIdentity,
|
||||||
|
|
||||||
|
/// Phone number-based authentication (`m.login.msisdn`).
|
||||||
|
#[ruma_enum(rename = "m.login.msisdn")]
|
||||||
|
Msisdn,
|
||||||
|
|
||||||
|
/// Dummy authentication (`m.login.dummy`).
|
||||||
|
#[ruma_enum(rename = "m.login.dummy")]
|
||||||
|
Dummy,
|
||||||
|
|
||||||
|
/// Registration token-based authentication (`org.matrix.msc3231.login.registration_token`).
|
||||||
|
#[cfg(feature = "unstable-pre-spec")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-pre-spec")))]
|
||||||
|
#[ruma_enum(rename = "org.matrix.msc3231.login.registration_token")]
|
||||||
|
RegistrationToken,
|
||||||
|
|
||||||
|
#[doc(hidden)]
|
||||||
|
_Custom(String),
|
||||||
|
}
|
||||||
|
|
||||||
/// Data for password-based UIAA flow.
|
/// Data for password-based UIAA flow.
|
||||||
///
|
///
|
||||||
/// See [the spec] for how to use this.
|
/// See [the spec] for how to use this.
|
||||||
@ -475,7 +517,7 @@ pub struct UiaaInfo {
|
|||||||
|
|
||||||
/// List of stages in the current flow completed by the client.
|
/// List of stages in the current flow completed by the client.
|
||||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
pub completed: Vec<String>,
|
pub completed: Vec<AuthType>,
|
||||||
|
|
||||||
/// Authentication parameters required for the client to complete authentication.
|
/// Authentication parameters required for the client to complete authentication.
|
||||||
///
|
///
|
||||||
@ -504,14 +546,14 @@ impl UiaaInfo {
|
|||||||
pub struct AuthFlow {
|
pub struct AuthFlow {
|
||||||
/// Ordered list of stages required to complete authentication.
|
/// Ordered list of stages required to complete authentication.
|
||||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
pub stages: Vec<String>,
|
pub stages: Vec<AuthType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthFlow {
|
impl AuthFlow {
|
||||||
/// Creates a new `AuthFlow` with the given stages.
|
/// Creates a new `AuthFlow` with the given stages.
|
||||||
///
|
///
|
||||||
/// To create an empty `AuthFlow`, use `AuthFlow::default()`.
|
/// To create an empty `AuthFlow`, use `AuthFlow::default()`.
|
||||||
pub fn new(stages: Vec<String>) -> Self {
|
pub fn new(stages: Vec<AuthType>) -> Self {
|
||||||
Self { stages }
|
Self { stages }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,8 @@ use serde_json::{
|
|||||||
use ruma_client_api::{
|
use ruma_client_api::{
|
||||||
error::{ErrorBody, ErrorKind},
|
error::{ErrorBody, ErrorKind},
|
||||||
r0::uiaa::{
|
r0::uiaa::{
|
||||||
self, AuthData, AuthFlow, IncomingAuthData, IncomingUserIdentifier, UiaaInfo, UiaaResponse,
|
self, AuthData, AuthFlow, AuthType, IncomingAuthData, IncomingUserIdentifier, UiaaInfo,
|
||||||
|
UiaaResponse,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -148,13 +149,13 @@ fn deserialize_uiaa_info() {
|
|||||||
let json = json!({
|
let json = json!({
|
||||||
"errcode": "M_FORBIDDEN",
|
"errcode": "M_FORBIDDEN",
|
||||||
"error": "Invalid password",
|
"error": "Invalid password",
|
||||||
"completed": ["example.type.foo"],
|
"completed": ["m.login.recaptcha"],
|
||||||
"flows": [
|
"flows": [
|
||||||
{
|
{
|
||||||
"stages": ["example.type.foo", "example.type.bar"]
|
"stages": ["m.login.password"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"stages": ["example.type.foo", "example.type.baz"]
|
"stages": ["m.login.email.identity", "m.login.msisdn"]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"params": {
|
"params": {
|
||||||
@ -178,17 +179,12 @@ fn deserialize_uiaa_info() {
|
|||||||
session: Some(session),
|
session: Some(session),
|
||||||
..
|
..
|
||||||
} if error_message == "Invalid password"
|
} if error_message == "Invalid password"
|
||||||
&& completed == vec!["example.type.foo".to_owned()]
|
&& completed == vec![AuthType::ReCaptcha]
|
||||||
&& matches!(
|
&& matches!(
|
||||||
flows.as_slice(),
|
flows.as_slice(),
|
||||||
[f1, f2]
|
[f1, f2]
|
||||||
if f1.stages == vec![
|
if f1.stages == vec![AuthType::Password]
|
||||||
"example.type.foo".to_owned(),
|
&& f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn]
|
||||||
"example.type.bar".to_owned()
|
|
||||||
] && f2.stages == vec![
|
|
||||||
"example.type.foo".to_owned(),
|
|
||||||
"example.type.baz".to_owned()
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
|
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
|
||||||
"example.type.baz": {
|
"example.type.baz": {
|
||||||
@ -201,7 +197,7 @@ fn deserialize_uiaa_info() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn try_uiaa_response_into_http_response() {
|
fn try_uiaa_response_into_http_response() {
|
||||||
let flows = vec![AuthFlow::new(vec!["m.login.password".into(), "m.login.dummy".into()])];
|
let flows = vec![AuthFlow::new(vec![AuthType::Password, AuthType::Dummy])];
|
||||||
let params = to_raw_json_value(&json!({
|
let params = to_raw_json_value(&json!({
|
||||||
"example.type.baz": {
|
"example.type.baz": {
|
||||||
"example_key": "foobar"
|
"example_key": "foobar"
|
||||||
@ -209,7 +205,7 @@ fn try_uiaa_response_into_http_response() {
|
|||||||
}))
|
}))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let uiaa_info = assign!(UiaaInfo::new(flows, params), {
|
let uiaa_info = assign!(UiaaInfo::new(flows, params), {
|
||||||
completed: vec!["m.login.password".into()],
|
completed: vec![AuthType::ReCaptcha],
|
||||||
});
|
});
|
||||||
let uiaa_response =
|
let uiaa_response =
|
||||||
UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::<Vec<u8>>().unwrap();
|
UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::<Vec<u8>>().unwrap();
|
||||||
@ -225,8 +221,8 @@ fn try_uiaa_response_into_http_response() {
|
|||||||
..
|
..
|
||||||
} if matches!(
|
} if matches!(
|
||||||
flows.as_slice(),
|
flows.as_slice(),
|
||||||
[flow] if flow.stages == vec!["m.login.password".to_owned(), "m.login.dummy".to_owned()]
|
[flow] if flow.stages == vec![AuthType::Password, AuthType::Dummy]
|
||||||
) && completed == vec!["m.login.password".to_owned()]
|
) && completed == vec![AuthType::ReCaptcha]
|
||||||
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
|
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
|
||||||
"example.type.baz": {
|
"example.type.baz": {
|
||||||
"example_key": "foobar"
|
"example_key": "foobar"
|
||||||
@ -241,13 +237,13 @@ fn try_uiaa_response_from_http_response() {
|
|||||||
let json = serde_json::to_string(&json!({
|
let json = serde_json::to_string(&json!({
|
||||||
"errcode": "M_FORBIDDEN",
|
"errcode": "M_FORBIDDEN",
|
||||||
"error": "Invalid password",
|
"error": "Invalid password",
|
||||||
"completed": [ "example.type.foo" ],
|
"completed": ["m.login.recaptcha"],
|
||||||
"flows": [
|
"flows": [
|
||||||
{
|
{
|
||||||
"stages": [ "example.type.foo", "example.type.bar" ]
|
"stages": ["m.login.password"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"stages": [ "example.type.foo", "example.type.baz" ]
|
"stages": ["m.login.email.identity", "m.login.msisdn"]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"params": {
|
"params": {
|
||||||
@ -282,17 +278,12 @@ fn try_uiaa_response_from_http_response() {
|
|||||||
session: Some(session),
|
session: Some(session),
|
||||||
..
|
..
|
||||||
} if error_message == "Invalid password"
|
} if error_message == "Invalid password"
|
||||||
&& completed == vec!["example.type.foo".to_owned()]
|
&& completed == vec![AuthType::ReCaptcha]
|
||||||
&& matches!(
|
&& matches!(
|
||||||
flows.as_slice(),
|
flows.as_slice(),
|
||||||
[f1, f2]
|
[f1, f2]
|
||||||
if f1.stages == vec![
|
if f1.stages == vec![AuthType::Password]
|
||||||
"example.type.foo".to_owned(),
|
&& f2.stages == vec![AuthType::EmailIdentity, AuthType::Msisdn]
|
||||||
"example.type.bar".to_owned()
|
|
||||||
] && f2.stages == vec![
|
|
||||||
"example.type.foo".to_owned(),
|
|
||||||
"example.type.baz".to_owned()
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
|
&& from_json_str::<JsonValue>(params.get()).unwrap() == json!({
|
||||||
"example.type.baz": {
|
"example.type.baz": {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user