diff --git a/crates/ruma-state-res/benches/state_res_bench.rs b/crates/ruma-state-res/benches/state_res_bench.rs index 2673eded..508f8be6 100644 --- a/crates/ruma-state-res/benches/state_res_bench.rs +++ b/crates/ruma-state-res/benches/state_res_bench.rs @@ -46,7 +46,7 @@ fn lexico_topo_sort(c: &mut Criterion) { }; b.iter(|| { let _ = StateResolution::lexicographical_topological_sort(&graph, |id| { - (0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone()) + Ok((0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone())) }); }) }); diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 91c49227..b60e691a 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -136,7 +136,7 @@ impl StateResolution { &control_events, &all_conflicted, &fetch_event, - ); + )?; debug!("sorted control events: {}", sorted_control_levels.len()); trace!("{:?}", sorted_control_levels); @@ -174,7 +174,7 @@ impl StateResolution { debug!("power event: {:?}", power_event); let sorted_left_events = - StateResolution::mainline_sort(&events_to_resolve, power_event, &fetch_event); + StateResolution::mainline_sort(&events_to_resolve, power_event, &fetch_event)?; trace!("events left, sorted: {:?}", sorted_left_events.iter().collect::>()); @@ -249,7 +249,7 @@ impl StateResolution { events_to_sort: &[EventId], auth_diff: &HashSet, fetch_event: F, - ) -> Vec + ) -> Result> where E: Event, F: Fn(&EventId) -> Option>, @@ -284,15 +284,15 @@ impl StateResolution { } StateResolution::lexicographical_topological_sort(&graph, |event_id| { - let ev = fetch_event(event_id).unwrap(); - let pl = event_to_pl.get(event_id).unwrap(); + let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?; + let pl = event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?; debug!("{:?}", (-*pl, ev.origin_server_ts(), &ev.event_id())); // This return value is the key used for sorting events, // events are then sorted by power level, time, // and lexically by event_id. - (-*pl, ev.origin_server_ts(), ev.event_id().clone()) + Ok((-*pl, ev.origin_server_ts(), ev.event_id().clone())) }) } @@ -302,9 +302,9 @@ impl StateResolution { pub fn lexicographical_topological_sort( graph: &HashMap>, key_fn: F, - ) -> Vec + ) -> Result> where - F: Fn(&EventId) -> (i64, MilliSecondsSinceUnixEpoch, EventId), + F: Fn(&EventId) -> Result<(i64, MilliSecondsSinceUnixEpoch, EventId)>, { info!("starting lexicographical topological sort"); // NOTE: an event that has no incoming edges happened most recently, @@ -314,12 +314,12 @@ impl StateResolution { // outgoing edges, c.f. // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm - // TODO make the HashSet conversion cleaner ?? // outdegree_map is an event referring to the events before it, the // more outdegree's the more recent the event. let mut outdegree_map = graph.clone(); - // The number of events that depend on the given event (the eventId key) + // The number of events that depend on the given event (the EventId key) + // How many events reference this event in the DAG as a parent let mut reverse_graph = HashMap::new(); // Vec of nodes that have zero out degree, least recent events. @@ -329,7 +329,7 @@ impl StateResolution { if edges.is_empty() { // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need // smallest -> largest - zero_outdegree.push(Reverse((key_fn(node), node))); + zero_outdegree.push(Reverse((key_fn(node)?, node))); } reverse_graph.entry(node).or_insert(hashset![]); @@ -345,14 +345,17 @@ impl StateResolution { // Destructure the `Reverse` and take the smallest `node` each time while let Some(Reverse((_, node))) = heap.pop() { let node: &EventId = node; - for parent in reverse_graph.get(node).unwrap() { + for parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") + { // The number of outgoing edges this node has - let out = outdegree_map.get_mut(parent).unwrap(); + let out = outdegree_map + .get_mut(parent) + .expect("outdegree_map knows of all referenced EventIds"); // Only push on the heap once older events have been cleared out.remove(node); if out.is_empty() { - heap.push(Reverse((key_fn(parent), parent))); + heap.push(Reverse((key_fn(parent)?, parent))); } } @@ -360,7 +363,7 @@ impl StateResolution { sorted.push(node.clone()); } - sorted + Ok(sorted) } /// Find the power level for the sender of `event_id` or return a default value of zero. @@ -516,7 +519,7 @@ impl StateResolution { to_sort: &[EventId], resolved_power_level: Option<&EventId>, fetch_event: F, - ) -> Vec + ) -> Result> where E: Event, F: Fn(&EventId) -> Option>, @@ -525,7 +528,7 @@ impl StateResolution { // There are no EventId's to sort, bail. if to_sort.is_empty() { - return vec![]; + return Ok(vec![]); } let mut mainline = vec![]; @@ -533,11 +536,13 @@ impl StateResolution { while let Some(p) = pl { mainline.push(p.clone()); - let event = fetch_event(&p).unwrap(); + let event = + fetch_event(&p).ok_or_else(|| Error::NotFound(format!("Failed to find {}", p)))?; let auth_events = &event.auth_events(); pl = None; for aid in auth_events { - let ev = fetch_event(aid).unwrap(); + let ev = fetch_event(aid) + .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; if is_type_and_key(&ev, EventType::RoomPowerLevels, "") { pl = Some(aid.clone()); break; @@ -582,7 +587,7 @@ impl StateResolution { let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::>(); sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap()); - sort_event_ids + Ok(sort_event_ids) } /// Get the mainline depth from the `mainline_map` or finds a power_level event diff --git a/crates/ruma-state-res/tests/event_sorting.rs b/crates/ruma-state-res/tests/event_sorting.rs index 39ab675a..31d9bc06 100644 --- a/crates/ruma-state-res/tests/event_sorting.rs +++ b/crates/ruma-state-res/tests/event_sorting.rs @@ -32,10 +32,9 @@ fn test_event_sort() { let sorted_power_events = StateResolution::reverse_topological_power_sort(&power_events, &auth_chain, |id| { events.get(id).map(Arc::clone) - }); + }) + .unwrap(); - // This is a TODO in conduit - // TODO we may be able to skip this since they are resolved according to spec let resolved_power = StateResolution::iterative_auth_check( &RoomVersion::version_6(), &sorted_power_events, @@ -53,7 +52,8 @@ fn test_event_sort() { let sorted_event_ids = StateResolution::mainline_sort(&events_to_sort, power_level, |id| { events.get(id).map(Arc::clone) - }); + }) + .unwrap(); assert_eq!( vec![ diff --git a/crates/ruma-state-res/tests/state_res.rs b/crates/ruma-state-res/tests/state_res.rs index c72873c8..46de5098 100644 --- a/crates/ruma-state-res/tests/state_res.rs +++ b/crates/ruma-state-res/tests/state_res.rs @@ -286,8 +286,9 @@ fn test_lexicographical_sort() { }; let res = StateResolution::lexicographical_topological_sort(&graph, |id| { - (0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone()) - }); + Ok((0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone())) + }) + .unwrap(); assert_eq!( vec!["o", "l", "n", "m", "p"], diff --git a/crates/ruma-state-res/tests/utils.rs b/crates/ruma-state-res/tests/utils.rs index ddd20689..5fb12865 100644 --- a/crates/ruma-state-res/tests/utils.rs +++ b/crates/ruma-state-res/tests/utils.rs @@ -80,8 +80,10 @@ pub fn do_check( // Resolve the current state and add it to the state_at_event map then continue // on in "time" for node in StateResolution::lexicographical_topological_sort(&graph, |id| { - (0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone()) - }) { + Ok((0, MilliSecondsSinceUnixEpoch(uint!(0)), id.clone())) + }) + .unwrap() + { let fake_event = fake_event_map.get(&node).unwrap(); let event_id = fake_event.event_id().clone();