diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 3ec7c3a7..6ef07c9c 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -1,6 +1,6 @@ use std::{ borrow::Borrow, - cmp::Reverse, + cmp::{Ordering, Reverse}, collections::{BinaryHeap, HashMap, HashSet}, hash::Hash, }; @@ -289,14 +289,44 @@ where Fut: Future> + Send, Id: Borrow + Clone + Eq + Hash + Ord + Send, { - #[derive(Eq, Ord, PartialEq, PartialOrd)] + #[derive(PartialEq, Eq)] struct TieBreaker<'a, Id> { - inv_power_level: Int, - age: MilliSecondsSinceUnixEpoch, + power_level: Int, + origin_server_ts: MilliSecondsSinceUnixEpoch, event_id: &'a Id, } + impl<'a, Id> Ord for TieBreaker<'a, Id> + where + Id: Ord, + { + fn cmp(&self, other: &Self) -> Ordering { + // NOTE: the power level comparison is "backwards" intentionally. + // See the "Mainline ordering" section of the Matrix specification + // around where it says the following: + // + // > for events `x` and `y`, `x < y` if [...] + // + // + other + .power_level + .cmp(&self.power_level) + .then(self.origin_server_ts.cmp(&other.origin_server_ts)) + .then(self.event_id.cmp(other.event_id)) + } + } + + impl<'a, Id> PartialOrd for TieBreaker<'a, Id> + where + Id: Ord, + { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + debug!("starting lexicographical topological sort"); + // NOTE: an event that has no incoming edges happened most recently, // and an event that has no outgoing edges happened least recently. @@ -317,12 +347,12 @@ where for (node, edges) in graph { if edges.is_empty() { - let (power_level, age) = key_fn(node.clone()).await?; + let (power_level, origin_server_ts) = key_fn(node.clone()).await?; // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need // smallest -> largest zero_outdegree.push(Reverse(TieBreaker { - inv_power_level: -power_level, - age, + power_level, + origin_server_ts, event_id: node, })); } @@ -350,12 +380,8 @@ where // Only push on the heap once older events have been cleared out.remove(node.borrow()); if out.is_empty() { - let (power_level, age) = key_fn(parent.clone()).await?; - heap.push(Reverse(TieBreaker { - inv_power_level: -power_level, - age, - event_id: parent, - })); + let (power_level, origin_server_ts) = key_fn(parent.clone()).await?; + heap.push(Reverse(TieBreaker { power_level, origin_server_ts, event_id: parent })); } }