From 3b5f3cb5a6bc26ea25a3182af8af7663ecec12b0 Mon Sep 17 00:00:00 2001 From: Amanda Graven Date: Tue, 28 Sep 2021 15:39:58 +0200 Subject: [PATCH] events: Move JoinRulesEventContent.allow into JoinRules --- crates/ruma-events/src/room/join_rules.rs | 189 +++++++++++++++--- .../ruma-state-res/benches/state_res_bench.rs | 6 +- crates/ruma-state-res/src/lib.rs | 7 +- crates/ruma-state-res/src/test_utils.rs | 6 +- 4 files changed, 171 insertions(+), 37 deletions(-) diff --git a/crates/ruma-events/src/room/join_rules.rs b/crates/ruma-events/src/room/join_rules.rs index 486a6508..e2b45b5f 100644 --- a/crates/ruma-events/src/room/join_rules.rs +++ b/crates/ruma-events/src/room/join_rules.rs @@ -1,17 +1,20 @@ //! Types for the *m.room.join_rules* event. -#[cfg(feature = "unstable-pre-spec")] -use std::collections::BTreeMap; - use ruma_events_macros::EventContent; #[cfg(feature = "unstable-pre-spec")] use ruma_identifiers::RoomId; -use ruma_serde::StringEnum; #[cfg(feature = "unstable-pre-spec")] -use serde::de::{DeserializeOwned, Deserializer, Error}; -use serde::{Deserialize, Serialize}; +use serde::de::DeserializeOwned; +use serde::{ + de::{Deserializer, Error}, + Deserialize, Serialize, +}; +use serde_json::value::RawValue as RawJsonValue; #[cfg(feature = "unstable-pre-spec")] -use serde_json::{value::RawValue as RawJsonValue, Value as JsonValue}; +use serde_json::Value as JsonValue; +use std::borrow::Cow; +#[cfg(feature = "unstable-pre-spec")] +use std::collections::BTreeMap; use crate::StateEvent; @@ -19,36 +22,37 @@ use crate::StateEvent; pub type JoinRulesEvent = StateEvent; /// The payload for `JoinRulesEvent`. -#[derive(Clone, Debug, Deserialize, Serialize, EventContent)] +#[derive(Clone, Debug, Serialize, EventContent)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[ruma_event(type = "m.room.join_rules", kind = State)] pub struct JoinRulesEventContent { /// The type of rules used for users wishing to join this room. #[ruma_event(skip_redaction)] + #[serde(flatten)] pub join_rule: JoinRule, - - /// Allow rules used for the `restricted` join rule. - #[cfg(feature = "unstable-pre-spec")] - #[serde(default)] - #[ruma_event(skip_redaction)] - pub allow: Vec, } impl JoinRulesEventContent { /// Creates a new `JoinRulesEventContent` with the given rule. pub fn new(join_rule: JoinRule) -> Self { - Self { - join_rule, - #[cfg(feature = "unstable-pre-spec")] - allow: Vec::new(), - } + Self { join_rule } } /// Creates a new `JoinRulesEventContent` with the restricted rule and the given set of allow /// rules. #[cfg(feature = "unstable-pre-spec")] pub fn restricted(allow: Vec) -> Self { - Self { join_rule: JoinRule::Restricted, allow } + Self { join_rule: JoinRule::Restricted(Restricted::new(allow)) } + } +} + +impl<'de> Deserialize<'de> for JoinRulesEventContent { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let join_rule = JoinRule::deserialize(deserializer)?; + Ok(JoinRulesEventContent { join_rule }) } } @@ -56,36 +60,101 @@ impl JoinRulesEventContent { /// /// This type can hold an arbitrary string. To check for formats that are not available as a /// documented variant here, use its string representation, obtained through `.as_str()`. -#[derive(Clone, Debug, PartialEq, Eq, StringEnum)] -#[ruma_enum(rename_all = "lowercase")] -#[non_exhaustive] +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[serde(tag = "join_rule")] pub enum JoinRule { /// A user who wishes to join the room must first receive an invite to the room from someone /// already inside of the room. + #[serde(rename = "invite")] Invite, /// Reserved but not yet implemented by the Matrix specification. + #[serde(rename = "knock")] Knock, /// Reserved but not yet implemented by the Matrix specification. + #[serde(rename = "private")] Private, /// Users can join the room if they are invited, or if they meet any of the conditions /// described in a set of [`AllowRule`]s. #[cfg(feature = "unstable-pre-spec")] - Restricted, + #[serde(rename = "restricted")] + Restricted(Restricted), /// Anyone can join the room without any prior action. + #[serde(rename = "public")] Public, #[doc(hidden)] + #[serde(skip_serializing)] _Custom(String), } impl JoinRule { - /// Creates a string slice from this `JoinRule`. + /// Returns the string name of this `JoinRule` pub fn as_str(&self) -> &str { - self.as_ref() + match self { + JoinRule::Invite => "invite", + JoinRule::Knock => "knock", + JoinRule::Private => "private", + #[cfg(feature = "unstable-pre-spec")] + JoinRule::Restricted(_) => "restricted", + JoinRule::Public => "public", + JoinRule::_Custom(rule) => rule, + } + } +} + +impl<'de> Deserialize<'de> for JoinRule { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[cfg(feature = "unstable-pre-spec")] + fn from_raw_json_value(raw: &RawJsonValue) -> Result { + serde_json::from_str(raw.get()).map_err(E::custom) + } + + let json: Box = Box::deserialize(deserializer)?; + + #[derive(Deserialize)] + struct ExtractType<'a> { + join_rule: Option>, + } + + let join_rule = serde_json::from_str::>(json.get()) + .map_err(serde::de::Error::custom)? + .join_rule + .ok_or_else(|| D::Error::missing_field("join_rule"))?; + + match join_rule.as_ref() { + "invite" => Ok(Self::Invite), + "knock" => Ok(Self::Knock), + "private" => Ok(Self::Private), + #[cfg(feature = "unstable-pre-spec")] + "restricted" => from_raw_json_value(&json).map(Self::Restricted), + "public" => Ok(Self::Public), + _ => Ok(Self::_Custom(join_rule.into_owned())), + } + } +} + +/// Configuration of the `Restricted` join rule. +#[cfg(feature = "unstable-pre-spec")] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +pub struct Restricted { + /// Allow rules which describe conditions that allow joining a room. + allow: Vec, +} + +#[cfg(feature = "unstable-pre-spec")] +impl Restricted { + /// Constructs a new rule set for restricted rooms with the given rules. + pub fn new(allow: Vec) -> Self { + Self { allow } } } @@ -103,6 +172,14 @@ pub enum AllowRule { _Custom(CustomAllowRule), } +#[cfg(feature = "unstable-pre-spec")] +impl AllowRule { + /// Constructs an `AllowRule` with membership of the room with the given id as its predicate. + pub fn room_membership(room_id: RoomId) -> Self { + Self::RoomMembership(RoomMembership::new(room_id)) + } +} + /// Allow rule which grants permission to join based on the membership of another room. #[cfg(feature = "unstable-pre-spec")] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -112,6 +189,14 @@ pub struct RoomMembership { pub room_id: RoomId, } +#[cfg(feature = "unstable-pre-spec")] +impl RoomMembership { + /// Constructs a new room membership rule for the given room id. + pub fn new(room_id: RoomId) -> Self { + Self { room_id } + } +} + #[cfg(feature = "unstable-pre-spec")] #[doc(hidden)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -137,12 +222,13 @@ impl<'de> Deserialize<'de> for AllowRule { // Extracts the `type` value. #[derive(Deserialize)] - struct ExtractType { - rule_type: Option, + struct ExtractType<'a> { + #[serde(borrow, rename = "type")] + rule_type: Option>, } // Get the value of `type` if present. - let rule_type = serde_json::from_str::(json.get()) + let rule_type = serde_json::from_str::>(json.get()) .map_err(serde::de::Error::custom)? .rule_type; @@ -153,3 +239,48 @@ impl<'de> Deserialize<'de> for AllowRule { } } } + +#[cfg(test)] +mod tests { + #[cfg(feature = "unstable-pre-spec")] + use super::AllowRule; + use super::{JoinRule, JoinRulesEventContent}; + #[cfg(feature = "unstable-pre-spec")] + use ruma_identifiers::room_id; + + #[test] + fn deserialize() { + let json = r#"{"join_rule": "public"}"#; + let event: JoinRulesEventContent = serde_json::from_str(json).unwrap(); + assert!(matches!(event, JoinRulesEventContent { join_rule: JoinRule::Public })); + } + + #[cfg(feature = "unstable-pre-spec")] + #[test] + fn deserialize_unstable() { + let json = r#"{ + "join_rule": "restricted", + "allow": [ + { + "type": "m.room_membership", + "room_id": "!mods:example.org" + }, + { + "type": "m.room_membership", + "room_id": "!users:example.org" + } + ] + }"#; + let event: JoinRulesEventContent = serde_json::from_str(json).unwrap(); + match event.join_rule { + JoinRule::Restricted(restricted) => assert_eq!( + restricted.allow, + &[ + AllowRule::room_membership(room_id!("!mods:example.org")), + AllowRule::room_membership(room_id!("!users:example.org")) + ] + ), + rule => panic!("Deserialized to wrong variant: {:?}", rule), + } + } +} diff --git a/crates/ruma-state-res/benches/state_res_bench.rs b/crates/ruma-state-res/benches/state_res_bench.rs index 841a617e..625a5d37 100644 --- a/crates/ruma-state-res/benches/state_res_bench.rs +++ b/crates/ruma-state-res/benches/state_res_bench.rs @@ -24,7 +24,7 @@ use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_events::{ pdu::{EventHash, Pdu, RoomV3Pdu}, room::{ - join_rules::JoinRule, + join_rules::{JoinRule, JoinRulesEventContent}, member::{MemberEventContent, MembershipState}, }, EventType, @@ -264,7 +264,7 @@ impl TestStore { alice(), EventType::RoomJoinRules, Some(""), - to_raw_json_value(&json!({ "join_rule": JoinRule::Public })).unwrap(), + to_raw_json_value(&JoinRulesEventContent::new(JoinRule::Public)).unwrap(), &[cre.clone(), alice_mem.event_id().clone()], &[alice_mem.event_id().clone()], ); @@ -441,7 +441,7 @@ fn INITIAL_EVENTS() -> HashMap> { alice(), EventType::RoomJoinRules, Some(""), - to_raw_json_value(&json!({ "join_rule": JoinRule::Public })).unwrap(), + to_raw_json_value(&JoinRulesEventContent::new(JoinRule::Public)).unwrap(), &["CREATE", "IMA", "IPOWER"], &["IPOWER"], ), diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 69b45961..036719b0 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -623,7 +623,10 @@ mod tests { use maplit::{hashmap, hashset}; use rand::seq::SliceRandom; use ruma_common::MilliSecondsSinceUnixEpoch; - use ruma_events::{room::join_rules::JoinRule, EventType}; + use ruma_events::{ + room::join_rules::{JoinRule, JoinRulesEventContent}, + EventType, + }; use ruma_identifiers::{EventId, RoomVersionId}; use serde_json::{json, value::to_raw_value as to_raw_json_value}; use tracing::debug; @@ -873,7 +876,7 @@ mod tests { alice(), EventType::RoomJoinRules, Some(""), - to_raw_json_value(&json!({ "join_rule": JoinRule::Private })).unwrap(), + to_raw_json_value(&JoinRulesEventContent::new(JoinRule::Private)).unwrap(), ), to_init_pdu_event( "ME", diff --git a/crates/ruma-state-res/src/test_utils.rs b/crates/ruma-state-res/src/test_utils.rs index 0be8489f..fcb33f95 100644 --- a/crates/ruma-state-res/src/test_utils.rs +++ b/crates/ruma-state-res/src/test_utils.rs @@ -12,7 +12,7 @@ use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_events::{ pdu::{EventHash, Pdu, RoomV3Pdu}, room::{ - join_rules::JoinRule, + join_rules::{JoinRule, JoinRulesEventContent}, member::{MemberEventContent, MembershipState}, }, EventType, @@ -269,7 +269,7 @@ impl TestStore { alice(), EventType::RoomJoinRules, Some(""), - to_raw_json_value(&json!({ "join_rule": JoinRule::Public })).unwrap(), + to_raw_json_value(&JoinRulesEventContent::new(JoinRule::Public)).unwrap(), &[cre.clone(), alice_mem.event_id().clone()], &[alice_mem.event_id().clone()], ); @@ -481,7 +481,7 @@ pub fn INITIAL_EVENTS() -> HashMap> { alice(), EventType::RoomJoinRules, Some(""), - to_raw_json_value(&json!({ "join_rule": JoinRule::Public })).unwrap(), + to_raw_json_value(&JoinRulesEventContent::new(JoinRule::Public)).unwrap(), &["CREATE", "IMA", "IPOWER"], &["IPOWER"], ),