diff --git a/Cargo.toml b/Cargo.toml index 5f77a127..fce5f776 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" petgraph = "0.5.1" serde = { version = "1.0.114", features = ["derive"] } serde_json = "1.0.56" +maplit = "1.0.2" [dependencies.ruma] git = "https://github.com/ruma/ruma" diff --git a/src/lib.rs b/src/lib.rs index fa2a6c50..1fea5903 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,9 +16,14 @@ mod state_store; pub use state_event::StateEvent; pub use state_store::StateStore; +// We want to yield to the reactor occasionally during state res when dealing +// with large data sets, so that we don't exhaust the reactor. This is done by +// yielding to reactor during loops every N iterations. +const _YIELD_AFTER_ITERATIONS: usize = 100; + pub enum ResolutionResult { Conflicted(Vec>), - Resolved(Vec>), + Resolved(StateMap), } /// A mapping of event type and state_key to some value `T`, usually an `EventId`. @@ -53,15 +58,17 @@ impl StateResolution { room_id: &RoomId, room_version: &RoomVersionId, state_sets: Vec>, - store: &mut dyn StateStore, + store: &dyn StateStore, // TODO actual error handling (`thiserror`??) ) -> Result { let mut event_map = EventMap::new(); // split non-conflicting and conflicting state - let (clean, mut conflicting) = self.seperate(&state_sets); + let (clean, conflicting) = self.seperate(&state_sets); if conflicting.is_empty() { - return Ok(ResolutionResult::Resolved(clean)); + return Ok(ResolutionResult::Resolved( + clean.into_iter().flat_map(|map| map.into_iter()).collect(), + )); } // the set of auth events that are not common across server forks @@ -71,17 +78,18 @@ impl StateResolution { auth_diff.extend(conflicting.iter().flat_map(|map| map.values().cloned())); let all_conflicted = auth_diff; - let all_conflicted = conflicting; + // TODO get events and add to event_map + // TODO throw error if event is not for this room + // TODO make sure each conflicting event is in?? event_map `{eid for eid in full_conflicted_set if eid in event_map}` let power_events = all_conflicted .iter() - .filter(is_power_event) - .flat_map(|map| map.values()) + .filter(|id| is_power_event(id, store)) .cloned() .collect::>(); // sort the power events based on power_level/clock/event_id and outgoing/incoming edges - let sorted_power_levels = self.revers_topological_power_sort( + let mut sorted_power_levels = self.revers_topological_power_sort( room_id, &power_events, &mut event_map, @@ -93,16 +101,48 @@ impl StateResolution { let resolved = self.iterative_auth_check( room_id, room_version, - &power_events, + &sorted_power_levels, &clean, &mut event_map, store, ); + // At this point the power_events have been resolved we now have to + // sort the remaining events using the mainline of the resolved power level. + + sorted_power_levels.dedup(); + let deduped_power_ev = sorted_power_levels; + + let events_to_resolve = all_conflicted + .iter() + .filter(|id| deduped_power_ev.contains(id)) + .cloned() + .collect::>(); + + let power_event = resolved.get(&(EventType::RoomPowerLevels, "".into())); + + let sorted_left_events = + self.mainline_sort(room_id, &events_to_resolve, power_event, &event_map, store); + + let mut resolved_state = self.iterative_auth_check( + room_id, + room_version, + &sorted_left_events, + &[resolved], + &mut event_map, + store, + ); + + // add unconflicted state to the resolved state + resolved_state.extend(clean.into_iter().flat_map(|map| map.into_iter())); + // TODO return something not a place holder - Ok(ResolutionResult::Resolved(vec![])) + Ok(ResolutionResult::Resolved(resolved_state)) } + /// Split the events that have no conflicts from those that are conflicting. + /// + /// The tuple looks like `(unconflicted, conflicted)`. fn seperate( &mut self, state_sets: &[StateMap], @@ -115,7 +155,7 @@ impl StateResolution { &mut self, state_sets: &[StateMap], event_map: &EventMap, - store: &mut dyn StateStore, + store: &dyn StateStore, ) -> Result, serde_json::Error> { panic!() } @@ -125,9 +165,9 @@ impl StateResolution { room_id: &RoomId, power_events: &[EventId], event_map: &EventMap, - store: &mut dyn StateStore, - conflicted_set: &[StateMap], - ) -> Vec { + store: &dyn StateStore, + conflicted_set: &[EventId], + ) -> Vec { panic!() } @@ -138,14 +178,125 @@ impl StateResolution { power_events: &[EventId], unconflicted_state: &[StateMap], event_map: &EventMap, - store: &mut dyn StateStore, - ) -> Vec { + store: &dyn StateStore, + ) -> StateMap { panic!() } + + /// Returns the sorted `to_sort` list of `EventId`s based on a mainline sort using + /// the `resolved_power_level`. + fn mainline_sort( + &mut self, + room_id: &RoomId, + to_sort: &[EventId], + resolved_power_level: Option<&EventId>, + event_map: &EventMap, + store: &dyn StateStore, + ) -> Vec { + // There can be no EventId's to sort, bail. + if to_sort.is_empty() { + return vec![]; + } + + let mut mainline = vec![]; + let mut pl = resolved_power_level.cloned(); + let mut idx = 0; + while let Some(p) = pl { + mainline.push(p.clone()); + // We don't need the actual pl_ev here since we delegate to the store + let auth_events = store.auth_event_ids(room_id, &p).unwrap(); + pl = None; + for aid in auth_events { + let ev = store.get_event(&aid).unwrap(); + if ev.is_type_and_key(EventType::RoomPowerLevels, "") { + pl = Some(aid); + break; + } + } + // We yield occasionally when we're working with large data sets to + // ensure that we don't block the reactor loop for too long. + if idx != 0 && idx % _YIELD_AFTER_ITERATIONS == 0 { + // yield clock.sleep(0) + } + idx += 1; + } + + let mainline_map = mainline + .iter() + .enumerate() + .map(|(idx, eid)| ((*eid).clone(), idx)) + .collect::>(); + let mut sort_event_ids = to_sort.to_vec(); + + let mut order_map = BTreeMap::new(); + for (idx, ev_id) in to_sort.iter().enumerate() { + let depth = self.get_mainline_depth( + room_id, + event_map.get(ev_id).cloned(), + &mainline_map, + store, + ); + order_map.insert( + ev_id, + ( + depth, + event_map.get(ev_id).map(|ev| ev.origin_server_ts()), + ev_id, // TODO should this be a &str to sort lexically?? + ), + ); + + // We yield occasionally when we're working with large data sets to + // ensure that we don't block the reactor loop for too long. + if idx % _YIELD_AFTER_ITERATIONS == 0 { + // yield clock.sleep(0) + } + } + + // sort the event_ids by their depth, timestamp and EventId + sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap()); + + sort_event_ids + } + + fn get_mainline_depth( + &mut self, + room_id: &RoomId, + mut event: Option, + mainline_map: &EventMap, + store: &dyn StateStore, + ) -> usize { + while let Some(sort_ev) = event { + if let Some(id) = sort_ev.event_id() { + if let Some(depth) = mainline_map.get(id) { + return *depth; + } + } + + let auth_events = if let Some(id) = sort_ev.event_id() { + store.auth_event_ids(room_id, id).unwrap() + } else { + vec![] + }; + event = None; + + for aid in auth_events { + let aev = store.get_event(&aid).unwrap(); + if aev.is_type_and_key(EventType::RoomPowerLevels, "") { + event = Some(aev); + break; + } + } + } + // Did not find a power level event so we default to zero + 0 + } } -pub fn is_power_event(event: &&StateMap) -> bool { - true +pub fn is_power_event(event_id: &EventId, store: &dyn StateStore) -> bool { + match store.get_event(event_id) { + Ok(state) => state.is_power_event(), + _ => false, // TODO this shouldn't eat errors + } } #[cfg(test)] diff --git a/src/state_event.rs b/src/state_event.rs index a42c1109..33954df4 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -1,8 +1,13 @@ -use ruma::events::{ - from_raw_json_value, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, EventDeHelper, +use ruma::{ + events::{ + from_raw_json_value, room::member::MembershipState, AnyStateEvent, AnyStrippedStateEvent, + AnySyncStateEvent, EventDeHelper, EventType, + }, + identifiers::{EventId, RoomId}, }; use serde::{de, Serialize}; use serde_json::value::RawValue as RawJsonValue; +use std::{convert::TryFrom, time::SystemTime}; #[derive(Clone, Debug, Serialize)] #[serde(untagged)] @@ -12,6 +17,74 @@ pub enum StateEvent { Stripped(AnyStrippedStateEvent), } +impl StateEvent { + pub fn is_power_event(&self) -> bool { + match self { + Self::Full(any_event) => match any_event { + AnyStateEvent::RoomPowerLevels(event) => event.state_key == "", + AnyStateEvent::RoomJoinRules(event) => event.state_key == "", + AnyStateEvent::RoomCreate(event) => event.state_key == "", + AnyStateEvent::RoomMember(event) => { + if [MembershipState::Leave, MembershipState::Ban] + .contains(&event.content.membership) + { + return event.sender.as_str() != event.state_key; + } + false + } + _ => false, + }, + Self::Sync(any_event) => match any_event { + AnySyncStateEvent::RoomPowerLevels(event) => event.state_key == "", + AnySyncStateEvent::RoomJoinRules(event) => event.state_key == "", + AnySyncStateEvent::RoomCreate(event) => event.state_key == "", + AnySyncStateEvent::RoomMember(event) => { + if [MembershipState::Leave, MembershipState::Ban] + .contains(&event.content.membership) + { + return event.sender.as_str() != event.state_key; + } + false + } + _ => false, + }, + Self::Stripped(any_event) => match any_event { + AnyStrippedStateEvent::RoomPowerLevels(event) => event.state_key == "", + AnyStrippedStateEvent::RoomJoinRules(event) => event.state_key == "", + AnyStrippedStateEvent::RoomCreate(event) => event.state_key == "", + AnyStrippedStateEvent::RoomMember(event) => { + if [MembershipState::Leave, MembershipState::Ban] + .contains(&event.content.membership) + { + return event.sender.as_str() != event.state_key; + } + false + } + _ => false, + }, + _ => false, + } + } + pub fn origin_server_ts(&self) -> Option<&SystemTime> { + match self { + Self::Full(ev) => Some(ev.origin_server_ts()), + Self::Sync(ev) => Some(ev.origin_server_ts()), + Self::Stripped(ev) => None, + } + } + pub fn event_id(&self) -> Option<&EventId> { + match self { + Self::Full(ev) => Some(ev.event_id()), + Self::Sync(ev) => Some(ev.event_id()), + Self::Stripped(ev) => None, + } + } + + pub fn is_type_and_key(&self, ev_type: EventType, state_key: &str) -> bool { + true + } +} + impl<'de> de::Deserialize<'de> for StateEvent { fn deserialize(deserializer: D) -> Result where diff --git a/src/state_store.rs b/src/state_store.rs index f84e8127..d5c7c862 100644 --- a/src/state_store.rs +++ b/src/state_store.rs @@ -12,8 +12,14 @@ use ruma::{ use crate::StateEvent; pub trait StateStore { + /// Return a single event based on the EventId. + fn get_event(&self, event_id: &EventId) -> Result; + /// Returns the events that correspond to the `event_ids` sorted in the same order. - fn get_events(&self, event_ids: &[EventId]) -> Result, serde_json::Error>; + fn get_events(&self, event_ids: &[EventId]) -> Result, String>; + + /// Returns a Vec of the related auth events to the given `event`. + fn auth_event_ids(&self, room_id: &RoomId, event_id: &EventId) -> Result, String>; /// Returns a tuple of requested state events from `event_id` and the auth chain events that /// relate to the. @@ -22,5 +28,5 @@ pub trait StateStore { room_id: &RoomId, version: &RoomVersionId, event_id: &EventId, - ) -> Result<(Vec, Vec), serde_json::Error>; + ) -> Result<(Vec, Vec), String>; } diff --git a/tests/init.rs b/tests/init.rs index 93efbc12..88c64260 100644 --- a/tests/init.rs +++ b/tests/init.rs @@ -1,5 +1,6 @@ -use std::convert::TryFrom; +use std::{collections::BTreeMap, convert::TryFrom}; +use maplit::btreemap; use ruma::{ events::{ room::{self}, @@ -136,8 +137,22 @@ fn power_levels() -> JsonValue { pub struct TestStore; impl StateStore for TestStore { - fn get_events(&self, events: &[EventId]) -> Result, serde_json::Error> { - Ok(vec![from_json_value(power_levels())?]) + fn get_events(&self, events: &[EventId]) -> Result, String> { + vec![room_create(), join_rules(), join_event(), power_levels()] + .into_iter() + .map(from_json_value) + .collect::>>() + .map_err(|e| e.to_string()) + } + + fn get_event(&self, event_id: &EventId) -> Result { + from_json_value(power_levels()).map_err(|e| e.to_string()) + } + + fn auth_event_ids(&self, room_id: &RoomId, event_id: &EventId) -> Result, String> { + Ok(vec![ + EventId::try_from("$aaa:example.org").map_err(|e| e.to_string())? + ]) } fn get_remote_state_for_room( @@ -145,10 +160,10 @@ impl StateStore for TestStore { room_id: &RoomId, version: &RoomVersionId, event_id: &EventId, - ) -> Result<(Vec, Vec), serde_json::Error> { + ) -> Result<(Vec, Vec), String> { Ok(( - vec![from_json_value(federated_json())?], - vec![from_json_value(power_levels())?], + vec![from_json_value(federated_json()).map_err(|e| e.to_string())?], + vec![from_json_value(power_levels()).map_err(|e| e.to_string())?], )) } } @@ -160,14 +175,18 @@ fn it_works() { let room_id = RoomId::try_from("!room_id:example.org").unwrap(); let room_version = RoomVersionId::version_6(); - let a = from_json_value::(room_create()).unwrap(); - let b = from_json_value::(join_rules()).unwrap(); - let c = from_json_value::(join_event()).unwrap(); + let initial_state = btreemap! { + (EventType::RoomCreate, "".into()) => EventId::try_from("").unwrap(), + }; + + let state_to_resolve = btreemap! { + (EventType::RoomCreate, "".into()) => EventId::try_from("").unwrap(), + }; let mut resolver = StateResolution::default(); let res = resolver - .resolve(&room_id, &room_version, vec![a.clone()], &mut store) + .resolve(&room_id, &room_version, vec![initial_state], &mut store) .unwrap(); assert!(if let ResolutionResult::Resolved(_) = res { true @@ -176,9 +195,10 @@ fn it_works() { }); let resolved = resolver - .resolve(&room_id, &room_version, vec![b, c], &mut store) + .resolve(&room_id, &room_version, vec![state_to_resolve], &mut store) .unwrap(); assert!(resolver.conflicting_events.is_empty()); assert_eq!(resolver.resolved_events.len(), 3); + assert_eq!(resolver.resolved_events.len(), 3); }