diff --git a/crates/ruma-common/CHANGELOG.md b/crates/ruma-common/CHANGELOG.md index 3e96cea2..87b002a0 100644 --- a/crates/ruma-common/CHANGELOG.md +++ b/crates/ruma-common/CHANGELOG.md @@ -9,6 +9,7 @@ Breaking changes: - Make `in_reply_to` field of `Thread` optional - It was wrong to be mandatory, spec was unclear (clarified [here](https://github.com/matrix-org/matrix-spec/pull/1439)) +- `FlattenedJson::get` returns a `FlattenedJsonValue` instead of a string Improvements: diff --git a/crates/ruma-common/src/push.rs b/crates/ruma-common/src/push.rs index 7cc8bf63..5741ecbc 100644 --- a/crates/ruma-common/src/push.rs +++ b/crates/ruma-common/src/push.rs @@ -38,8 +38,8 @@ pub use self::condition::RoomVersionFeature; pub use self::{ action::{Action, Tweak}, condition::{ - ComparisonOperator, FlattenedJson, PushCondition, PushConditionRoomCtx, RoomMemberCountIs, - _CustomPushCondition, + ComparisonOperator, FlattenedJson, FlattenedJsonValue, PushCondition, PushConditionRoomCtx, + RoomMemberCountIs, _CustomPushCondition, }, iter::{AnyPushRule, AnyPushRuleRef, RulesetIntoIter, RulesetIter}, predefined::{ @@ -291,7 +291,7 @@ impl Ruleset { ) -> Option> { let event = FlattenedJson::from_raw(event); - if event.get("sender").map_or(false, |sender| sender == context.user_id) { + if event.get_str("sender").map_or(false, |sender| sender == context.user_id) { // no need to look at the rules if the event was by the user themselves None } else { @@ -601,7 +601,7 @@ impl PatternedPushRule { event: &FlattenedJson, context: &PushConditionRoomCtx, ) -> bool { - if event.get("sender").map_or(false, |sender| sender == context.user_id) { + if event.get_str("sender").map_or(false, |sender| sender == context.user_id) { return false; } diff --git a/crates/ruma-common/src/push/condition.rs b/crates/ruma-common/src/push/condition.rs index dd78d3fd..00692c88 100644 --- a/crates/ruma-common/src/push/condition.rs +++ b/crates/ruma-common/src/push/condition.rs @@ -17,7 +17,7 @@ mod push_condition_serde; mod room_member_count_is; pub use self::{ - flattened_json::FlattenedJson, + flattened_json::{FlattenedJson, FlattenedJsonValue, ScalarJsonValue}, room_member_count_is::{ComparisonOperator, RoomMemberCountIs}, }; @@ -118,7 +118,7 @@ pub(super) fn check_event_match( ) -> bool { let value = match key { "room_id" => context.room_id.as_str(), - _ => match event.get(key) { + _ => match event.get_str(key) { Some(v) => v, None => return false, }, @@ -135,14 +135,14 @@ impl PushCondition { /// * `event` - The flattened JSON representation of a room message event. /// * `context` - The context of the room at the time of the event. pub fn applies(&self, event: &FlattenedJson, context: &PushConditionRoomCtx) -> bool { - if event.get("sender").map_or(false, |sender| sender == context.user_id) { + if event.get_str("sender").map_or(false, |sender| sender == context.user_id) { return false; } match self { Self::EventMatch { key, pattern } => check_event_match(event, key, pattern, context), Self::ContainsDisplayName => { - let value = match event.get("content.body") { + let value = match event.get_str("content.body") { Some(v) => v, None => return false, }; @@ -151,7 +151,7 @@ impl PushCondition { } Self::RoomMemberCount { is } => is.contains(&context.member_count), Self::SenderNotificationPermission { key } => { - let sender_id = match event.get("sender") { + let sender_id = match event.get_str("sender") { Some(v) => match <&UserId>::try_from(v) { Ok(u) => u, Err(_) => return false, diff --git a/crates/ruma-common/src/push/condition/flattened_json.rs b/crates/ruma-common/src/push/condition/flattened_json.rs index 8b7e599d..53fa674c 100644 --- a/crates/ruma-common/src/push/condition/flattened_json.rs +++ b/crates/ruma-common/src/push/condition/flattened_json.rs @@ -1,5 +1,8 @@ +use js_int::Int; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{to_value as to_json_value, value::Value as JsonValue}; use std::collections::BTreeMap; +use thiserror::Error; use tracing::{instrument, warn}; use crate::serde::Raw; @@ -8,7 +11,7 @@ use crate::serde::Raw; #[derive(Clone, Debug)] pub struct FlattenedJson { /// The internal map containing the flattened JSON as a pair path, value. - map: BTreeMap, + map: BTreeMap, } impl FlattenedJson { @@ -30,18 +33,24 @@ impl FlattenedJson { self.flatten_value(value, path); } } - JsonValue::String(s) => { - if self.map.insert(path.clone(), s).is_some() { - warn!("Duplicate path in flattened JSON: {path}"); + value => { + if let Some(v) = FlattenedJsonValue::from_json_value(value) { + if self.map.insert(path.clone(), v).is_some() { + warn!("Duplicate path in flattened JSON: {path}"); + } } } - JsonValue::Number(_) | JsonValue::Bool(_) | JsonValue::Array(_) | JsonValue::Null => {} } } - /// Value associated with the given `path`. - pub fn get(&self, path: &str) -> Option<&str> { - self.map.get(path).map(|s| s.as_str()) + /// Get the value associated with the given `path`. + pub fn get(&self, path: &str) -> Option<&FlattenedJsonValue> { + self.map.get(path) + } + + /// Get the value associated with the given `path`, if it is a string. + pub fn get_str(&self, path: &str) -> Option<&str> { + self.map.get(path).and_then(|v| v.as_str()) } } @@ -52,12 +61,257 @@ fn escape_key(key: &str) -> String { key.replace('\\', r"\\").replace('.', r"\.") } +/// The set of possible errors when converting to a JSON subset. +#[derive(Debug, Error)] +#[allow(clippy::exhaustive_enums)] +enum IntoJsonSubsetError { + /// The numeric value failed conversion to js_int::Int. + #[error("number found is not a valid `js_int::Int`")] + IntConvert, + + /// The JSON type is not accepted in this subset. + #[error("JSON type is not accepted in this subset")] + NotInSubset, +} + +/// Scalar (non-compound) JSON values. +#[derive(Debug, Clone, Default, Eq, PartialEq)] +#[allow(clippy::exhaustive_enums)] +pub enum ScalarJsonValue { + /// Represents a `null` value. + #[default] + Null, + + /// Represents a boolean. + Bool(bool), + + /// Represents an integer. + Integer(Int), + + /// Represents a string. + String(String), +} + +impl ScalarJsonValue { + fn try_from_json_value(val: JsonValue) -> Result { + Ok(match val { + JsonValue::Bool(b) => Self::Bool(b), + JsonValue::Number(num) => Self::Integer( + Int::try_from(num.as_i64().ok_or(IntoJsonSubsetError::IntConvert)?) + .map_err(|_| IntoJsonSubsetError::IntConvert)?, + ), + JsonValue::String(string) => Self::String(string), + JsonValue::Null => Self::Null, + _ => Err(IntoJsonSubsetError::NotInSubset)?, + }) + } + + /// If the `ScalarJsonValue` is a `Bool`, return the inner value. + pub fn as_bool(&self) -> Option { + match self { + Self::Bool(b) => Some(*b), + _ => None, + } + } + + /// If the `ScalarJsonValue` is an `Integer`, return the inner value. + pub fn as_integer(&self) -> Option { + match self { + Self::Integer(i) => Some(*i), + _ => None, + } + } + + /// If the `ScalarJsonValue` is a `String`, return a reference to the inner value. + pub fn as_str(&self) -> Option<&str> { + match self { + Self::String(s) => Some(s), + _ => None, + } + } +} + +impl Serialize for ScalarJsonValue { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::Null => serializer.serialize_unit(), + Self::Bool(b) => serializer.serialize_bool(*b), + Self::Integer(n) => n.serialize(serializer), + Self::String(s) => serializer.serialize_str(s), + } + } +} + +impl<'de> Deserialize<'de> for ScalarJsonValue { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let val = JsonValue::deserialize(deserializer)?; + ScalarJsonValue::try_from_json_value(val).map_err(serde::de::Error::custom) + } +} + +impl From for ScalarJsonValue { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} + +impl From for ScalarJsonValue { + fn from(value: Int) -> Self { + Self::Integer(value) + } +} + +impl From for ScalarJsonValue { + fn from(value: String) -> Self { + Self::String(value) + } +} + +impl From<&str> for ScalarJsonValue { + fn from(value: &str) -> Self { + value.to_owned().into() + } +} + +impl PartialEq for ScalarJsonValue { + fn eq(&self, other: &FlattenedJsonValue) -> bool { + match self { + Self::Null => *other == FlattenedJsonValue::Null, + Self::Bool(b) => other.as_bool() == Some(*b), + Self::Integer(i) => other.as_integer() == Some(*i), + Self::String(s) => other.as_str() == Some(s), + } + } +} + +/// Possible JSON values after an object is flattened. +#[derive(Debug, Clone, Default, Eq, PartialEq)] +#[allow(clippy::exhaustive_enums)] +pub enum FlattenedJsonValue { + /// Represents a `null` value. + #[default] + Null, + + /// Represents a boolean. + Bool(bool), + + /// Represents an integer. + Integer(Int), + + /// Represents a string. + String(String), + + /// Represents an array. + Array(Vec), +} + +impl FlattenedJsonValue { + fn from_json_value(val: JsonValue) -> Option { + Some(match val { + JsonValue::Bool(b) => Self::Bool(b), + JsonValue::Number(num) => Self::Integer(Int::try_from(num.as_i64()?).ok()?), + JsonValue::String(string) => Self::String(string), + JsonValue::Null => Self::Null, + JsonValue::Array(vec) => Self::Array( + // Drop values we don't need instead of throwing an error. + vec.into_iter() + .filter_map(|v| ScalarJsonValue::try_from_json_value(v).ok()) + .collect::>(), + ), + _ => None?, + }) + } + + /// If the `FlattenedJsonValue` is a `Bool`, return the inner value. + pub fn as_bool(&self) -> Option { + match self { + Self::Bool(b) => Some(*b), + _ => None, + } + } + + /// If the `FlattenedJsonValue` is an `Integer`, return the inner value. + pub fn as_integer(&self) -> Option { + match self { + Self::Integer(i) => Some(*i), + _ => None, + } + } + + /// If the `FlattenedJsonValue` is a `String`, return a reference to the inner value. + pub fn as_str(&self) -> Option<&str> { + match self { + Self::String(s) => Some(s), + _ => None, + } + } + + /// If the `FlattenedJsonValue` is an `Array`, return a reference to the inner value. + pub fn as_array(&self) -> Option<&[ScalarJsonValue]> { + match self { + Self::Array(a) => Some(a), + _ => None, + } + } +} + +impl From for FlattenedJsonValue { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} + +impl From for FlattenedJsonValue { + fn from(value: Int) -> Self { + Self::Integer(value) + } +} + +impl From for FlattenedJsonValue { + fn from(value: String) -> Self { + Self::String(value) + } +} + +impl From<&str> for FlattenedJsonValue { + fn from(value: &str) -> Self { + value.to_owned().into() + } +} + +impl From> for FlattenedJsonValue { + fn from(value: Vec) -> Self { + Self::Array(value) + } +} + +impl PartialEq for FlattenedJsonValue { + fn eq(&self, other: &ScalarJsonValue) -> bool { + match self { + Self::Null => *other == ScalarJsonValue::Null, + Self::Bool(b) => other.as_bool() == Some(*b), + Self::Integer(i) => other.as_integer() == Some(*i), + Self::String(s) => other.as_str() == Some(s), + Self::Array(_) => false, + } + } +} + #[cfg(test)] mod tests { + use js_int::int; use maplit::btreemap; use serde_json::Value as JsonValue; - use super::FlattenedJson; + use super::{FlattenedJson, FlattenedJsonValue}; use crate::serde::Raw; #[test] @@ -74,7 +328,16 @@ mod tests { .unwrap(); let flattened = FlattenedJson::from_raw(&raw); - assert_eq!(flattened.map, btreemap! { "string".into() => "Hello World".into() }); + assert_eq!( + flattened.map, + btreemap! { + "string".into() => "Hello World".into(), + "number".into() => int!(10).into(), + "array".into() => vec![int!(1).into(), int!(2).into()].into(), + "boolean".into() => true.into(), + "null".into() => FlattenedJsonValue::Null, + } + ); } #[test] @@ -84,11 +347,11 @@ mod tests { "desc": "Level 0", "desc.bis": "Level 0 bis", "up": { - "desc": "Level 1", - "desc.bis": "Level 1 bis", + "desc": 1, + "desc.bis": null, "up": { - "desc": "Level 2", - "desc\\bis": "Level 2 bis" + "desc": ["Level 2a", "Level 2b"], + "desc\\bis": true } } }"#, @@ -101,10 +364,10 @@ mod tests { btreemap! { "desc".into() => "Level 0".into(), r"desc\.bis".into() => "Level 0 bis".into(), - "up.desc".into() => "Level 1".into(), - r"up.desc\.bis".into() => "Level 1 bis".into(), - "up.up.desc".into() => "Level 2".into(), - r"up.up.desc\\bis".into() => "Level 2 bis".into(), + "up.desc".into() => int!(1).into(), + r"up.desc\.bis".into() => FlattenedJsonValue::Null, + "up.up.desc".into() => vec!["Level 2a".into(), "Level 2b".into()].into(), + r"up.up.desc\\bis".into() => true.into(), }, ); } diff --git a/crates/ruma-common/src/push/iter.rs b/crates/ruma-common/src/push/iter.rs index b978a938..f34404d4 100644 --- a/crates/ruma-common/src/push/iter.rs +++ b/crates/ruma-common/src/push/iter.rs @@ -191,7 +191,7 @@ impl<'a> AnyPushRuleRef<'a> { /// * `event` - The flattened JSON representation of a room message event. /// * `context` - The context of the room at the time of the event. pub fn applies(self, event: &FlattenedJson, context: &PushConditionRoomCtx) -> bool { - if event.get("sender").map_or(false, |sender| sender == context.user_id) { + if event.get_str("sender").map_or(false, |sender| sender == context.user_id) { return false; }