diff --git a/ruma-common/Cargo.toml b/ruma-common/Cargo.toml index 92abf09e..75fe68bb 100644 --- a/ruma-common/Cargo.toml +++ b/ruma-common/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/ruma/ruma" edition = "2018" [dependencies] +js_int = { version = "0.1.7", features = ["serde"] } matches = "0.1.8" ruma-serde = { version = "0.2.2", path = "../ruma-serde" } serde = { version = "1.0.113", features = ["derive"] } diff --git a/ruma-common/src/push.rs b/ruma-common/src/push.rs index 653d8e25..101b3642 100644 --- a/ruma-common/src/push.rs +++ b/ruma-common/src/push.rs @@ -7,6 +7,9 @@ use std::fmt::{self, Formatter}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::value::RawValue as RawJsonValue; +pub use room_member_count_is::{ComparisonOperator, RoomMemberCountIs}; + +mod room_member_count_is; mod tweak_serde; /// A push ruleset scopes a set of rules according to some criteria. @@ -229,11 +232,8 @@ pub enum PushCondition { /// This matches the current number of members in the room. RoomMemberCount { - /// A decimal integer optionally prefixed by one of `==`, `<`, `>`, `>=` or `<=`. - /// - /// A prefix of `<` matches rooms where the member count is strictly less than the given - /// number and so forth. If no prefix is present, this parameter defaults to `==`. - is: String, + /// The condition on the current number of members in the room. + is: RoomMemberCountIs, }, /// This takes into account the current power levels in the room, ensuring the sender of the @@ -249,10 +249,11 @@ pub enum PushCondition { #[cfg(test)] mod tests { + use js_int::uint; use matches::assert_matches; use serde_json::{from_value as from_json_value, json, to_value as to_json_value}; - use super::{Action, PushCondition, Tweak}; + use super::{Action, PushCondition, RoomMemberCountIs, Tweak}; #[test] fn serialize_string_action() { @@ -349,7 +350,10 @@ mod tests { "kind": "room_member_count" }); assert_eq!( - to_json_value(&PushCondition::RoomMemberCount { is: "2".to_string() }).unwrap(), + to_json_value(&PushCondition::RoomMemberCount { + is: RoomMemberCountIs::from(uint!(2)) + }) + .unwrap(), json_data ); } @@ -398,7 +402,7 @@ mod tests { assert_matches!( from_json_value::(json_data).unwrap(), PushCondition::RoomMemberCount { is } - if is == "2" + if is == RoomMemberCountIs::from(uint!(2)) ); } diff --git a/ruma-common/src/push/room_member_count_is.rs b/ruma-common/src/push/room_member_count_is.rs new file mode 100644 index 00000000..74aa6eef --- /dev/null +++ b/ruma-common/src/push/room_member_count_is.rs @@ -0,0 +1,209 @@ +use std::{ + fmt::{self, Display, Formatter}, + ops::{Bound, RangeBounds, RangeFrom, RangeTo, RangeToInclusive}, + str::FromStr, +}; + +use js_int::UInt; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// One of `==`, `<`, `>`, `>=` or `<=`. +/// +/// Used by `RoomMemberCountIs`. Defaults to `==`. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ComparisonOperator { + /// Equals + Eq, + /// Less than + Lt, + /// Greater than + Gt, + /// Greater or equal + Ge, + /// Less or equal + Le, +} + +impl Default for ComparisonOperator { + fn default() -> Self { + ComparisonOperator::Eq + } +} + +/// A decimal integer optionally prefixed by one of `==`, `<`, `>`, `>=` or `<=`. +/// +/// A prefix of `<` matches rooms where the member count is strictly less than the given +/// number and so forth. If no prefix is present, this parameter defaults to `==`. +/// +/// Can be constructed from a number or a range: +/// ``` +/// use js_int::uint; +/// use ruma_common::push::RoomMemberCountIs; +/// +/// // equivalent to `is: "3"` or `is: "==3"` +/// let exact = RoomMemberCountIs::from(uint!(3)); +/// +/// // equivalent to `is: ">=3"` +/// let greater_or_equal = RoomMemberCountIs::from(uint!(3)..); +/// +/// // equivalent to `is: "<3"` +/// let less = RoomMemberCountIs::from(..uint!(3)); +/// +/// // equivalent to `is: "<=3"` +/// let less_or_equal = RoomMemberCountIs::from(..=uint!(3)); +/// +/// // An exclusive range can be constructed with `RoomMemberCountIs::gt`: +/// // (equivalent to `is: ">3"`) +/// let greater = RoomMemberCountIs::gt(uint!(3)); +/// ``` +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct RoomMemberCountIs { + /// One of `==`, `<`, `>`, `>=`, `<=`, or no prefix. + pub prefix: ComparisonOperator, + /// The number of people in the room. + pub count: UInt, +} + +impl RoomMemberCountIs { + /// Creates an instance of `RoomMemberCount` equivalent to ` Self { + RoomMemberCountIs { prefix: ComparisonOperator::Gt, count } + } +} + +impl From for RoomMemberCountIs { + fn from(x: UInt) -> Self { + RoomMemberCountIs { prefix: ComparisonOperator::Eq, count: x } + } +} + +impl From> for RoomMemberCountIs { + fn from(x: RangeFrom) -> Self { + RoomMemberCountIs { prefix: ComparisonOperator::Ge, count: x.start } + } +} + +impl From> for RoomMemberCountIs { + fn from(x: RangeTo) -> Self { + RoomMemberCountIs { prefix: ComparisonOperator::Lt, count: x.end } + } +} + +impl From> for RoomMemberCountIs { + fn from(x: RangeToInclusive) -> Self { + RoomMemberCountIs { prefix: ComparisonOperator::Le, count: x.end } + } +} + +impl Display for RoomMemberCountIs { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + use ComparisonOperator::*; + + let prefix = match self.prefix { + Eq => "", + Lt => "<", + Gt => ">", + Ge => ">=", + Le => "<=", + }; + + write!(f, "{}{}", prefix, self.count) + } +} + +impl Serialize for RoomMemberCountIs { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let s = self.to_string(); + s.serialize(serializer) + } +} + +impl FromStr for RoomMemberCountIs { + type Err = js_int::ParseIntError; + + fn from_str(s: &str) -> Result { + use ComparisonOperator::*; + + let (prefix, count_str) = match s { + s if s.starts_with("<=") => (Le, &s[2..]), + s if s.starts_with('<') => (Lt, &s[1..]), + s if s.starts_with(">=") => (Ge, &s[2..]), + s if s.starts_with('>') => (Gt, &s[1..]), + s if s.starts_with("==") => (Eq, &s[2..]), + s => (Eq, s), + }; + + Ok(RoomMemberCountIs { prefix, count: UInt::from_str(count_str)? }) + } +} + +impl<'de> Deserialize<'de> for RoomMemberCountIs { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + FromStr::from_str(&s).map_err(serde::de::Error::custom) + } +} + +impl RangeBounds for RoomMemberCountIs { + fn start_bound(&self) -> Bound<&UInt> { + use ComparisonOperator::*; + + match self.prefix { + Eq => Bound::Included(&self.count), + Lt | Le => Bound::Unbounded, + Gt => Bound::Excluded(&self.count), + Ge => Bound::Included(&self.count), + } + } + + fn end_bound(&self) -> Bound<&UInt> { + use ComparisonOperator::*; + + match self.prefix { + Eq => Bound::Included(&self.count), + Gt | Ge => Bound::Unbounded, + Lt => Bound::Excluded(&self.count), + Le => Bound::Included(&self.count), + } + } +} + +#[cfg(test)] +mod tests { + use std::ops::RangeBounds; + + use js_int::uint; + + use super::RoomMemberCountIs; + + #[test] + fn eq_range_contains_its_own_count() { + let count = 2u32.into(); + let range = RoomMemberCountIs::from(count); + + assert!(range.contains(&count)); + } + + #[test] + fn ge_range_contains_large_number() { + let range = RoomMemberCountIs::from(uint!(2)..); + let large_number = 9001u32.into(); + + assert!(range.contains(&large_number)); + } + + #[test] + fn gt_range_does_not_contain_initial_point() { + let range = RoomMemberCountIs::gt(uint!(2)); + let initial_point = uint!(2); + + assert!(!range.contains(&initial_point)); + } +}