state-res: Change BTreeMap/Set to HashMap/Set

This commit is contained in:
Timo Kösters 2021-07-16 19:45:13 +02:00 committed by Jonas Platte
parent d970501c85
commit 3a0ee7740f
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
7 changed files with 62 additions and 57 deletions

View File

@ -3,6 +3,7 @@
Breaking changes: Breaking changes:
* state_res::resolve auth_events type has been slightly changed and renamed to auth_chain_sets * state_res::resolve auth_events type has been slightly changed and renamed to auth_chain_sets
* state_res::resolve structs were changed from BTreeMap/Set to HashMap/Set
# 0.2.0 # 0.2.0

View File

@ -175,7 +175,11 @@ impl<E: Event> TestStore<E> {
} }
/// Returns a Vec of the related auth events to the given `event`. /// Returns a Vec of the related auth events to the given `event`.
pub fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<BTreeSet<EventId>> { pub fn auth_event_ids(
&self,
room_id: &RoomId,
event_ids: &[EventId],
) -> Result<BTreeSet<EventId>> {
let mut result = BTreeSet::new(); let mut result = BTreeSet::new();
let mut stack = event_ids.to_vec(); let mut stack = event_ids.to_vec();

View File

@ -1,11 +1,11 @@
use std::{ use std::{
cmp::Reverse, cmp::Reverse,
collections::{BTreeMap, BTreeSet, BinaryHeap}, collections::{BinaryHeap, HashMap, HashSet},
sync::Arc, sync::Arc,
}; };
use itertools::Itertools; use itertools::Itertools;
use maplit::btreeset; use maplit::hashset;
use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_common::MilliSecondsSinceUnixEpoch;
use ruma_events::{ use ruma_events::{
room::{ room::{
@ -28,10 +28,10 @@ pub use room_version::RoomVersion;
pub use state_event::Event; pub use state_event::Event;
/// A mapping of event type and state_key to some value `T`, usually an `EventId`. /// A mapping of event type and state_key to some value `T`, usually an `EventId`.
pub type StateMap<T> = BTreeMap<(EventType, String), T>; pub type StateMap<T> = HashMap<(EventType, String), T>;
/// A mapping of `EventId` to `T`, usually a `ServerPdu`. /// A mapping of `EventId` to `T`, usually a `ServerPdu`.
pub type EventMap<T> = BTreeMap<EventId, T>; pub type EventMap<T> = HashMap<EventId, T>;
#[derive(Default)] #[derive(Default)]
#[allow(clippy::exhaustive_structs)] #[allow(clippy::exhaustive_structs)]
@ -61,7 +61,7 @@ impl StateResolution {
room_id: &RoomId, room_id: &RoomId,
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: &[StateMap<EventId>], state_sets: &[StateMap<EventId>],
auth_chain_sets: Vec<BTreeSet<EventId>>, auth_chain_sets: Vec<HashSet<EventId>>,
fetch_event: F, fetch_event: F,
) -> Result<StateMap<EventId>> ) -> Result<StateMap<EventId>>
where where
@ -89,7 +89,7 @@ impl StateResolution {
.next() .next()
.expect("we made sure conflicting is not empty") .expect("we made sure conflicting is not empty")
.iter() .iter()
.map(|o| if let Some(e) = o { btreeset![e.clone()] } else { BTreeSet::new() }) .map(|o| if let Some(e) = o { hashset![e.clone()] } else { HashSet::new() })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for events in iter { for events in iter {
@ -116,7 +116,7 @@ impl StateResolution {
// Don't honor events we cannot "verify" // Don't honor events we cannot "verify"
// TODO: BTreeSet::retain() when stable 1.53 // TODO: BTreeSet::retain() when stable 1.53
let all_conflicted = let all_conflicted =
auth_diff.into_iter().filter(|id| fetch_event(id).is_some()).collect::<BTreeSet<_>>(); auth_diff.into_iter().filter(|id| fetch_event(id).is_some()).collect::<HashSet<_>>();
info!("full conflicted set: {}", all_conflicted.len()); info!("full conflicted set: {}", all_conflicted.len());
debug!("{:?}", all_conflicted); debug!("{:?}", all_conflicted);
@ -155,7 +155,7 @@ impl StateResolution {
// At this point the control_events have been resolved we now have to // At this point the control_events have been resolved we now have to
// sort the remaining events using the mainline of the resolved power level. // sort the remaining events using the mainline of the resolved power level.
let deduped_power_ev = sorted_control_levels.into_iter().collect::<BTreeSet<_>>(); let deduped_power_ev = sorted_control_levels.into_iter().collect::<HashSet<_>>();
// This removes the control events that passed auth and more importantly those that failed // This removes the control events that passed auth and more importantly those that failed
// auth // auth
@ -224,17 +224,17 @@ impl StateResolution {
/// Returns a Vec of deduped EventIds that appear in some chains but not others. /// Returns a Vec of deduped EventIds that appear in some chains but not others.
pub fn get_auth_chain_diff( pub fn get_auth_chain_diff(
_room_id: &RoomId, _room_id: &RoomId,
auth_chain_sets: Vec<BTreeSet<EventId>>, auth_chain_sets: Vec<HashSet<EventId>>,
) -> Result<BTreeSet<EventId>> { ) -> Result<HashSet<EventId>> {
if let Some(first) = auth_chain_sets.first().cloned() { if let Some(first) = auth_chain_sets.first().cloned() {
let common = auth_chain_sets let common = auth_chain_sets
.iter() .iter()
.skip(1) .skip(1)
.fold(first, |a, b| a.intersection(&b).cloned().collect::<BTreeSet<EventId>>()); .fold(first, |a, b| a.intersection(&b).cloned().collect::<HashSet<EventId>>());
Ok(auth_chain_sets.into_iter().flatten().filter(|id| !common.contains(&id)).collect()) Ok(auth_chain_sets.into_iter().flatten().filter(|id| !common.contains(&id)).collect())
} else { } else {
Ok(btreeset![]) Ok(hashset![])
} }
} }
@ -247,7 +247,7 @@ impl StateResolution {
/// earlier (further back in time) origin server timestamp. /// earlier (further back in time) origin server timestamp.
pub fn reverse_topological_power_sort<E, F>( pub fn reverse_topological_power_sort<E, F>(
events_to_sort: &[EventId], events_to_sort: &[EventId],
auth_diff: &BTreeSet<EventId>, auth_diff: &HashSet<EventId>,
fetch_event: F, fetch_event: F,
) -> Vec<EventId> ) -> Vec<EventId>
where where
@ -256,7 +256,7 @@ impl StateResolution {
{ {
debug!("reverse topological sort of power events"); debug!("reverse topological sort of power events");
let mut graph = BTreeMap::new(); let mut graph = HashMap::new();
for event_id in events_to_sort.iter() { for event_id in events_to_sort.iter() {
StateResolution::add_event_and_auth_chain_to_graph( StateResolution::add_event_and_auth_chain_to_graph(
&mut graph, &mut graph,
@ -271,7 +271,7 @@ impl StateResolution {
} }
// This is used in the `key_fn` passed to the lexico_topo_sort fn // This is used in the `key_fn` passed to the lexico_topo_sort fn
let mut event_to_pl = BTreeMap::new(); let mut event_to_pl = HashMap::new();
for event_id in graph.keys() { for event_id in graph.keys() {
let pl = StateResolution::get_power_level_for_sender(event_id, &fetch_event); let pl = StateResolution::get_power_level_for_sender(event_id, &fetch_event);
info!("{} power level {}", event_id, pl); info!("{} power level {}", event_id, pl);
@ -300,7 +300,7 @@ impl StateResolution {
/// `key_fn` is used as a tie breaker. The tie breaker happens based on /// `key_fn` is used as a tie breaker. The tie breaker happens based on
/// power level, age, and event_id. /// power level, age, and event_id.
pub fn lexicographical_topological_sort<F>( pub fn lexicographical_topological_sort<F>(
graph: &BTreeMap<EventId, BTreeSet<EventId>>, graph: &HashMap<EventId, HashSet<EventId>>,
key_fn: F, key_fn: F,
) -> Vec<EventId> ) -> Vec<EventId>
where where
@ -314,13 +314,13 @@ impl StateResolution {
// outgoing edges, c.f. // outgoing edges, c.f.
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
// TODO make the BTreeSet conversion cleaner ?? // TODO make the HashSet conversion cleaner ??
// outdegree_map is an event referring to the events before it, the // outdegree_map is an event referring to the events before it, the
// more outdegree's the more recent the event. // more outdegree's the more recent the event.
let mut outdegree_map = graph.clone(); let mut outdegree_map = graph.clone();
// The number of events that depend on the given event (the eventId key) // The number of events that depend on the given event (the eventId key)
let mut reverse_graph = BTreeMap::new(); let mut reverse_graph = HashMap::new();
// Vec of nodes that have zero out degree, least recent events. // Vec of nodes that have zero out degree, least recent events.
let mut zero_outdegree = vec![]; let mut zero_outdegree = vec![];
@ -332,9 +332,9 @@ impl StateResolution {
zero_outdegree.push(Reverse((key_fn(node), node))); zero_outdegree.push(Reverse((key_fn(node), node)));
} }
reverse_graph.entry(node).or_insert(btreeset![]); reverse_graph.entry(node).or_insert(hashset![]);
for edge in edges { for edge in edges {
reverse_graph.entry(edge).or_insert(btreeset![]).insert(node); reverse_graph.entry(edge).or_insert(hashset![]).insert(node);
} }
} }
@ -437,7 +437,7 @@ impl StateResolution {
.state_key() .state_key()
.ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?;
let mut auth_events = BTreeMap::new(); let mut auth_events = HashMap::new();
for aid in &event.auth_events() { for aid in &event.auth_events() {
if let Some(ev) = fetch_event(aid) { if let Some(ev) = fetch_event(aid) {
// TODO synapse check "rejected_reason", I'm guessing this is redacted_because // TODO synapse check "rejected_reason", I'm guessing this is redacted_because
@ -553,9 +553,9 @@ impl StateResolution {
.rev() .rev()
.enumerate() .enumerate()
.map(|(idx, eid)| ((*eid).clone(), idx)) .map(|(idx, eid)| ((*eid).clone(), idx))
.collect::<BTreeMap<_, _>>(); .collect::<HashMap<_, _>>();
let mut order_map = BTreeMap::new(); let mut order_map = HashMap::new();
for ev_id in to_sort.iter() { for ev_id in to_sort.iter() {
if let Some(event) = fetch_event(ev_id) { if let Some(event) = fetch_event(ev_id) {
if let Ok(depth) = if let Ok(depth) =
@ -619,9 +619,9 @@ impl StateResolution {
} }
fn add_event_and_auth_chain_to_graph<E, F>( fn add_event_and_auth_chain_to_graph<E, F>(
graph: &mut BTreeMap<EventId, BTreeSet<EventId>>, graph: &mut HashMap<EventId, HashSet<EventId>>,
event_id: &EventId, event_id: &EventId,
auth_diff: &BTreeSet<EventId>, auth_diff: &HashSet<EventId>,
fetch_event: F, fetch_event: F,
) where ) where
E: Event, E: Event,
@ -631,7 +631,7 @@ impl StateResolution {
while !state.is_empty() { while !state.is_empty() {
// We just checked if it was empty so unwrap is fine // We just checked if it was empty so unwrap is fine
let eid = state.pop().unwrap(); let eid = state.pop().unwrap();
graph.entry(eid.clone()).or_insert(btreeset![]); graph.entry(eid.clone()).or_insert(hashset![]);
// Prefer the store to event as the store filters dedups the events // Prefer the store to event as the store filters dedups the events
for aid in &fetch_event(&eid).map(|ev| ev.auth_events()).unwrap_or_default() { for aid in &fetch_event(&eid).map(|ev| ev.auth_events()).unwrap_or_default() {
if auth_diff.contains(aid) { if auth_diff.contains(aid) {

View File

@ -1,5 +1,5 @@
use std::{ use std::{
collections::{BTreeMap, BTreeSet}, collections::{HashMap, HashSet},
sync::Arc, sync::Arc,
}; };
@ -18,7 +18,7 @@ fn test_event_sort() {
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) .map(|ev| ((ev.kind(), ev.state_key()), ev.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let auth_chain = BTreeSet::new(); let auth_chain = HashSet::new();
let power_events = event_map let power_events = event_map
.values() .values()
@ -39,7 +39,7 @@ fn test_event_sort() {
let resolved_power = StateResolution::iterative_auth_check( let resolved_power = StateResolution::iterative_auth_check(
&RoomVersion::version_6(), &RoomVersion::version_6(),
&sorted_power_events, &sorted_power_events,
&BTreeMap::new(), // unconflicted events &HashMap::new(), // unconflicted events
|id| events.get(id).map(Arc::clone), |id| events.get(id).map(Arc::clone),
) )
.expect("iterative auth check failed on resolved events"); .expect("iterative auth check failed on resolved events");

View File

@ -1,6 +1,6 @@
#![allow(clippy::or_fun_call, clippy::expect_fun_call)] #![allow(clippy::or_fun_call, clippy::expect_fun_call)]
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use ruma_events::EventType; use ruma_events::EventType;
use ruma_identifiers::{EventId, RoomVersionId}; use ruma_identifiers::{EventId, RoomVersionId};
@ -48,7 +48,7 @@ fn ban_with_auth_chains2() {
] ]
.iter() .iter()
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone()))
.collect::<BTreeMap<_, _>>(); .collect::<StateMap<_>>();
let state_set_b = [ let state_set_b = [
inner.get(&event_id("CREATE")).unwrap(), inner.get(&event_id("CREATE")).unwrap(),
@ -116,7 +116,7 @@ fn join_rule_with_auth_chain() {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> { fn BAN_STATE_SET() -> HashMap<EventId, Arc<StateEvent>> {
vec![ vec![
to_pdu_event( to_pdu_event(
"PA", "PA",
@ -161,7 +161,7 @@ fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn JOIN_RULE() -> BTreeMap<EventId, Arc<StateEvent>> { fn JOIN_RULE() -> HashMap<EventId, Arc<StateEvent>> {
vec![ vec![
to_pdu_event( to_pdu_event(
"JR", "JR",

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use js_int::uint; use js_int::uint;
use maplit::{btreemap, btreeset}; use maplit::{hashmap, hashset};
use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_common::MilliSecondsSinceUnixEpoch;
use ruma_events::{room::join_rules::JoinRule, EventType}; use ruma_events::{room::join_rules::JoinRule, EventType};
use ruma_identifiers::{EventId, RoomVersionId}; use ruma_identifiers::{EventId, RoomVersionId};
@ -247,7 +247,7 @@ fn test_event_map_none() {
let _ = LOGGER let _ = LOGGER
.call_once(|| tracer::fmt().with_env_filter(tracer::EnvFilter::from_default_env()).init()); .call_once(|| tracer::fmt().with_env_filter(tracer::EnvFilter::from_default_env()).init());
let mut store = TestStore::<StateEvent>(btreemap! {}); let mut store = TestStore::<StateEvent>(hashmap! {});
// build up the DAG // build up the DAG
let (state_at_bob, state_at_charlie, expected) = store.set_up(); let (state_at_bob, state_at_charlie, expected) = store.set_up();
@ -277,12 +277,12 @@ fn test_event_map_none() {
#[test] #[test]
fn test_lexicographical_sort() { fn test_lexicographical_sort() {
let graph = btreemap! { let graph = hashmap! {
event_id("l") => btreeset![event_id("o")], event_id("l") => hashset![event_id("o")],
event_id("m") => btreeset![event_id("n"), event_id("o")], event_id("m") => hashset![event_id("n"), event_id("o")],
event_id("n") => btreeset![event_id("o")], event_id("n") => hashset![event_id("o")],
event_id("o") => btreeset![], // "o" has zero outgoing edges but 4 incoming edges event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => btreeset![event_id("o")], event_id("p") => hashset![event_id("o")],
}; };
let res = StateResolution::lexicographical_topological_sort(&graph, |id| { let res = StateResolution::lexicographical_topological_sort(&graph, |id| {

View File

@ -1,7 +1,7 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::{ use std::{
collections::{BTreeMap, BTreeSet}, collections::{HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
sync::{ sync::{
atomic::{AtomicU64, Ordering::SeqCst}, atomic::{AtomicU64, Ordering::SeqCst},
@ -10,7 +10,7 @@ use std::{
}; };
use js_int::uint; use js_int::uint;
use maplit::{btreemap, btreeset}; use maplit::{btreemap, hashset};
use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_common::MilliSecondsSinceUnixEpoch;
use ruma_events::{ use ruma_events::{
pdu::{EventHash, Pdu, RoomV3Pdu}, pdu::{EventHash, Pdu, RoomV3Pdu},
@ -47,35 +47,35 @@ pub fn do_check(
); );
// This will be lexi_topo_sorted for resolution // This will be lexi_topo_sorted for resolution
let mut graph = BTreeMap::new(); let mut graph = HashMap::new();
// This is the same as in `resolve` event_id -> StateEvent // This is the same as in `resolve` event_id -> StateEvent
let mut fake_event_map = BTreeMap::new(); let mut fake_event_map = HashMap::new();
// Create the DB of events that led up to this point // Create the DB of events that led up to this point
// TODO maybe clean up some of these clones it is just tests but... // TODO maybe clean up some of these clones it is just tests but...
for ev in init_events.values().chain(events) { for ev in init_events.values().chain(events) {
graph.insert(ev.event_id().clone(), btreeset![]); graph.insert(ev.event_id().clone(), hashset![]);
fake_event_map.insert(ev.event_id().clone(), ev.clone()); fake_event_map.insert(ev.event_id().clone(), ev.clone());
} }
for pair in INITIAL_EDGES().windows(2) { for pair in INITIAL_EDGES().windows(2) {
if let [a, b] = &pair { if let [a, b] = &pair {
graph.entry(a.clone()).or_insert_with(BTreeSet::new).insert(b.clone()); graph.entry(a.clone()).or_insert_with(HashSet::new).insert(b.clone());
} }
} }
for edge_list in edges { for edge_list in edges {
for pair in edge_list.windows(2) { for pair in edge_list.windows(2) {
if let [a, b] = &pair { if let [a, b] = &pair {
graph.entry(a.clone()).or_insert_with(BTreeSet::new).insert(b.clone()); graph.entry(a.clone()).or_insert_with(HashSet::new).insert(b.clone());
} }
} }
} }
// event_id -> StateEvent // event_id -> StateEvent
let mut event_map: BTreeMap<EventId, Arc<StateEvent>> = BTreeMap::new(); let mut event_map: HashMap<EventId, Arc<StateEvent>> = HashMap::new();
// event_id -> StateMap<EventId> // event_id -> StateMap<EventId>
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new(); let mut state_at_event: HashMap<EventId, StateMap<EventId>> = HashMap::new();
// Resolve the current state and add it to the state_at_event map then continue // Resolve the current state and add it to the state_at_event map then continue
// on in "time" // on in "time"
@ -88,7 +88,7 @@ pub fn do_check(
let prev_events = graph.get(&node).unwrap(); let prev_events = graph.get(&node).unwrap();
let state_before: StateMap<EventId> = if prev_events.is_empty() { let state_before: StateMap<EventId> = if prev_events.is_empty() {
BTreeMap::new() HashMap::new()
} else if prev_events.len() == 1 { } else if prev_events.len() == 1 {
state_at_event.get(prev_events.iter().next().unwrap()).unwrap().clone() state_at_event.get(prev_events.iter().next().unwrap()).unwrap().clone()
} else { } else {
@ -207,7 +207,7 @@ pub fn do_check(
} }
#[allow(clippy::exhaustive_structs)] #[allow(clippy::exhaustive_structs)]
pub struct TestStore<E: Event>(pub BTreeMap<EventId, Arc<E>>); pub struct TestStore<E: Event>(pub HashMap<EventId, Arc<E>>);
impl<E: Event> TestStore<E> { impl<E: Event> TestStore<E> {
pub fn get_event(&self, _: &RoomId, event_id: &EventId) -> Result<Arc<E>> { pub fn get_event(&self, _: &RoomId, event_id: &EventId) -> Result<Arc<E>> {
@ -231,8 +231,8 @@ impl<E: Event> TestStore<E> {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_ids: &[EventId], event_ids: &[EventId],
) -> Result<BTreeSet<EventId>> { ) -> Result<HashSet<EventId>> {
let mut result = BTreeSet::new(); let mut result = HashSet::new();
let mut stack = event_ids.to_vec(); let mut stack = event_ids.to_vec();
// DFS for auth event chain // DFS for auth event chain
@ -263,7 +263,7 @@ impl<E: Event> TestStore<E> {
for ids in event_ids { for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list // TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained // when an event returns `auth_event_ids` self is not contained
let chain = self.auth_event_ids(room_id, &ids)?.into_iter().collect::<BTreeSet<_>>(); let chain = self.auth_event_ids(room_id, &ids)?.into_iter().collect::<HashSet<_>>();
chains.push(chain); chains.push(chain);
} }
@ -388,7 +388,7 @@ where
// all graphs start with these input events // all graphs start with these input events
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> { pub fn INITIAL_EVENTS() -> HashMap<EventId, Arc<StateEvent>> {
// this is always called so we can init the logger here // this is always called so we can init the logger here
let _ = LOGGER let _ = LOGGER
.call_once(|| tracer::fmt().with_env_filter(tracer::EnvFilter::from_default_env()).init()); .call_once(|| tracer::fmt().with_env_filter(tracer::EnvFilter::from_default_env()).init());