use std::{ borrow::Borrow, cmp::{Ordering, Reverse}, collections::{BinaryHeap, HashMap, HashSet}, fmt::Debug, hash::Hash, }; use futures_util::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use js_int::{int, Int}; use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId}; use ruma_events::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, TimelineEventType, }; use serde_json::from_str as from_json_str; use tracing::{debug, instrument, trace, warn}; mod error; pub mod event_auth; mod power_levels; pub mod room_version; mod state_event; #[cfg(test)] mod test_utils; pub use error::{Error, Result}; pub use event_auth::{auth_check, auth_types_for_event}; use power_levels::PowerLevelsContentFields; pub use room_version::RoomVersion; pub use state_event::Event; /// A mapping of event type and state_key to some value `T`, usually an `EventId`. pub type StateMap = HashMap; pub type StateMapItem = (TypeStateKey, T); pub type TypeStateKey = (StateEventType, String); /// Resolve sets of state events as they come in. /// /// Internally `StateResolution` builds a graph and an auth chain to allow for state conflict /// resolution. /// /// ## Arguments /// /// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork in /// the state of a room. /// /// * `auth_chain_sets` - The full recursive set of `auth_events` for each event in the /// `state_sets`. /// /// * `event_fetch` - Any event not found in the `event_map` will defer to this closure to find the /// event. /// /// * `parallel_fetches` - The number of asynchronous fetch requests in-flight for any given /// operation. /// /// ## Invariants /// /// The caller of `resolve` must ensure that all the events are from the same room. Although this /// function takes a `RoomId` it does not check that each event is part of the same room. //#[instrument(level = "debug", skip(state_sets, auth_chain_sets, event_fetch))] pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>( room_version: &RoomVersionId, state_sets: impl IntoIterator + Send, auth_chain_sets: &'a [HashSet], event_fetch: &Fetch, event_exists: &Exists, parallel_fetches: usize, ) -> Result> where Fetch: Fn(E::Id) -> FetchFut + Sync, FetchFut: Future> + Send, Exists: Fn(E::Id) -> ExistsFut + Sync, ExistsFut: Future + Send, SetIter: Iterator> + Clone + Send, E: Event + Clone + Send + Sync, E::Id: Borrow + Send + Sync, for<'b> &'b E: Send, { debug!("State resolution starting"); // Split non-conflicting and conflicting state let (clean, conflicting) = separate(state_sets.into_iter()); debug!(count = clean.len(), "non-conflicting events"); trace!(map = ?clean, "non-conflicting events"); if conflicting.is_empty() { debug!("no conflicting state found"); return Ok(clean); } debug!(count = conflicting.len(), "conflicting events"); trace!(map = ?conflicting, "conflicting events"); let auth_chain_diff = get_auth_chain_diff(auth_chain_sets).chain(conflicting.into_values().flatten()); // `all_conflicted` contains unique items // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` let all_conflicted: HashSet<_> = stream::iter(auth_chain_diff) // Don't honor events we cannot "verify" .map(|id| event_exists(id.clone()).map(move |exists| (id, exists))) .buffer_unordered(parallel_fetches) .filter_map(|(id, exists)| future::ready(exists.then_some(id.clone()))) .collect() .boxed() .await; debug!(count = all_conflicted.len(), "full conflicted set"); trace!(set = ?all_conflicted, "full conflicted set"); // We used to check that all events are events from the correct room // this is now a check the caller of `resolve` must make. // Get only the control events with a state_key: "" or ban/kick event (sender != state_key) let control_events: Vec<_> = stream::iter(all_conflicted.iter()) .map(|id| is_power_event_id(id, &event_fetch).map(move |is| (id, is))) .buffer_unordered(parallel_fetches) .filter_map(|(id, is)| future::ready(is.then_some(id.clone()))) .collect() .boxed() .await; // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges let sorted_control_levels = reverse_topological_power_sort( control_events, &all_conflicted, &event_fetch, parallel_fetches, ) .boxed() .await?; debug!(count = sorted_control_levels.len(), "power events"); trace!(list = ?sorted_control_levels, "sorted power events"); let room_version = RoomVersion::new(room_version)?; // Sequentially auth check each control event. let resolved_control = iterative_auth_check( &room_version, sorted_control_levels.iter(), clean.clone(), &event_fetch, parallel_fetches, ) .boxed() .await?; debug!(count = resolved_control.len(), "resolved power events"); trace!(map = ?resolved_control, "resolved power events"); // At this point the control_events have been resolved we now have to // sort the remaining events using the mainline of the resolved power level. let deduped_power_ev = sorted_control_levels.into_iter().collect::>(); // This removes the control events that passed auth and more importantly those that failed // auth let events_to_resolve = all_conflicted .iter() .filter(|&id| !deduped_power_ev.contains(id.borrow())) .cloned() .collect::>(); debug!(count = events_to_resolve.len(), "events left to resolve"); trace!(list = ?events_to_resolve, "events left to resolve"); // This "epochs" power level event let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, "".into())); debug!(event_id = ?power_event, "power event"); let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch, parallel_fetches) .boxed() .await?; trace!(list = ?sorted_left_events, "events left, sorted"); let mut resolved_state = iterative_auth_check( &room_version, sorted_left_events.iter(), resolved_control, // The control events are added to the final resolved state &event_fetch, parallel_fetches, ) .boxed() .await?; // Add unconflicted state to the resolved state // We priorities the unconflicting state resolved_state.extend(clean); debug!("state resolution finished"); Ok(resolved_state) } /// Split the events that have no conflicts from those that are conflicting. /// /// The return tuple looks like `(unconflicted, conflicted)`. /// /// State is determined to be conflicting if for the given key (StateEventType, StateKey) there is /// not exactly one event ID. This includes missing events, if one state_set includes an event that /// none of the other have this is a conflicting event. fn separate<'a, Id>( state_sets_iter: impl Iterator>, ) -> (StateMap, StateMap>) where Id: Clone + Eq + Hash + 'a, { let mut state_set_count = 0_usize; let mut occurrences = HashMap::<_, HashMap<_, _>>::new(); let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1); for (k, v) in state_sets_iter.flatten() { occurrences.entry(k).or_default().entry(v).and_modify(|x| *x += 1).or_insert(1); } let mut unconflicted_state = StateMap::new(); let mut conflicted_state = StateMap::new(); for (k, v) in occurrences { for (id, occurrence_count) in v { if occurrence_count == state_set_count { unconflicted_state.insert((k.0.clone(), k.1.clone()), id.clone()); } else { conflicted_state .entry((k.0.clone(), k.1.clone())) .and_modify(|x: &mut Vec<_>| x.push(id.clone())) .or_insert(vec![id.clone()]); } } } (unconflicted_state, conflicted_state) } /// Returns a Vec of deduped EventIds that appear in some chains but not others. fn get_auth_chain_diff(auth_chain_sets: &[HashSet]) -> impl Iterator + Send where Id: Clone + Eq + Hash + Send, { let num_sets = auth_chain_sets.len(); let mut id_counts: HashMap = HashMap::new(); for id in auth_chain_sets.iter().flatten() { *id_counts.entry(id.clone()).or_default() += 1; } id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id)) } /// Events are sorted from "earliest" to "latest". /// /// They are compared using the negative power level (reverse topological ordering), the origin /// server timestamp and in case of a tie the `EventId`s are compared lexicographically. /// /// The power level is negative because a higher power level is equated to an earlier (further back /// in time) origin server timestamp. #[instrument(level = "debug", skip_all)] async fn reverse_topological_power_sort( events_to_sort: Vec, auth_diff: &HashSet, fetch_event: &F, parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Send + Sync, { debug!("reverse topological sort of power events"); let mut graph = HashMap::new(); for event_id in events_to_sort { add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, fetch_event).await; } // This is used in the `key_fn` passed to the lexico_topo_sort fn let event_to_pl = stream::iter(graph.keys()) .map(|event_id| { get_power_level_for_sender(event_id.clone(), fetch_event, parallel_fetches) .map(move |res| res.map(|pl| (event_id, pl))) }) .buffer_unordered(parallel_fetches) .try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| { debug!( event_id = event_id.borrow().as_str(), power_level = i64::from(pl), "found the power level of an event's sender", ); event_to_pl.insert(event_id.clone(), pl); future::ok(event_to_pl) }) .boxed() .await?; let event_to_pl = &event_to_pl; let fetcher = |event_id: E::Id| async move { let pl = *event_to_pl.get(event_id.borrow()).ok_or_else(|| Error::NotFound("".into()))?; let ev = fetch_event(event_id).await.ok_or_else(|| Error::NotFound("".into()))?; Ok((pl, ev.origin_server_ts())) }; lexicographical_topological_sort(&graph, &fetcher).await } /// Sorts the event graph based on number of outgoing/incoming edges. /// /// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together /// with the event ID). #[instrument(level = "debug", skip_all)] pub async fn lexicographical_topological_sort( graph: &HashMap>, key_fn: &F, ) -> Result> where F: Fn(Id) -> Fut + Sync, Fut: Future> + Send, Id: Borrow + Clone + Eq + Hash + Ord + Send, { #[derive(PartialEq, Eq)] struct TieBreaker<'a, Id> { power_level: Int, origin_server_ts: MilliSecondsSinceUnixEpoch, event_id: &'a Id, } impl Ord for TieBreaker<'_, Id> where Id: Ord, { fn cmp(&self, other: &Self) -> Ordering { // NOTE: the power level comparison is "backwards" intentionally. // See the "Mainline ordering" section of the Matrix specification // around where it says the following: // // > for events `x` and `y`, `x < y` if [...] // // other .power_level .cmp(&self.power_level) .then(self.origin_server_ts.cmp(&other.origin_server_ts)) .then(self.event_id.cmp(other.event_id)) } } impl PartialOrd for TieBreaker<'_, Id> where Id: Ord, { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } debug!("starting lexicographical topological sort"); // NOTE: an event that has no incoming edges happened most recently, // and an event that has no outgoing edges happened least recently. // NOTE: this is basically Kahn's algorithm except we look at nodes with no // outgoing edges, c.f. // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm // outdegree_map is an event referring to the events before it, the // more outdegree's the more recent the event. let mut outdegree_map = graph.clone(); // The number of events that depend on the given event (the EventId key) // How many events reference this event in the DAG as a parent let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new(); // Vec of nodes that have zero out degree, least recent events. let mut zero_outdegree = Vec::new(); for (node, edges) in graph { if edges.is_empty() { let (power_level, origin_server_ts) = key_fn(node.clone()).await?; // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need // smallest -> largest zero_outdegree.push(Reverse(TieBreaker { power_level, origin_server_ts, event_id: node, })); } reverse_graph.entry(node).or_default(); for edge in edges { reverse_graph.entry(edge).or_default().insert(node); } } let mut heap = BinaryHeap::from(zero_outdegree); // We remove the oldest node (most incoming edges) and check against all other let mut sorted = vec![]; // Destructure the `Reverse` and take the smallest `node` each time while let Some(Reverse(item)) = heap.pop() { let node = item.event_id; for &parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") { // The number of outgoing edges this node has let out = outdegree_map .get_mut(parent.borrow()) .expect("outdegree_map knows of all referenced EventIds"); // Only push on the heap once older events have been cleared out.remove(node.borrow()); if out.is_empty() { let (power_level, origin_server_ts) = key_fn(parent.clone()).await?; heap.push(Reverse(TieBreaker { power_level, origin_server_ts, event_id: parent })); } } // synapse yields we push then return the vec sorted.push(node.clone()); } Ok(sorted) } /// Find the power level for the sender of `event_id` or return a default value of zero. /// /// Do NOT use this any where but topological sort, we find the power level for the eventId /// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level /// event). async fn get_power_level_for_sender( event_id: E::Id, fetch_event: &F, parallel_fetches: usize, ) -> serde_json::Result where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Send, { debug!("fetch event ({event_id}) senders power level"); let event = fetch_event(event_id.clone()).await; let auth_events = event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten(); let pl = stream::iter(auth_events) .map(|aid| fetch_event(aid.clone())) .buffer_unordered(parallel_fetches.min(5)) .filter_map(future::ready) .collect::>() .boxed() .await .into_iter() .find(|aev| is_type_and_key(aev, &TimelineEventType::RoomPowerLevels, "")); let content: PowerLevelsContentFields = match pl { None => return Ok(int!(0)), Some(ev) => from_json_str(ev.content().get())?, }; if let Some(ev) = event { if let Some(&user_level) = content.get_user_power(ev.sender()) { debug!("found {} at power_level {user_level}", ev.sender()); return Ok(user_level); } } Ok(content.users_default) } /// Check the that each event is authenticated based on the events before it. /// /// ## Returns /// /// The `unconflicted_state` combined with the newly auth'ed events. So any event that fails the /// `event_auth::auth_check` will be excluded from the returned state map. /// /// For each `events_to_check` event we gather the events needed to auth it from the the /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. async fn iterative_auth_check<'a, E, F, Fut, I>( room_version: &RoomVersion, events_to_check: I, unconflicted_state: StateMap, fetch_event: &F, parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E::Id: Borrow + Clone + Eq + Ord + Send + Sync + 'a, I: Iterator + Debug + Send + 'a, E: Event + Clone + Send + Sync, { debug!("starting iterative auth check"); trace!( list = ?events_to_check, "events to check" ); let events_to_check: Vec<_> = stream::iter(events_to_check) .map(Result::Ok) .map_ok(|event_id| { fetch_event(event_id.clone()).map(move |result| { result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}"))) }) }) .try_buffer_unordered(parallel_fetches) .try_collect() .boxed() .await?; let auth_event_ids: HashSet = events_to_check .iter() .flat_map(|event: &E| event.auth_events().map(Clone::clone)) .collect(); let auth_events: HashMap = stream::iter(auth_event_ids.into_iter()) .map(fetch_event) .buffer_unordered(parallel_fetches) .filter_map(future::ready) .map(|auth_event| (auth_event.event_id().clone(), auth_event)) .collect() .boxed() .await; let auth_events = &auth_events; let mut resolved_state = unconflicted_state; for event in events_to_check.iter() { let event_id = event.event_id(); let state_key = event .state_key() .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; let auth_types = auth_types_for_event( event.event_type(), event.sender(), Some(state_key), event.content(), )?; let mut auth_state = StateMap::new(); for aid in event.auth_events() { if let Some(ev) = auth_events.get(aid.borrow()) { //TODO: synapse checks "rejected_reason" which is most likely related to // soft-failing auth_state.insert( ev.event_type().with_state_key(ev.state_key().ok_or_else(|| { Error::InvalidPdu("State event had no state key".to_owned()) })?), ev.clone(), ); } else { warn!(event_id = aid.borrow().as_str(), "missing auth event"); } } stream::iter(auth_types.iter().filter_map(|key| Some((key, resolved_state.get(key)?)))) .filter_map(|(key, ev_id)| async move { if let Some(event) = auth_events.get(ev_id.borrow()) { Some((key, event.clone())) } else { Some((key, fetch_event(ev_id.clone()).await?.clone())) } }) .for_each(|(key, event)| { //TODO: synapse checks "rejected_reason" is None here auth_state.insert(key.to_owned(), event); future::ready(()) }) .await; debug!("event to check {:?}", event.event_id()); // The key for this is (eventType + a state_key of the signed token not sender) so // search for it let current_third_party = auth_state.iter().find_map(|(_, pdu)| { (*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu) }); let fetch_state = |ty: &StateEventType, key: &str| future::ready(auth_state.get(&ty.with_state_key(key))); if auth_check(room_version, &event, current_third_party.as_ref(), fetch_state).await? { // add event to resolved state map resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); } else { // synapse passes here on AuthError. We do not add this event to resolved_state. warn!("event {event_id} failed the authentication check"); } } Ok(resolved_state) } /// Returns the sorted `to_sort` list of `EventId`s based on a mainline sort using the depth of /// `resolved_power_level`, the server timestamp, and the eventId. /// /// The depth of the given event is calculated based on the depth of it's closest "parent" /// power_level event. If there have been two power events the after the most recent are depth 0, /// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is /// "older" than depth 0. async fn mainline_sort( to_sort: &[E::Id], resolved_power_level: Option, fetch_event: &F, parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Clone + Send + Sync, E::Id: Borrow + Clone + Send + Sync, { debug!("mainline sort of events"); // There are no EventId's to sort, bail. if to_sort.is_empty() { return Ok(vec![]); } let mut mainline = vec![]; let mut pl = resolved_power_level; while let Some(p) = pl { mainline.push(p.clone()); let event = fetch_event(p.clone()) .await .ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?; pl = None; for aid in event.auth_events() { let ev = fetch_event(aid.clone()) .await .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") { pl = Some(aid.to_owned()); break; } } } let mainline_map = mainline .iter() .rev() .enumerate() .map(|(idx, eid)| ((*eid).clone(), idx)) .collect::>(); let order_map = stream::iter(to_sort.iter()) .map(|ev_id| fetch_event(ev_id.clone()).map(move |event| event.map(|event| (event, ev_id)))) .buffer_unordered(parallel_fetches) .filter_map(future::ready) .map(|(event, ev_id)| { get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event) .map_ok(move |depth| (depth, event, ev_id)) .map(Result::ok) }) .buffer_unordered(parallel_fetches) .filter_map(future::ready) .fold(HashMap::new(), |mut order_map, (depth, event, ev_id)| { order_map.insert(ev_id, (depth, event.origin_server_ts(), ev_id)); future::ready(order_map) }) .boxed() .await; // Sort the event_ids by their depth, timestamp and EventId // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::>(); sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap()); Ok(sort_event_ids) } /// Get the mainline depth from the `mainline_map` or finds a power_level event that has an /// associated mainline depth. async fn get_mainline_depth( mut event: Option, mainline_map: &HashMap, fetch_event: &F, ) -> Result where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Send, { while let Some(sort_ev) = event { debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline"); let id = sort_ev.event_id(); if let Some(depth) = mainline_map.get(id.borrow()) { return Ok(*depth); } event = None; for aid in sort_ev.auth_events() { let aev = fetch_event(aid.clone()) .await .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { event = Some(aev); break; } } } // Did not find a power level event so we default to zero Ok(0) } async fn add_event_and_auth_chain_to_graph( graph: &mut HashMap>, event_id: E::Id, auth_diff: &HashSet, fetch_event: &F, ) where F: Fn(E::Id) -> Fut, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Clone + Send, { let mut state = vec![event_id]; while let Some(eid) = state.pop() { graph.entry(eid.clone()).or_default(); let event = fetch_event(eid.clone()).await; let auth_events = event.as_ref().map(|ev| ev.auth_events()).into_iter().flatten(); // Prefer the store to event as the store filters dedups the events for aid in auth_events { if auth_diff.contains(aid.borrow()) { if !graph.contains_key(aid.borrow()) { state.push(aid.to_owned()); } // We just inserted this at the start of the while loop graph.get_mut(eid.borrow()).unwrap().insert(aid.to_owned()); } } } } async fn is_power_event_id(event_id: &E::Id, fetch: &F) -> bool where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Send, { match fetch(event_id.clone()).await.as_ref() { Some(state) => is_power_event(state), _ => false, } } fn is_type_and_key(ev: impl Event, ev_type: &TimelineEventType, state_key: &str) -> bool { ev.event_type() == ev_type && ev.state_key() == Some(state_key) } fn is_power_event(event: impl Event) -> bool { match event.event_type() { TimelineEventType::RoomPowerLevels | TimelineEventType::RoomJoinRules | TimelineEventType::RoomCreate => event.state_key() == Some(""), TimelineEventType::RoomMember => { if let Ok(content) = from_json_str::(event.content().get()) { if [MembershipState::Leave, MembershipState::Ban].contains(&content.membership) { return Some(event.sender().as_str()) != event.state_key(); } } false } _ => false, } } /// Convenience trait for adding event type plus state key to state maps. pub trait EventTypeExt { fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); } impl EventTypeExt for StateEventType { fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { (self, state_key.into()) } } impl EventTypeExt for TimelineEventType { fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { (self.to_string().into(), state_key.into()) } } impl EventTypeExt for &T where T: EventTypeExt + Clone, { fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { self.to_owned().with_state_key(state_key) } } #[cfg(test)] mod tests { use std::{ collections::{HashMap, HashSet}, sync::Arc, }; use js_int::{int, uint}; use maplit::{hashmap, hashset}; use rand::seq::SliceRandom; use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId}; use ruma_events::{ room::join_rules::{JoinRule, RoomJoinRulesEventContent}, StateEventType, TimelineEventType, }; use serde_json::{json, value::to_raw_value as to_raw_json_value}; use tracing::debug; use crate::{ is_power_event, room_version::RoomVersion, test_utils::{ alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join, room_id, to_init_pdu_event, to_pdu_event, zara, PduEvent, TestStore, INITIAL_EVENTS, }, Event, EventTypeExt, StateMap, }; async fn test_event_sort() { use futures_util::future::ready; let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = INITIAL_EVENTS(); let event_map = events .values() .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone())) .collect::>(); let auth_chain: HashSet = HashSet::new(); let power_events = event_map .values() .filter(|&pdu| is_power_event(&**pdu)) .map(|pdu| pdu.event_id.clone()) .collect::>(); let fetcher = |id| ready(events.get(&id).cloned()); let sorted_power_events = crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1) .await .unwrap(); let resolved_power = crate::iterative_auth_check( &RoomVersion::V6, sorted_power_events.iter(), HashMap::new(), // unconflicted events &fetcher, 1, ) .await .expect("iterative auth check failed on resolved events"); // don't remove any events so we know it sorts them all correctly let mut events_to_sort = events.keys().cloned().collect::>(); events_to_sort.shuffle(&mut rand::thread_rng()); let power_level = resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned(); let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1).await.unwrap(); assert_eq!( vec![ "$CREATE:foo", "$IMA:foo", "$IPOWER:foo", "$IJR:foo", "$IMB:foo", "$IMC:foo", "$START:foo", "$END:foo" ], sorted_event_ids.iter().map(|id| id.to_string()).collect::>() ); } #[tokio::test] async fn test_sort() { for _ in 0..20 { // since we shuffle the eventIds before we sort them introducing randomness // seems like we should test this a few times test_event_sort().await; } } #[tokio::test] async fn ban_vs_power_level() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = &[ to_init_pdu_event( "PA", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), to_init_pdu_event( "MA", alice(), TimelineEventType::RoomMember, Some(alice().to_string().as_str()), member_content_join(), ), to_init_pdu_event( "MB", alice(), TimelineEventType::RoomMember, Some(bob().to_string().as_str()), member_content_ban(), ), to_init_pdu_event( "PB", bob(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), ]; let edges = vec![vec!["END", "MB", "MA", "PA", "START"], vec!["END", "PA", "PB"]] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::>(); do_check(events, edges, expected_state_ids).await; } #[tokio::test] async fn topic_basic() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = &[ to_init_pdu_event( "T1", alice(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "PA1", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), to_init_pdu_event( "T2", alice(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "PA2", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(), ), to_init_pdu_event( "PB", bob(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), to_init_pdu_event( "T3", bob(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), ]; let edges = vec![vec!["END", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "T3", "PB", "PA1"]] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::>(); do_check(events, edges, expected_state_ids).await; } #[tokio::test] async fn topic_reset() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = &[ to_init_pdu_event( "T1", alice(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "PA", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), to_init_pdu_event( "T2", bob(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "MB", alice(), TimelineEventType::RoomMember, Some(bob().to_string().as_str()), member_content_ban(), ), ]; let edges = vec![vec!["END", "MB", "T2", "PA", "T1", "START"], vec!["END", "T1"]] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::>(); do_check(events, edges, expected_state_ids).await; } #[tokio::test] async fn join_rule_evasion() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = &[ to_init_pdu_event( "JR", alice(), TimelineEventType::RoomJoinRules, Some(""), to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Private)).unwrap(), ), to_init_pdu_event( "ME", ella(), TimelineEventType::RoomMember, Some(ella().to_string().as_str()), member_content_join(), ), ]; let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec![event_id("JR")]; do_check(events, edges, expected_state_ids).await; } #[tokio::test] async fn offtopic_power_level() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = &[ to_init_pdu_event( "PA", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), to_init_pdu_event( "PB", bob(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 50 } })) .unwrap(), ), to_init_pdu_event( "PC", charlie(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 0 } })) .unwrap(), ), ]; let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::>(); do_check(events, edges, expected_state_ids).await; } #[tokio::test] async fn topic_setting() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = &[ to_init_pdu_event( "T1", alice(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "PA1", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), to_init_pdu_event( "T2", alice(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "PA2", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(), ), to_init_pdu_event( "PB", bob(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), ), to_init_pdu_event( "T3", bob(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "MZ1", zara(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), to_init_pdu_event( "T4", alice(), TimelineEventType::RoomTopic, Some(""), to_raw_json_value(&json!({})).unwrap(), ), ]; let edges = vec![ vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "MZ1", "T3", "PB", "PA1"], ] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::>(); do_check(events, edges, expected_state_ids).await; } #[tokio::test] async fn test_event_map_none() { use futures_util::future::ready; let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let mut store = TestStore::(hashmap! {}); // build up the DAG let (state_at_bob, state_at_charlie, expected) = store.set_up(); let ev_map = store.0.clone(); let fetcher = |id| ready(ev_map.get(&id).cloned()); let exists = |id: ::Id| ready(ev_map.get(&*id).is_some()); let state_sets = [state_at_bob, state_at_charlie]; let auth_chain: Vec<_> = state_sets .iter() .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()) .collect(); let resolved = match crate::resolve( &RoomVersionId::V2, &state_sets, &auth_chain, &fetcher, &exists, 1, ) .await { Ok(state) => state, Err(e) => panic!("{e}"), }; assert_eq!(expected, resolved); } #[tokio::test] async fn test_lexicographical_sort() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let graph = hashmap! { event_id("l") => hashset![event_id("o")], event_id("m") => hashset![event_id("n"), event_id("o")], event_id("n") => hashset![event_id("o")], event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges event_id("p") => hashset![event_id("o")], }; let res = crate::lexicographical_topological_sort(&graph, &|_id| async { Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) }) .await .unwrap(); assert_eq!( vec!["o", "l", "n", "m", "p"], res.iter() .map(ToString::to_string) .map(|s| s.replace('$', "").replace(":foo", "")) .collect::>() ); } #[tokio::test] async fn ban_with_auth_chains() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let ban = BAN_STATE_SET(); let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::>(); do_check(&ban.values().cloned().collect::>(), edges, expected_state_ids).await; } #[tokio::test] async fn ban_with_auth_chains2() { use futures_util::future::ready; let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let init = INITIAL_EVENTS(); let ban = BAN_STATE_SET(); let mut inner = init.clone(); inner.extend(ban); let store = TestStore(inner.clone()); let state_set_a = [ inner.get(&event_id("CREATE")).unwrap(), inner.get(&event_id("IJR")).unwrap(), inner.get(&event_id("IMA")).unwrap(), inner.get(&event_id("IMB")).unwrap(), inner.get(&event_id("IMC")).unwrap(), inner.get(&event_id("MB")).unwrap(), inner.get(&event_id("PA")).unwrap(), ] .iter() .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone())) .collect::>(); let state_set_b = [ inner.get(&event_id("CREATE")).unwrap(), inner.get(&event_id("IJR")).unwrap(), inner.get(&event_id("IMA")).unwrap(), inner.get(&event_id("IMB")).unwrap(), inner.get(&event_id("IMC")).unwrap(), inner.get(&event_id("IME")).unwrap(), inner.get(&event_id("PA")).unwrap(), ] .iter() .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone())) .collect::>(); let ev_map = &store.0; let state_sets = [state_set_a, state_set_b]; let auth_chain: Vec<_> = state_sets .iter() .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()) .collect(); let fetcher = |id: ::Id| ready(ev_map.get(&id).cloned()); let exists = |id: ::Id| ready(ev_map.get(&id).is_some()); let resolved = match crate::resolve( &RoomVersionId::V6, &state_sets, &auth_chain, &fetcher, &exists, 1, ) .await { Ok(state) => state, Err(e) => panic!("{e}"), }; debug!( resolved = ?resolved .iter() .map(|((ty, key), id)| format!("(({ty}{key:?}), {id})")) .collect::>(), "resolved state", ); let expected = ["$CREATE:foo", "$IJR:foo", "$PA:foo", "$IMA:foo", "$IMB:foo", "$IMC:foo", "$MB:foo"]; for id in expected.iter().map(|i| event_id(i)) { // make sure our resolved events are equal to the expected list assert!(resolved.values().any(|eid| eid == &id) || init.contains_key(&id), "{id}"); } assert_eq!(expected.len(), resolved.len()); } #[tokio::test] async fn join_rule_with_auth_chain() { let join_rule = JOIN_RULE(); let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]] .into_iter() .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::>(); do_check(&join_rule.values().cloned().collect::>(), edges, expected_state_ids).await; } #[allow(non_snake_case)] fn BAN_STATE_SET() -> HashMap> { vec![ to_pdu_event( "PA", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), &["CREATE", "IMA", "IPOWER"], // auth_events &["START"], // prev_events ), to_pdu_event( "PB", alice(), TimelineEventType::RoomPowerLevels, Some(""), to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), &["CREATE", "IMA", "IPOWER"], &["END"], ), to_pdu_event( "MB", alice(), TimelineEventType::RoomMember, Some(ella().as_str()), member_content_ban(), &["CREATE", "IMA", "PB"], &["PA"], ), to_pdu_event( "IME", ella(), TimelineEventType::RoomMember, Some(ella().as_str()), member_content_join(), &["CREATE", "IJR", "PA"], &["MB"], ), ] .into_iter() .map(|ev| (ev.event_id.clone(), ev)) .collect() } #[allow(non_snake_case)] fn JOIN_RULE() -> HashMap> { vec![ to_pdu_event( "JR", alice(), TimelineEventType::RoomJoinRules, Some(""), to_raw_json_value(&json!({ "join_rule": "invite" })).unwrap(), &["CREATE", "IMA", "IPOWER"], &["START"], ), to_pdu_event( "IMZ", zara(), TimelineEventType::RoomPowerLevels, Some(zara().as_str()), member_content_join(), &["CREATE", "JR", "IPOWER"], &["START"], ), ] .into_iter() .map(|ev| (ev.event_id.clone(), ev)) .collect() } macro_rules! state_set { ($($kind:expr => $key:expr => $id:expr),* $(,)?) => {{ #[allow(unused_mut)] let mut x = StateMap::new(); $( x.insert(($kind, $key.to_owned()), $id); )* x }}; } #[test] fn separate_unique_conflicted() { let (unconflicted, conflicted) = super::separate( [ state_set![StateEventType::RoomMember => "@a:hs1" => 0], state_set![StateEventType::RoomMember => "@b:hs1" => 1], state_set![StateEventType::RoomMember => "@c:hs1" => 2], ] .iter(), ); assert_eq!(unconflicted, StateMap::new()); assert_eq!( conflicted, state_set![ StateEventType::RoomMember => "@a:hs1" => vec![0], StateEventType::RoomMember => "@b:hs1" => vec![1], StateEventType::RoomMember => "@c:hs1" => vec![2], ], ); } #[test] fn separate_conflicted() { let (unconflicted, mut conflicted) = super::separate( [ state_set![StateEventType::RoomMember => "@a:hs1" => 0], state_set![StateEventType::RoomMember => "@a:hs1" => 1], state_set![StateEventType::RoomMember => "@a:hs1" => 2], ] .iter(), ); // HashMap iteration order is random, so sort this before asserting on it for v in conflicted.values_mut() { v.sort_unstable(); } assert_eq!(unconflicted, StateMap::new()); assert_eq!( conflicted, state_set![ StateEventType::RoomMember => "@a:hs1" => vec![0, 1, 2], ], ); } #[test] fn separate_unconflicted() { let (unconflicted, conflicted) = super::separate( [ state_set![StateEventType::RoomMember => "@a:hs1" => 0], state_set![StateEventType::RoomMember => "@a:hs1" => 0], state_set![StateEventType::RoomMember => "@a:hs1" => 0], ] .iter(), ); assert_eq!( unconflicted, state_set![ StateEventType::RoomMember => "@a:hs1" => 0, ], ); assert_eq!(conflicted, StateMap::new()); } #[test] fn separate_mixed() { let (unconflicted, conflicted) = super::separate( [ state_set![StateEventType::RoomMember => "@a:hs1" => 0], state_set![ StateEventType::RoomMember => "@a:hs1" => 0, StateEventType::RoomMember => "@b:hs1" => 1, ], state_set![ StateEventType::RoomMember => "@a:hs1" => 0, StateEventType::RoomMember => "@c:hs1" => 2, ], ] .iter(), ); assert_eq!( unconflicted, state_set![ StateEventType::RoomMember => "@a:hs1" => 0, ], ); assert_eq!( conflicted, state_set![ StateEventType::RoomMember => "@b:hs1" => vec![1], StateEventType::RoomMember => "@c:hs1" => vec![2], ], ); } }