diff --git a/crates/ruma-state-res/benches/state_res_bench.rs b/crates/ruma-state-res/benches/state_res_bench.rs index 508f8be6..7c8cbcf1 100644 --- a/crates/ruma-state-res/benches/state_res_bench.rs +++ b/crates/ruma-state-res/benches/state_res_bench.rs @@ -30,7 +30,7 @@ use ruma_events::{ EventType, }; use ruma_identifiers::{EventId, RoomId, RoomVersionId, UserId}; -use ruma_state_res::{Error, Event, EventMap, Result, StateMap, StateResolution}; +use ruma_state_res::{self as state_res, Error, Event, EventMap, Result, StateMap}; use serde_json::{json, Value as JsonValue}; static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); @@ -45,7 +45,7 @@ fn lexico_topo_sort(c: &mut Criterion) { event_id("p") => hashset![event_id("o")], }; b.iter(|| { - let _ = StateResolution::lexicographical_topological_sort(&graph, |id| { + let _ = state_res::lexicographical_topological_sort(&graph, |id| { Ok((0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone())) }); }) @@ -62,7 +62,7 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) { b.iter(|| { let ev_map: EventMap> = store.0.clone(); let state_sets = vec![state_at_bob.clone(), state_at_charlie.clone()]; - let _ = match StateResolution::resolve::( + let _ = match state_res::resolve::( &room_id(), &RoomVersionId::Version6, &state_sets, @@ -119,7 +119,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { b.iter(|| { let state_sets = vec![state_set_a.clone(), state_set_b.clone()]; - let _ = match StateResolution::resolve::( + let _ = match state_res::resolve::( &room_id(), &RoomVersionId::Version6, &state_sets, diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 27ee673c..3aff2ae5 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -32,602 +32,580 @@ pub type StateMap = HashMap<(EventType, String), T>; /// A mapping of `EventId` to `T`, usually a `ServerPdu`. pub type EventMap = HashMap; -#[derive(Default)] -#[allow(clippy::exhaustive_structs)] -pub struct StateResolution; +/// Resolve sets of state events as they come in. Internally `StateResolution` builds a graph and an +/// auth chain to allow for state conflict resolution. +/// +/// ## Arguments +/// +/// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork in +/// the state of a room. +/// +/// * `auth_chain_sets` - The full recursive set of `auth_events` for each event in the +/// `state_sets`. +/// +/// * `fetch_event` - Any event not found in the `event_map` will defer to this closure to find the +/// event. +/// +/// ## Invariants +/// +/// The caller of `resolve` must ensure that all the events are from the same room. Although this +/// function takes a `RoomId` it does not check that each event is part of the same room. +pub fn resolve( + room_id: &RoomId, + room_version: &RoomVersionId, + state_sets: &[StateMap], + auth_chain_sets: Vec>, + fetch_event: F, +) -> Result> +where + E: Event, + F: Fn(&EventId) -> Option>, +{ + info!("State resolution starting"); -impl StateResolution { - /// Resolve sets of state events as they come in. Internally `StateResolution` builds a graph - /// and an auth chain to allow for state conflict resolution. - /// - /// ## Arguments - /// - /// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork - /// in the state of a room. - /// - /// * `auth_chain_sets` - The full recursive set of `auth_events` for each event in the - /// `state_sets`. - /// - /// * `fetch_event` - Any event not found in the `event_map` will defer to this closure to find - /// the event. - /// - /// ## Invariants - /// - /// The caller of `StateResolution::resolve` must ensure that all the events are from the same - /// room. Although this function takes a `RoomId` it does not check that each event is part - /// of the same room. - pub fn resolve( - room_id: &RoomId, - room_version: &RoomVersionId, - state_sets: &[StateMap], - auth_chain_sets: Vec>, - fetch_event: F, - ) -> Result> - where - E: Event, - F: Fn(&EventId) -> Option>, - { - info!("State resolution starting"); + // Split non-conflicting and conflicting state + let (clean, conflicting) = separate(state_sets); - // Split non-conflicting and conflicting state - let (clean, conflicting) = StateResolution::separate(state_sets); + info!("non conflicting events: {}", clean.len()); + trace!("{:?}", clean); - info!("non conflicting events: {}", clean.len()); - trace!("{:?}", clean); - - if conflicting.is_empty() { - info!("no conflicting state found"); - return Ok(clean); - } - - info!("conflicting events: {}", conflicting.len()); - debug!("{:?}", conflicting); - - // The set of auth events that are not common across server forks - let mut auth_diff = StateResolution::get_auth_chain_diff(room_id, auth_chain_sets)?; - - // Add the auth_diff to conflicting now we have a full set of conflicting events - auth_diff.extend(conflicting.values().cloned().flatten().flatten()); - - debug!("auth diff: {}", auth_diff.len()); - trace!("{:?}", auth_diff); - - // `all_conflicted` contains unique items - // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` - // - // Don't honor events we cannot "verify" - // TODO: BTreeSet::retain() when stable 1.53 - let all_conflicted = - auth_diff.into_iter().filter(|id| fetch_event(id).is_some()).collect::>(); - - info!("full conflicted set: {}", all_conflicted.len()); - debug!("{:?}", all_conflicted); - - // We used to check that all events are events from the correct room - // this is now a check the caller of `resolve` must make. - - // Get only the control events with a state_key: "" or ban/kick event (sender != state_key) - let control_events = all_conflicted - .iter() - .filter(|id| is_power_event_id(id, &fetch_event)) - .cloned() - .collect::>(); - - // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges - let sorted_control_levels = StateResolution::reverse_topological_power_sort( - control_events, - &all_conflicted, - &fetch_event, - )?; - - debug!("sorted control events: {}", sorted_control_levels.len()); - trace!("{:?}", sorted_control_levels); - - let room_version = RoomVersion::new(room_version)?; - // Sequentially auth check each control event. - let resolved_control = StateResolution::iterative_auth_check( - &room_version, - &sorted_control_levels, - &clean, - &fetch_event, - )?; - - debug!("resolved control events: {}", resolved_control.len()); - trace!("{:?}", resolved_control); - - // At this point the control_events have been resolved we now have to - // sort the remaining events using the mainline of the resolved power level. - let deduped_power_ev = sorted_control_levels.into_iter().collect::>(); - - // This removes the control events that passed auth and more importantly those that failed - // auth - let events_to_resolve = all_conflicted - .iter() - .filter(|id| !deduped_power_ev.contains(id)) - .cloned() - .collect::>(); - - debug!("events left to resolve: {}", events_to_resolve.len()); - trace!("{:?}", events_to_resolve); - - // This "epochs" power level event - let power_event = resolved_control.get(&(EventType::RoomPowerLevels, "".into())); - - debug!("power event: {:?}", power_event); - - let sorted_left_events = - StateResolution::mainline_sort(&events_to_resolve, power_event, &fetch_event)?; - - trace!("events left, sorted: {:?}", sorted_left_events.iter().collect::>()); - - let mut resolved_state = StateResolution::iterative_auth_check( - &room_version, - &sorted_left_events, - &resolved_control, // The control events are added to the final resolved state - &fetch_event, - )?; - - // Add unconflicted state to the resolved state - // We priorities the unconflicting state - resolved_state.extend(clean); - Ok(resolved_state) + if conflicting.is_empty() { + info!("no conflicting state found"); + return Ok(clean); } - /// Split the events that have no conflicts from those that are conflicting. - /// The return tuple looks like `(unconflicted, conflicted)`. - /// - /// State is determined to be conflicting if for the given key (EventType, StateKey) there - /// is not exactly one eventId. This includes missing events, if one state_set includes an event - /// that none of the other have this is a conflicting event. - pub fn separate( - state_sets: &[StateMap], - ) -> (StateMap, StateMap>>) { - info!("separating {} sets of events into conflicted/unconflicted", state_sets.len()); + info!("conflicting events: {}", conflicting.len()); + debug!("{:?}", conflicting); - let mut unconflicted_state = StateMap::new(); - let mut conflicted_state = StateMap::new(); + // The set of auth events that are not common across server forks + let mut auth_diff = get_auth_chain_diff(room_id, auth_chain_sets)?; - for key in state_sets.iter().flat_map(|map| map.keys()).unique() { - let mut event_ids = - state_sets.iter().map(|state_set| state_set.get(key)).collect::>(); + // Add the auth_diff to conflicting now we have a full set of conflicting events + auth_diff.extend(conflicting.values().cloned().flatten().flatten()); - if event_ids.iter().all_equal() { - // First .unwrap() is okay because - // * event_ids has the same length as state_sets - // * we never enter the loop this code is in if state_sets is empty - let id = event_ids.pop().unwrap().expect("unconflicting `EventId` is not None"); - unconflicted_state.insert(key.clone(), id.clone()); - } else { - conflicted_state - .insert(key.clone(), event_ids.into_iter().map(|o| o.cloned()).collect()); - } - } + debug!("auth diff: {}", auth_diff.len()); + trace!("{:?}", auth_diff); - (unconflicted_state, conflicted_state) - } + // `all_conflicted` contains unique items + // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` + // + // Don't honor events we cannot "verify" + // TODO: BTreeSet::retain() when stable 1.53 + let all_conflicted = + auth_diff.into_iter().filter(|id| fetch_event(id).is_some()).collect::>(); - /// Returns a Vec of deduped EventIds that appear in some chains but not others. - pub fn get_auth_chain_diff( - _room_id: &RoomId, - auth_chain_sets: Vec>, - ) -> Result> { - if let Some(first) = auth_chain_sets.first().cloned() { - let common = auth_chain_sets - .iter() - .skip(1) - .fold(first, |a, b| a.intersection(b).cloned().collect::>()); + info!("full conflicted set: {}", all_conflicted.len()); + debug!("{:?}", all_conflicted); - Ok(auth_chain_sets.into_iter().flatten().filter(|id| !common.contains(id)).collect()) + // We used to check that all events are events from the correct room + // this is now a check the caller of `resolve` must make. + + // Get only the control events with a state_key: "" or ban/kick event (sender != state_key) + let control_events = all_conflicted + .iter() + .filter(|id| is_power_event_id(id, &fetch_event)) + .cloned() + .collect::>(); + + // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges + let sorted_control_levels = + reverse_topological_power_sort(control_events, &all_conflicted, &fetch_event)?; + + debug!("sorted control events: {}", sorted_control_levels.len()); + trace!("{:?}", sorted_control_levels); + + let room_version = RoomVersion::new(room_version)?; + // Sequentially auth check each control event. + let resolved_control = + iterative_auth_check(&room_version, &sorted_control_levels, &clean, &fetch_event)?; + + debug!("resolved control events: {}", resolved_control.len()); + trace!("{:?}", resolved_control); + + // At this point the control_events have been resolved we now have to + // sort the remaining events using the mainline of the resolved power level. + let deduped_power_ev = sorted_control_levels.into_iter().collect::>(); + + // This removes the control events that passed auth and more importantly those that failed + // auth + let events_to_resolve = all_conflicted + .iter() + .filter(|id| !deduped_power_ev.contains(id)) + .cloned() + .collect::>(); + + debug!("events left to resolve: {}", events_to_resolve.len()); + trace!("{:?}", events_to_resolve); + + // This "epochs" power level event + let power_event = resolved_control.get(&(EventType::RoomPowerLevels, "".into())); + + debug!("power event: {:?}", power_event); + + let sorted_left_events = mainline_sort(&events_to_resolve, power_event, &fetch_event)?; + + trace!("events left, sorted: {:?}", sorted_left_events.iter().collect::>()); + + let mut resolved_state = iterative_auth_check( + &room_version, + &sorted_left_events, + &resolved_control, // The control events are added to the final resolved state + &fetch_event, + )?; + + // Add unconflicted state to the resolved state + // We priorities the unconflicting state + resolved_state.extend(clean); + Ok(resolved_state) +} + +/// Split the events that have no conflicts from those that are conflicting. +/// +/// The return tuple looks like `(unconflicted, conflicted)`. +/// +/// State is determined to be conflicting if for the given key (EventType, StateKey) there is not +/// exactly one eventId. This includes missing events, if one state_set includes an event that none +/// of the other have this is a conflicting event. +pub fn separate( + state_sets: &[StateMap], +) -> (StateMap, StateMap>>) { + info!("separating {} sets of events into conflicted/unconflicted", state_sets.len()); + + let mut unconflicted_state = StateMap::new(); + let mut conflicted_state = StateMap::new(); + + for key in state_sets.iter().flat_map(|map| map.keys()).unique() { + let mut event_ids = + state_sets.iter().map(|state_set| state_set.get(key)).collect::>(); + + if event_ids.iter().all_equal() { + // First .unwrap() is okay because + // * event_ids has the same length as state_sets + // * we never enter the loop this code is in if state_sets is empty + let id = event_ids.pop().unwrap().expect("unconflicting `EventId` is not None"); + unconflicted_state.insert(key.clone(), id.clone()); } else { - Ok(HashSet::new()) + conflicted_state + .insert(key.clone(), event_ids.into_iter().map(|o| o.cloned()).collect()); } } - /// Events are sorted from "earliest" to "latest". They are compared using - /// the negative power level (reverse topological ordering), the - /// origin server timestamp and incase of a tie the `EventId`s - /// are compared lexicographically. - /// - /// The power level is negative because a higher power level is equated to an - /// earlier (further back in time) origin server timestamp. - pub fn reverse_topological_power_sort( - events_to_sort: Vec, - auth_diff: &HashSet, - fetch_event: F, - ) -> Result> - where - E: Event, - F: Fn(&EventId) -> Option>, - { - debug!("reverse topological sort of power events"); + (unconflicted_state, conflicted_state) +} - let mut graph = HashMap::new(); - for event_id in events_to_sort { - StateResolution::add_event_and_auth_chain_to_graph( - &mut graph, - event_id, - auth_diff, - &fetch_event, - ); - - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress - } - - // This is used in the `key_fn` passed to the lexico_topo_sort fn - let mut event_to_pl = HashMap::new(); - for event_id in graph.keys() { - let pl = StateResolution::get_power_level_for_sender(event_id, &fetch_event); - info!("{} power level {}", event_id, pl); - - event_to_pl.insert(event_id.clone(), pl); - - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress - } - - StateResolution::lexicographical_topological_sort(&graph, |event_id| { - let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?; - let pl = event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?; - - debug!("{:?}", (-*pl, ev.origin_server_ts(), &ev.event_id())); - - // This return value is the key used for sorting events, - // events are then sorted by power level, time, - // and lexically by event_id. - Ok((-*pl, ev.origin_server_ts(), ev.event_id().clone())) - }) - } - - /// Sorts the event graph based on number of outgoing/incoming edges, where - /// `key_fn` is used as a tie breaker. The tie breaker happens based on - /// power level, age, and event_id. - pub fn lexicographical_topological_sort( - graph: &HashMap>, - key_fn: F, - ) -> Result> - where - F: Fn(&EventId) -> Result<(i64, MilliSecondsSinceUnixEpoch, EventId)>, - { - info!("starting lexicographical topological sort"); - // NOTE: an event that has no incoming edges happened most recently, - // and an event that has no outgoing edges happened least recently. - - // NOTE: this is basically Kahn's algorithm except we look at nodes with no - // outgoing edges, c.f. - // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm - - // outdegree_map is an event referring to the events before it, the - // more outdegree's the more recent the event. - let mut outdegree_map = graph.clone(); - - // The number of events that depend on the given event (the EventId key) - // How many events reference this event in the DAG as a parent - let mut reverse_graph: HashMap<&EventId, HashSet<&EventId>> = HashMap::new(); - - // Vec of nodes that have zero out degree, least recent events. - let mut zero_outdegree = vec![]; - - for (node, edges) in graph.iter() { - if edges.is_empty() { - // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need - // smallest -> largest - zero_outdegree.push(Reverse((key_fn(node)?, node))); - } - - reverse_graph.entry(node).or_default(); - for edge in edges { - reverse_graph.entry(edge).or_default().insert(node); - } - } - - let mut heap = BinaryHeap::from(zero_outdegree); - - // We remove the oldest node (most incoming edges) and check against all other - let mut sorted = vec![]; - // Destructure the `Reverse` and take the smallest `node` each time - while let Some(Reverse((_, node))) = heap.pop() { - let node: &EventId = node; - for parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") - { - // The number of outgoing edges this node has - let out = outdegree_map - .get_mut(parent) - .expect("outdegree_map knows of all referenced EventIds"); - - // Only push on the heap once older events have been cleared - out.remove(node); - if out.is_empty() { - heap.push(Reverse((key_fn(parent)?, parent))); - } - } - - // synapse yields we push then return the vec - sorted.push(node.clone()); - } - - Ok(sorted) - } - - /// Find the power level for the sender of `event_id` or return a default value of zero. - fn get_power_level_for_sender(event_id: &EventId, fetch_event: F) -> i64 - where - E: Event, - F: Fn(&EventId) -> Option>, - { - info!("fetch event ({}) senders power level", event_id); - - let event = fetch_event(event_id); - let mut pl = None; - - for aid in event.as_ref().map(|pdu| pdu.auth_events()).unwrap_or_default() { - if let Some(aev) = fetch_event(&aid) { - if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { - pl = Some(aev); - break; - } - } - } - - if pl.is_none() { - return 0; - } - - if let Some(content) = - pl.and_then(|pl| serde_json::from_value::(pl.content()).ok()) - { - if let Some(ev) = event { - if let Some(user) = content.users.get(ev.sender()) { - debug!("found {} at power_level {}", ev.sender(), user); - return (*user).into(); - } - } - content.users_default.into() - } else { - 0 - } - } - - /// Check the that each event is authenticated based on the events before it. - /// - /// ## Returns - /// - /// The `unconflicted_state` combined with the newly auth'ed events. So any event that - /// fails the `event_auth::auth_check` will be excluded from the returned `StateMap`. - /// - /// For each `events_to_check` event we gather the events needed to auth it from the - /// the `fetch_event` closure and verify each event using the `event_auth::auth_check` - /// function. - pub fn iterative_auth_check( - room_version: &RoomVersion, - events_to_check: &[EventId], - unconflicted_state: &StateMap, - fetch_event: F, - ) -> Result> - where - E: Event, - F: Fn(&EventId) -> Option>, - { - info!("starting iterative auth check"); - - debug!("performing auth checks on {:?}", events_to_check.iter().collect::>()); - - let mut resolved_state = unconflicted_state.clone(); - - for event_id in events_to_check.iter() { - let event = fetch_event(event_id) - .ok_or_else(|| Error::NotFound(format!("Failed to find {}", event_id)))?; - let state_key = event - .state_key() - .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; - - let mut auth_events = HashMap::new(); - for aid in &event.auth_events() { - if let Some(ev) = fetch_event(aid) { - // TODO synapse check "rejected_reason" which is most likely - // related to soft-failing - auth_events.insert( - ( - ev.kind(), - ev.state_key().ok_or_else(|| { - Error::InvalidPdu("State event had no state key".to_owned()) - })?, - ), - ev, - ); - } else { - warn!("auth event id for {} is missing {}", aid, event_id); - } - } - - for key in auth_types_for_event( - &event.kind(), - event.sender(), - Some(state_key.clone()), - event.content(), - ) { - if let Some(ev_id) = resolved_state.get(&key) { - if let Some(event) = fetch_event(ev_id) { - // TODO synapse checks `rejected_reason` is None here - auth_events.insert(key.clone(), event); - } - } - } - - debug!("event to check {:?}", event.event_id()); - - let most_recent_prev_event = - event.prev_events().iter().filter_map(|id| fetch_event(id)).next_back(); - - // The key for this is (eventType + a state_key of the signed token not sender) so - // search for it - let current_third_party = auth_events.iter().find_map(|(_, pdu)| { - (pdu.kind() == EventType::RoomThirdPartyInvite).then(|| { - // TODO no clone, auth_events is borrowed while moved - pdu.clone() - }) - }); - - if auth_check( - room_version, - &event, - most_recent_prev_event, - current_third_party, - |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), - )? { - // add event to resolved state map - resolved_state.insert((event.kind(), state_key), event_id.clone()); - } else { - // synapse passes here on AuthError. We do not add this event to resolved_state. - warn!("event {} failed the authentication check", event_id); - } - - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress - } - Ok(resolved_state) - } - - /// Returns the sorted `to_sort` list of `EventId`s based on a mainline sort using - /// the depth of `resolved_power_level`, the server timestamp, and the eventId. - /// - /// The depth of the given event is calculated based on the depth of it's closest "parent" - /// power_level event. If there have been two power events the after the most recent are - /// depth 0, the events before (with the first power level as a parent) will be marked - /// as depth 1. depth 1 is "older" than depth 0. - pub fn mainline_sort( - to_sort: &[EventId], - resolved_power_level: Option<&EventId>, - fetch_event: F, - ) -> Result> - where - E: Event, - F: Fn(&EventId) -> Option>, - { - debug!("mainline sort of events"); - - // There are no EventId's to sort, bail. - if to_sort.is_empty() { - return Ok(vec![]); - } - - let mut mainline = vec![]; - let mut pl = resolved_power_level.cloned(); - while let Some(p) = pl { - mainline.push(p.clone()); - - let event = - fetch_event(&p).ok_or_else(|| Error::NotFound(format!("Failed to find {}", p)))?; - let auth_events = &event.auth_events(); - pl = None; - for aid in auth_events { - let ev = fetch_event(aid) - .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; - if is_type_and_key(&ev, EventType::RoomPowerLevels, "") { - pl = Some(aid.clone()); - break; - } - } - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress - } - - let mainline_map = mainline +/// Returns a Vec of deduped EventIds that appear in some chains but not others. +pub fn get_auth_chain_diff( + _room_id: &RoomId, + auth_chain_sets: Vec>, +) -> Result> { + if let Some(first) = auth_chain_sets.first().cloned() { + let common = auth_chain_sets .iter() - .rev() - .enumerate() - .map(|(idx, eid)| ((*eid).clone(), idx)) - .collect::>(); + .skip(1) + .fold(first, |a, b| a.intersection(b).cloned().collect::>()); - let mut order_map = HashMap::new(); - for ev_id in to_sort.iter() { - if let Some(event) = fetch_event(ev_id) { - if let Ok(depth) = - StateResolution::get_mainline_depth(Some(event), &mainline_map, &fetch_event) - { - order_map.insert( - ev_id, - ( - depth, - fetch_event(ev_id).map(|ev| ev.origin_server_ts()), - ev_id, // TODO should this be a &str to sort lexically?? - ), - ); - } - } + Ok(auth_chain_sets.into_iter().flatten().filter(|id| !common.contains(id)).collect()) + } else { + Ok(HashSet::new()) + } +} - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress - } +/// Events are sorted from "earliest" to "latest". +/// +/// They are compared using the negative power level (reverse topological ordering), the origin +/// server timestamp and in case of a tie the `EventId`s are compared lexicographically. +/// +/// The power level is negative because a higher power level is equated to an earlier (further back +/// in time) origin server timestamp. +pub fn reverse_topological_power_sort( + events_to_sort: Vec, + auth_diff: &HashSet, + fetch_event: F, +) -> Result> +where + E: Event, + F: Fn(&EventId) -> Option>, +{ + debug!("reverse topological sort of power events"); - // Sort the event_ids by their depth, timestamp and EventId - // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) - let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::>(); - sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap()); + let mut graph = HashMap::new(); + for event_id in events_to_sort { + add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event); - Ok(sort_event_ids) + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress } - /// Get the mainline depth from the `mainline_map` or finds a power_level event - /// that has an associated mainline depth. - fn get_mainline_depth( - mut event: Option>, - mainline_map: &EventMap, - fetch_event: F, - ) -> Result - where - E: Event, - F: Fn(&EventId) -> Option>, - { - while let Some(sort_ev) = event { - debug!("mainline event_id {}", sort_ev.event_id()); - let id = &sort_ev.event_id(); - if let Some(depth) = mainline_map.get(id) { - return Ok(*depth); - } + // This is used in the `key_fn` passed to the lexico_topo_sort fn + let mut event_to_pl = HashMap::new(); + for event_id in graph.keys() { + let pl = get_power_level_for_sender(event_id, &fetch_event); + info!("{} power level {}", event_id, pl); - let auth_events = &sort_ev.auth_events(); - event = None; - for aid in auth_events { - let aev = fetch_event(aid) - .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; - if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { - event = Some(aev); - break; + event_to_pl.insert(event_id.clone(), pl); + + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + + lexicographical_topological_sort(&graph, |event_id| { + let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?; + let pl = event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?; + + debug!("{:?}", (-*pl, ev.origin_server_ts(), &ev.event_id())); + + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + Ok((-*pl, ev.origin_server_ts(), ev.event_id().clone())) + }) +} + +/// Sorts the event graph based on number of outgoing/incoming edges. +/// +/// `key_fn` is used as a tie breaker. The tie breaker happens based on power level, age, and +/// event_id. +pub fn lexicographical_topological_sort( + graph: &HashMap>, + key_fn: F, +) -> Result> +where + F: Fn(&EventId) -> Result<(i64, MilliSecondsSinceUnixEpoch, EventId)>, +{ + info!("starting lexicographical topological sort"); + // NOTE: an event that has no incoming edges happened most recently, + // and an event that has no outgoing edges happened least recently. + + // NOTE: this is basically Kahn's algorithm except we look at nodes with no + // outgoing edges, c.f. + // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + + // outdegree_map is an event referring to the events before it, the + // more outdegree's the more recent the event. + let mut outdegree_map = graph.clone(); + + // The number of events that depend on the given event (the EventId key) + // How many events reference this event in the DAG as a parent + let mut reverse_graph: HashMap<&EventId, HashSet<&EventId>> = HashMap::new(); + + // Vec of nodes that have zero out degree, least recent events. + let mut zero_outdegree = vec![]; + + for (node, edges) in graph.iter() { + if edges.is_empty() { + // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need + // smallest -> largest + zero_outdegree.push(Reverse((key_fn(node)?, node))); + } + + reverse_graph.entry(node).or_default(); + for edge in edges { + reverse_graph.entry(edge).or_default().insert(node); + } + } + + let mut heap = BinaryHeap::from(zero_outdegree); + + // We remove the oldest node (most incoming edges) and check against all other + let mut sorted = vec![]; + // Destructure the `Reverse` and take the smallest `node` each time + while let Some(Reverse((_, node))) = heap.pop() { + let node: &EventId = node; + for parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") { + // The number of outgoing edges this node has + let out = outdegree_map + .get_mut(parent) + .expect("outdegree_map knows of all referenced EventIds"); + + // Only push on the heap once older events have been cleared + out.remove(node); + if out.is_empty() { + heap.push(Reverse((key_fn(parent)?, parent))); + } + } + + // synapse yields we push then return the vec + sorted.push(node.clone()); + } + + Ok(sorted) +} + +/// Find the power level for the sender of `event_id` or return a default value of zero. +fn get_power_level_for_sender(event_id: &EventId, fetch_event: F) -> i64 +where + E: Event, + F: Fn(&EventId) -> Option>, +{ + info!("fetch event ({}) senders power level", event_id); + + let event = fetch_event(event_id); + let mut pl = None; + + for aid in event.as_ref().map(|pdu| pdu.auth_events()).unwrap_or_default() { + if let Some(aev) = fetch_event(&aid) { + if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { + pl = Some(aev); + break; + } + } + } + + if pl.is_none() { + return 0; + } + + if let Some(content) = + pl.and_then(|pl| serde_json::from_value::(pl.content()).ok()) + { + if let Some(ev) = event { + if let Some(user) = content.users.get(ev.sender()) { + debug!("found {} at power_level {}", ev.sender(), user); + return (*user).into(); + } + } + content.users_default.into() + } else { + 0 + } +} + +/// Check the that each event is authenticated based on the events before it. +/// +/// ## Returns +/// +/// The `unconflicted_state` combined with the newly auth'ed events. So any event that fails the +/// `event_auth::auth_check` will be excluded from the returned `StateMap`. +/// +/// For each `events_to_check` event we gather the events needed to auth it from the the +/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. +pub fn iterative_auth_check( + room_version: &RoomVersion, + events_to_check: &[EventId], + unconflicted_state: &StateMap, + fetch_event: F, +) -> Result> +where + E: Event, + F: Fn(&EventId) -> Option>, +{ + info!("starting iterative auth check"); + + debug!("performing auth checks on {:?}", events_to_check.iter().collect::>()); + + let mut resolved_state = unconflicted_state.clone(); + + for event_id in events_to_check.iter() { + let event = fetch_event(event_id) + .ok_or_else(|| Error::NotFound(format!("Failed to find {}", event_id)))?; + let state_key = event + .state_key() + .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; + + let mut auth_events = HashMap::new(); + for aid in &event.auth_events() { + if let Some(ev) = fetch_event(aid) { + // TODO synapse check "rejected_reason" which is most likely + // related to soft-failing + auth_events.insert( + ( + ev.kind(), + ev.state_key().ok_or_else(|| { + Error::InvalidPdu("State event had no state key".to_owned()) + })?, + ), + ev, + ); + } else { + warn!("auth event id for {} is missing {}", aid, event_id); + } + } + + for key in auth_types_for_event( + &event.kind(), + event.sender(), + Some(state_key.clone()), + event.content(), + ) { + if let Some(ev_id) = resolved_state.get(&key) { + if let Some(event) = fetch_event(ev_id) { + // TODO synapse checks `rejected_reason` is None here + auth_events.insert(key.clone(), event); } } } - // Did not find a power level event so we default to zero - Ok(0) + + debug!("event to check {:?}", event.event_id()); + + let most_recent_prev_event = + event.prev_events().iter().filter_map(|id| fetch_event(id)).next_back(); + + // The key for this is (eventType + a state_key of the signed token not sender) so + // search for it + let current_third_party = auth_events.iter().find_map(|(_, pdu)| { + (pdu.kind() == EventType::RoomThirdPartyInvite).then(|| { + // TODO no clone, auth_events is borrowed while moved + pdu.clone() + }) + }); + + if auth_check( + room_version, + &event, + most_recent_prev_event, + current_third_party, + |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), + )? { + // add event to resolved state map + resolved_state.insert((event.kind(), state_key), event_id.clone()); + } else { + // synapse passes here on AuthError. We do not add this event to resolved_state. + warn!("event {} failed the authentication check", event_id); + } + + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + Ok(resolved_state) +} + +/// Returns the sorted `to_sort` list of `EventId`s based on a mainline sort using the depth of +/// `resolved_power_level`, the server timestamp, and the eventId. +/// +/// The depth of the given event is calculated based on the depth of it's closest "parent" +/// power_level event. If there have been two power events the after the most recent are depth 0, +/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is +/// "older" than depth 0. +pub fn mainline_sort( + to_sort: &[EventId], + resolved_power_level: Option<&EventId>, + fetch_event: F, +) -> Result> +where + E: Event, + F: Fn(&EventId) -> Option>, +{ + debug!("mainline sort of events"); + + // There are no EventId's to sort, bail. + if to_sort.is_empty() { + return Ok(vec![]); } - fn add_event_and_auth_chain_to_graph( - graph: &mut HashMap>, - event_id: EventId, - auth_diff: &HashSet, - fetch_event: F, - ) where - E: Event, - F: Fn(&EventId) -> Option>, - { - let mut state = vec![event_id]; - while let Some(eid) = state.pop() { - graph.entry(eid.clone()).or_default(); - // Prefer the store to event as the store filters dedups the events - for aid in &fetch_event(&eid).map(|ev| ev.auth_events()).unwrap_or_default() { - if auth_diff.contains(aid) { - if !graph.contains_key(aid) { - state.push(aid.clone()); - } + let mut mainline = vec![]; + let mut pl = resolved_power_level.cloned(); + while let Some(p) = pl { + mainline.push(p.clone()); - // We just inserted this at the start of the while loop - graph.get_mut(&eid).unwrap().insert(aid.clone()); + let event = + fetch_event(&p).ok_or_else(|| Error::NotFound(format!("Failed to find {}", p)))?; + let auth_events = &event.auth_events(); + pl = None; + for aid in auth_events { + let ev = fetch_event(aid) + .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; + if is_type_and_key(&ev, EventType::RoomPowerLevels, "") { + pl = Some(aid.clone()); + break; + } + } + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + + let mainline_map = mainline + .iter() + .rev() + .enumerate() + .map(|(idx, eid)| ((*eid).clone(), idx)) + .collect::>(); + + let mut order_map = HashMap::new(); + for ev_id in to_sort.iter() { + if let Some(event) = fetch_event(ev_id) { + if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) { + order_map.insert( + ev_id, + ( + depth, + fetch_event(ev_id).map(|ev| ev.origin_server_ts()), + ev_id, // TODO should this be a &str to sort lexically?? + ), + ); + } + } + + // TODO: if these functions are ever made async here + // is a good place to yield every once in a while so other + // tasks can make progress + } + + // Sort the event_ids by their depth, timestamp and EventId + // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) + let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::>(); + sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap()); + + Ok(sort_event_ids) +} + +/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an +/// associated mainline depth. +fn get_mainline_depth( + mut event: Option>, + mainline_map: &EventMap, + fetch_event: F, +) -> Result +where + E: Event, + F: Fn(&EventId) -> Option>, +{ + while let Some(sort_ev) = event { + debug!("mainline event_id {}", sort_ev.event_id()); + let id = &sort_ev.event_id(); + if let Some(depth) = mainline_map.get(id) { + return Ok(*depth); + } + + let auth_events = &sort_ev.auth_events(); + event = None; + for aid in auth_events { + let aev = fetch_event(aid) + .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; + if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { + event = Some(aev); + break; + } + } + } + // Did not find a power level event so we default to zero + Ok(0) +} + +fn add_event_and_auth_chain_to_graph( + graph: &mut HashMap>, + event_id: EventId, + auth_diff: &HashSet, + fetch_event: F, +) where + E: Event, + F: Fn(&EventId) -> Option>, +{ + let mut state = vec![event_id]; + while let Some(eid) = state.pop() { + graph.entry(eid.clone()).or_default(); + // Prefer the store to event as the store filters dedups the events + for aid in &fetch_event(&eid).map(|ev| ev.auth_events()).unwrap_or_default() { + if auth_diff.contains(aid) { + if !graph.contains_key(aid) { + state.push(aid.clone()); } + + // We just inserted this at the start of the while loop + graph.get_mut(&eid).unwrap().insert(aid.clone()); } } } diff --git a/crates/ruma-state-res/tests/event_sorting.rs b/crates/ruma-state-res/tests/event_sorting.rs index b3b8d6c5..21c09aad 100644 --- a/crates/ruma-state-res/tests/event_sorting.rs +++ b/crates/ruma-state-res/tests/event_sorting.rs @@ -5,7 +5,7 @@ use std::{ use rand::seq::SliceRandom; use ruma_events::EventType; -use ruma_state_res::{is_power_event, room_version::RoomVersion, StateMap, StateResolution}; +use ruma_state_res::{self as state_res, is_power_event, room_version::RoomVersion, StateMap}; mod utils; use utils::INITIAL_EVENTS; @@ -27,12 +27,12 @@ fn test_event_sort() { .collect::>(); let sorted_power_events = - StateResolution::reverse_topological_power_sort(power_events, &auth_chain, |id| { + state_res::reverse_topological_power_sort(power_events, &auth_chain, |id| { events.get(id).map(Arc::clone) }) .unwrap(); - let resolved_power = StateResolution::iterative_auth_check( + let resolved_power = state_res::iterative_auth_check( &RoomVersion::version_6(), &sorted_power_events, &HashMap::new(), // unconflicted events @@ -47,10 +47,9 @@ fn test_event_sort() { let power_level = resolved_power.get(&(EventType::RoomPowerLevels, "".to_owned())); - let sorted_event_ids = StateResolution::mainline_sort(&events_to_sort, power_level, |id| { - events.get(id).map(Arc::clone) - }) - .unwrap(); + let sorted_event_ids = + state_res::mainline_sort(&events_to_sort, power_level, |id| events.get(id).map(Arc::clone)) + .unwrap(); assert_eq!( vec![ diff --git a/crates/ruma-state-res/tests/res_with_auth_ids.rs b/crates/ruma-state-res/tests/res_with_auth_ids.rs index 872ab369..14c93859 100644 --- a/crates/ruma-state-res/tests/res_with_auth_ids.rs +++ b/crates/ruma-state-res/tests/res_with_auth_ids.rs @@ -4,7 +4,7 @@ use std::{collections::HashMap, sync::Arc}; use ruma_events::EventType; use ruma_identifiers::{EventId, RoomVersionId}; -use ruma_state_res::{EventMap, StateMap, StateResolution}; +use ruma_state_res::{self as state_res, EventMap, StateMap}; use serde_json::json; use tracing::debug; @@ -65,7 +65,7 @@ fn ban_with_auth_chains2() { let ev_map: EventMap> = store.0.clone(); let state_sets = vec![state_set_a, state_set_b]; - let resolved = match StateResolution::resolve::( + let resolved = match state_res::resolve::( &room_id(), &RoomVersionId::Version6, &state_sets, diff --git a/crates/ruma-state-res/tests/state_res.rs b/crates/ruma-state-res/tests/state_res.rs index 46de5098..18fd99d3 100644 --- a/crates/ruma-state-res/tests/state_res.rs +++ b/crates/ruma-state-res/tests/state_res.rs @@ -5,7 +5,7 @@ use maplit::{hashmap, hashset}; use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_events::{room::join_rules::JoinRule, EventType}; use ruma_identifiers::{EventId, RoomVersionId}; -use ruma_state_res::{EventMap, StateMap, StateResolution}; +use ruma_state_res::{self as state_res, EventMap, StateMap}; use serde_json::json; use tracing_subscriber as tracer; @@ -254,7 +254,7 @@ fn test_event_map_none() { let ev_map: EventMap> = store.0.clone(); let state_sets = vec![state_at_bob, state_at_charlie]; - let resolved = match StateResolution::resolve::( + let resolved = match state_res::resolve::( &room_id(), &RoomVersionId::Version2, &state_sets, @@ -285,7 +285,7 @@ fn test_lexicographical_sort() { event_id("p") => hashset![event_id("o")], }; - let res = StateResolution::lexicographical_topological_sort(&graph, |id| { + let res = state_res::lexicographical_topological_sort(&graph, |id| { Ok((0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone())) }) .unwrap(); diff --git a/crates/ruma-state-res/tests/utils.rs b/crates/ruma-state-res/tests/utils.rs index 662f0fa0..390c0da0 100644 --- a/crates/ruma-state-res/tests/utils.rs +++ b/crates/ruma-state-res/tests/utils.rs @@ -21,7 +21,7 @@ use ruma_events::{ EventType, }; use ruma_identifiers::{EventId, RoomId, RoomVersionId, UserId}; -use ruma_state_res::{auth_types_for_event, Error, Event, Result, StateMap, StateResolution}; +use ruma_state_res::{self as state_res, auth_types_for_event, Error, Event, Result, StateMap}; use serde_json::{json, Value as JsonValue}; use tracing::info; use tracing_subscriber as tracer; @@ -79,7 +79,7 @@ pub fn do_check( // Resolve the current state and add it to the state_at_event map then continue // on in "time" - for node in StateResolution::lexicographical_topological_sort(&graph, |id| { + for node in state_res::lexicographical_topological_sort(&graph, |id| { Ok((0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone())) }) .unwrap() @@ -111,7 +111,7 @@ pub fn do_check( .collect::>() ); - let resolved = StateResolution::resolve( + let resolved = state_res::resolve( &room_id(), &RoomVersionId::Version6, &state_sets,