diff --git a/crates/ruma-client-api/CHANGELOG.md b/crates/ruma-client-api/CHANGELOG.md index b118d277..e4e09f78 100644 --- a/crates/ruma-client-api/CHANGELOG.md +++ b/crates/ruma-client-api/CHANGELOG.md @@ -2,7 +2,10 @@ Breaking changes: -* Change inconsistent types in `rooms` and `not_rooms` fields in `RoomEventFilter` structure. Both types now use `RoomId`. +* Change inconsistent types in `rooms` and `not_rooms` fields in + `RoomEventFilter` structure: both types now use `RoomId` +* Move `r0::{session::login::UserIdentifier => uiaa::UserIdentifier}` +* Add `stages` parameter to `r0::uiaa::AuthFlow::new` Improvements: diff --git a/crates/ruma-client-api/src/r0/session/login.rs b/crates/ruma-client-api/src/r0/session/login.rs index 5a107af3..b5f40c0c 100644 --- a/crates/ruma-client-api/src/r0/session/login.rs +++ b/crates/ruma-client-api/src/r0/session/login.rs @@ -1,11 +1,12 @@ //! [POST /_matrix/client/r0/login](https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-login) use ruma_api::ruma_api; -use ruma_common::thirdparty::Medium; use ruma_identifiers::{DeviceId, DeviceIdBox, ServerNameBox, UserId}; use ruma_serde::Outgoing; use serde::{Deserialize, Serialize}; +use crate::r0::uiaa::{IncomingUserIdentifier, UserIdentifier}; + ruma_api! { metadata: { description: "Login to the homeserver.", @@ -75,35 +76,6 @@ impl Response { } } -/// Identification information for the user. -#[derive(Clone, Debug, PartialEq, Eq, Outgoing, Serialize)] -#[serde(from = "user_serde::IncomingUserIdentifier", into = "user_serde::UserIdentifier<'_>")] -#[allow(clippy::exhaustive_enums)] -pub enum UserIdentifier<'a> { - /// Either a fully qualified Matrix user ID, or just the localpart (as part of the 'identifier' - /// field). - MatrixId(&'a str), - - /// Third party identifier (as part of the 'identifier' field). - ThirdPartyId { - /// Third party identifier for the user. - address: &'a str, - - /// The medium of the identifier. - medium: Medium, - }, - - /// Same as third-party identification with medium == msisdn, but with a non-canonicalised - /// phone number. - PhoneNumber { - /// The country that the phone number is from. - country: &'a str, - - /// The phone number. - phone: &'a str, - }, -} - /// The authentication mechanism. #[derive(Clone, Debug, PartialEq, Eq, Outgoing, Serialize)] #[serde(tag = "type")] @@ -114,6 +86,7 @@ pub enum LoginInfo<'a> { Password { /// Identification information for the user. identifier: UserIdentifier<'a>, + /// The password. password: &'a str, }, @@ -176,14 +149,13 @@ impl IdentityServerInfo { } } -mod user_serde; - #[cfg(test)] mod tests { use matches::assert_matches; use serde_json::{from_value as from_json_value, json}; - use super::{IncomingLoginInfo, IncomingUserIdentifier}; + use super::IncomingLoginInfo; + use crate::r0::uiaa::IncomingUserIdentifier; #[test] fn deserialize_login_type() { @@ -212,26 +184,15 @@ mod tests { ); } - #[test] - fn deserialize_user() { - assert_matches!( - from_json_value(json!({ - "type": "m.id.user", - "user": "cheeky_monkey" - })) - .unwrap(), - IncomingUserIdentifier::MatrixId(id) - if id == "cheeky_monkey" - ); - } - #[test] #[cfg(feature = "client")] fn serialize_login_request_body() { use ruma_api::{OutgoingRequest, SendAccessToken}; + use ruma_common::thirdparty::Medium; use serde_json::Value as JsonValue; - use super::{LoginInfo, Medium, Request, UserIdentifier}; + use super::{LoginInfo, Request}; + use crate::r0::uiaa::UserIdentifier; let req: http::Request> = Request { login_info: LoginInfo::Token { token: "0xdeadbeef" }, diff --git a/crates/ruma-client-api/src/r0/uiaa.rs b/crates/ruma-client-api/src/r0/uiaa.rs index aff80dd2..83a6d87a 100644 --- a/crates/ruma-client-api/src/r0/uiaa.rs +++ b/crates/ruma-client-api/src/r0/uiaa.rs @@ -2,65 +2,333 @@ //! //! [uiaa]: https://matrix.org/docs/spec/client_server/r0.6.1#user-interactive-authentication-api -use std::{collections::BTreeMap, fmt}; +use std::{borrow::Cow, fmt}; use bytes::BufMut; use ruma_api::{ error::{IntoHttpError, ResponseDeserializationError}, EndpointError, OutgoingResponse, }; +use ruma_common::thirdparty::Medium; +use ruma_identifiers::{ClientSecret, SessionId}; use ruma_serde::Outgoing; -use serde::{Deserialize, Serialize}; -use serde_json::{ - from_slice as from_json_slice, value::RawValue as RawJsonValue, Value as JsonValue, +use serde::{ + de::{self, DeserializeOwned}, + Deserialize, Deserializer, Serialize, }; +use serde_json::{from_slice as from_json_slice, value::RawValue as RawJsonValue}; use crate::error::{Error as MatrixError, ErrorBody}; pub mod authorize_fallback; +mod user_serde; /// Additional authentication information for the user-interactive authentication API. #[derive(Clone, Debug, Outgoing, Serialize)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[allow(clippy::manual_non_exhaustive)] +#[incoming_derive(!Deserialize)] #[serde(untagged)] pub enum AuthData<'a> { - /// Used for sending UIAA authentication requests to the homeserver directly from the client. - // Could be made non-exhaustive by creating a separate struct and making auth_parameters - // private, but probably not worth the hassle. - DirectRequest { - /// The login type that the client is attempting to complete. - #[serde(rename = "type")] - kind: &'a str, + /// Password-based authentication (`m.login.password`). + Password(Password<'a>), - /// The value of the session key given by the homeserver. - #[serde(skip_serializing_if = "Option::is_none")] - session: Option<&'a str>, + /// Google ReCaptcha 2.0 authentication (`m.login.recaptcha`). + ReCaptcha(ReCaptcha<'a>), - /// Parameters submitted for a particular authentication stage. - #[serde(flatten)] - auth_parameters: BTreeMap, - }, + /// Token-based authentication (`m.login.token`). + Token(Token<'a>), - /// Used by the client to acknowledge that the user has completed a UIAA stage through the - /// fallback method. - // Exhaustiveness is correct here since this variant is a fallback that explicitly only has a - // single field. TODO: #[serde(deny_unknown_fields)] not supported on enum variants - // https://github.com/serde-rs/serde/issues/1982 - FallbackAcknowledgement { - /// The value of the session key given by the homeserver. - session: &'a str, - }, + /// OAuth2-based authentication (`m.login.oauth2`). + OAuth2(OAuth2<'a>), + + /// Email-based authentication (`m.login.email.identity`). + EmailIdentity(EmailIdentity<'a>), + + /// Phone number-based authentication (`m.login.msisdn`). + Msisdn(Msisdn<'a>), + + /// Dummy authentication (`m.login.dummy`). + Dummy(Dummy<'a>), + + /// Fallback acknowledgement. + FallbackAcknowledgement(FallbackAcknowledgement<'a>), + + #[doc(hidden)] + _Custom, } impl<'a> AuthData<'a> { - /// Creates a new `AuthData::DirectRequest` with the given login type. - pub fn direct_request(kind: &'a str) -> Self { - Self::DirectRequest { kind, session: None, auth_parameters: BTreeMap::new() } - } - /// Creates a new `AuthData::FallbackAcknowledgement` with the given session key. pub fn fallback_acknowledgement(session: &'a str) -> Self { - Self::FallbackAcknowledgement { session } + Self::FallbackAcknowledgement(FallbackAcknowledgement::new(session)) + } +} + +impl<'de> Deserialize<'de> for IncomingAuthData { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + fn from_raw_json_value( + raw: &RawJsonValue, + ) -> Result { + serde_json::from_str(raw.get()).map_err(E::custom) + } + + let json = Box::::deserialize(deserializer)?; + + #[derive(Deserialize)] + struct ExtractType<'a> { + #[serde(borrow, rename = "type")] + auth_type: Option>, + } + + let auth_type = serde_json::from_str::>(json.get()) + .map_err(de::Error::custom)? + .auth_type; + + match auth_type.as_deref() { + Some("m.login.password") => from_raw_json_value(&json).map(Self::Password), + Some("m.login.recaptcha") => from_raw_json_value(&json).map(Self::ReCaptcha), + Some("m.login.token") => from_raw_json_value(&json).map(Self::Token), + Some("m.login.oauth2") => from_raw_json_value(&json).map(Self::OAuth2), + Some("m.login.email.identity") => from_raw_json_value(&json).map(Self::EmailIdentity), + Some("m.login.msisdn") => from_raw_json_value(&json).map(Self::Msisdn), + Some("m.login.dummy") => from_raw_json_value(&json).map(Self::Dummy), + None => from_raw_json_value(&json).map(Self::FallbackAcknowledgement), + Some(_) => Ok(Self::_Custom), + } + } +} + +/// Data for password-based UIAA flow. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#password-based +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "type", rename = "m.login.password")] +pub struct Password<'a> { + /// One of the user's identifiers. + pub identifier: UserIdentifier<'a>, + + /// The plaintext password. + pub password: &'a str, + + /// The value of the session key given by the homeserver, if any. + pub session: Option<&'a str>, +} + +impl<'a> Password<'a> { + /// Creates a new `Password` with the given identifier and password. + pub fn new(identifier: UserIdentifier<'a>, password: &'a str) -> Self { + Self { identifier, password, session: None } + } +} + +/// Data for ReCaptcha UIAA flow. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#google-recaptcha +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "type", rename = "m.login.recaptcha")] +pub struct ReCaptcha<'a> { + /// The captcha response. + pub response: &'a str, + + /// The value of the session key given by the homeserver, if any. + pub session: Option<&'a str>, +} + +impl<'a> ReCaptcha<'a> { + /// Creates a new `ReCaptcha` with the given response string. + pub fn new(response: &'a str) -> Self { + Self { response, session: None } + } +} + +/// Data for token-based UIAA flow. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#token-based +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "type", rename = "m.login.token")] +pub struct Token<'a> { + /// The login token. + pub token: &'a str, + + /// The transaction ID. + pub txn_id: &'a str, + + /// The value of the session key given by the homeserver, if any. + pub session: Option<&'a str>, +} + +impl<'a> Token<'a> { + /// Creates a new `Token` with the given token and transaction ID. + pub fn new(token: &'a str, txn_id: &'a str) -> Self { + Self { token, txn_id, session: None } + } +} + +/// Data for OAuth2-based UIAA flow. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#oauth2-based +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "type", rename = "m.login.oauth2")] +pub struct OAuth2<'a> { + /// Authorization Request URI or service selection URI. + pub uri: &'a str, + + /// The value of the session key given by the homeserver, if any. + pub session: Option<&'a str>, +} + +impl<'a> OAuth2<'a> { + /// Creates a new `OAuth2` with the given URI. + pub fn new(uri: &'a str) -> Self { + Self { uri, session: None } + } +} + +/// Data for Email-based UIAA flow. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#email-based-identity-homeserver +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "type", rename = "m.login.email.identity")] +pub struct EmailIdentity<'a> { + /// Thirdparty identifier credentials. + #[serde(rename = "threepidCreds")] + pub thirdparty_id_creds: &'a [ThirdpartyIdCredentials<'a>], + + /// The value of the session key given by the homeserver, if any. + pub session: Option<&'a str>, +} + +/// Data for phone number-based UIAA flow. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#phone-number-msisdn-based-identity-homeserver +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "type", rename = "m.login.msisdn")] +pub struct Msisdn<'a> { + /// Thirdparty identifier credentials. + #[serde(rename = "threepidCreds")] + pub thirdparty_id_creds: &'a [ThirdpartyIdCredentials<'a>], + + /// The value of the session key given by the homeserver, if any. + pub session: Option<&'a str>, +} + +/// Data for dummy UIAA flow. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#password-based +#[derive(Clone, Debug, Default, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "type", rename = "m.login.dummy")] +pub struct Dummy<'a> { + /// The value of the session key given by the homeserver, if any. + pub session: Option<&'a str>, +} + +impl Dummy<'_> { + /// Creates an empty `Dummy`. + pub fn new() -> Self { + Self::default() + } +} + +/// Data for UIAA fallback acknowledgement. +/// +/// See [the spec] for how to use this. +/// +/// [the spec]: https://matrix.org/docs/spec/client_server/r0.6.1#fallback +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +pub struct FallbackAcknowledgement<'a> { + /// The value of the session key given by the homeserver. + pub session: &'a str, +} + +impl<'a> FallbackAcknowledgement<'a> { + /// Creates a new `FallbackAcknowledgement` with the given session key. + pub fn new(session: &'a str) -> Self { + Self { session } + } +} + +/// Identification information for the user. +#[derive(Clone, Debug, PartialEq, Eq, Outgoing, Serialize)] +#[serde(from = "user_serde::IncomingUserIdentifier", into = "user_serde::UserIdentifier<'_>")] +#[allow(clippy::exhaustive_enums)] +pub enum UserIdentifier<'a> { + /// Either a fully qualified Matrix user ID, or just the localpart (as part of the 'identifier' + /// field). + MatrixId(&'a str), + + /// Third party identifier (as part of the 'identifier' field). + ThirdPartyId { + /// Third party identifier for the user. + address: &'a str, + + /// The medium of the identifier. + medium: Medium, + }, + + /// Same as third-party identification with medium == msisdn, but with a non-canonicalised + /// phone number. + PhoneNumber { + /// The country that the phone number is from. + country: &'a str, + + /// The phone number. + phone: &'a str, + }, +} + +/// Credentials for thirdparty authentification (e.g. email / phone number). +#[derive(Clone, Debug, Outgoing, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +pub struct ThirdpartyIdCredentials<'a> { + /// Identity server session ID. + pub sid: &'a SessionId, + + /// Identity server client secret. + pub client_secret: &'a ClientSecret, + + /// Identity server URL. + pub id_server: &'a str, + + /// Identity server access token. + pub id_access_token: &'a str, +} + +impl<'a> ThirdpartyIdCredentials<'a> { + /// Creates a new `ThirdpartyIdCredentials` with the given session ID, client secret, identity + /// server address and access token. + pub fn new( + sid: &'a SessionId, + client_secret: &'a ClientSecret, + id_server: &'a str, + id_access_token: &'a str, + ) -> Self { + Self { sid, client_secret, id_server, id_access_token } } } @@ -100,7 +368,6 @@ impl UiaaInfo { /// Description of steps required to authenticate via the User-Interactive Authentication API. #[derive(Clone, Debug, Default, Deserialize, Serialize)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] -#[cfg_attr(test, derive(PartialEq))] pub struct AuthFlow { /// Ordered list of stages required to complete authentication. #[serde(default, skip_serializing_if = "Vec::is_empty")] @@ -108,9 +375,11 @@ pub struct AuthFlow { } impl AuthFlow { - /// Creates an empty `AuthFlow`. - pub fn new() -> Self { - Self { stages: Vec::new() } + /// Creates a new `AuthFlow` with the given stages. + /// + /// To create an empty `AuthFlow`, use `AuthFlow::default()`. + pub fn new(stages: Vec) -> Self { + Self { stages } } } @@ -168,271 +437,3 @@ impl OutgoingResponse for UiaaResponse { } } } - -#[cfg(test)] -mod tests { - use maplit::btreemap; - use matches::assert_matches; - use ruma_api::{EndpointError, OutgoingResponse}; - use serde_json::{ - from_slice as from_json_slice, from_str as from_json_str, from_value as from_json_value, - json, to_value as to_json_value, value::to_raw_value as to_raw_json_value, - Value as JsonValue, - }; - - use super::{AuthData, AuthFlow, IncomingAuthData, UiaaInfo, UiaaResponse}; - use crate::error::{ErrorBody, ErrorKind}; - - #[test] - fn serialize_authentication_data_direct_request() { - let authentication_data = AuthData::DirectRequest { - kind: "example.type.foo", - session: Some("ZXY000"), - auth_parameters: btreemap! { - "example_credential".to_owned() => json!("verypoorsharedsecret") - }, - }; - - assert_eq!( - json!({ - "type": "example.type.foo", - "session": "ZXY000", - "example_credential": "verypoorsharedsecret", - }), - to_json_value(authentication_data).unwrap() - ); - } - - #[test] - fn deserialize_authentication_data_direct_request() { - let json = json!({ - "type": "example.type.foo", - "session": "opaque_session_id", - "example_credential": "verypoorsharedsecret", - }); - - assert_matches!( - from_json_value(json).unwrap(), - IncomingAuthData::DirectRequest { kind, session: Some(session), auth_parameters } - if kind == "example.type.foo" - && session == "opaque_session_id" - && auth_parameters == btreemap!{ - "example_credential".to_owned() => json!("verypoorsharedsecret") - } - ); - } - - #[test] - fn serialize_authentication_data_fallback() { - let authentication_data = AuthData::FallbackAcknowledgement { session: "ZXY000" }; - - assert_eq!(json!({ "session": "ZXY000" }), to_json_value(authentication_data).unwrap()); - } - - #[test] - fn deserialize_authentication_data_fallback() { - let json = json!({ "session": "opaque_session_id" }); - - assert_matches!( - from_json_value(json).unwrap(), - IncomingAuthData::FallbackAcknowledgement { session } - if session == "opaque_session_id" - ); - } - - #[test] - fn serialize_uiaa_info() { - let uiaa_info = UiaaInfo { - flows: vec![AuthFlow { - stages: vec!["m.login.password".into(), "m.login.dummy".into()], - }], - completed: vec!["m.login.password".into()], - params: to_raw_json_value(&json!({ - "example.type.baz": { - "example_key": "foobar" - } - })) - .unwrap(), - session: None, - auth_error: None, - }; - - let json = json!({ - "flows": [{ "stages": ["m.login.password", "m.login.dummy"] }], - "completed": ["m.login.password"], - "params": { - "example.type.baz": { - "example_key": "foobar" - } - } - }); - assert_eq!(to_json_value(uiaa_info).unwrap(), json); - } - - #[test] - fn deserialize_uiaa_info() { - let json = json!({ - "errcode": "M_FORBIDDEN", - "error": "Invalid password", - "completed": ["example.type.foo"], - "flows": [ - { - "stages": ["example.type.foo", "example.type.bar"] - }, - { - "stages": ["example.type.foo", "example.type.baz"] - } - ], - "params": { - "example.type.baz": { - "example_key": "foobar" - } - }, - "session": "xxxxxx" - }); - - assert_matches!( - from_json_value::(json).unwrap(), - UiaaInfo { - auth_error: Some(ErrorBody { - kind: ErrorKind::Forbidden, - message: error_message, - }), - completed, - flows, - params, - session: Some(session), - } if error_message == "Invalid password" - && completed == vec!["example.type.foo".to_owned()] - && flows == vec![ - AuthFlow { - stages: vec![ - "example.type.foo".into(), - "example.type.bar".into(), - ], - }, - AuthFlow { - stages: vec![ - "example.type.foo".into(), - "example.type.baz".into(), - ], - }, - ] - && from_json_str::(params.get()).unwrap() == json!({ - "example.type.baz": { - "example_key": "foobar" - } - }) - && session == "xxxxxx" - ); - } - - #[test] - fn try_uiaa_response_into_http_response() { - let uiaa_info = UiaaInfo { - flows: vec![AuthFlow { - stages: vec!["m.login.password".into(), "m.login.dummy".into()], - }], - completed: vec!["m.login.password".into()], - params: to_raw_json_value(&json!({ - "example.type.baz": { - "example_key": "foobar" - } - })) - .unwrap(), - session: None, - auth_error: None, - }; - let uiaa_response = - UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::>().unwrap(); - - assert_matches!( - from_json_slice::(uiaa_response.body()).unwrap(), - UiaaInfo { - flows, - completed, - params, - session: None, - auth_error: None, - } if flows == vec![AuthFlow { - stages: vec!["m.login.password".into(), "m.login.dummy".into()], - }] - && completed == vec!["m.login.password".to_owned()] - && from_json_str::(params.get()).unwrap() == json!({ - "example.type.baz": { - "example_key": "foobar" - } - }) - ); - assert_eq!(uiaa_response.status(), http::status::StatusCode::UNAUTHORIZED); - } - - #[test] - fn try_uiaa_response_from_http_response() { - let json = serde_json::to_string(&json!({ - "errcode": "M_FORBIDDEN", - "error": "Invalid password", - "completed": [ "example.type.foo" ], - "flows": [ - { - "stages": [ "example.type.foo", "example.type.bar" ] - }, - { - "stages": [ "example.type.foo", "example.type.baz" ] - } - ], - "params": { - "example.type.baz": { - "example_key": "foobar" - } - }, - "session": "xxxxxx" - })) - .unwrap(); - - let http_response = http::Response::builder() - .status(http::StatusCode::UNAUTHORIZED) - .body(json.as_bytes()) - .unwrap(); - - let parsed_uiaa_info = match UiaaResponse::try_from_http_response(http_response).unwrap() { - UiaaResponse::AuthResponse(uiaa_info) => uiaa_info, - _ => panic!("Expected UiaaResponse::AuthResponse"), - }; - - assert_matches!( - parsed_uiaa_info, - UiaaInfo { - auth_error: Some(ErrorBody { - kind: ErrorKind::Forbidden, - message: error_message, - }), - completed, - flows, - params, - session: Some(session), - } if error_message == "Invalid password" - && completed == vec!["example.type.foo".to_owned()] - && flows == vec![ - AuthFlow { - stages: vec![ - "example.type.foo".into(), - "example.type.bar".into(), - ], - }, - AuthFlow { - stages: vec![ - "example.type.foo".into(), - "example.type.baz".into(), - ], - }, - ] - && from_json_str::(params.get()).unwrap() == json!({ - "example.type.baz": { - "example_key": "foobar" - } - }) - && session == "xxxxxx" - ); - } -} diff --git a/crates/ruma-client-api/src/r0/session/login/user_serde.rs b/crates/ruma-client-api/src/r0/uiaa/user_serde.rs similarity index 100% rename from crates/ruma-client-api/src/r0/session/login/user_serde.rs rename to crates/ruma-client-api/src/r0/uiaa/user_serde.rs diff --git a/crates/ruma-client-api/tests/uiaa.rs b/crates/ruma-client-api/tests/uiaa.rs new file mode 100644 index 00000000..2d224b51 --- /dev/null +++ b/crates/ruma-client-api/tests/uiaa.rs @@ -0,0 +1,269 @@ +use assign::assign; +use matches::assert_matches; +use ruma_api::{EndpointError, OutgoingResponse}; +use serde_json::{ + from_slice as from_json_slice, from_str as from_json_str, from_value as from_json_value, json, + to_value as to_json_value, value::to_raw_value as to_raw_json_value, Value as JsonValue, +}; + +use ruma_client_api::{ + error::{ErrorBody, ErrorKind}, + r0::uiaa::{ + self, AuthData, AuthFlow, IncomingAuthData, IncomingUserIdentifier, UiaaInfo, UiaaResponse, + }, +}; + +#[test] +fn deserialize_user_identifier() { + assert_matches!( + from_json_value(json!({ + "type": "m.id.user", + "user": "cheeky_monkey" + })) + .unwrap(), + IncomingUserIdentifier::MatrixId(id) + if id == "cheeky_monkey" + ); +} + +#[test] +fn serialize_auth_data_token() { + let auth_data = AuthData::Token( + assign!(uiaa::Token::new("mytoken", "txn123"), { session: Some("session") }), + ); + + assert_matches!( + to_json_value(auth_data), + Ok(val) if val == json!({ + "type": "m.login.token", + "token": "mytoken", + "txn_id": "txn123", + "session": "session", + }) + ); +} + +#[test] +fn deserialize_auth_data_direct_request() { + let json = json!({ + "type": "m.login.token", + "token": "mytoken", + "txn_id": "txn123", + "session": "session", + }); + + assert_matches!( + from_json_value(json), + Ok(IncomingAuthData::Token( + uiaa::IncomingToken { token, txn_id, session: Some(session), .. }, + )) + if token == "mytoken" + && txn_id == "txn123" + && session == "session" + ); +} + +#[test] +fn serialize_auth_data_fallback() { + let auth_data = AuthData::FallbackAcknowledgement(uiaa::FallbackAcknowledgement::new("ZXY000")); + + assert_eq!(json!({ "session": "ZXY000" }), to_json_value(auth_data).unwrap()); +} + +#[test] +fn deserialize_auth_data_fallback() { + let json = json!({ "session": "opaque_session_id" }); + + assert_matches!( + from_json_value(json).unwrap(), + IncomingAuthData::FallbackAcknowledgement( + uiaa::IncomingFallbackAcknowledgement { session, .. }, + ) + if session == "opaque_session_id" + ); +} + +#[test] +fn serialize_uiaa_info() { + let flows = vec![AuthFlow::new(vec!["m.login.password".into(), "m.login.dummy".into()])]; + let params = to_raw_json_value(&json!({ + "example.type.baz": { + "example_key": "foobar" + } + })) + .unwrap(); + let uiaa_info = assign!(UiaaInfo::new(flows, params), { + completed: vec!["m.login.password".into()], + }); + + let json = json!({ + "flows": [{ "stages": ["m.login.password", "m.login.dummy"] }], + "completed": ["m.login.password"], + "params": { + "example.type.baz": { + "example_key": "foobar" + } + } + }); + assert_eq!(to_json_value(uiaa_info).unwrap(), json); +} + +#[test] +fn deserialize_uiaa_info() { + let json = json!({ + "errcode": "M_FORBIDDEN", + "error": "Invalid password", + "completed": ["example.type.foo"], + "flows": [ + { + "stages": ["example.type.foo", "example.type.bar"] + }, + { + "stages": ["example.type.foo", "example.type.baz"] + } + ], + "params": { + "example.type.baz": { + "example_key": "foobar" + } + }, + "session": "xxxxxx" + }); + + assert_matches!( + from_json_value::(json).unwrap(), + UiaaInfo { + auth_error: Some(ErrorBody { + kind: ErrorKind::Forbidden, + message: error_message, + }), + completed, + flows, + params, + session: Some(session), + .. + } if error_message == "Invalid password" + && completed == vec!["example.type.foo".to_owned()] + && matches!( + flows.as_slice(), + [f1, f2] + if f1.stages == vec![ + "example.type.foo".to_owned(), + "example.type.bar".to_owned() + ] && f2.stages == vec![ + "example.type.foo".to_owned(), + "example.type.baz".to_owned() + ] + ) + && from_json_str::(params.get()).unwrap() == json!({ + "example.type.baz": { + "example_key": "foobar" + } + }) + && session == "xxxxxx" + ); +} + +#[test] +fn try_uiaa_response_into_http_response() { + let flows = vec![AuthFlow::new(vec!["m.login.password".into(), "m.login.dummy".into()])]; + let params = to_raw_json_value(&json!({ + "example.type.baz": { + "example_key": "foobar" + } + })) + .unwrap(); + let uiaa_info = assign!(UiaaInfo::new(flows, params), { + completed: vec!["m.login.password".into()], + }); + let uiaa_response = + UiaaResponse::AuthResponse(uiaa_info).try_into_http_response::>().unwrap(); + + assert_matches!( + from_json_slice::(uiaa_response.body()).unwrap(), + UiaaInfo { + flows, + completed, + params, + session: None, + auth_error: None, + .. + } if matches!( + flows.as_slice(), + [flow] if flow.stages == vec!["m.login.password".to_owned(), "m.login.dummy".to_owned()] + ) && completed == vec!["m.login.password".to_owned()] + && from_json_str::(params.get()).unwrap() == json!({ + "example.type.baz": { + "example_key": "foobar" + } + }) + ); + assert_eq!(uiaa_response.status(), http::status::StatusCode::UNAUTHORIZED); +} + +#[test] +fn try_uiaa_response_from_http_response() { + let json = serde_json::to_string(&json!({ + "errcode": "M_FORBIDDEN", + "error": "Invalid password", + "completed": [ "example.type.foo" ], + "flows": [ + { + "stages": [ "example.type.foo", "example.type.bar" ] + }, + { + "stages": [ "example.type.foo", "example.type.baz" ] + } + ], + "params": { + "example.type.baz": { + "example_key": "foobar" + } + }, + "session": "xxxxxx" + })) + .unwrap(); + + let http_response = http::Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(json.as_bytes()) + .unwrap(); + + let parsed_uiaa_info = match UiaaResponse::try_from_http_response(http_response).unwrap() { + UiaaResponse::AuthResponse(uiaa_info) => uiaa_info, + _ => panic!("Expected UiaaResponse::AuthResponse"), + }; + + assert_matches!( + parsed_uiaa_info, + UiaaInfo { + auth_error: Some(ErrorBody { + kind: ErrorKind::Forbidden, + message: error_message, + }), + completed, + flows, + params, + session: Some(session), + .. + } if error_message == "Invalid password" + && completed == vec!["example.type.foo".to_owned()] + && matches!( + flows.as_slice(), + [f1, f2] + if f1.stages == vec![ + "example.type.foo".to_owned(), + "example.type.bar".to_owned() + ] && f2.stages == vec![ + "example.type.foo".to_owned(), + "example.type.baz".to_owned() + ] + ) + && from_json_str::(params.get()).unwrap() == json!({ + "example.type.baz": { + "example_key": "foobar" + } + }) + && session == "xxxxxx" + ); +} diff --git a/crates/ruma-client/src/client_api.rs b/crates/ruma-client/src/client_api.rs index 39864789..3dc1edf2 100644 --- a/crates/ruma-client/src/client_api.rs +++ b/crates/ruma-client/src/client_api.rs @@ -5,8 +5,9 @@ use async_stream::try_stream; use futures_core::stream::Stream; use ruma_client_api::r0::{ account::register::{self, RegistrationKind}, - session::login::{self, LoginInfo, UserIdentifier}, + session::login::{self, LoginInfo}, sync::sync_events, + uiaa::UserIdentifier, }; use ruma_common::presence::PresenceState; use ruma_identifiers::DeviceId;