From 0c21f38cb1d7e32276ed9e554a7d93b3749d430d Mon Sep 17 00:00:00 2001 From: Devin R Date: Mon, 20 Jul 2020 22:02:29 -0400 Subject: [PATCH] Fixing failing first failing state res test lexicographical_topological_sort test passes. Chasing bug somewhere in resolve. --- Cargo.toml | 6 +- README.md | 11 +- src/event_auth.rs | 4 +- src/lib.rs | 88 ++++-- src/state_event.rs | 17 +- src/state_store.rs | 23 +- tests/event_auth.rs | 0 tests/init.rs | 726 ++++++++++++++++++++++++++++++++++---------- tests/resolve.rs | 1 + 9 files changed, 655 insertions(+), 221 deletions(-) create mode 100644 tests/event_auth.rs create mode 100644 tests/resolve.rs diff --git a/Cargo.toml b/Cargo.toml index bb725f44..1745e72d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,10 +12,12 @@ petgraph = "0.5.1" serde = { version = "1.0.114", features = ["derive"] } serde_json = "1.0.56" tracing = "0.1.16" +maplit = "1.0.2" [dependencies.ruma] -git = "https://github.com/ruma/ruma" +# git = "https://github.com/ruma/ruma" +path = "../__forks__/ruma/ruma" features = ["client-api", "federation-api", "appservice-api"] [dev-dependencies] -maplit = "1.0.2" +lazy_static = "1.4.0" diff --git a/README.md b/README.md index 759c0565..d5a37561 100644 --- a/README.md +++ b/README.md @@ -26,16 +26,7 @@ trait StateStore { fn get_events(&self, event_ids: &[EventId]) -> Result, String>; /// Returns a Vec of the related auth events to the given `event`. - fn auth_event_ids(&self, room_id: &RoomId, event_id: &EventId) -> Result, String>; - - /// Returns a tuple of requested state events from `event_id` and the auth chain events that - /// they relate to the. - fn get_remote_state_for_room( - &self, - room_id: &RoomId, - version: &RoomVersionId, - event_id: &EventId, - ) -> Result<(Vec, Vec), String>; + fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result, String>; } diff --git a/src/event_auth.rs b/src/event_auth.rs index 4b0ebe11..f9c40259 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -20,7 +20,7 @@ pub enum RedactAllowed { No, } -pub(crate) fn auth_types_for_event(event: &StateEvent) -> Vec<(EventType, String)> { +pub fn auth_types_for_event(event: &StateEvent) -> Vec<(EventType, String)> { if event.kind() == EventType::RoomCreate { return vec![]; } @@ -50,7 +50,7 @@ pub(crate) fn auth_types_for_event(event: &StateEvent) -> Vec<(EventType, String auth_types } -pub(crate) fn auth_check( +pub fn auth_check( room_version: &RoomVersionId, event: &StateEvent, auth_events: StateMap, diff --git a/src/lib.rs b/src/lib.rs index de63dfa6..36053264 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ use std::{ - collections::{BTreeMap, BinaryHeap}, + cmp::Reverse, + collections::{BTreeMap, BTreeSet, BinaryHeap}, time::SystemTime, }; +use maplit::btreeset; use ruma::{ events::EventType, identifiers::{EventId, RoomId, RoomVersionId}, @@ -14,6 +16,7 @@ mod room_version; mod state_event; mod state_store; +pub use event_auth::{auth_check, auth_types_for_event}; pub use state_event::StateEvent; pub use state_store::StateStore; @@ -75,7 +78,10 @@ impl StateResolution { tracing::debug!("computing {} conflicting events", conflicting.len()); // the set of auth events that are not common across server forks - let mut auth_diff = self.get_auth_chain_diff(&state_sets, &mut event_map, store)?; + let mut auth_diff = + self.get_auth_chain_diff(room_id, &state_sets, &mut event_map, store)?; + + println!("{:?}", auth_diff); // add the auth_diff to conflicting now we have a full set of conflicting events auth_diff.extend(conflicting.values().cloned().flatten()); @@ -181,7 +187,7 @@ impl StateResolution { /// Split the events that have no conflicts from those that are conflicting. /// /// The tuple looks like `(unconflicted, conflicted)`. - fn separate( + pub fn separate( &mut self, state_sets: &[StateMap], ) -> (StateMap, StateMap>) { @@ -206,8 +212,9 @@ impl StateResolution { } /// Returns a Vec of deduped EventIds that appear in some chains but no others. - fn get_auth_chain_diff( + pub fn get_auth_chain_diff( &mut self, + room_id: &RoomId, state_sets: &[StateMap], _event_map: &EventMap, store: &dyn StateStore, @@ -216,6 +223,7 @@ impl StateResolution { tracing::debug!("calculating auth chain difference"); store.auth_chain_diff( + room_id, &state_sets .iter() .flat_map(|map| map.values()) @@ -224,7 +232,7 @@ impl StateResolution { ) } - fn reverse_topological_power_sort( + pub fn reverse_topological_power_sort( &mut self, room_id: &RoomId, power_events: &[EventId], @@ -272,7 +280,7 @@ impl StateResolution { /// Sorts the event graph based on number of outgoing/incoming edges, where /// `key_fn` is used as a tie breaker. The tie breaker happens based on /// power level, age, and event_id. - fn lexicographical_topological_sort( + pub fn lexicographical_topological_sort( &mut self, graph: &BTreeMap>, key_fn: F, @@ -286,7 +294,12 @@ impl StateResolution { // NOTE: this is basically Kahn's algorithm except we look at nodes with no // outgoing edges, c.f. // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm - let outdegree_map = graph; + + // TODO make the BTreeSet conversion cleaner ?? + let mut outdegree_map: BTreeMap> = graph + .into_iter() + .map(|(k, v)| (k.clone(), v.into_iter().cloned().collect())) + .collect(); let mut reverse_graph = BTreeMap::new(); // Vec of nodes that have zero out degree, least recent events. @@ -294,34 +307,46 @@ impl StateResolution { for (node, edges) in graph.iter() { if edges.is_empty() { - zero_outdegree.push((key_fn(node), node)); + // the `Reverse` is because rusts bin heap sorts largest -> smallest we need + // smallest -> largest + zero_outdegree.push(Reverse((key_fn(node), node))); } - reverse_graph.insert(node, vec![]); + reverse_graph.entry(node).or_insert(btreeset![]); for edge in edges { - reverse_graph.entry(edge).or_insert(vec![]).push(node); + reverse_graph + .entry(edge) + .or_insert(btreeset![]) + .insert(node); } } let mut heap = BinaryHeap::from(zero_outdegree); // we remove the oldest node (most incoming edges) and check against all other - // - while let Some((_, node)) = heap.pop() { + let mut sorted = vec![]; + // match out the `Reverse` and take the smallest `node` each time + while let Some(Reverse((_, node))) = heap.pop() { + let node: &EventId = node; for parent in reverse_graph.get(node).unwrap() { - let out = outdegree_map.get(parent).unwrap(); - if out.iter().filter(|id| *id == node).count() == 0 { - heap.push((key_fn(parent), parent)); + // the number of outgoing edges this node has + let out = outdegree_map.get_mut(parent).unwrap(); + + // only push on the heap once older events have been cleared + out.remove(node); + if out.is_empty() { + heap.push(Reverse((key_fn(parent), parent))); } } + + // synapse yields we push then return the vec + sorted.push(node.clone()); } - // rust BinaryHeap does not iter in order so we gotta do it the long way - let mut sorted = vec![]; - while let Some((_, id)) = heap.pop() { - sorted.push(id.clone()) - } - + // println!( + // "{:#?}", + // sorted.iter().map(ToString::to_string).collect::>() + // ); sorted } @@ -333,7 +358,7 @@ impl StateResolution { store: &dyn StateStore, ) -> i64 { let mut pl = None; - for aid in store.auth_event_ids(room_id, event_id).unwrap() { + for aid in store.auth_event_ids(room_id, &[event_id.clone()]).unwrap() { if let Ok(aev) = store.get_event(&aid) { if aev.is_type_and_key(EventType::RoomPowerLevels, "") { pl = Some(aev); @@ -343,7 +368,7 @@ impl StateResolution { } if pl.is_none() { - for aid in store.auth_event_ids(room_id, event_id).unwrap() { + for aid in store.auth_event_ids(room_id, &[event_id.clone()]).unwrap() { if let Ok(aev) = store.get_event(&aid) { if aev.is_type_and_key(EventType::RoomCreate, "") { if let Ok(content) = aev @@ -384,12 +409,14 @@ impl StateResolution { store: &dyn StateStore, ) -> Result, String> { tracing::debug!("starting iter auth check"); + let resolved_state = unconflicted_state.clone(); + for (idx, event_id) in power_events.iter().enumerate() { let event = store.get_event(event_id).unwrap(); let mut auth_events = BTreeMap::new(); - for aid in store.auth_event_ids(room_id, event_id).unwrap() { + for aid in store.auth_event_ids(room_id, &[event_id.clone()]).unwrap() { if let Ok(ev) = store.get_event(&aid) { // TODO is None the same as "" for state_key, pretty sure it is NOT auth_events.insert((ev.kind(), ev.state_key().unwrap_or_default()), ev); @@ -408,7 +435,11 @@ impl StateResolution { } } - if !event_auth::auth_check(room_version, &event, auth_events).ok_or("".to_string())? {} + if !event_auth::auth_check(room_version, &event, auth_events) + .ok_or("Auth check failed due to deserialization most likely".to_string())? + { + // TODO synapse passes here on AuthError ?? + } // We yield occasionally when we're working with large data sets to // ensure that we don't block the reactor loop for too long. @@ -441,7 +472,7 @@ impl StateResolution { while let Some(p) = pl { mainline.push(p.clone()); // We don't need the actual pl_ev here since we delegate to the store - let auth_events = store.auth_event_ids(room_id, &p).unwrap(); + let auth_events = store.auth_event_ids(room_id, &[p]).unwrap(); pl = None; for aid in auth_events { let ev = store.get_event(&aid).unwrap(); @@ -490,6 +521,7 @@ impl StateResolution { } // sort the event_ids by their depth, timestamp and EventId + // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap()); sort_event_ids @@ -510,7 +542,7 @@ impl StateResolution { } let auth_events = if let Some(id) = sort_ev.event_id() { - store.auth_event_ids(room_id, id).unwrap() + store.auth_event_ids(room_id, &[id.clone()]).unwrap() } else { vec![] }; @@ -542,7 +574,7 @@ impl StateResolution { let eid = state.pop().unwrap(); graph.insert(eid.clone(), vec![]); - for aid in store.auth_event_ids(room_id, &eid).unwrap() { + for aid in store.auth_event_ids(room_id, &[eid.clone()]).unwrap() { if auth_diff.contains(&aid) { if !graph.contains_key(&aid) { state.push(aid.clone()); diff --git a/src/state_event.rs b/src/state_event.rs index d8df36ce..d24f5e3b 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -171,7 +171,7 @@ impl StateEvent { pub fn prev_event_ids(&self) -> Vec { match self { Self::Full(ev) => match ev { - Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().map(|(id, _)| id).cloned().collect(), + Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().cloned().collect(), Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(), }, Self::Sync(ev) => match ev { @@ -183,6 +183,21 @@ impl StateEvent { } } + pub fn auth_event_ids(&self) -> Vec { + match self { + Self::Full(ev) => match ev { + Pdu::RoomV1Pdu(ev) => ev.auth_events.iter().cloned().collect(), + Pdu::RoomV3Pdu(ev) => ev.auth_events.clone(), + }, + Self::Sync(ev) => match ev { + PduStub::RoomV1PduStub(ev) => { + ev.auth_events.iter().map(|(id, _)| id).cloned().collect() + } + PduStub::RoomV3PduStub(ev) => ev.auth_events.clone(), + }, + } + } + pub fn content(&self) -> serde_json::Value { match self { Self::Full(ev) => match ev { diff --git a/src/state_store.rs b/src/state_store.rs index 37c357e0..757444ab 100644 --- a/src/state_store.rs +++ b/src/state_store.rs @@ -1,4 +1,4 @@ -use ruma::identifiers::{EventId, RoomId, RoomVersionId}; +use ruma::identifiers::{EventId, RoomId}; use crate::StateEvent; @@ -10,17 +10,16 @@ pub trait StateStore { fn get_events(&self, event_ids: &[EventId]) -> Result, String>; /// Returns a Vec of the related auth events to the given `event`. - fn auth_event_ids(&self, room_id: &RoomId, event_id: &EventId) -> Result, String>; - - /// Returns a Vec representing the difference in auth chains of the given `events`. - fn auth_chain_diff(&self, event_id: &[&EventId]) -> Result, String>; - - /// Returns a tuple of requested state events from `event_id` and the auth chain events that - /// relate to the. - fn get_remote_state_for_room( + fn auth_event_ids( &self, room_id: &RoomId, - version: &RoomVersionId, - event_id: &EventId, - ) -> Result<(Vec, Vec), String>; + event_ids: &[EventId], + ) -> Result, String>; + + /// Returns a Vec representing the difference in auth chains of the given `events`. + fn auth_chain_diff( + &self, + room_id: &RoomId, + event_id: &[&EventId], + ) -> Result, String>; } diff --git a/tests/event_auth.rs b/tests/event_auth.rs new file mode 100644 index 00000000..e69de29b diff --git a/tests/init.rs b/tests/init.rs index f0aa8839..7ab1d278 100644 --- a/tests/init.rs +++ b/tests/init.rs @@ -1,208 +1,602 @@ -use std::{collections::BTreeMap, convert::TryFrom}; +#![allow(unused)] + +use std::{ + collections::{BTreeMap, BTreeSet}, + convert::TryFrom, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; use maplit::btreemap; use ruma::{ events::{ - room::{self}, - AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, EventType, + pdu::Pdu, + room::{ + join_rules::JoinRule, + member::{MemberEventContent, MembershipState}, + }, + EventType, }, - identifiers::{EventId, RoomId, RoomVersionId}, + identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{from_value as from_json_value, json, Value as JsonValue}; -use state_res::{ResolutionResult, StateEvent, StateResolution, StateStore}; +use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; -// TODO make this an array of events -fn federated_json() -> JsonValue { - json!({ - "content": { - "creator": "@example:example.org", - "m.federate": true, - "predecessor": { - "event_id": "$something:example.org", - "room_id": "!oldroom:example.org" - }, - "room_version": "6" - }, - "event_id": "$aaa:example.org", - "origin_server_ts": 1, - "room_id": "!room_id:example.org", - "sender": "@alice:example.org", - "state_key": "", - "type": "m.room.create", - "unsigned": { - "age": 1234 - } - }) +static mut SERVER_TIMESTAMP: i32 = 0; + +fn id(id: &str) -> EventId { + EventId::try_from(format!("${}:foo", id)).unwrap() } -fn room_create() -> JsonValue { - json!({ - "content": { - "creator": "@example:example.org", - "m.federate": true, - "predecessor": { - "event_id": "$something:example.org", - "room_id": "!oldroom:example.org" - }, - "room_version": "6" - }, - "event_id": "$aaa:example.org", - "origin_server_ts": 1, - "room_id": "!room_id:example.org", - "sender": "@alice:example.org", - "state_key": "", - "type": "m.room.create", - "unsigned": { - "age": 1234 - } - }) +fn alice() -> UserId { + UserId::try_from("@alice:example.com").unwrap() +} +fn bobo() -> UserId { + UserId::try_from("@bobo:example.com").unwrap() +} +fn devin() -> UserId { + UserId::try_from("@devin:example.com").unwrap() +} +fn zera() -> UserId { + UserId::try_from("@zera:example.com").unwrap() } -fn join_rules() -> JsonValue { - json!({ - "content": { - "join_rule": "public" - }, - "event_id": "$bbb:example.org", - "origin_server_ts": 2, - "room_id": "!room_id:example.org", - "sender": "@alice:example.org", - "state_key": "", - "type": "m.room.join_rules", - "unsigned": { - "age": 1234 - } - }) +fn room_id() -> RoomId { + RoomId::try_from("!test:example.com").unwrap() } -fn join_event() -> JsonValue { - json!({ - "content": { - "avatar_url": null, - "displayname": "example", - "membership": "join" - }, - "event_id": "$ccc:example.org", - "membership": "join", - "room_id": "!room_id:example.org", - "origin_server_ts": 3, - "sender": "@alice:example.org", - "state_key": "@alice:example.org", - "type": "m.room.member", - "unsigned": { - "age": 1, - "replaces_state": "$151800111315tsynI:example.org", - "prev_content": { - "avatar_url": null, - "displayname": "example", - "membership": "invite" - } - } +fn member_content_ban() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Ban, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, }) + .unwrap() +} +fn member_content_join() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() } -fn power_levels() -> JsonValue { - json!({ - "content": { - "ban": 50, - "events": { - "m.room.name": 100, - "m.room.power_levels": 100 - }, - "events_default": 0, - "invite": 50, - "kick": 50, - "notifications": { - "room": 20 - }, - "redact": 50, - "state_default": 50, - "users": { - "@example:example.org": 100 - }, - "users_default": 0 - }, - "event_id": "$ddd:example.org", - "origin_server_ts": 4, - "room_id": "!room_id:example.org", - "sender": "@example:example.org", - "state_key": "", - "type": "m.room.power_levels", - "unsigned": { - "age": 1234 - } - }) +fn to_pdu_event( + id: &str, + sender: UserId, + ev_type: EventType, + state_key: Option<&str>, + content: JsonValue, + auth_events: &[EventId], + prev_events: &[EventId], +) -> StateEvent { + let ts = unsafe { + let ts = SERVER_TIMESTAMP; + // increment the "origin_server_ts" value + SERVER_TIMESTAMP += 1; + ts + }; + let id = if id.contains("$") { + id.to_string() + } else { + format!("${}:foo", id) + }; + + let json = if let Some(state_key) = state_key { + json!({ + "auth_events": auth_events, + "prev_events": prev_events, + "event_id": id, + "sender": sender, + "type": ev_type, + "state_key": state_key, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + } else { + json!({ + "auth_events": auth_events, + "prev_events": prev_events, + "event_id": id, + "sender": sender, + "type": ev_type, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + }; + serde_json::from_value(json).unwrap() } -pub struct TestStore; +fn to_init_pdu_event( + id: &str, + sender: UserId, + ev_type: EventType, + state_key: Option<&str>, + content: JsonValue, +) -> StateEvent { + let ts = unsafe { + let ts = SERVER_TIMESTAMP; + // increment the "origin_server_ts" value + SERVER_TIMESTAMP += 1; + ts + }; + let id = if id.contains("$") { + id.to_string() + } else { + format!("${}:foo", id) + }; + let json = if let Some(state_key) = state_key { + json!({ + "auth_events": [], + "prev_events": [], + "event_id": id, + "sender": sender, + "type": ev_type, + "state_key": state_key, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + } else { + json!({ + "auth_events": [], + "prev_events": [], + "event_id": id, + "sender": sender, + "type": ev_type, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + }; + serde_json::from_value(json).unwrap() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn INITIAL_EVENTS() -> BTreeMap { + vec![ + to_init_pdu_event( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + ), + to_init_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + ), + to_init_pdu_event( + "IPOWER", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice().to_string(): 100}}), + ), + to_init_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + ), + to_init_pdu_event( + "IMB", + bobo(), + EventType::RoomMember, + Some(bobo().to_string().as_str()), + member_content_join(), + ), + to_init_pdu_event( + "IMC", + devin(), + EventType::RoomMember, + Some(devin().to_string().as_str()), + member_content_join(), + ), + to_init_pdu_event( + "IMZ", + zera(), + EventType::RoomMember, + Some(zera().to_string().as_str()), + member_content_join(), + ), + to_init_pdu_event("START", zera(), EventType::RoomMessage, None, json!({})), + to_init_pdu_event("END", zera(), EventType::RoomMessage, None, json!({})), + ] + .into_iter() + .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .collect() +} + +#[allow(non_snake_case)] +fn INITIAL_EDGES() -> Vec { + vec![ + "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", + ] + .into_iter() + .map(|s| format!("${}:foo", s)) + .map(EventId::try_from) + .collect::, _>>() + .unwrap() +} + +pub struct TestStore(BTreeMap); + +#[allow(unused)] impl StateStore for TestStore { fn get_events(&self, events: &[EventId]) -> Result, String> { - vec![room_create(), join_rules(), join_event(), power_levels()] - .into_iter() - .map(from_json_value) - .collect::>>() - .map_err(|e| e.to_string()) + Ok(self + .0 + .iter() + .filter(|e| events.contains(e.0)) + .map(|(_, s)| s) + .cloned() + .collect()) } fn get_event(&self, event_id: &EventId) -> Result { - from_json_value(power_levels()).map_err(|e| e.to_string()) + self.0 + .get(event_id) + .cloned() + .ok_or(format!("{} not found", event_id.to_string())) } - fn auth_event_ids(&self, room_id: &RoomId, event_id: &EventId) -> Result, String> { - Ok(vec![ - EventId::try_from("$aaa:example.org").map_err(|e| e.to_string())? - ]) - } - - fn auth_chain_diff(&self, event_id: &[&EventId]) -> Result, String> { - Ok(vec![]) - } - - fn get_remote_state_for_room( + fn auth_event_ids( &self, room_id: &RoomId, - version: &RoomVersionId, - event_id: &EventId, - ) -> Result<(Vec, Vec), String> { - Ok(( - vec![from_json_value(federated_json()).map_err(|e| e.to_string())?], - vec![from_json_value(power_levels()).map_err(|e| e.to_string())?], - )) + event_ids: &[EventId], + ) -> Result, String> { + let mut result = vec![]; + let mut stack = event_ids.to_vec(); + + while !stack.is_empty() { + let ev_id = stack.pop().unwrap(); + if result.contains(&ev_id) { + continue; + } + + result.push(ev_id.clone()); + + let event = self.get_event(&ev_id).unwrap(); + for aid in event.auth_event_ids() { + stack.push(aid); + } + } + + Ok(result) } + + fn auth_chain_diff( + &self, + room_id: &RoomId, + event_ids: &[&EventId], + ) -> Result, String> { + let mut chains = BTreeSet::new(); + let mut list = vec![]; + for id in event_ids { + let chain = self + .auth_event_ids(room_id, &[(*id).clone()])? + .into_iter() + .collect::>(); + list.push(chain.clone()); + chains.insert(chain); + } + if let Some(chain) = list.first() { + let set = maplit::btreeset!(chain.clone()); + let common = set.intersection(&chains).flatten().collect::>(); + Ok(chains + .iter() + .flatten() + .filter(|id| common.contains(&id)) + .cloned() + .collect()) + } else { + Ok(vec![]) + } + } +} + +fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: Vec) { + let mut resolver = StateResolution::default(); + // TODO what do we fill this with, everything ?? + let store = TestStore( + INITIAL_EVENTS() + .values() + .chain(events) + .map(|ev| (ev.event_id().unwrap().clone(), ev.clone())) + .collect(), + ); + + // This will be lexi_topo_sorted for resolution + let mut graph = BTreeMap::new(); + // this is the same as in `resolve` event_id -> StateEvent + let mut fake_event_map = BTreeMap::new(); + + // TODO maybe clean up some of these clones it is just tests but... + for ev in INITIAL_EVENTS().values().chain(events) { + graph.insert(ev.event_id().unwrap().clone(), vec![]); + fake_event_map.insert(ev.event_id().unwrap().clone(), ev.clone()); + } + + for edge_list in edges { + for pair in edge_list.windows(2) { + if let &[a, b] = &pair { + graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); + } + } + } + + // event_id -> StateEvent + let mut event_map: BTreeMap = BTreeMap::new(); + // event_id -> StateMap + let mut state_at_event: BTreeMap> = BTreeMap::new(); + + // resolve the current state and add it to the state_at_event map then continue + // on in "time"? + for node in resolver + // TODO is this `key_fn` return correct ?? + .lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))) + { + let fake_event = fake_event_map.get(&node).unwrap(); + let event_id = fake_event.event_id().unwrap(); + + let prev_events = graph.get(&node).unwrap(); + + let state_before: StateMap = if prev_events.is_empty() { + BTreeMap::new() + } else if prev_events.len() == 1 { + state_at_event.get(&prev_events[0]).unwrap().clone() + } else { + let state_sets = prev_events + .iter() + .filter_map(|k| state_at_event.get(k)) + .cloned() + .collect::>(); + + // println!( + // "resolving {:#?}", + // state_sets + // .iter() + // .map(|map| map + // .iter() + // .map(|((t, s), id)| (t, s, id.to_string())) + // .collect::>()) + // .collect::>() + // ); + + let resolved = + resolver.resolve(&room_id(), &RoomVersionId::version_1(), &state_sets, &store); + match resolved { + Ok(ResolutionResult::Resolved(state)) => state, + _ => panic!("resolution for {} failed", node), + } + }; + + let mut state_after = state_before.clone(); + if fake_event.state_key().is_some() { + let ty = fake_event.kind().clone(); + let key = fake_event.state_key().unwrap().clone(); + state_after.insert((ty, key), event_id.clone()); + } + + let auth_types = state_res::auth_types_for_event(fake_event) + .into_iter() + .collect::>(); + // println!( + // "{:?}", + // auth_types + // .iter() + // .map(|(t, id)| (t, id.to_string())) + // .collect::>() + // ); + + let mut auth_events = vec![]; + for key in auth_types { + if state_before.contains_key(&key) { + auth_events.push(state_before[&key].clone()) + } + } + + // TODO The event is just remade, adding the auth_events and prev_events here + // UPDATE: the `to_pdu_event` was split into `init` and the fn below, could be better + let e = fake_event; + let event = to_pdu_event( + &e.event_id().unwrap().to_string(), + e.sender().clone(), + e.kind(), + e.state_key().as_deref(), + e.content(), + &auth_events, + prev_events, + ); + + state_at_event.insert(node, state_after); + event_map.insert(event_id.clone(), event); + } + + let mut expected_state = BTreeMap::new(); + for node in expected_state_ids { + let ev = event_map.get(&node).expect(&format!( + "{} not found in {:?}", + node.to_string(), + event_map + .keys() + .map(ToString::to_string) + .collect::>(), + )); + + let key = (ev.kind(), ev.state_key().unwrap()); + + expected_state.insert(key, node); + } + + let start_state = state_at_event + .get(&EventId::try_from("$START:foo").unwrap()) + .unwrap(); + + println!("{:?}", start_state); + + let end_state = state_at_event + .get(&EventId::try_from("$END:foo").unwrap()) + .unwrap() + .iter() + .filter(|(k, v)| expected_state.contains_key(k) || start_state.get(k) != Some(*v)) + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + assert_eq!(expected_state, end_state); } #[test] -fn it_works() { - let mut store = TestStore; +fn ban_vs_power_level() { + let events = &[ + to_init_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bobo(): 50}}), + ), + to_init_pdu_event( + "MA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + ), + to_init_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(bobo().to_string().as_str()), + member_content_ban(), + ), + to_init_pdu_event( + "PB", + bobo(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bobo(): 50}}), + ), + ]; - let room_id = RoomId::try_from("!room_id:example.org").unwrap(); - let room_version = RoomVersionId::version_6(); + let edges = vec![ + vec!["END", "MB", "MA", "PA", "START"], + vec!["END", "PB", "PA"], + ] + .into_iter() + .map(|list| { + list.into_iter() + .map(|s| format!("${}:foo", s)) + .map(EventId::try_from) + .collect::, _>>() + .unwrap() + }) + .collect::>(); - let initial_state = btreemap! { - (EventType::RoomCreate, "".into()) => EventId::try_from("$aaa:example.org").unwrap(), - }; + let expected_state_ids = vec!["PA", "MA", "MB"] + .into_iter() + .map(|s| format!("${}:foo", s)) + .map(EventId::try_from) + .collect::, _>>() + .unwrap(); - let state_to_resolve = btreemap! { - (EventType::RoomCreate, "".into()) => EventId::try_from("$bbb:example.org").unwrap(), - }; + do_check(events, edges, expected_state_ids) +} +// #[test] +fn topic_reset() { + let events = &[ + to_init_pdu_event("T1", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bobo(): 50}}), + ), + to_init_pdu_event("T2", bobo(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(bobo().to_string().as_str()), + member_content_ban(), + ), + ]; + + let edges = vec![ + vec!["END", "MB", "T2", "PA", "T1", "START"], + vec!["END", "T1"], + ] + .into_iter() + .map(|list| { + list.into_iter() + .map(|s| format!("${}:foo", s)) + .map(EventId::try_from) + .collect::, _>>() + .unwrap() + }) + .collect::>(); + + let expected_state_ids = vec!["T1", "MB", "PA"] + .into_iter() + .map(|s| format!("${}:foo", s)) + .map(EventId::try_from) + .collect::, _>>() + .unwrap(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn test_lexicographical_sort() { let mut resolver = StateResolution::default(); - let res = resolver - .resolve(&room_id, &room_version, &[initial_state], &mut store) - .unwrap(); - assert!(if let ResolutionResult::Resolved(_) = res { - true - } else { - false - }); + let graph = btreemap! { + id("l") => vec![id("o")], + id("m") => vec![id("n"), id("o")], + id("n") => vec![id("o")], + id("o") => vec![], // "o" has zero outgoing edges but 4 incoming edges + id("p") => vec![id("o")], + }; - let resolved = resolver - .resolve(&room_id, &room_version, &[state_to_resolve], &mut store) - .unwrap(); + let res = + resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))); - assert!(resolver.conflicting_events.is_empty()); - assert_eq!(resolver.resolved_events.len(), 3); - assert_eq!(resolver.resolved_events.len(), 3); + assert_eq!( + vec!["o", "l", "n", "m", "p"], + res.iter() + .map(ToString::to_string) + .map(|s| s.replace("$", "").replace(":foo", "")) + .collect::>() + ) } diff --git a/tests/resolve.rs b/tests/resolve.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/resolve.rs @@ -0,0 +1 @@ +