state-res: refactor tiebreaking logic for clarity

This commit is contained in:
Charles Hall 2024-11-06 12:29:08 -08:00 committed by strawberry
parent 97e2fb6df1
commit e31b9dd3a4

View File

@ -1,6 +1,6 @@
use std::{ use std::{
borrow::Borrow, borrow::Borrow,
cmp::Reverse, cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet}, collections::{BinaryHeap, HashMap, HashSet},
hash::Hash, hash::Hash,
}; };
@ -289,14 +289,44 @@ where
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send, Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send, Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send,
{ {
#[derive(Eq, Ord, PartialEq, PartialOrd)] #[derive(PartialEq, Eq)]
struct TieBreaker<'a, Id> { struct TieBreaker<'a, Id> {
inv_power_level: Int, power_level: Int,
age: MilliSecondsSinceUnixEpoch, origin_server_ts: MilliSecondsSinceUnixEpoch,
event_id: &'a Id, event_id: &'a Id,
} }
impl<'a, Id> Ord for TieBreaker<'a, Id>
where
Id: Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
// NOTE: the power level comparison is "backwards" intentionally.
// See the "Mainline ordering" section of the Matrix specification
// around where it says the following:
//
// > for events `x` and `y`, `x<y` if [...]
//
// <https://spec.matrix.org/v1.12/rooms/v11/#definitions>
other
.power_level
.cmp(&self.power_level)
.then(self.origin_server_ts.cmp(&other.origin_server_ts))
.then(self.event_id.cmp(other.event_id))
}
}
impl<'a, Id> PartialOrd for TieBreaker<'a, Id>
where
Id: Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
debug!("starting lexicographical topological sort"); debug!("starting lexicographical topological sort");
// NOTE: an event that has no incoming edges happened most recently, // NOTE: an event that has no incoming edges happened most recently,
// and an event that has no outgoing edges happened least recently. // and an event that has no outgoing edges happened least recently.
@ -317,12 +347,12 @@ where
for (node, edges) in graph { for (node, edges) in graph {
if edges.is_empty() { if edges.is_empty() {
let (power_level, age) = key_fn(node.clone()).await?; let (power_level, origin_server_ts) = key_fn(node.clone()).await?;
// The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need
// smallest -> largest // smallest -> largest
zero_outdegree.push(Reverse(TieBreaker { zero_outdegree.push(Reverse(TieBreaker {
inv_power_level: -power_level, power_level,
age, origin_server_ts,
event_id: node, event_id: node,
})); }));
} }
@ -350,12 +380,8 @@ where
// Only push on the heap once older events have been cleared // Only push on the heap once older events have been cleared
out.remove(node.borrow()); out.remove(node.borrow());
if out.is_empty() { if out.is_empty() {
let (power_level, age) = key_fn(parent.clone()).await?; let (power_level, origin_server_ts) = key_fn(parent.clone()).await?;
heap.push(Reverse(TieBreaker { heap.push(Reverse(TieBreaker { power_level, origin_server_ts, event_id: parent }));
inv_power_level: -power_level,
age,
event_id: parent,
}));
} }
} }