diff --git a/crates/ruma-state-res/benches/state_res_bench.rs b/crates/ruma-state-res/benches/state_res_bench.rs index 707e4248..3cf27ca0 100644 --- a/crates/ruma-state-res/benches/state_res_bench.rs +++ b/crates/ruma-state-res/benches/state_res_bench.rs @@ -61,10 +61,10 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) { b.iter(|| { let ev_map = store.0.clone(); - let state_sets = vec![state_at_bob.clone(), state_at_charlie.clone()]; + let state_sets = [&state_at_bob, &state_at_charlie]; let _ = match state_res::resolve( &RoomVersionId::Version6, - &state_sets, + state_sets, state_sets .iter() .map(|map| { @@ -125,10 +125,10 @@ fn resolve_deeper_event_set(c: &mut Criterion) { .collect::>(); b.iter(|| { - let state_sets = vec![state_set_a.clone(), state_set_b.clone()]; + let state_sets = [&state_set_a, &state_set_b]; let _ = match state_res::resolve( &RoomVersionId::Version6, - &state_sets, + state_sets, state_sets .iter() .map(|map| { diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index dce57abf..98c34526 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -51,20 +51,21 @@ type EventMap = HashMap; /// /// The caller of `resolve` must ensure that all the events are from the same room. Although this /// function takes a `RoomId` it does not check that each event is part of the same room. -pub fn resolve( +pub fn resolve<'a, E, F, SSI>( room_version: &RoomVersionId, - state_sets: &[StateMap], + state_sets: impl IntoIterator, auth_chain_sets: Vec>, fetch_event: F, ) -> Result> where E: Event + Clone, F: Fn(&EventId) -> Option, + SSI: Iterator> + Clone, { info!("State resolution starting"); // Split non-conflicting and conflicting state - let (clean, conflicting) = separate(state_sets); + let (clean, conflicting) = separate(state_sets.into_iter()); info!("non conflicting events: {}", clean.len()); trace!("{:?}", clean); @@ -158,15 +159,15 @@ where /// State is determined to be conflicting if for the given key (EventType, StateKey) there is not /// exactly one eventId. This includes missing events, if one state_set includes an event that none /// of the other have this is a conflicting event. -fn separate(state_sets: &[StateMap]) -> (StateMap, StateMap>) { - info!("separating {} sets of events into conflicted/unconflicted", state_sets.len()); - +fn separate<'a>( + state_sets_iter: impl Iterator> + Clone, +) -> (StateMap, StateMap>) { let mut unconflicted_state = StateMap::new(); let mut conflicted_state = StateMap::new(); - for key in state_sets.iter().flat_map(|map| map.keys()).unique() { + for key in state_sets_iter.clone().flat_map(|map| map.keys()).unique() { let mut event_ids = - state_sets.iter().map(|state_set| state_set.get(key)).collect::>(); + state_sets_iter.clone().map(|state_set| state_set.get(key)).collect::>(); if event_ids.iter().all_equal() { // First .unwrap() is okay because @@ -975,7 +976,7 @@ mod tests { let (state_at_bob, state_at_charlie, expected) = store.set_up(); let ev_map: EventMap> = store.0.clone(); - let state_sets = vec![state_at_bob, state_at_charlie]; + let state_sets = [state_at_bob, state_at_charlie]; let resolved = match crate::resolve( &RoomVersionId::Version2, &state_sets, @@ -1079,7 +1080,7 @@ mod tests { .collect::>(); let ev_map: EventMap> = store.0.clone(); - let state_sets = vec![state_set_a, state_set_b]; + let state_sets = [state_set_a, state_set_b]; let resolved = match crate::resolve( &RoomVersionId::Version6, &state_sets, diff --git a/crates/ruma-state-res/src/test_utils.rs b/crates/ruma-state-res/src/test_utils.rs index 5a68cfaf..21532923 100644 --- a/crates/ruma-state-res/src/test_utils.rs +++ b/crates/ruma-state-res/src/test_utils.rs @@ -89,11 +89,8 @@ pub fn do_check( } else if prev_events.len() == 1 { state_at_event.get(prev_events.iter().next().unwrap()).unwrap().clone() } else { - let state_sets = prev_events - .iter() - .filter_map(|k| state_at_event.get(k)) - .cloned() - .collect::>(); + let state_sets = + prev_events.iter().filter_map(|k| state_at_event.get(k)).collect::>(); info!( "{:#?}", @@ -106,17 +103,17 @@ pub fn do_check( .collect::>() ); - let resolved = crate::resolve( - &RoomVersionId::Version6, - &state_sets, - state_sets - .iter() - .map(|map| { - store.auth_event_ids(&room_id(), map.values().cloned().collect()).unwrap() - }) - .collect(), - |id| event_map.get(id).map(Arc::clone), - ); + let auth_chain_sets = state_sets + .iter() + .map(|map| { + store.auth_event_ids(&room_id(), map.values().cloned().collect()).unwrap() + }) + .collect(); + + let resolved = + crate::resolve(&RoomVersionId::Version6, state_sets, auth_chain_sets, |id| { + event_map.get(id).map(Arc::clone) + }); match resolved { Ok(state) => state, Err(e) => panic!("resolution for {} failed: {}", node, e),