diff --git a/crates/ruma-state-res/.github/workflows/nightly.yml b/crates/ruma-state-res/.github/workflows/nightly.yml new file mode 100644 index 00000000..9c567b20 --- /dev/null +++ b/crates/ruma-state-res/.github/workflows/nightly.yml @@ -0,0 +1,30 @@ +name: Rust Nightly + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + check: + name: Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + profile: minimal + override: true + components: rustfmt, clippy + - name: Check formatting + uses: actions-rs/cargo@v1 + with: + command: fmt + args: -- --check + - name: Catch common mistakes + uses: actions-rs/cargo@v1 + with: + command: clippy + args: --all-features --all-targets -- -D warnings diff --git a/crates/ruma-state-res/.github/workflows/stable.yml b/crates/ruma-state-res/.github/workflows/stable.yml new file mode 100644 index 00000000..2ebfed57 --- /dev/null +++ b/crates/ruma-state-res/.github/workflows/stable.yml @@ -0,0 +1,28 @@ +name: Rust Stable + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + check: + name: Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + profile: minimal + override: true + - name: Run tests + uses: actions-rs/cargo@v1 + with: + command: test + - name: Run tests (unstable-pre-spec) + uses: actions-rs/cargo@v1 + with: + command: test + args: --features unstable-pre-spec diff --git a/crates/ruma-state-res/.gitignore b/crates/ruma-state-res/.gitignore new file mode 100644 index 00000000..96ef6c0b --- /dev/null +++ b/crates/ruma-state-res/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/crates/ruma-state-res/Cargo.toml b/crates/ruma-state-res/Cargo.toml new file mode 100644 index 00000000..79deb406 --- /dev/null +++ b/crates/ruma-state-res/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "state-res" +version = "0.1.0" +authors = ["Devin R "] +edition = "2018" +categories = ["api-bindings", "web-programming"] +description = "An abstraction for Matrix state resolution." +homepage = "https://www.ruma.io/" +keywords = ["matrix", "chat", "state resolution", "ruma"] +license = "MIT" +readme = "README.md" +repository = "https://github.com/ruma/state-res" + +[dependencies] +itertools = "0.10.0" +serde = { version = "1.0.118", features = ["derive"] } +serde_json = "1.0.60" +maplit = "1.0.2" +thiserror = "1.0.22" +log = "0.4.11" + +[features] +unstable-pre-spec = ["ruma/unstable-pre-spec"] + +[dependencies.ruma] +git = "https://github.com/ruma/ruma" +rev = "8c286e78d41770fe431e7304cc2fe23e383793df" +features = ["events", "signatures"] + +[dev-dependencies] +criterion = "0.3.3" +rand = "0.7.3" +tracing-subscriber = "0.2.15" + +[[bench]] +name = "state_res_bench" +harness = false diff --git a/crates/ruma-state-res/LICENSE b/crates/ruma-state-res/LICENSE new file mode 100644 index 00000000..09ad525f --- /dev/null +++ b/crates/ruma-state-res/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2020 Devin Ragotzy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/crates/ruma-state-res/README.md b/crates/ruma-state-res/README.md new file mode 100644 index 00000000..466bee4d --- /dev/null +++ b/crates/ruma-state-res/README.md @@ -0,0 +1,46 @@ +# Matrix State Resolution in Rust! + +```rust +/// Abstraction of a PDU so users can have their own PDU types. +pub trait Event { + /// The `EventId` of this event. + fn event_id(&self) -> &EventId; + /// The `RoomId` of this event. + fn room_id(&self) -> &RoomId; + /// The `UserId` of this event. + fn sender(&self) -> &UserId; + // and so on... +} + +/// A mapping of event type and state_key to some value `T`, usually an `EventId`. +pub type StateMap = BTreeMap<(EventType, Option), T>; + +/// A mapping of `EventId` to `T`, usually a `StateEvent`. +pub type EventMap = BTreeMap; + +struct StateResolution { + // For now the StateResolution struct is empty. If "caching" `event_map` + // between `resolve` calls ends up being more efficient (probably not, as this would eat memory) + // it may have an `event_map` field. The `event_map` is all the events + // `StateResolution` has to know about to resolve state. +} + +impl StateResolution { + /// The point of this all, resolve the possibly conflicting sets of events. + pub fn resolve( + room_id: &RoomId, + room_version: &RoomVersionId, + state_sets: &[StateMap], + auth_events: Vec>, + event_map: &mut EventMap>, + ) -> Result> {; + +} + +``` + + + +The `StateStore` trait is an abstraction around what ever database your server (or maybe even client) uses to store __P__[]()ersistant __D__[]()ata __U__[]()nits. + +We use `ruma`s types when deserializing any PDU or it's contents which helps avoid a lot of type checking logic [synapse](https://github.com/matrix-org/synapse) must do while authenticating event chains. diff --git a/crates/ruma-state-res/architecture.md b/crates/ruma-state-res/architecture.md new file mode 100644 index 00000000..78280b09 --- /dev/null +++ b/crates/ruma-state-res/architecture.md @@ -0,0 +1,70 @@ +# Architecture + +This document describes the high-level architecture of state-res. +If you want to familiarize yourself with the code base, you are just in the right place! + +## Overview + +The state-res crate provides all the necessary algorithms to resolve the state of a +room according to the Matrix spec. Given sets of state and the complete authorization +chain, a final resolved state is calculated. + +The state sets (`BTreeMap<(EventType, StateKey), EventId>`) can be the state of a room +according to different servers or at different points in time. The authorization chain +is the recursive set of all events that authorize events that come after. +Any event that can be referenced needs to be available in the `event_map` argument, +or the call fails. The `StateResolution` struct keeps no state and is only a +collection of associated functions. + +## Important Terms + + - **event** In state-res this refers to a **P**ersistent **D**ata **U**nit which + represents the event and keeps metadata used for resolution + - **state resolution** The process of calculating the final state of a DAG from + conflicting input DAGs + +## Code Map + +This section talks briefly about important files and data structures. + +### `error` + +An enum representing all possible error cases in state-res. Most of the variants are +passing information of failures from other libraries except `Error::NotFound`. +The `NotFound` variant is used when an event was not in the `event_map`. + +### `event_auth` + +This module contains all the logic needed to authenticate and verify events. +The main function for authentication is `auth_check`. There are a few checks +that happen to every event and specific checks for some state events. +Each event is authenticated against the state before the event. +The state is built iteratively with each successive event being checked against +the current state then added. + +**Note:** Any type of event can be check, not just state events. + +### `state_event` + +A trait called `Event` that allows the state-res library to take any PDU type the user +supplies. The main `StateResolution::resolve` function can resolve any user-defined +type that satisfies `Event`. This avoids a lot of unnecessary conversions and +gives more flexibility to users. + +### `lib` + +All the associated functions of `StateResolution` that are needed to resolve state live +here. The focus is `StateResolution::resolve`, given a DAG and new events +`resolve` calculates the end state of the DAG. Everything that is used by `resolve` +is exported giving users access to the pieces of the algorithm. + +**Note:** only state events (events that have a state_key field) are allowed to +participate in resolution. + +## Testing + +state-res has three main test types: event sorting, event authentication, and state +resolution. State resolution tests the whole system. Start by setting up a room with +events and check the resolved state after adding conflicting events. +Event authentication checks that an event passes or fails based on some initial state. +Event sorting tests that given a DAG of events, the events can be predictably sorted. diff --git a/crates/ruma-state-res/benches/event_auth_bench.rs b/crates/ruma-state-res/benches/event_auth_bench.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/crates/ruma-state-res/benches/event_auth_bench.rs @@ -0,0 +1 @@ + diff --git a/crates/ruma-state-res/benches/outcomes.txt b/crates/ruma-state-res/benches/outcomes.txt new file mode 100644 index 00000000..7696c06b --- /dev/null +++ b/crates/ruma-state-res/benches/outcomes.txt @@ -0,0 +1,59 @@ +11/29/2020 BRANCH: timo-spec-comp REV: d2a85669cc6056679ce6ca0fde4658a879ad2b08 +lexicographical topological sort + time: [1.7123 us 1.7157 us 1.7199 us] + change: [-1.7584% -1.5433% -1.3205%] (p = 0.00 < 0.05) + Performance has improved. +Found 8 outliers among 100 measurements (8.00%) + 2 (2.00%) low mild + 5 (5.00%) high mild + 1 (1.00%) high severe + +resolve state of 5 events one fork + time: [10.981 us 10.998 us 11.020 us] +Found 3 outliers among 100 measurements (3.00%) + 3 (3.00%) high mild + +resolve state of 10 events 3 conflicting + time: [26.858 us 26.946 us 27.037 us] + +11/29/2020 BRANCH: event-trait REV: f0eb1310efd49d722979f57f20bd1ac3592b0479 +lexicographical topological sort + time: [1.7686 us 1.7738 us 1.7810 us] + change: [-3.2752% -2.4634% -1.7635%] (p = 0.00 < 0.05) + Performance has improved. +Found 1 outliers among 100 measurements (1.00%) + 1 (1.00%) high severe + +resolve state of 5 events one fork + time: [10.643 us 10.656 us 10.669 us] + change: [-4.9990% -3.8078% -2.8319%] (p = 0.00 < 0.05) + Performance has improved. +Found 1 outliers among 100 measurements (1.00%) + 1 (1.00%) high severe + +resolve state of 10 events 3 conflicting + time: [29.149 us 29.252 us 29.375 us] + change: [-0.8433% -0.3270% +0.2656%] (p = 0.25 > 0.05) + No change in performance detected. +Found 1 outliers among 100 measurements (1.00%) + 1 (1.00%) high mild + +4/26/2020 BRANCH: fix-test-serde REV: +lexicographical topological sort + time: [1.6793 us 1.6823 us 1.6857 us] +Found 9 outliers among 100 measurements (9.00%) + 1 (1.00%) low mild + 4 (4.00%) high mild + 4 (4.00%) high severe + +resolve state of 5 events one fork + time: [9.9993 us 10.062 us 10.159 us] +Found 9 outliers among 100 measurements (9.00%) + 7 (7.00%) high mild + 2 (2.00%) high severe + +resolve state of 10 events 3 conflicting + time: [26.004 us 26.092 us 26.195 us] +Found 16 outliers among 100 measurements (16.00%) + 11 (11.00%) high mild + 5 (5.00%) high severe \ No newline at end of file diff --git a/crates/ruma-state-res/benches/state_res_bench.rs b/crates/ruma-state-res/benches/state_res_bench.rs new file mode 100644 index 00000000..1937d65b --- /dev/null +++ b/crates/ruma-state-res/benches/state_res_bench.rs @@ -0,0 +1,812 @@ +// Because of criterion `cargo bench` works, +// but if you use `cargo bench -- --save-baseline ` +// or pass any other args to it, it fails with the error +// `cargo bench unknown option --save-baseline`. +// To pass args to criterion, use this form +// `cargo bench --bench -- --save-baseline `. +use std::{ + collections::{BTreeMap, BTreeSet}, + convert::TryFrom, + sync::Arc, + time::{Duration, UNIX_EPOCH}, +}; + +use criterion::{criterion_group, criterion_main, Criterion}; +use event::StateEvent; +use maplit::btreemap; +use ruma::{ + events::{ + pdu::{EventHash, Pdu, RoomV3Pdu}, + room::{ + join_rules::JoinRule, + member::{MemberEventContent, MembershipState}, + }, + EventType, + }, + EventId, RoomId, RoomVersionId, UserId, +}; +use serde_json::{json, Value as JsonValue}; +use state_res::{Error, Event, Result, StateMap, StateResolution}; + +static mut SERVER_TIMESTAMP: u64 = 0; + +fn lexico_topo_sort(c: &mut Criterion) { + c.bench_function("lexicographical topological sort", |b| { + let graph = btreemap! { + event_id("l") => vec![event_id("o")], + event_id("m") => vec![event_id("n"), event_id("o")], + event_id("n") => vec![event_id("o")], + event_id("o") => vec![], // "o" has zero outgoing edges but 4 incoming edges + event_id("p") => vec![event_id("o")], + }; + b.iter(|| { + let _ = StateResolution::lexicographical_topological_sort(&graph, |id| { + (0, UNIX_EPOCH, id.clone()) + }); + }) + }); +} + +fn resolution_shallow_auth_chain(c: &mut Criterion) { + c.bench_function("resolve state of 5 events one fork", |b| { + let mut store = TestStore(btreemap! {}); + + // build up the DAG + let (state_at_bob, state_at_charlie, _) = store.set_up(); + + b.iter(|| { + let mut ev_map: state_res::EventMap> = store.0.clone(); + let state_sets = vec![state_at_bob.clone(), state_at_charlie.clone()]; + let _ = match StateResolution::resolve::( + &room_id(), + &RoomVersionId::Version2, + &state_sets, + state_sets + .iter() + .map(|map| { + store + .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) + .unwrap() + }) + .collect(), + &mut ev_map, + ) { + Ok(state) => state, + Err(e) => panic!("{}", e), + }; + }) + }); +} + +fn resolve_deeper_event_set(c: &mut Criterion) { + c.bench_function("resolve state of 10 events 3 conflicting", |b| { + let init = INITIAL_EVENTS(); + let ban = BAN_STATE_SET(); + + let mut inner = init; + inner.extend(ban); + let store = TestStore(inner.clone()); + + let state_set_a = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("MB")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| ((ev.kind(), ev.state_key().unwrap()), ev.event_id().clone())) + .collect::>(); + + let state_set_b = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("IME")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| ((ev.kind(), ev.state_key().unwrap()), ev.event_id().clone())) + .collect::>(); + + b.iter(|| { + let state_sets = vec![state_set_a.clone(), state_set_b.clone()]; + let _ = match StateResolution::resolve::( + &room_id(), + &RoomVersionId::Version2, + &state_sets, + state_sets + .iter() + .map(|map| { + store + .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) + .unwrap() + }) + .collect(), + &mut inner, + ) { + Ok(state) => state, + Err(_) => panic!("resolution failed during benchmarking"), + }; + }) + }); +} + +criterion_group!( + benches, + lexico_topo_sort, + resolution_shallow_auth_chain, + resolve_deeper_event_set +); + +criterion_main!(benches); + +//*///////////////////////////////////////////////////////////////////// +// +// IMPLEMENTATION DETAILS AHEAD +// +/////////////////////////////////////////////////////////////////////*/ +pub struct TestStore(pub BTreeMap>); + +#[allow(unused)] +impl TestStore { + pub fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + self.0 + .get(event_id) + .map(Arc::clone) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) + } + + /// Returns the events that correspond to the `event_ids` sorted in the same order. + pub fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result>> { + let mut events = vec![]; + for id in event_ids { + events.push(self.get_event(room_id, id)?); + } + Ok(events) + } + + /// Returns a Vec of the related auth events to the given `event`. + pub fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result> { + let mut result = vec![]; + let mut stack = event_ids.to_vec(); + + // DFS for auth event chain + while !stack.is_empty() { + let ev_id = stack.pop().unwrap(); + if result.contains(&ev_id) { + continue; + } + + result.push(ev_id.clone()); + + let event = self.get_event(room_id, &ev_id)?; + + stack.extend(event.auth_events().clone()); + } + + Ok(result) + } + + /// Returns a Vec representing the difference in auth chains of the given `events`. + pub fn auth_chain_diff( + &self, + room_id: &RoomId, + event_ids: Vec>, + ) -> Result> { + let mut chains = vec![]; + for ids in event_ids { + // TODO state store `auth_event_ids` returns self in the event ids list + // when an event returns `auth_event_ids` self is not contained + let chain = self + .auth_event_ids(room_id, &ids)? + .into_iter() + .collect::>(); + chains.push(chain); + } + + if let Some(chain) = chains.first() { + let rest = chains.iter().skip(1).flatten().cloned().collect(); + let common = chain.intersection(&rest).collect::>(); + + Ok(chains + .iter() + .flatten() + .filter(|id| !common.contains(id)) + .cloned() + .collect::>() + .into_iter() + .collect()) + } else { + Ok(vec![]) + } + } +} + +impl TestStore { + pub fn set_up(&mut self) -> (StateMap, StateMap, StateMap) { + let create_event = to_pdu_event::( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + &[], + &[], + ); + let cre = create_event.event_id().clone(); + self.0.insert(cre.clone(), Arc::clone(&create_event)); + + let alice_mem = to_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &[cre.clone()], + &[cre.clone()], + ); + self.0 + .insert(alice_mem.event_id().clone(), Arc::clone(&alice_mem)); + + let join_rules = to_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + &[cre.clone(), alice_mem.event_id().clone()], + &[alice_mem.event_id().clone()], + ); + self.0 + .insert(join_rules.event_id().clone(), join_rules.clone()); + + // Bob and Charlie join at the same time, so there is a fork + // this will be represented in the state_sets when we resolve + let bob_mem = to_pdu_event( + "IMB", + bob(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &[cre.clone(), join_rules.event_id().clone()], + &[join_rules.event_id().clone()], + ); + self.0.insert(bob_mem.event_id().clone(), bob_mem.clone()); + + let charlie_mem = to_pdu_event( + "IMC", + charlie(), + EventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &[cre, join_rules.event_id().clone()], + &[join_rules.event_id().clone()], + ); + self.0 + .insert(charlie_mem.event_id().clone(), charlie_mem.clone()); + + let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] + .iter() + .map(|e| ((e.kind(), e.state_key().unwrap()), e.event_id().clone())) + .collect::>(); + + let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] + .iter() + .map(|e| ((e.kind(), e.state_key().unwrap()), e.event_id().clone())) + .collect::>(); + + let expected = [ + &create_event, + &alice_mem, + &join_rules, + &bob_mem, + &charlie_mem, + ] + .iter() + .map(|e| ((e.kind(), e.state_key().unwrap()), e.event_id().clone())) + .collect::>(); + + (state_at_bob, state_at_charlie, expected) + } +} + +fn event_id(id: &str) -> EventId { + if id.contains('$') { + return EventId::try_from(id).unwrap(); + } + EventId::try_from(format!("${}:foo", id)).unwrap() +} + +fn alice() -> UserId { + UserId::try_from("@alice:foo").unwrap() +} +fn bob() -> UserId { + UserId::try_from("@bob:foo").unwrap() +} +fn charlie() -> UserId { + UserId::try_from("@charlie:foo").unwrap() +} +fn ella() -> UserId { + UserId::try_from("@ella:foo").unwrap() +} + +fn room_id() -> RoomId { + RoomId::try_from("!test:foo").unwrap() +} + +fn member_content_ban() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Ban, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +fn member_content_join() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +pub fn to_pdu_event( + id: &str, + sender: UserId, + ev_type: EventType, + state_key: Option<&str>, + content: JsonValue, + auth_events: &[S], + prev_events: &[S], +) -> Arc +where + S: AsRef, +{ + let ts = unsafe { + let ts = SERVER_TIMESTAMP; + // increment the "origin_server_ts" value + SERVER_TIMESTAMP += 1; + ts + }; + let id = if id.contains('$') { + id.to_string() + } else { + format!("${}:foo", id) + }; + let auth_events = auth_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + let prev_events = prev_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + + let state_key = state_key.map(ToString::to_string); + Arc::new(StateEvent { + event_id: EventId::try_from(id).unwrap(), + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id(), + sender, + origin_server_ts: UNIX_EPOCH + Duration::from_secs(ts), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: btreemap! {}, + #[cfg(not(feature = "unstable-pre-spec"))] + origin: "foo".into(), + auth_events, + prev_events, + depth: ruma::uint!(0), + hashes: EventHash { sha256: "".into() }, + signatures: btreemap! {}, + }), + }) +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn INITIAL_EVENTS() -> BTreeMap> { + vec![ + to_pdu_event::( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + &[], + &[], + ), + to_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ), + to_pdu_event( + "IPOWER", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice().to_string(): 100}}), + &["CREATE", "IMA"], + &["IMA"], + ), + to_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ), + to_pdu_event( + "IMB", + bob(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IJR"], + ), + to_pdu_event( + "IMC", + charlie(), + EventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IMB"], + ), + to_pdu_event::( + "START", + charlie(), + EventType::RoomTopic, + Some(""), + json!({}), + &[], + &[], + ), + to_pdu_event::( + "END", + charlie(), + EventType::RoomTopic, + Some(""), + json!({}), + &[], + &[], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().clone(), ev)) + .collect() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn BAN_STATE_SET() -> BTreeMap> { + vec![ + to_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], // auth_events + &["START"], // prev_events + ), + to_pdu_event( + "PB", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], + &["END"], + ), + to_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_ban(), + &["CREATE", "IMA", "PB"], + &["PA"], + ), + to_pdu_event( + "IME", + ella(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_join(), + &["CREATE", "IJR", "PA"], + &["MB"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().clone(), ev)) + .collect() +} + +pub mod event { + use std::{collections::BTreeMap, time::SystemTime}; + + use ruma::{ + events::{ + pdu::{EventHash, Pdu}, + room::member::MembershipState, + EventType, + }, + EventId, RoomId, RoomVersionId, ServerName, ServerSigningKeyId, UInt, UserId, + }; + use serde::{Deserialize, Serialize}; + use serde_json::Value as JsonValue; + + use state_res::Event; + + impl Event for StateEvent { + fn event_id(&self) -> &EventId { + self.event_id() + } + + fn room_id(&self) -> &RoomId { + self.room_id() + } + + fn sender(&self) -> &UserId { + self.sender() + } + + fn kind(&self) -> EventType { + self.kind() + } + + fn content(&self) -> serde_json::Value { + self.content() + } + + fn origin_server_ts(&self) -> SystemTime { + *self.origin_server_ts() + } + + fn state_key(&self) -> Option { + self.state_key() + } + + fn prev_events(&self) -> Vec { + self.prev_event_ids() + } + + fn depth(&self) -> &UInt { + self.depth() + } + + fn auth_events(&self) -> Vec { + self.auth_events() + } + + fn redacts(&self) -> Option<&EventId> { + self.redacts() + } + + fn hashes(&self) -> &EventHash { + self.hashes() + } + + fn signatures(&self) -> BTreeMap, BTreeMap> { + self.signatures() + } + + fn unsigned(&self) -> &BTreeMap { + self.unsigned() + } + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + pub struct StateEvent { + pub event_id: EventId, + #[serde(flatten)] + pub rest: Pdu, + } + + impl StateEvent { + pub fn from_id_value( + id: EventId, + json: serde_json::Value, + ) -> Result { + Ok(Self { + event_id: id, + rest: Pdu::RoomV3Pdu(serde_json::from_value(json)?), + }) + } + + pub fn from_id_canon_obj( + id: EventId, + json: ruma::serde::CanonicalJsonObject, + ) -> Result { + Ok(Self { + event_id: id, + // TODO: this is unfortunate (from_value(to_value(json)))... + rest: Pdu::RoomV3Pdu(serde_json::from_value(serde_json::to_value(json)?)?), + }) + } + + pub fn is_power_event(&self) -> bool { + match &self.rest { + Pdu::RoomV1Pdu(event) => match event.kind { + EventType::RoomPowerLevels + | EventType::RoomJoinRules + | EventType::RoomCreate => event.state_key == Some("".into()), + EventType::RoomMember => { + // TODO fix clone + if let Ok(membership) = serde_json::from_value::( + event.content["membership"].clone(), + ) { + [MembershipState::Leave, MembershipState::Ban].contains(&membership) + && event.sender.as_str() + // TODO is None here a failure + != event.state_key.as_deref().unwrap_or("NOT A STATE KEY") + } else { + false + } + } + _ => false, + }, + Pdu::RoomV3Pdu(event) => event.state_key == Some("".into()), + } + } + + pub fn deserialize_content( + &self, + ) -> Result { + match &self.rest { + Pdu::RoomV1Pdu(ev) => serde_json::from_value(ev.content.clone()), + Pdu::RoomV3Pdu(ev) => serde_json::from_value(ev.content.clone()), + } + } + + pub fn origin_server_ts(&self) -> &SystemTime { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.origin_server_ts, + Pdu::RoomV3Pdu(ev) => &ev.origin_server_ts, + } + } + + pub fn event_id(&self) -> &EventId { + &self.event_id + } + + pub fn sender(&self) -> &UserId { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.sender, + Pdu::RoomV3Pdu(ev) => &ev.sender, + } + } + + pub fn redacts(&self) -> Option<&EventId> { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.redacts.as_ref(), + Pdu::RoomV3Pdu(ev) => ev.redacts.as_ref(), + } + } + + pub fn room_id(&self) -> &RoomId { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.room_id, + Pdu::RoomV3Pdu(ev) => &ev.room_id, + } + } + pub fn kind(&self) -> EventType { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.kind.clone(), + Pdu::RoomV3Pdu(ev) => ev.kind.clone(), + } + } + pub fn state_key(&self) -> Option { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.state_key.clone(), + Pdu::RoomV3Pdu(ev) => ev.state_key.clone(), + } + } + + #[cfg(not(feature = "unstable-pre-spec"))] + pub fn origin(&self) -> String { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.origin.clone(), + Pdu::RoomV3Pdu(ev) => ev.origin.clone(), + } + } + + pub fn prev_event_ids(&self) -> Vec { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().map(|(id, _)| id).cloned().collect(), + Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(), + } + } + + pub fn auth_events(&self) -> Vec { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.auth_events.iter().map(|(id, _)| id).cloned().collect(), + Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(), + } + } + + pub fn content(&self) -> serde_json::Value { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.content.clone(), + Pdu::RoomV3Pdu(ev) => ev.content.clone(), + } + } + + pub fn unsigned(&self) -> &BTreeMap { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.unsigned, + Pdu::RoomV3Pdu(ev) => &ev.unsigned, + } + } + + pub fn signatures( + &self, + ) -> BTreeMap, BTreeMap> { + match &self.rest { + Pdu::RoomV1Pdu(_) => maplit::btreemap! {}, + Pdu::RoomV3Pdu(ev) => ev.signatures.clone(), + } + } + + pub fn hashes(&self) -> &EventHash { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.hashes, + Pdu::RoomV3Pdu(ev) => &ev.hashes, + } + } + + pub fn depth(&self) -> &UInt { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.depth, + Pdu::RoomV3Pdu(ev) => &ev.depth, + } + } + + pub fn is_type_and_key(&self, ev_type: EventType, state_key: &str) -> bool { + match &self.rest { + Pdu::RoomV1Pdu(ev) => { + ev.kind == ev_type && ev.state_key.as_deref() == Some(state_key) + } + Pdu::RoomV3Pdu(ev) => { + ev.kind == ev_type && ev.state_key.as_deref() == Some(state_key) + } + } + } + + /// Returns the room version this event is formatted for. + /// + /// Currently either version 1 or 6 is returned, 6 represents + /// version 3 and above. + pub fn room_version(&self) -> RoomVersionId { + // TODO: We have to know the actual room version this is not sufficient + match self.rest { + Pdu::RoomV1Pdu(_) => RoomVersionId::Version1, + Pdu::RoomV3Pdu(_) => RoomVersionId::Version6, + } + } + } +} diff --git a/crates/ruma-state-res/rustfmt.toml b/crates/ruma-state-res/rustfmt.toml new file mode 100644 index 00000000..d3db3454 --- /dev/null +++ b/crates/ruma-state-res/rustfmt.toml @@ -0,0 +1 @@ +imports_granularity="Crate" \ No newline at end of file diff --git a/crates/ruma-state-res/src/error.rs b/crates/ruma-state-res/src/error.rs new file mode 100644 index 00000000..1f27df89 --- /dev/null +++ b/crates/ruma-state-res/src/error.rs @@ -0,0 +1,35 @@ +use serde_json::Error as JsonError; +use thiserror::Error; + +/// Result type for state resolution. +pub type Result = std::result::Result; + +/// Represents the various errors that arise when resolving state. +#[derive(Error, Debug)] +pub enum Error { + /// A deserialization error. + #[error(transparent)] + SerdeJson(#[from] JsonError), + + /// The given option or version is unsupported. + #[error("Unsupported room version: {0}")] + Unsupported(String), + + /// The given event was not found. + #[error("Not found error: {0}")] + NotFound(String), + + /// Invalid fields in the given PDU. + #[error("Invalid PDU: {0}")] + InvalidPdu(String), + + /// A custom error. + #[error("{0}")] + Custom(Box), +} + +impl Error { + pub fn custom(e: E) -> Self { + Self::Custom(Box::new(e)) + } +} diff --git a/crates/ruma-state-res/src/event_auth.rs b/crates/ruma-state-res/src/event_auth.rs new file mode 100644 index 00000000..bad99240 --- /dev/null +++ b/crates/ruma-state-res/src/event_auth.rs @@ -0,0 +1,885 @@ +use std::{convert::TryFrom, sync::Arc}; + +use log::warn; +use maplit::btreeset; +use ruma::{ + events::{ + room::{ + create::CreateEventContent, + join_rules::{JoinRule, JoinRulesEventContent}, + member::{MembershipState, ThirdPartyInvite}, + power_levels::PowerLevelsEventContent, + }, + EventType, + }, + RoomVersionId, UserId, +}; + +use crate::{room_version::RoomVersion, Error, Event, Result, StateMap}; + +/// For the given event `kind` what are the relevant auth events +/// that are needed to authenticate this `content`. +pub fn auth_types_for_event( + kind: &EventType, + sender: &UserId, + state_key: Option, + content: serde_json::Value, +) -> Vec<(EventType, String)> { + if kind == &EventType::RoomCreate { + return vec![]; + } + + let mut auth_types = vec![ + (EventType::RoomPowerLevels, "".to_string()), + (EventType::RoomMember, sender.to_string()), + (EventType::RoomCreate, "".to_string()), + ]; + + if kind == &EventType::RoomMember { + if let Some(state_key) = state_key { + if let Some(Ok(membership)) = content + .get("membership") + .map(|m| serde_json::from_value::(m.clone())) + { + if [MembershipState::Join, MembershipState::Invite].contains(&membership) { + let key = (EventType::RoomJoinRules, "".to_string()); + if !auth_types.contains(&key) { + auth_types.push(key) + } + } + + let key = (EventType::RoomMember, state_key); + if !auth_types.contains(&key) { + auth_types.push(key) + } + + if membership == MembershipState::Invite { + if let Some(Ok(t_id)) = content + .get("third_party_invite") + .map(|t| serde_json::from_value::(t.clone())) + { + let key = (EventType::RoomThirdPartyInvite, t_id.signed.token); + if !auth_types.contains(&key) { + auth_types.push(key) + } + } + } + } + } + } + + auth_types +} + +/// Authenticate the incoming `event`. The steps of authentication are: +/// * check that the event is being authenticated for the correct room +/// * check that the events signatures are valid +/// * then there are checks for specific event types +/// +/// The `auth_events` that are passed to this function should be a state snapshot. +/// We need to know if the event passes auth against some state not a recursive collection +/// of auth_events fields. +/// +/// ## Returns +/// This returns an `Error` only when serialization fails or some other fatal outcome. +pub fn auth_check( + room_version: &RoomVersion, + incoming_event: &Arc, + prev_event: Option>, + auth_events: &StateMap>, + current_third_party_invite: Option>, +) -> Result { + log::info!( + "auth_check beginning for {} ({})", + incoming_event.event_id(), + incoming_event.kind() + ); + + // [synapse] check that all the events are in the same room as `incoming_event` + + // [synapse] do_sig_check check the event has valid signatures for member events + + // TODO do_size_check is false when called by `iterative_auth_check` + // do_size_check is also mostly accomplished by ruma with the exception of checking event_type, + // state_key, and json are below a certain size (255 and 65_536 respectively) + + // Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules + // + // 1. If type is m.room.create: + if incoming_event.kind() == EventType::RoomCreate { + log::info!("start m.room.create check"); + + // If it has any previous events, reject + if !incoming_event.prev_events().is_empty() { + log::warn!("the room creation event had previous events"); + return Ok(false); + } + + // If the domain of the room_id does not match the domain of the sender, reject + if incoming_event.room_id().server_name() != incoming_event.sender().server_name() { + log::warn!("creation events server does not match sender"); + return Ok(false); // creation events room id does not match senders + } + + // If content.room_version is present and is not a recognized version, reject + if serde_json::from_value::( + incoming_event + .content() + .get("room_version") + .cloned() + // TODO synapse defaults to version 1 + .unwrap_or_else(|| serde_json::json!("1")), + ) + .is_err() + { + log::warn!("invalid room version found in m.room.create event"); + return Ok(false); + } + + // If content has no creator field, reject + if incoming_event.content().get("creator").is_none() { + log::warn!("no creator field found in room create content"); + return Ok(false); + } + + log::info!("m.room.create event was allowed"); + return Ok(true); + } + + /* + // 2. Reject if auth_events + // a. auth_events cannot have duplicate keys since it's a BTree + // b. All entries are valid auth events according to spec + let expected_auth = auth_types_for_event( + incoming_event.kind, + incoming_event.sender(), + incoming_event.state_key, + incoming_event.content().clone(), + ); + + dbg!(&expected_auth); + + for ev_key in auth_events.keys() { + // (b) + if !expected_auth.contains(ev_key) { + log::warn!("auth_events contained invalid auth event"); + return Ok(false); + } + } + */ + + // 3. If event does not have m.room.create in auth_events reject + if auth_events + .get(&(EventType::RoomCreate, "".to_string())) + .is_none() + { + log::warn!("no m.room.create event in auth chain"); + + return Ok(false); + } + + // [synapse] checks for federation here + + // 4. if type is m.room.aliases + if incoming_event.kind() == EventType::RoomAliases && room_version.special_case_aliases_auth { + log::info!("starting m.room.aliases check"); + + // If sender's domain doesn't matches state_key, reject + if incoming_event.state_key() != Some(incoming_event.sender().server_name().to_string()) { + log::warn!("state_key does not match sender"); + return Ok(false); + } + + log::info!("m.room.aliases event was allowed"); + return Ok(true); + } + + if incoming_event.kind() == EventType::RoomMember { + log::info!("starting m.room.member check"); + let state_key = match incoming_event.state_key() { + None => { + log::warn!("no statekey in member event"); + return Ok(false); + } + Some(s) => s, + }; + + let membership = incoming_event + .content() + .get("membership") + .map(|m| serde_json::from_value::(m.clone())); + + if !matches!(membership, Some(Ok(_))) { + log::warn!("no valid membership field found for m.room.member event content"); + return Ok(false); + } + + if !valid_membership_change( + &state_key, + incoming_event.sender(), + incoming_event.content(), + prev_event, + current_third_party_invite, + auth_events, + )? { + return Ok(false); + } + + log::info!("m.room.member event was allowed"); + return Ok(true); + } + + // If the sender's current membership state is not join, reject + match check_event_sender_in_room(incoming_event.sender(), auth_events) { + Some(true) => {} // sender in room + Some(false) => { + log::warn!("sender's membership is not join"); + return Ok(false); + } + None => { + log::warn!("sender not found in room"); + return Ok(false); + } + } + + // Allow if and only if sender's current power level is greater than + // or equal to the invite level + if incoming_event.kind() == EventType::RoomThirdPartyInvite + && !can_send_invite(incoming_event, auth_events)? + { + log::warn!("sender's cannot send invites in this room"); + return Ok(false); + } + + // If the event type's required power level is greater than the sender's power level, reject + // If the event has a state_key that starts with an @ and does not match the sender, reject. + if !can_send_event(incoming_event, auth_events) { + log::warn!("user cannot send event"); + return Ok(false); + } + + if incoming_event.kind() == EventType::RoomPowerLevels { + log::info!("starting m.room.power_levels check"); + + if let Some(required_pwr_lvl) = + check_power_levels(room_version, incoming_event, auth_events) + { + if !required_pwr_lvl { + log::warn!("power level was not allowed"); + return Ok(false); + } + } else { + log::warn!("power level was not allowed"); + return Ok(false); + } + log::info!("power levels event allowed"); + } + + // Room version 3: Redaction events are always accepted (provided the event is allowed by `events` and + // `events_default` in the power levels). However, servers should not apply or send redaction's + // to clients until both the redaction event and original event have been seen, and are valid. + // Servers should only apply redaction's to events where the sender's domains match, + // or the sender of the redaction has the appropriate permissions per the power levels. + + if room_version.extra_redaction_checks + && incoming_event.kind() == EventType::RoomRedaction + && !check_redaction(room_version, incoming_event, auth_events)? + { + return Ok(false); + } + + log::info!("allowing event passed all checks"); + Ok(true) +} + +// TODO deserializing the member, power, join_rules event contents is done in conduit +// just before this is called. Could they be passed in? +/// Does the user who sent this member event have required power levels to do so. +/// +/// * `user` - Information about the membership event and user making the request. +/// * `prev_event` - The event that occurred immediately before the `user` event or None. +/// * `auth_events` - The set of auth events that relate to a membership event. +/// this is generated by calling `auth_types_for_event` with the membership event and +/// the current State. +pub fn valid_membership_change( + state_key: &str, + user_sender: &UserId, + content: serde_json::Value, + prev_event: Option>, + current_third_party_invite: Option>, + auth_events: &StateMap>, +) -> Result { + let target_membership = serde_json::from_value::( + content + .get("membership") + .expect("we should test before that this field exists") + .clone(), + )?; + + let third_party_invite = content + .get("third_party_invite") + .map(|t| serde_json::from_value::(t.clone())); + + let target_user_id = + UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?; + + let key = (EventType::RoomMember, user_sender.to_string()); + let sender = auth_events.get(&key); + let sender_membership = sender.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { + Ok(serde_json::from_value::( + pdu.content() + .get("membership") + .expect("we assume existing events are valid") + .clone(), + )?) + })?; + + let key = (EventType::RoomMember, target_user_id.to_string()); + let current = auth_events.get(&key); + + let current_membership = current.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { + Ok(serde_json::from_value::( + pdu.content() + .get("membership") + .expect("we assume existing events are valid") + .clone(), + )?) + })?; + + let key = (EventType::RoomPowerLevels, "".into()); + let power_levels = auth_events.get(&key).map_or_else( + || Ok::<_, Error>(PowerLevelsEventContent::default()), + |power_levels| { + serde_json::from_value::(power_levels.content()) + .map_err(Into::into) + }, + )?; + + let sender_power = power_levels.users.get(user_sender).map_or_else( + || { + if sender_membership != MembershipState::Join { + None + } else { + Some(&power_levels.users_default) + } + }, + // If it's okay, wrap with Some(_) + Some, + ); + let target_power = power_levels.users.get(&target_user_id).map_or_else( + || { + if target_membership != MembershipState::Join { + None + } else { + Some(&power_levels.users_default) + } + }, + // If it's okay, wrap with Some(_) + Some, + ); + + let key = (EventType::RoomJoinRules, "".into()); + let join_rules_event = auth_events.get(&key); + let mut join_rules = JoinRule::Invite; + if let Some(jr) = join_rules_event { + join_rules = serde_json::from_value::(jr.content())?.join_rule; + } + + if let Some(prev) = prev_event { + if prev.kind() == EventType::RoomCreate && prev.prev_events().is_empty() { + return Ok(true); + } + } + + Ok(if target_membership == MembershipState::Join { + if user_sender != &target_user_id { + warn!("Can't make other user join"); + false + } else if let MembershipState::Ban = current_membership { + warn!("Banned user can't join"); + false + } else { + let allow = join_rules == JoinRule::Invite + && (current_membership == MembershipState::Join + || current_membership == MembershipState::Invite) + || join_rules == JoinRule::Public; + + if !allow { + warn!("Can't join if join rules is not public and user is not invited/joined"); + } + allow + } + } else if target_membership == MembershipState::Invite { + // If content has third_party_invite key + if let Some(Ok(tp_id)) = third_party_invite { + if current_membership == MembershipState::Ban { + warn!("Can't invite banned user"); + false + } else { + let allow = verify_third_party_invite( + Some(state_key), + user_sender, + &tp_id, + current_third_party_invite, + ); + if !allow { + warn!("Third party invite invalid"); + } + allow + } + } else if sender_membership != MembershipState::Join + || current_membership == MembershipState::Join + || current_membership == MembershipState::Ban + { + warn!( + "Can't invite user if sender not joined or the user is currently joined or banned" + ); + false + } else { + let allow = sender_power + .filter(|&p| p >= &power_levels.invite) + .is_some(); + if !allow { + warn!("User does not have enough power to invite"); + } + allow + } + } else if target_membership == MembershipState::Leave { + if user_sender == &target_user_id { + let allow = current_membership == MembershipState::Join + || current_membership == MembershipState::Invite; + if !allow { + warn!("Can't leave if not invited or joined"); + } + allow + } else if sender_membership != MembershipState::Join + || current_membership == MembershipState::Ban + && sender_power.filter(|&p| p < &power_levels.ban).is_some() + { + warn!("Can't kick if sender not joined or user is already banned"); + false + } else { + let allow = sender_power.filter(|&p| p >= &power_levels.kick).is_some() + && target_power < sender_power; + if !allow { + warn!("User does not have enough power to kick"); + } + allow + } + } else if target_membership == MembershipState::Ban { + if sender_membership != MembershipState::Join { + warn!("Can't ban user if sender is not joined"); + false + } else { + let allow = sender_power.filter(|&p| p >= &power_levels.ban).is_some() + && target_power < sender_power; + if !allow { + warn!("User does not have enough power to ban"); + } + allow + } + } else { + warn!("Unknown membership transition"); + false + }) +} + +/// Is the event's sender in the room that they sent the event to. +pub fn check_event_sender_in_room( + sender: &UserId, + auth_events: &StateMap>, +) -> Option { + let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?; + + let membership = serde_json::from_value::( + mem.content() + .get("membership") + .expect("we should test before that this field exists") + .clone(), + ) + .ok()?; + + Some(membership == MembershipState::Join) +} + +/// Is the user allowed to send a specific event based on the rooms power levels. Does the event +/// have the correct userId as it's state_key if it's not the "" state_key. +pub fn can_send_event(event: &Arc, auth_events: &StateMap>) -> bool { + let ple = auth_events.get(&(EventType::RoomPowerLevels, "".into())); + + let event_type_power_level = get_send_level(&event.kind(), event.state_key(), ple); + let user_level = get_user_power_level(event.sender(), auth_events); + + log::debug!( + "{} ev_type {} usr {}", + event.event_id().as_str(), + event_type_power_level, + user_level + ); + + if user_level < event_type_power_level { + return false; + } + + if event + .state_key() + .as_ref() + .map_or(false, |k| k.starts_with('@')) + && event.state_key().as_deref() != Some(event.sender().as_str()) + { + return false; // permission required to post in this room + } + + true +} + +/// Confirm that the event sender has the required power levels. +pub fn check_power_levels( + room_version: &RoomVersion, + power_event: &Arc, + auth_events: &StateMap>, +) -> Option { + let power_event_state_key = power_event + .state_key() + .expect("power events have state keys"); + let key = (power_event.kind(), power_event_state_key); + let current_state = if let Some(current_state) = auth_events.get(&key) { + current_state + } else { + // If there is no previous m.room.power_levels event in the room, allow + return Some(true); + }; + + // If users key in content is not a dictionary with keys that are valid user IDs + // with values that are integers (or a string that is an integer), reject. + let user_content = + serde_json::from_value::(power_event.content()).unwrap(); + + let current_content = + serde_json::from_value::(current_state.content()).unwrap(); + + // validation of users is done in Ruma, synapse for loops validating user_ids and integers here + log::info!("validation of power event finished"); + + let user_level = get_user_power_level(power_event.sender(), auth_events); + + let mut user_levels_to_check = btreeset![]; + let old_list = ¤t_content.users; + let user_list = &user_content.users; + for user in old_list.keys().chain(user_list.keys()) { + let user: &UserId = user; + user_levels_to_check.insert(user); + } + + log::debug!("users to check {:?}", user_levels_to_check); + + let mut event_levels_to_check = btreeset![]; + let old_list = ¤t_content.events; + let new_list = &user_content.events; + for ev_id in old_list.keys().chain(new_list.keys()) { + let ev_id: &EventType = ev_id; + event_levels_to_check.insert(ev_id); + } + + log::debug!("events to check {:?}", event_levels_to_check); + + let old_state = ¤t_content; + let new_state = &user_content; + + // synapse does not have to split up these checks since we can't combine UserIds and + // EventTypes we do 2 loops + + // UserId loop + for user in user_levels_to_check { + let old_level = old_state.users.get(user); + let new_level = new_state.users.get(user); + if old_level.is_some() && new_level.is_some() && old_level == new_level { + continue; + } + + // If the current value is equal to the sender's current power level, reject + if user != power_event.sender() && old_level.map(|int| (*int).into()) == Some(user_level) { + log::warn!("m.room.power_level cannot remove ops == to own"); + return Some(false); // cannot remove ops level == to own + } + + // If the current value is higher than the sender's current power level, reject + // If the new value is higher than the sender's current power level, reject + let old_level_too_big = old_level.map(|int| (*int).into()) > Some(user_level); + let new_level_too_big = new_level.map(|int| (*int).into()) > Some(user_level); + if old_level_too_big || new_level_too_big { + log::warn!("m.room.power_level failed to add ops > than own"); + return Some(false); // cannot add ops greater than own + } + } + + // EventType loop + for ev_type in event_levels_to_check { + let old_level = old_state.events.get(ev_type); + let new_level = new_state.events.get(ev_type); + if old_level.is_some() && new_level.is_some() && old_level == new_level { + continue; + } + + // If the current value is higher than the sender's current power level, reject + // If the new value is higher than the sender's current power level, reject + let old_level_too_big = old_level.map(|int| (*int).into()) > Some(user_level); + let new_level_too_big = new_level.map(|int| (*int).into()) > Some(user_level); + if old_level_too_big || new_level_too_big { + log::warn!("m.room.power_level failed to add ops > than own"); + return Some(false); // cannot add ops greater than own + } + } + + // Notifications, currently there is only @room + if room_version.limit_notifications_power_levels { + let old_level = old_state.notifications.room; + let new_level = new_state.notifications.room; + if old_level != new_level { + // If the current value is higher than the sender's current power level, reject + // If the new value is higher than the sender's current power level, reject + let old_level_too_big = i64::from(old_level) > user_level; + let new_level_too_big = i64::from(new_level) > user_level; + if old_level_too_big || new_level_too_big { + log::warn!("m.room.power_level failed to add ops > than own"); + return Some(false); // cannot add ops greater than own + } + } + } + + let levels = [ + "users_default", + "events_default", + "state_default", + "ban", + "redact", + "kick", + "invite", + ]; + let old_state = serde_json::to_value(old_state).unwrap(); + let new_state = serde_json::to_value(new_state).unwrap(); + for lvl_name in &levels { + if let Some((old_lvl, new_lvl)) = get_deserialize_levels(&old_state, &new_state, lvl_name) { + let old_level_too_big = old_lvl > user_level; + let new_level_too_big = new_lvl > user_level; + + if old_level_too_big || new_level_too_big { + log::warn!("cannot add ops > than own"); + return Some(false); + } + } + } + + Some(true) +} + +fn get_deserialize_levels( + old: &serde_json::Value, + new: &serde_json::Value, + name: &str, +) -> Option<(i64, i64)> { + Some(( + serde_json::from_value(old.get(name)?.clone()).ok()?, + serde_json::from_value(new.get(name)?.clone()).ok()?, + )) +} + +/// Does the event redacting come from a user with enough power to redact the given event. +pub fn check_redaction( + _room_version: &RoomVersion, + redaction_event: &Arc, + auth_events: &StateMap>, +) -> Result { + let user_level = get_user_power_level(redaction_event.sender(), auth_events); + let redact_level = get_named_level(auth_events, "redact", 50); + + if user_level >= redact_level { + log::info!("redaction allowed via power levels"); + return Ok(true); + } + + // If the domain of the event_id of the event being redacted is the same as the + // domain of the event_id of the m.room.redaction, allow + if redaction_event.event_id().server_name() + == redaction_event + .redacts() + .as_ref() + .and_then(|id| id.server_name()) + { + log::info!("redaction event allowed via room version 1 rules"); + return Ok(true); + } + + Ok(false) +} + +/// Check that the member event matches `state`. +/// +/// This function returns false instead of failing when deserialization fails. +pub fn check_membership(member_event: Option>, state: MembershipState) -> bool { + if let Some(event) = member_event { + if let Some(Ok(membership)) = event + .content() + .get("membership") + .map(|m| serde_json::from_value::(m.clone())) + { + membership == state + } else { + false + } + } else { + false + } +} + +/// Can this room federate based on its m.room.create event. +pub fn can_federate(auth_events: &StateMap>) -> bool { + let creation_event = auth_events.get(&(EventType::RoomCreate, "".into())); + if let Some(ev) = creation_event { + if let Some(fed) = ev.content().get("m.federate") { + fed == "true" + } else { + false + } + } else { + false + } +} + +/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. +/// or return `default` if no power level event is found or zero if no field matches `name`. +pub fn get_named_level(auth_events: &StateMap>, name: &str, default: i64) -> i64 { + let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, "".into())); + if let Some(pl) = power_level_event { + // TODO do this the right way and deserialize + if let Some(level) = pl.content().get(name) { + level.to_string().parse().unwrap_or(default) + } else { + 0 + } + } else { + default + } +} + +/// Helper function to fetch a users default power level from a "m.room.power_level" event's `users` +/// object. +pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap>) -> i64 { + if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) { + if let Ok(content) = serde_json::from_value::(pl.content()) { + if let Some(level) = content.users.get(user_id) { + (*level).into() + } else { + content.users_default.into() + } + } else { + 0 // TODO if this fails DB error? + } + } else { + // if no power level event found the creator gets 100 everyone else gets 0 + let key = (EventType::RoomCreate, "".into()); + if let Some(create) = auth_events.get(&key) { + if let Ok(c) = serde_json::from_value::(create.content()) { + if &c.creator == user_id { + 100 + } else { + 0 + } + } else { + 0 + } + } else { + 0 + } + } +} + +/// Helper function to fetch the power level needed to send an event of type +/// `e_type` based on the rooms "m.room.power_level" event. +pub fn get_send_level( + e_type: &EventType, + state_key: Option, + power_lvl: Option<&Arc>, +) -> i64 { + log::debug!("{:?} {:?}", e_type, state_key); + power_lvl + .and_then(|ple| { + serde_json::from_value::(ple.content()) + .map(|content| { + content.events.get(e_type).copied().unwrap_or_else(|| { + if state_key.is_some() { + content.state_default + } else { + content.events_default + } + }) + }) + .ok() + }) + .map(i64::from) + .unwrap_or_else(|| if state_key.is_some() { 50 } else { 0 }) +} + +/// Check user can send invite. +pub fn can_send_invite(event: &Arc, auth_events: &StateMap>) -> Result { + let user_level = get_user_power_level(event.sender(), auth_events); + let key = (EventType::RoomPowerLevels, "".into()); + let invite_level = auth_events + .get(&key) + .map_or_else( + || Ok::<_, Error>(ruma::int!(50)), + |power_levels| { + serde_json::from_value::(power_levels.content()) + .map(|pl| pl.invite) + .map_err(Into::into) + }, + )? + .into(); + + Ok(user_level >= invite_level) +} + +pub fn verify_third_party_invite( + user_state_key: Option<&str>, + sender: &UserId, + tp_id: &ThirdPartyInvite, + current_third_party_invite: Option>, +) -> bool { + // 1. check for user being banned happens before this is called + // checking for mxid and token keys is done by ruma when deserializing + + if user_state_key != Some(tp_id.signed.mxid.as_str()) { + return false; + } + + // If there is no m.room.third_party_invite event in the current room state + // with state_key matching token, reject + if let Some(current_tpid) = current_third_party_invite { + if current_tpid.state_key().as_ref() != Some(&tp_id.signed.token) { + return false; + } + + if sender != current_tpid.sender() { + return false; + } + + // If any signature in signed matches any public key in the m.room.third_party_invite event, allow + if let Ok(tpid_ev) = serde_json::from_value::< + ruma::events::room::third_party_invite::ThirdPartyInviteEventContent, + >(current_tpid.content()) + { + // A list of public keys in the public_keys field + for key in tpid_ev.public_keys.unwrap_or_default() { + if key.public_key == tp_id.signed.token { + return true; + } + } + // A single public key in the public_key field + tpid_ev.public_key == tp_id.signed.token + } else { + false + } + } else { + false + } +} diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs new file mode 100644 index 00000000..84634e5b --- /dev/null +++ b/crates/ruma-state-res/src/lib.rs @@ -0,0 +1,717 @@ +use std::{ + cmp::Reverse, + collections::{BTreeMap, BTreeSet, BinaryHeap}, + sync::Arc, + time::SystemTime, +}; + +use maplit::btreeset; +use room_version::RoomVersion; +use ruma::{ + events::{ + room::{ + member::{MemberEventContent, MembershipState}, + power_levels::PowerLevelsEventContent, + }, + EventType, + }, + EventId, RoomId, RoomVersionId, +}; + +mod error; +pub mod event_auth; +pub mod room_version; +mod state_event; + +pub use error::{Error, Result}; +pub use event_auth::{auth_check, auth_types_for_event}; +pub use state_event::Event; + +/// A mapping of event type and state_key to some value `T`, usually an `EventId`. +pub type StateMap = BTreeMap<(EventType, String), T>; + +/// A mapping of `EventId` to `T`, usually a `ServerPdu`. +pub type EventMap = BTreeMap; + +#[derive(Default)] +pub struct StateResolution; + +impl StateResolution { + /// Resolve sets of state events as they come in. Internally `StateResolution` builds a graph + /// and an auth chain to allow for state conflict resolution. + /// + /// ## Arguments + /// + /// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork + /// in the state of a room. + /// + /// * `auth_events` - The full recursive set of `auth_events` for each event in the `state_sets`. + /// + /// * `event_map` - The `EventMap` acts as a local cache of state, any event that is not found + /// in the `event_map` will cause an unrecoverable `Error` in `resolve`. + pub fn resolve( + room_id: &RoomId, + room_version: &RoomVersionId, + state_sets: &[StateMap], + auth_events: Vec>, + event_map: &mut EventMap>, + ) -> Result> { + log::info!("State resolution starting"); + + // split non-conflicting and conflicting state + let (clean, conflicting) = StateResolution::separate(state_sets); + + log::info!("non conflicting {:?}", clean.len()); + + if conflicting.is_empty() { + log::info!("no conflicting state found"); + return Ok(clean); + } + + log::info!("{} conflicting events", conflicting.len()); + + // the set of auth events that are not common across server forks + let mut auth_diff = StateResolution::get_auth_chain_diff(room_id, &auth_events)?; + + log::debug!("auth diff size {:?}", auth_diff); + + // add the auth_diff to conflicting now we have a full set of conflicting events + auth_diff.extend(conflicting.values().cloned().flatten()); + let mut all_conflicted = auth_diff + .into_iter() + .collect::>() + .into_iter() + .collect::>(); + + log::info!("full conflicted set is {} events", all_conflicted.len()); + + // we used to check that all events are events from the correct room + // this is now a check the caller of `resolve` must make. + + // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` + // + // don't honor events we cannot "verify" + all_conflicted.retain(|id| event_map.contains_key(id)); + + // get only the control events with a state_key: "" or ban/kick event (sender != state_key) + let control_events = all_conflicted + .iter() + .filter(|id| is_power_event_id(id, event_map)) + .cloned() + .collect::>(); + + // sort the control events based on power_level/clock/event_id and outgoing/incoming edges + let mut sorted_control_levels = StateResolution::reverse_topological_power_sort( + room_id, + &control_events, + event_map, + &all_conflicted, + ); + + log::debug!("SRTD {:?}", sorted_control_levels); + + let room_version = RoomVersion::new(room_version)?; + // sequentially auth check each control event. + let resolved_control = StateResolution::iterative_auth_check( + room_id, + &room_version, + &sorted_control_levels, + &clean, + event_map, + )?; + + log::debug!( + "AUTHED {:?}", + resolved_control + .iter() + .map(|(key, id)| (key, id.to_string())) + .collect::>() + ); + + // At this point the control_events have been resolved we now have to + // sort the remaining events using the mainline of the resolved power level. + sorted_control_levels.dedup(); + let deduped_power_ev = sorted_control_levels; + + // This removes the control events that passed auth and more importantly those that failed auth + let events_to_resolve = all_conflicted + .iter() + .filter(|id| !deduped_power_ev.contains(id)) + .cloned() + .collect::>(); + + log::debug!( + "LEFT {:?}", + events_to_resolve + .iter() + .map(ToString::to_string) + .collect::>() + ); + + // This "epochs" power level event + let power_event = resolved_control.get(&(EventType::RoomPowerLevels, "".into())); + + log::debug!("PL {:?}", power_event); + + let sorted_left_events = + StateResolution::mainline_sort(room_id, &events_to_resolve, power_event, event_map); + + log::debug!( + "SORTED LEFT {:?}", + sorted_left_events + .iter() + .map(ToString::to_string) + .collect::>() + ); + + let mut resolved_state = StateResolution::iterative_auth_check( + room_id, + &room_version, + &sorted_left_events, + &resolved_control, // The control events are added to the final resolved state + event_map, + )?; + + // add unconflicted state to the resolved state + // We priorities the unconflicting state + resolved_state.extend(clean); + Ok(resolved_state) + } + + /// Split the events that have no conflicts from those that are conflicting. + /// The return tuple looks like `(unconflicted, conflicted)`. + /// + /// State is determined to be conflicting if for the given key (EventType, StateKey) there + /// is not exactly one eventId. This includes missing events, if one state_set includes an event + /// that none of the other have this is a conflicting event. + pub fn separate( + state_sets: &[StateMap], + ) -> (StateMap, StateMap>) { + use itertools::Itertools; + + log::info!( + "seperating {} sets of events into conflicted/unconflicted", + state_sets.len() + ); + + let mut unconflicted_state = StateMap::new(); + let mut conflicted_state = StateMap::new(); + + for key in state_sets.iter().flat_map(|map| map.keys()).dedup() { + let mut event_ids = state_sets + .iter() + .map(|state_set| state_set.get(key)) + .dedup() + .collect::>(); + + if event_ids.len() == 1 { + if let Some(Some(id)) = event_ids.pop() { + unconflicted_state.insert(key.clone(), id.clone()); + } else { + panic!() + } + } else { + conflicted_state.insert( + key.clone(), + event_ids.into_iter().flatten().cloned().collect::>(), + ); + } + } + + (unconflicted_state, conflicted_state) + } + + /// Returns a Vec of deduped EventIds that appear in some chains but not others. + pub fn get_auth_chain_diff( + _room_id: &RoomId, + auth_event_ids: &[Vec], + ) -> Result> { + use itertools::Itertools; + + let mut chains = vec![]; + + for ids in auth_event_ids { + // TODO state store `auth_event_ids` returns self in the event ids list + // when an event returns `auth_event_ids` self is not contained + let chain = ids.iter().cloned().collect::>(); + chains.push(chain); + } + + if let Some(chain) = chains.first().cloned() { + let rest = chains.iter().skip(1).flatten().cloned().collect(); + let common = chain.intersection(&rest).collect::>(); + + Ok(chains + .into_iter() + .flatten() + .filter(|id| !common.contains(&id)) + .dedup() + .collect()) + } else { + Ok(vec![]) + } + } + + /// Events are sorted from "earliest" to "latest". They are compared using + /// the negative power level (reverse topological ordering), the + /// origin server timestamp and incase of a tie the `EventId`s + /// are compared lexicographically. + /// + /// The power level is negative because a higher power level is equated to an + /// earlier (further back in time) origin server timestamp. + pub fn reverse_topological_power_sort( + room_id: &RoomId, + events_to_sort: &[EventId], + event_map: &mut EventMap>, + auth_diff: &[EventId], + ) -> Vec { + log::debug!("reverse topological sort of power events"); + + let mut graph = BTreeMap::new(); + for event_id in events_to_sort.iter() { + StateResolution::add_event_and_auth_chain_to_graph( + room_id, &mut graph, event_id, event_map, auth_diff, + ); + + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + + // this is used in the `key_fn` passed to the lexico_topo_sort fn + let mut event_to_pl = BTreeMap::new(); + for event_id in graph.keys() { + let pl = StateResolution::get_power_level_for_sender(room_id, event_id, event_map); + log::info!("{} power level {}", event_id.to_string(), pl); + + event_to_pl.insert(event_id.clone(), pl); + + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + + StateResolution::lexicographical_topological_sort(&graph, |event_id| { + // log::debug!("{:?}", event_map.get(event_id).unwrap().origin_server_ts()); + let ev = event_map.get(event_id).unwrap(); + let pl = event_to_pl.get(event_id).unwrap(); + + log::debug!("{:?}", (-*pl, ev.origin_server_ts(), &ev.event_id())); + + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + (-*pl, ev.origin_server_ts(), ev.event_id().clone()) + }) + } + + /// Sorts the event graph based on number of outgoing/incoming edges, where + /// `key_fn` is used as a tie breaker. The tie breaker happens based on + /// power level, age, and event_id. + pub fn lexicographical_topological_sort( + graph: &BTreeMap>, + key_fn: F, + ) -> Vec + where + F: Fn(&EventId) -> (i64, SystemTime, EventId), + { + log::info!("starting lexicographical topological sort"); + // NOTE: an event that has no incoming edges happened most recently, + // and an event that has no outgoing edges happened least recently. + + // NOTE: this is basically Kahn's algorithm except we look at nodes with no + // outgoing edges, c.f. + // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + + // TODO make the BTreeSet conversion cleaner ?? + // outdegree_map is an event referring to the events before it, the + // more outdegree's the more recent the event. + let mut outdegree_map: BTreeMap> = graph + .iter() + .map(|(k, v)| (k.clone(), v.iter().cloned().collect())) + .collect(); + + // The number of events that depend on the given event (the eventId key) + let mut reverse_graph = BTreeMap::new(); + + // Vec of nodes that have zero out degree, least recent events. + let mut zero_outdegree = vec![]; + + for (node, edges) in graph.iter() { + if edges.is_empty() { + // the `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need + // smallest -> largest + zero_outdegree.push(Reverse((key_fn(node), node))); + } + + reverse_graph.entry(node).or_insert(btreeset![]); + for edge in edges { + reverse_graph + .entry(edge) + .or_insert(btreeset![]) + .insert(node); + } + } + + let mut heap = BinaryHeap::from(zero_outdegree); + + // we remove the oldest node (most incoming edges) and check against all other + let mut sorted = vec![]; + // destructure the `Reverse` and take the smallest `node` each time + while let Some(Reverse((_, node))) = heap.pop() { + let node: &EventId = node; + for parent in reverse_graph.get(node).unwrap() { + // the number of outgoing edges this node has + let out = outdegree_map.get_mut(parent).unwrap(); + + // only push on the heap once older events have been cleared + out.remove(node); + if out.is_empty() { + heap.push(Reverse((key_fn(parent), parent))); + } + } + + // synapse yields we push then return the vec + sorted.push(node.clone()); + } + + sorted + } + + /// Find the power level for the sender of `event_id` or return a default value of zero. + fn get_power_level_for_sender( + room_id: &RoomId, + event_id: &EventId, + event_map: &mut EventMap>, + ) -> i64 { + log::info!("fetch event ({}) senders power level", event_id.to_string()); + + let event = StateResolution::get_or_load_event(room_id, event_id, event_map); + let mut pl = None; + + // TODO store.auth_event_ids returns "self" with the event ids is this ok + // event.auth_event_ids does not include its own event id ? + for aid in event + .as_ref() + .map(|pdu| pdu.auth_events()) + .unwrap_or_default() + { + if let Ok(aev) = StateResolution::get_or_load_event(room_id, &aid, event_map) { + if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { + pl = Some(aev); + break; + } + } + } + + if pl.is_none() { + return 0; + } + + if let Some(content) = + pl.and_then(|pl| serde_json::from_value::(pl.content()).ok()) + { + if let Ok(ev) = event { + if let Some(user) = content.users.get(ev.sender()) { + log::debug!("found {} at power_level {}", ev.sender().as_str(), user); + return (*user).into(); + } + } + content.users_default.into() + } else { + 0 + } + } + + /// Check the that each event is authenticated based on the events before it. + /// + /// ## Returns + /// + /// The `unconflicted_state` combined with the newly auth'ed events. So any event that + /// fails the `event_auth::auth_check` will be excluded from the returned `StateMap`. + /// + /// For each `events_to_check` event we gather the events needed to auth it from the + /// `event_map` or `store` and verify each event using the `event_auth::auth_check` + /// function. + pub fn iterative_auth_check( + room_id: &RoomId, + room_version: &RoomVersion, + events_to_check: &[EventId], + unconflicted_state: &StateMap, + event_map: &mut EventMap>, + ) -> Result> { + log::info!("starting iterative auth check"); + + log::debug!( + "performing auth checks on {:?}", + events_to_check + .iter() + .map(ToString::to_string) + .collect::>() + ); + + let mut resolved_state = unconflicted_state.clone(); + + for event_id in events_to_check.iter() { + let event = StateResolution::get_or_load_event(room_id, event_id, event_map)?; + let state_key = event + .state_key() + .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; + + let mut auth_events = BTreeMap::new(); + for aid in &event.auth_events() { + if let Ok(ev) = StateResolution::get_or_load_event(room_id, aid, event_map) { + // TODO synapse check "rejected_reason", I'm guessing this is redacted_because in ruma ?? + auth_events.insert( + ( + ev.kind(), + ev.state_key().ok_or_else(|| { + Error::InvalidPdu("State event had no state key".to_owned()) + })?, + ), + ev, + ); + } else { + log::warn!("auth event id for {} is missing {}", aid, event_id); + } + } + + for key in auth_types_for_event( + &event.kind(), + event.sender(), + Some(state_key.clone()), + event.content(), + ) { + if let Some(ev_id) = resolved_state.get(&key) { + if let Ok(event) = StateResolution::get_or_load_event(room_id, ev_id, event_map) + { + // TODO synapse checks `rejected_reason` is None here + auth_events.insert(key.clone(), event); + } + } + } + + log::debug!("event to check {:?}", event.event_id().as_str()); + + let most_recent_prev_event = event + .prev_events() + .iter() + .filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map).ok()) + .next_back(); + + // The key for this is (eventType + a state_key of the signed token not sender) so search + // for it + let current_third_party = auth_events.iter().find_map(|(_, pdu)| { + if pdu.kind() == EventType::RoomThirdPartyInvite { + Some(pdu.clone()) // TODO no clone, auth_events is borrowed while moved + } else { + None + } + }); + + if auth_check( + room_version, + &event, + most_recent_prev_event, + &auth_events, + current_third_party, + )? { + // add event to resolved state map + resolved_state.insert((event.kind(), state_key), event_id.clone()); + } else { + // synapse passes here on AuthError. We do not add this event to resolved_state. + log::warn!( + "event {} failed the authentication check", + event_id.to_string() + ); + } + + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + Ok(resolved_state) + } + + /// Returns the sorted `to_sort` list of `EventId`s based on a mainline sort using + /// the depth of `resolved_power_level`, the server timestamp, and the eventId. + /// + /// The depth of the given event is calculated based on the depth of it's closest "parent" + /// power_level event. If there have been two power events the after the most recent are + /// depth 0, the events before (with the first power level as a parent) will be marked + /// as depth 1. depth 1 is "older" than depth 0. + pub fn mainline_sort( + room_id: &RoomId, + to_sort: &[EventId], + resolved_power_level: Option<&EventId>, + event_map: &mut EventMap>, + ) -> Vec { + log::debug!("mainline sort of events"); + + // There are no EventId's to sort, bail. + if to_sort.is_empty() { + return vec![]; + } + + let mut mainline = vec![]; + let mut pl = resolved_power_level.cloned(); + while let Some(p) = pl { + mainline.push(p.clone()); + + let event = StateResolution::get_or_load_event(room_id, &p, event_map).unwrap(); + let auth_events = &event.auth_events(); + pl = None; + for aid in auth_events { + let ev = StateResolution::get_or_load_event(room_id, aid, event_map).unwrap(); + if is_type_and_key(&ev, EventType::RoomPowerLevels, "") { + pl = Some(aid.clone()); + break; + } + } + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + + let mainline_map = mainline + .iter() + .rev() + .enumerate() + .map(|(idx, eid)| ((*eid).clone(), idx)) + .collect::>(); + + let mut order_map = BTreeMap::new(); + for ev_id in to_sort.iter() { + if let Ok(event) = StateResolution::get_or_load_event(room_id, ev_id, event_map) { + if let Ok(depth) = StateResolution::get_mainline_depth( + room_id, + Some(event), + &mainline_map, + event_map, + ) { + order_map.insert( + ev_id, + ( + depth, + event_map.get(ev_id).map(|ev| ev.origin_server_ts()), + ev_id, // TODO should this be a &str to sort lexically?? + ), + ); + } + } + + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + + // sort the event_ids by their depth, timestamp and EventId + // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) + let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::>(); + sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap()); + + sort_event_ids + } + + /// Get the mainline depth from the `mainline_map` or finds a power_level event + /// that has an associated mainline depth. + fn get_mainline_depth( + room_id: &RoomId, + mut event: Option>, + mainline_map: &EventMap, + event_map: &mut EventMap>, + ) -> Result { + while let Some(sort_ev) = event { + log::debug!("mainline event_id {}", sort_ev.event_id().to_string()); + let id = &sort_ev.event_id(); + if let Some(depth) = mainline_map.get(id) { + return Ok(*depth); + } + + // dbg!(&sort_ev); + let auth_events = &sort_ev.auth_events(); + event = None; + for aid in auth_events { + // dbg!(&aid); + let aev = StateResolution::get_or_load_event(room_id, aid, event_map)?; + if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { + event = Some(aev); + break; + } + } + } + // Did not find a power level event so we default to zero + Ok(0) + } + + fn add_event_and_auth_chain_to_graph( + room_id: &RoomId, + graph: &mut BTreeMap>, + event_id: &EventId, + event_map: &mut EventMap>, + auth_diff: &[EventId], + ) { + let mut state = vec![event_id.clone()]; + while !state.is_empty() { + // we just checked if it was empty so unwrap is fine + let eid = state.pop().unwrap(); + graph.entry(eid.clone()).or_insert_with(Vec::new); + // prefer the store to event as the store filters dedups the events + // otherwise it seems we can loop forever + for aid in &StateResolution::get_or_load_event(room_id, &eid, event_map) + .unwrap() + .auth_events() + { + if auth_diff.contains(aid) { + if !graph.contains_key(aid) { + state.push(aid.clone()); + } + + // we just inserted this at the start of the while loop + graph.get_mut(&eid).unwrap().push(aid.clone()); + } + } + } + } + + /// Uses the `event_map` to return the full PDU or fails. + fn get_or_load_event( + _room_id: &RoomId, + ev_id: &EventId, + event_map: &EventMap>, + ) -> Result> { + event_map.get(ev_id).map_or_else( + || Err(Error::NotFound(format!("EventId: {:?} not found", ev_id))), + |e| Ok(Arc::clone(e)), + ) + } +} + +pub fn is_power_event_id(event_id: &EventId, event_map: &EventMap>) -> bool { + match event_map.get(event_id) { + Some(state) => is_power_event(state), + _ => false, + } +} + +pub fn is_type_and_key(ev: &Arc, ev_type: EventType, state_key: &str) -> bool { + ev.kind() == ev_type && ev.state_key().as_deref() == Some(state_key) +} + +pub fn is_power_event(event: &Arc) -> bool { + match event.kind() { + EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => { + event.state_key() == Some("".into()) + } + EventType::RoomMember => { + if let Ok(content) = serde_json::from_value::(event.content()) { + if [MembershipState::Leave, MembershipState::Ban].contains(&content.membership) { + return Some(event.sender().as_str()) != event.state_key().as_deref(); + } + } + + false + } + _ => false, + } +} diff --git a/crates/ruma-state-res/src/room_version.rs b/crates/ruma-state-res/src/room_version.rs new file mode 100644 index 00000000..a4c27087 --- /dev/null +++ b/crates/ruma-state-res/src/room_version.rs @@ -0,0 +1,160 @@ +use ruma::RoomVersionId; + +use crate::{Error, Result}; + +pub enum RoomDisposition { + /// A room version that has a stable specification. + Stable, + /// A room version that is not yet fully specified. + #[allow(dead_code)] + Unstable, +} + +pub enum EventFormatVersion { + /// $id:server event id format + V1, + /// MSC1659-style $hash event id format: introduced for room v3 + V2, + /// MSC1884-style $hash format: introduced for room v4 + V3, +} + +pub enum StateResolutionVersion { + /// State resolution for rooms at version 1. + V1, + /// State resolution for room at version 2 or later. + V2, +} + +pub struct RoomVersion { + /// The version this room is set to. + pub version: RoomVersionId, + /// The stability of this room. + pub disposition: RoomDisposition, + /// The format of the EventId. + pub event_format: EventFormatVersion, + /// Which state resolution algorithm is used. + pub state_res: StateResolutionVersion, + /// not sure + pub enforce_key_validity: bool, + + /// `m.room.aliases` had special auth rules and redaction rules + /// before room version 6. + /// + /// before MSC2261/MSC2432, + pub special_case_aliases_auth: bool, + /// Strictly enforce canonicaljson, do not allow: + /// * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] + /// * Floats + /// * NaN, Infinity, -Infinity + pub strict_canonicaljson: bool, + /// Verify notifications key while checking m.room.power_levels. + /// + /// bool: MSC2209: Check 'notifications' + pub limit_notifications_power_levels: bool, + /// Extra rules when verifying redaction events. + pub extra_redaction_checks: bool, +} + +impl RoomVersion { + pub fn new(version: &RoomVersionId) -> Result { + Ok(match version { + RoomVersionId::Version1 => Self::version_1(), + RoomVersionId::Version2 => Self::version_2(), + RoomVersionId::Version3 => Self::version_3(), + RoomVersionId::Version4 => Self::version_4(), + RoomVersionId::Version5 => Self::version_5(), + RoomVersionId::Version6 => Self::version_6(), + ver => { + return Err(Error::Unsupported(format!( + "found version `{}`", + ver.as_str() + ))) + } + }) + } + + pub fn version_1() -> Self { + Self { + version: RoomVersionId::Version1, + disposition: RoomDisposition::Stable, + event_format: EventFormatVersion::V1, + state_res: StateResolutionVersion::V1, + enforce_key_validity: false, + special_case_aliases_auth: true, + strict_canonicaljson: false, + limit_notifications_power_levels: false, + extra_redaction_checks: false, + } + } + + pub fn version_2() -> Self { + Self { + version: RoomVersionId::Version2, + disposition: RoomDisposition::Stable, + event_format: EventFormatVersion::V1, + state_res: StateResolutionVersion::V2, + enforce_key_validity: false, + special_case_aliases_auth: true, + strict_canonicaljson: false, + limit_notifications_power_levels: false, + extra_redaction_checks: false, + } + } + + pub fn version_3() -> Self { + Self { + version: RoomVersionId::Version3, + disposition: RoomDisposition::Stable, + event_format: EventFormatVersion::V2, + state_res: StateResolutionVersion::V2, + enforce_key_validity: false, + special_case_aliases_auth: true, + strict_canonicaljson: false, + limit_notifications_power_levels: false, + extra_redaction_checks: true, + } + } + + pub fn version_4() -> Self { + Self { + version: RoomVersionId::Version4, + disposition: RoomDisposition::Stable, + event_format: EventFormatVersion::V3, + state_res: StateResolutionVersion::V2, + enforce_key_validity: false, + special_case_aliases_auth: true, + strict_canonicaljson: false, + limit_notifications_power_levels: false, + extra_redaction_checks: true, + } + } + + pub fn version_5() -> Self { + Self { + version: RoomVersionId::Version5, + disposition: RoomDisposition::Stable, + event_format: EventFormatVersion::V3, + state_res: StateResolutionVersion::V2, + enforce_key_validity: true, + special_case_aliases_auth: true, + strict_canonicaljson: false, + limit_notifications_power_levels: false, + extra_redaction_checks: true, + } + } + + pub fn version_6() -> Self { + Self { + version: RoomVersionId::Version6, + disposition: RoomDisposition::Stable, + event_format: EventFormatVersion::V3, + state_res: StateResolutionVersion::V2, + enforce_key_validity: true, + special_case_aliases_auth: false, + strict_canonicaljson: true, + limit_notifications_power_levels: true, + extra_redaction_checks: true, + } + } +} diff --git a/crates/ruma-state-res/src/state_event.rs b/crates/ruma-state-res/src/state_event.rs new file mode 100644 index 00000000..0a81eec4 --- /dev/null +++ b/crates/ruma-state-res/src/state_event.rs @@ -0,0 +1,52 @@ +use std::{collections::BTreeMap, time::SystemTime}; + +use ruma::{ + events::{pdu::EventHash, EventType}, + EventId, RoomId, ServerName, UInt, UserId, +}; +use serde_json::value::Value as JsonValue; + +/// Abstraction of a PDU so users can have their own PDU types. +pub trait Event { + /// The `EventId` of this event. + fn event_id(&self) -> &EventId; + + /// The `RoomId` of this event. + fn room_id(&self) -> &RoomId; + + /// The `UserId` of this event. + fn sender(&self) -> &UserId; + + /// The time of creation on the originating server. + fn origin_server_ts(&self) -> SystemTime; + + /// The kind of event. + fn kind(&self) -> EventType; + + /// The `UserId` of this PDU. + fn content(&self) -> serde_json::Value; + + /// The state key for this event. + fn state_key(&self) -> Option; + + /// The events before this event. + fn prev_events(&self) -> Vec; + + /// The maximum number of `prev_events` plus 1. + /// + /// This is only used in state resolution version 1. + fn depth(&self) -> &UInt; + + /// All the authenticating events for this event. + fn auth_events(&self) -> Vec; + + /// If this event is a redaction event this is the event it redacts. + fn redacts(&self) -> Option<&EventId>; + + /// The `unsigned` content of this event. + fn unsigned(&self) -> &BTreeMap; + + fn hashes(&self) -> &EventHash; + + fn signatures(&self) -> BTreeMap, BTreeMap>; +} diff --git a/crates/ruma-state-res/tests/event_auth.rs b/crates/ruma-state-res/tests/event_auth.rs new file mode 100644 index 00000000..bd91d0ab --- /dev/null +++ b/crates/ruma-state-res/tests/event_auth.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use state_res::{event_auth::valid_membership_change, StateMap}; +// use state_res::event_auth:::{ +// auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction, +// }; + +mod utils; +use utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS}; + +#[test] +fn test_ban_pass() { + let events = INITIAL_EVENTS(); + + let prev = events + .values() + .find(|ev| ev.event_id().as_str().contains("IMC")) + .map(Arc::clone); + + let auth_events = events + .values() + .map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + alice(), + ruma::events::EventType::RoomMember, + Some(charlie().as_str()), + member_content_ban(), + &[], + &[event_id("IMC")], + ); + + assert!(valid_membership_change( + &requester.state_key(), + requester.sender(), + requester.content(), + prev, + None, + &auth_events + ) + .unwrap()) +} + +#[test] +fn test_ban_fail() { + let events = INITIAL_EVENTS(); + + let prev = events + .values() + .find(|ev| ev.event_id().as_str().contains("IMC")) + .map(Arc::clone); + + let auth_events = events + .values() + .map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + charlie(), + ruma::events::EventType::RoomMember, + Some(alice().as_str()), + member_content_ban(), + &[], + &[event_id("IMC")], + ); + + assert!(!valid_membership_change( + &requester.state_key(), + requester.sender(), + requester.content(), + prev, + None, + &auth_events + ) + .unwrap()) +} diff --git a/crates/ruma-state-res/tests/event_sorting.rs b/crates/ruma-state-res/tests/event_sorting.rs new file mode 100644 index 00000000..2f50fff5 --- /dev/null +++ b/crates/ruma-state-res/tests/event_sorting.rs @@ -0,0 +1,95 @@ +use std::collections::BTreeMap; + +use ruma::{events::EventType, EventId}; +use state_res::{is_power_event, room_version::RoomVersion, StateMap}; + +mod utils; +use utils::{room_id, INITIAL_EVENTS}; + +fn shuffle(list: &mut [EventId]) { + use rand::Rng; + + let mut rng = rand::thread_rng(); + for i in 1..list.len() { + let j = rng.gen_range(0, list.len()); + list.swap(i, j); + } +} + +fn test_event_sort() { + let mut events = INITIAL_EVENTS(); + + let event_map = events + .values() + .map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) + .collect::>(); + + let auth_chain = &[] as &[_]; + + let power_events = event_map + .values() + .filter(|pdu| is_power_event(pdu)) + .map(|pdu| pdu.event_id().clone()) + .collect::>(); + + // This is a TODO in conduit + // TODO these events are not guaranteed to be sorted but they are resolved, do + // we need the auth_chain + let sorted_power_events = state_res::StateResolution::reverse_topological_power_sort( + &room_id(), + &power_events, + &mut events, + auth_chain, + ); + + // This is a TODO in conduit + // TODO we may be able to skip this since they are resolved according to spec + let resolved_power = state_res::StateResolution::iterative_auth_check( + &room_id(), + &RoomVersion::version_6(), + &sorted_power_events, + &BTreeMap::new(), // unconflicted events + &mut events, + ) + .expect("iterative auth check failed on resolved events"); + + // don't remove any events so we know it sorts them all correctly + let mut events_to_sort = events.keys().cloned().collect::>(); + + shuffle(&mut events_to_sort); + + let power_level = resolved_power.get(&(EventType::RoomPowerLevels, "".to_string())); + + let sorted_event_ids = state_res::StateResolution::mainline_sort( + &room_id(), + &events_to_sort, + power_level, + &mut events, + ); + + assert_eq!( + vec![ + "$CREATE:foo", + "$IMA:foo", + "$IPOWER:foo", + "$IJR:foo", + "$IMB:foo", + "$IMC:foo", + "$START:foo", + "$END:foo" + ], + sorted_event_ids + .iter() + .map(|id| id.to_string()) + .collect::>() + ) +} + +#[test] +fn test_sort() { + for _ in 0..20 { + // since we shuffle the eventIds before we sort them introducing randomness + // seems like we should test this a few times + test_event_sort() + } +} diff --git a/crates/ruma-state-res/tests/res_with_auth_ids.rs b/crates/ruma-state-res/tests/res_with_auth_ids.rs new file mode 100644 index 00000000..4c637abe --- /dev/null +++ b/crates/ruma-state-res/tests/res_with_auth_ids.rs @@ -0,0 +1,208 @@ +#![allow(clippy::or_fun_call, clippy::expect_fun_call)] + +use std::{collections::BTreeMap, sync::Arc}; + +use ruma::{events::EventType, EventId, RoomVersionId}; +use serde_json::json; +use state_res::{EventMap, StateMap, StateResolution}; + +mod utils; +use utils::{ + alice, bob, do_check, ella, event_id, member_content_ban, member_content_join, room_id, + to_pdu_event, zara, StateEvent, TestStore, INITIAL_EVENTS, +}; + +#[test] +fn ban_with_auth_chains() { + let ban = BAN_STATE_SET(); + + let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PA", "MB"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check( + &ban.values().cloned().collect::>(), + edges, + expected_state_ids, + ); +} + +#[test] +fn ban_with_auth_chains2() { + let init = INITIAL_EVENTS(); + let ban = BAN_STATE_SET(); + + let mut inner = init.clone(); + inner.extend(ban); + let store = TestStore(inner.clone()); + + let state_set_a = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("MB")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) + .collect::>(); + + let state_set_b = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("IME")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) + .collect::>(); + + let mut ev_map: EventMap> = store.0.clone(); + let state_sets = vec![state_set_a, state_set_b]; + let resolved = match StateResolution::resolve::( + &room_id(), + &RoomVersionId::Version2, + &state_sets, + state_sets + .iter() + .map(|map| { + store + .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) + .unwrap() + }) + .collect(), + &mut ev_map, + ) { + Ok(state) => state, + Err(e) => panic!("{}", e), + }; + + log::debug!( + "{:#?}", + resolved + .iter() + .map(|((ty, key), id)| format!("(({}{:?}), {})", ty, key, id)) + .collect::>() + ); + + let expected = vec![ + "$CREATE:foo", + "$IJR:foo", + "$PA:foo", + "$IMA:foo", + "$IMB:foo", + "$IMC:foo", + "$MB:foo", + ]; + + for id in expected.iter().map(|i| event_id(i)) { + // make sure our resolved events are equal to the expected list + assert!( + resolved.values().any(|eid| eid == &id) || init.contains_key(&id), + "{}", + id + ) + } + assert_eq!(expected.len(), resolved.len()) +} + +#[test] +fn join_rule_with_auth_chain() { + let join_rule = JOIN_RULE(); + + let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::>(); + + do_check( + &join_rule.values().cloned().collect::>(), + edges, + expected_state_ids, + ); +} + +#[allow(non_snake_case)] +fn BAN_STATE_SET() -> BTreeMap> { + vec![ + to_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], // auth_events + &["START"], // prev_events + ), + to_pdu_event( + "PB", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], + &["END"], + ), + to_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_ban(), + &["CREATE", "IMA", "PB"], + &["PA"], + ), + to_pdu_event( + "IME", + ella(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_join(), + &["CREATE", "IJR", "PA"], + &["MB"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().clone(), ev)) + .collect() +} + +#[allow(non_snake_case)] +fn JOIN_RULE() -> BTreeMap> { + vec![ + to_pdu_event( + "JR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({"join_rule": "invite"}), + &["CREATE", "IMA", "IPOWER"], + &["START"], + ), + to_pdu_event( + "IMZ", + zara(), + EventType::RoomPowerLevels, + Some(zara().as_str()), + member_content_join(), + &["CREATE", "JR", "IPOWER"], + &["START"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().clone(), ev)) + .collect() +} diff --git a/crates/ruma-state-res/tests/state_res.rs b/crates/ruma-state-res/tests/state_res.rs new file mode 100644 index 00000000..11452f53 --- /dev/null +++ b/crates/ruma-state-res/tests/state_res.rs @@ -0,0 +1,408 @@ +use std::{sync::Arc, time::UNIX_EPOCH}; + +use maplit::btreemap; +use ruma::{ + events::{room::join_rules::JoinRule, EventType}, + EventId, RoomVersionId, +}; +use serde_json::json; +use state_res::{StateMap, StateResolution}; +use tracing_subscriber as tracer; + +mod utils; +use utils::{ + alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join, + room_id, to_init_pdu_event, to_pdu_event, zara, StateEvent, TestStore, LOGGER, +}; + +#[test] +fn ban_vs_power_level() { + let events = &[ + to_init_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event( + "MA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + ), + to_init_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_ban(), + ), + to_init_pdu_event( + "PB", + bob(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + ]; + + let edges = vec![ + vec!["END", "MB", "MA", "PA", "START"], + vec!["END", "PA", "PB"], + ] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PA", "MA", "MB"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn topic_basic() { + let events = &[ + to_init_pdu_event("T1", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA1", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event("T2", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA2", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 0}}), + ), + to_init_pdu_event( + "PB", + bob(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event("T3", bob(), EventType::RoomTopic, Some(""), json!({})), + ]; + + let edges = vec![ + vec!["END", "PA2", "T2", "PA1", "T1", "START"], + vec!["END", "T3", "PB", "PA1"], + ] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PA2", "T2"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn topic_reset() { + let events = &[ + to_init_pdu_event("T1", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event("T2", bob(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_ban(), + ), + ]; + + let edges = vec![ + vec!["END", "MB", "T2", "PA", "T1", "START"], + vec!["END", "T1"], + ] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["T1", "MB", "PA"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn join_rule_evasion() { + let events = &[ + to_init_pdu_event( + "JR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Private }), + ), + to_init_pdu_event( + "ME", + ella(), + EventType::RoomMember, + Some(ella().to_string().as_str()), + member_content_join(), + ), + ]; + + let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec![event_id("JR")]; + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn offtopic_power_level() { + let events = &[ + to_init_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event( + "PB", + bob(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50, charlie(): 50}}), + ), + to_init_pdu_event( + "PC", + charlie(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50, charlie(): 0}}), + ), + ]; + + let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::>(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn topic_setting() { + let events = &[ + to_init_pdu_event("T1", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA1", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event("T2", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA2", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 0}}), + ), + to_init_pdu_event( + "PB", + bob(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event("T3", bob(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event("MZ1", zara(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event("T4", alice(), EventType::RoomTopic, Some(""), json!({})), + ]; + + let edges = vec![ + vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], + vec!["END", "MZ1", "T3", "PB", "PA1"], + ] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["T4", "PA2"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn test_event_map_none() { + let mut store = TestStore::(btreemap! {}); + + // build up the DAG + let (state_at_bob, state_at_charlie, expected) = store.set_up(); + + let mut ev_map: state_res::EventMap> = store.0.clone(); + let state_sets = vec![state_at_bob, state_at_charlie]; + let resolved = match StateResolution::resolve::( + &room_id(), + &RoomVersionId::Version2, + &state_sets, + state_sets + .iter() + .map(|map| { + store + .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) + .unwrap() + }) + .collect(), + &mut ev_map, + ) { + Ok(state) => state, + Err(e) => panic!("{}", e), + }; + + assert_eq!(expected, resolved) +} + +#[test] +fn test_lexicographical_sort() { + let graph = btreemap! { + event_id("l") => vec![event_id("o")], + event_id("m") => vec![event_id("n"), event_id("o")], + event_id("n") => vec![event_id("o")], + event_id("o") => vec![], // "o" has zero outgoing edges but 4 incoming edges + event_id("p") => vec![event_id("o")], + }; + + let res = + StateResolution::lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone())); + + assert_eq!( + vec!["o", "l", "n", "m", "p"], + res.iter() + .map(ToString::to_string) + .map(|s| s.replace("$", "").replace(":foo", "")) + .collect::>() + ) +} + +// A StateStore implementation for testing +// +// +impl TestStore { + pub fn set_up(&mut self) -> (StateMap, StateMap, StateMap) { + // to activate logging use `RUST_LOG=debug cargo t one_test_only` + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); + let create_event = to_pdu_event::( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + &[], + &[], + ); + let cre = create_event.event_id().clone(); + self.0.insert(cre.clone(), Arc::clone(&create_event)); + + let alice_mem = to_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &[cre.clone()], + &[cre.clone()], + ); + self.0 + .insert(alice_mem.event_id().clone(), Arc::clone(&alice_mem)); + + let join_rules = to_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + &[cre.clone(), alice_mem.event_id().clone()], + &[alice_mem.event_id().clone()], + ); + self.0 + .insert(join_rules.event_id().clone(), join_rules.clone()); + + // Bob and Charlie join at the same time, so there is a fork + // this will be represented in the state_sets when we resolve + let bob_mem = to_pdu_event( + "IMB", + bob(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &[cre.clone(), join_rules.event_id().clone()], + &[join_rules.event_id().clone()], + ); + self.0.insert(bob_mem.event_id().clone(), bob_mem.clone()); + + let charlie_mem = to_pdu_event( + "IMC", + charlie(), + EventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &[cre, join_rules.event_id().clone()], + &[join_rules.event_id().clone()], + ); + self.0 + .insert(charlie_mem.event_id().clone(), charlie_mem.clone()); + + let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] + .iter() + .map(|e| ((e.kind(), e.state_key()), e.event_id().clone())) + .collect::>(); + + let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] + .iter() + .map(|e| ((e.kind(), e.state_key()), e.event_id().clone())) + .collect::>(); + + let expected = [ + &create_event, + &alice_mem, + &join_rules, + &bob_mem, + &charlie_mem, + ] + .iter() + .map(|e| ((e.kind(), e.state_key()), e.event_id().clone())) + .collect::>(); + + (state_at_bob, state_at_charlie, expected) + } +} diff --git a/crates/ruma-state-res/tests/utils.rs b/crates/ruma-state-res/tests/utils.rs new file mode 100644 index 00000000..d487749e --- /dev/null +++ b/crates/ruma-state-res/tests/utils.rs @@ -0,0 +1,792 @@ +#![allow(clippy::or_fun_call, clippy::expect_fun_call, dead_code)] + +use std::{ + collections::{BTreeMap, BTreeSet}, + convert::TryFrom, + sync::{Arc, Once}, + time::{Duration, UNIX_EPOCH}, +}; + +use maplit::btreemap; +use ruma::{ + events::{ + pdu::{EventHash, Pdu, RoomV3Pdu}, + room::{ + join_rules::JoinRule, + member::{MemberEventContent, MembershipState}, + }, + EventType, + }, + EventId, RoomId, RoomVersionId, UserId, +}; +use serde_json::{json, Value as JsonValue}; +use state_res::{Error, Event, Result, StateMap, StateResolution}; +use tracing_subscriber as tracer; + +pub use event::StateEvent; + +pub static LOGGER: Once = Once::new(); + +static mut SERVER_TIMESTAMP: u64 = 0; + +pub fn do_check( + events: &[Arc], + edges: Vec>, + expected_state_ids: Vec, +) { + // to activate logging use `RUST_LOG=debug cargo t` + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); + + let init_events = INITIAL_EVENTS(); + + let mut store = TestStore( + init_events + .values() + .chain(events) + .map(|ev| (ev.event_id().clone(), ev.clone())) + .collect(), + ); + + // This will be lexi_topo_sorted for resolution + let mut graph = BTreeMap::new(); + // this is the same as in `resolve` event_id -> StateEvent + let mut fake_event_map = BTreeMap::new(); + + // create the DB of events that led up to this point + // TODO maybe clean up some of these clones it is just tests but... + for ev in init_events.values().chain(events) { + graph.insert(ev.event_id().clone(), vec![]); + fake_event_map.insert(ev.event_id().clone(), ev.clone()); + } + + for pair in INITIAL_EDGES().windows(2) { + if let [a, b] = &pair { + graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); + } + } + + for edge_list in edges { + for pair in edge_list.windows(2) { + if let [a, b] = &pair { + graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); + } + } + } + + // event_id -> StateEvent + let mut event_map: BTreeMap> = BTreeMap::new(); + // event_id -> StateMap + let mut state_at_event: BTreeMap> = BTreeMap::new(); + + // resolve the current state and add it to the state_at_event map then continue + // on in "time" + for node in + StateResolution::lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone())) + { + let fake_event = fake_event_map.get(&node).unwrap(); + let event_id = fake_event.event_id().clone(); + + let prev_events = graph.get(&node).unwrap(); + + let state_before: StateMap = if prev_events.is_empty() { + BTreeMap::new() + } else if prev_events.len() == 1 { + state_at_event.get(&prev_events[0]).unwrap().clone() + } else { + let state_sets = prev_events + .iter() + .filter_map(|k| state_at_event.get(k)) + .cloned() + .collect::>(); + + log::info!( + "{:#?}", + state_sets + .iter() + .map(|map| map + .iter() + .map(|((ty, key), id)| format!("(({}{:?}), {})", ty, key, id)) + .collect::>()) + .collect::>() + ); + + let resolved = StateResolution::resolve( + &room_id(), + &RoomVersionId::Version6, + &state_sets, + state_sets + .iter() + .map(|map| { + store + .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) + .unwrap() + }) + .collect(), + &mut event_map, + ); + match resolved { + Ok(state) => state, + Err(e) => panic!("resolution for {} failed: {}", node, e), + } + }; + + let mut state_after = state_before.clone(); + + let ty = fake_event.kind(); + let key = fake_event.state_key(); + state_after.insert((ty, key), event_id.clone()); + + let auth_types = state_res::auth_types_for_event( + &fake_event.kind(), + fake_event.sender(), + Some(fake_event.state_key()), + fake_event.content(), + ); + + let mut auth_events = vec![]; + for key in auth_types { + if state_before.contains_key(&key) { + auth_events.push(state_before[&key].clone()) + } + } + + // TODO The event is just remade, adding the auth_events and prev_events here + // the `to_pdu_event` was split into `init` and the fn below, could be better + let e = fake_event; + let ev_id = e.event_id().clone(); + let event = to_pdu_event( + e.event_id().as_str(), + e.sender().clone(), + e.kind().clone(), + Some(&e.state_key()), + e.content(), + &auth_events, + prev_events, + ); + + // we have to update our store, an actual user of this lib would + // be giving us state from a DB. + store.0.insert(ev_id.clone(), event.clone()); + + state_at_event.insert(node, state_after); + event_map.insert(event_id.clone(), Arc::clone(store.0.get(&ev_id).unwrap())); + } + + let mut expected_state = StateMap::new(); + for node in expected_state_ids { + let ev = event_map.get(&node).expect(&format!( + "{} not found in {:?}", + node.to_string(), + event_map + .keys() + .map(ToString::to_string) + .collect::>(), + )); + + let key = (ev.kind(), ev.state_key()); + + expected_state.insert(key, node); + } + + let start_state = state_at_event.get(&event_id("$START:foo")).unwrap(); + + let end_state = state_at_event + .get(&event_id("$END:foo")) + .unwrap() + .iter() + .filter(|(k, v)| { + expected_state.contains_key(k) + || start_state.get(k) != Some(*v) + // Filter out the dummy messages events. + // These act as points in time where there should be a known state to + // test against. + && k != &&(EventType::RoomMessage, "dummy".to_string()) + }) + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + assert_eq!(expected_state, end_state); +} + +pub struct TestStore(pub BTreeMap>); + +#[allow(unused)] +impl TestStore { + pub fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + self.0 + .get(event_id) + .map(Arc::clone) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) + } + + /// Returns the events that correspond to the `event_ids` sorted in the same order. + pub fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result>> { + let mut events = vec![]; + for id in event_ids { + events.push(self.get_event(room_id, id)?); + } + Ok(events) + } + + /// Returns a Vec of the related auth events to the given `event`. + pub fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result> { + let mut result = vec![]; + let mut stack = event_ids.to_vec(); + + // DFS for auth event chain + while !stack.is_empty() { + let ev_id = stack.pop().unwrap(); + if result.contains(&ev_id) { + continue; + } + + result.push(ev_id.clone()); + + let event = self.get_event(room_id, &ev_id)?; + + stack.extend(event.auth_events().clone()); + } + + Ok(result) + } + + /// Returns a Vec representing the difference in auth chains of the given `events`. + pub fn auth_chain_diff( + &self, + room_id: &RoomId, + event_ids: Vec>, + ) -> Result> { + use itertools::Itertools; + let mut chains = vec![]; + for ids in event_ids { + // TODO state store `auth_event_ids` returns self in the event ids list + // when an event returns `auth_event_ids` self is not contained + let chain = self + .auth_event_ids(room_id, &ids)? + .into_iter() + .collect::>(); + chains.push(chain); + } + + if let Some(chain) = chains.first().cloned() { + let rest = chains.iter().skip(1).flatten().cloned().collect(); + let common = chain.intersection(&rest).collect::>(); + + Ok(chains + .into_iter() + .flatten() + .filter(|id| !common.contains(&id)) + .dedup() + .collect()) + } else { + Ok(vec![]) + } + } +} + +pub fn event_id(id: &str) -> EventId { + if id.contains('$') { + return EventId::try_from(id).unwrap(); + } + EventId::try_from(format!("${}:foo", id)).unwrap() +} + +pub fn alice() -> UserId { + UserId::try_from("@alice:foo").unwrap() +} +pub fn bob() -> UserId { + UserId::try_from("@bob:foo").unwrap() +} +pub fn charlie() -> UserId { + UserId::try_from("@charlie:foo").unwrap() +} +pub fn ella() -> UserId { + UserId::try_from("@ella:foo").unwrap() +} +pub fn zara() -> UserId { + UserId::try_from("@zara:foo").unwrap() +} + +pub fn room_id() -> RoomId { + RoomId::try_from("!test:foo").unwrap() +} + +pub fn member_content_ban() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Ban, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +pub fn member_content_join() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +pub fn to_init_pdu_event( + id: &str, + sender: UserId, + ev_type: EventType, + state_key: Option<&str>, + content: JsonValue, +) -> Arc { + let ts = unsafe { + let ts = SERVER_TIMESTAMP; + // increment the "origin_server_ts" value + SERVER_TIMESTAMP += 1; + ts + }; + let id = if id.contains('$') { + id.to_string() + } else { + format!("${}:foo", id) + }; + + let state_key = state_key.map(ToString::to_string); + Arc::new(StateEvent { + event_id: EventId::try_from(id).unwrap(), + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id(), + sender, + origin_server_ts: UNIX_EPOCH + Duration::from_secs(ts), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: btreemap! {}, + #[cfg(not(feature = "unstable-pre-spec"))] + origin: "foo".into(), + auth_events: vec![], + prev_events: vec![], + depth: ruma::uint!(0), + hashes: EventHash { sha256: "".into() }, + signatures: btreemap! {}, + }), + }) +} + +pub fn to_pdu_event( + id: &str, + sender: UserId, + ev_type: EventType, + state_key: Option<&str>, + content: JsonValue, + auth_events: &[S], + prev_events: &[S], +) -> Arc +where + S: AsRef, +{ + let ts = unsafe { + let ts = SERVER_TIMESTAMP; + // increment the "origin_server_ts" value + SERVER_TIMESTAMP += 1; + ts + }; + let id = if id.contains('$') { + id.to_string() + } else { + format!("${}:foo", id) + }; + let auth_events = auth_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + let prev_events = prev_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + + let state_key = state_key.map(ToString::to_string); + Arc::new(StateEvent { + event_id: EventId::try_from(id).unwrap(), + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id(), + sender, + origin_server_ts: UNIX_EPOCH + Duration::from_secs(ts), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: btreemap! {}, + #[cfg(not(feature = "unstable-pre-spec"))] + origin: "foo".into(), + auth_events, + prev_events, + depth: ruma::uint!(0), + hashes: EventHash { sha256: "".into() }, + signatures: btreemap! {}, + }), + }) +} + +// all graphs start with these input events +#[allow(non_snake_case)] +pub fn INITIAL_EVENTS() -> BTreeMap> { + // this is always called so we can init the logger here + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); + + vec![ + to_pdu_event::( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + &[], + &[], + ), + to_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ), + to_pdu_event( + "IPOWER", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice().to_string(): 100}}), + &["CREATE", "IMA"], + &["IMA"], + ), + to_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ), + to_pdu_event( + "IMB", + bob(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IJR"], + ), + to_pdu_event( + "IMC", + charlie(), + EventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IMB"], + ), + to_pdu_event::( + "START", + charlie(), + EventType::RoomMessage, + Some("dummy"), + json!({}), + &[], + &[], + ), + to_pdu_event::( + "END", + charlie(), + EventType::RoomMessage, + Some("dummy"), + json!({}), + &[], + &[], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().clone(), ev)) + .collect() +} + +#[allow(non_snake_case)] +pub fn INITIAL_EDGES() -> Vec { + vec!["START", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] + .into_iter() + .map(event_id) + .collect::>() +} + +pub mod event { + use std::{collections::BTreeMap, time::SystemTime}; + + use ruma::{ + events::{ + pdu::{EventHash, Pdu}, + room::member::{MemberEventContent, MembershipState}, + EventType, + }, + EventId, RoomId, RoomVersionId, ServerName, UInt, UserId, + }; + use serde::{Deserialize, Serialize}; + use serde_json::Value as JsonValue; + + use state_res::Event; + + impl Event for StateEvent { + fn event_id(&self) -> &EventId { + self.event_id() + } + + fn room_id(&self) -> &RoomId { + self.room_id() + } + + fn sender(&self) -> &UserId { + self.sender() + } + fn kind(&self) -> EventType { + self.kind() + } + + fn content(&self) -> serde_json::Value { + self.content() + } + fn origin_server_ts(&self) -> SystemTime { + *self.origin_server_ts() + } + + fn state_key(&self) -> Option { + Some(self.state_key()) + } + fn prev_events(&self) -> Vec { + self.prev_event_ids() + } + fn depth(&self) -> &UInt { + self.depth() + } + fn auth_events(&self) -> Vec { + self.auth_events() + } + fn redacts(&self) -> Option<&EventId> { + self.redacts() + } + fn hashes(&self) -> &EventHash { + self.hashes() + } + fn signatures( + &self, + ) -> BTreeMap, BTreeMap> { + self.signatures() + } + fn unsigned(&self) -> &BTreeMap { + self.unsigned() + } + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + pub struct StateEvent { + pub event_id: EventId, + #[serde(flatten)] + pub rest: Pdu, + } + + impl StateEvent { + pub fn from_id_value( + id: EventId, + json: serde_json::Value, + ) -> Result { + Ok(Self { + event_id: id, + rest: Pdu::RoomV3Pdu(serde_json::from_value(json)?), + }) + } + + pub fn from_id_canon_obj( + id: EventId, + json: ruma::serde::CanonicalJsonObject, + ) -> Result { + Ok(Self { + event_id: id, + // TODO: this is unfortunate (from_value(to_value(json)))... + rest: Pdu::RoomV3Pdu(serde_json::from_value(serde_json::to_value(json)?)?), + }) + } + + pub fn is_power_event(&self) -> bool { + match &self.rest { + Pdu::RoomV1Pdu(event) => match event.kind { + EventType::RoomPowerLevels + | EventType::RoomJoinRules + | EventType::RoomCreate => event.state_key == Some("".into()), + EventType::RoomMember => { + // TODO fix clone + if let Ok(content) = + serde_json::from_value::(event.content.clone()) + { + if [MembershipState::Leave, MembershipState::Ban] + .contains(&content.membership) + { + return event.sender.as_str() + // TODO is None here a failure + != event.state_key.as_deref().unwrap_or("NOT A STATE KEY"); + } + } + + false + } + _ => false, + }, + Pdu::RoomV3Pdu(event) => event.state_key == Some("".into()), + } + } + pub fn deserialize_content( + &self, + ) -> Result { + match &self.rest { + Pdu::RoomV1Pdu(ev) => serde_json::from_value(ev.content.clone()), + Pdu::RoomV3Pdu(ev) => serde_json::from_value(ev.content.clone()), + } + } + pub fn origin_server_ts(&self) -> &SystemTime { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.origin_server_ts, + Pdu::RoomV3Pdu(ev) => &ev.origin_server_ts, + } + } + pub fn event_id(&self) -> &EventId { + &self.event_id + } + + pub fn sender(&self) -> &UserId { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.sender, + Pdu::RoomV3Pdu(ev) => &ev.sender, + } + } + + pub fn redacts(&self) -> Option<&EventId> { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.redacts.as_ref(), + Pdu::RoomV3Pdu(ev) => ev.redacts.as_ref(), + } + } + + pub fn room_id(&self) -> &RoomId { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.room_id, + Pdu::RoomV3Pdu(ev) => &ev.room_id, + } + } + pub fn kind(&self) -> EventType { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.kind.clone(), + Pdu::RoomV3Pdu(ev) => ev.kind.clone(), + } + } + pub fn state_key(&self) -> String { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.state_key.clone().unwrap(), + Pdu::RoomV3Pdu(ev) => ev.state_key.clone().unwrap(), + } + } + + #[cfg(not(feature = "unstable-pre-spec"))] + pub fn origin(&self) -> String { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.origin.clone(), + Pdu::RoomV3Pdu(ev) => ev.origin.clone(), + } + } + + pub fn prev_event_ids(&self) -> Vec { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().map(|(id, _)| id).cloned().collect(), + Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(), + } + } + + pub fn auth_events(&self) -> Vec { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.auth_events.iter().map(|(id, _)| id).cloned().collect(), + Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(), + } + } + + pub fn content(&self) -> serde_json::Value { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.content.clone(), + Pdu::RoomV3Pdu(ev) => ev.content.clone(), + } + } + + pub fn unsigned(&self) -> &BTreeMap { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.unsigned, + Pdu::RoomV3Pdu(ev) => &ev.unsigned, + } + } + + pub fn signatures( + &self, + ) -> BTreeMap, BTreeMap> { + match &self.rest { + Pdu::RoomV1Pdu(_) => maplit::btreemap! {}, + Pdu::RoomV3Pdu(ev) => ev.signatures.clone(), + } + } + + pub fn hashes(&self) -> &EventHash { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.hashes, + Pdu::RoomV3Pdu(ev) => &ev.hashes, + } + } + + pub fn depth(&self) -> &UInt { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.depth, + Pdu::RoomV3Pdu(ev) => &ev.depth, + } + } + + pub fn is_type_and_key(&self, ev_type: EventType, state_key: &str) -> bool { + match &self.rest { + Pdu::RoomV1Pdu(ev) => { + ev.kind == ev_type && ev.state_key.as_deref() == Some(state_key) + } + Pdu::RoomV3Pdu(ev) => { + ev.kind == ev_type && ev.state_key.as_deref() == Some(state_key) + } + } + } + + /// Returns the room version this event is formatted for. + /// + /// Currently either version 1 or 6 is returned, 6 represents + /// version 3 and above. + pub fn room_version(&self) -> RoomVersionId { + // TODO: We have to know the actual room version this is not sufficient + match self.rest { + Pdu::RoomV1Pdu(_) => RoomVersionId::Version1, + Pdu::RoomV3Pdu(_) => RoomVersionId::Version6, + } + } + } +}