diff --git a/crates/ruma-state-res/CHANGELOG.md b/crates/ruma-state-res/CHANGELOG.md index f5c4bb58..2d16eaa0 100644 --- a/crates/ruma-state-res/CHANGELOG.md +++ b/crates/ruma-state-res/CHANGELOG.md @@ -1,5 +1,9 @@ # [unreleased] +Breaking changes: + +* state_res::resolve now doesn't take auth_events anymore and calculates it on its own instead + # 0.2.0 Breaking changes: diff --git a/crates/ruma-state-res/benches/state_res_bench.rs b/crates/ruma-state-res/benches/state_res_bench.rs index 6d5a7438..22980ac5 100644 --- a/crates/ruma-state-res/benches/state_res_bench.rs +++ b/crates/ruma-state-res/benches/state_res_bench.rs @@ -66,15 +66,7 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) { &room_id(), &RoomVersionId::Version6, &state_sets, - state_sets - .iter() - .map(|map| { - store - .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) - .unwrap() - }) - .collect(), - &|id| ev_map.get(id).map(Arc::clone), + |id| ev_map.get(id).map(Arc::clone), ) { Ok(state) => state, Err(e) => panic!("{}", e), @@ -89,7 +81,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { let ban = BAN_STATE_SET(); inner.extend(ban); - let store = TestStore(inner.clone()); + let _store = TestStore(inner.clone()); let state_set_a = [ inner.get(&event_id("CREATE")).unwrap(), @@ -123,15 +115,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { &room_id(), &RoomVersionId::Version6, &state_sets, - state_sets - .iter() - .map(|map| { - store - .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) - .unwrap() - }) - .collect(), - &|id| inner.get(id).map(Arc::clone), + |id| inner.get(id).map(Arc::clone), ) { Ok(state) => state, Err(_) => panic!("resolution failed during benchmarking"), diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 6e7620ea..ce2358f2 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -4,6 +4,7 @@ use std::{ sync::Arc, }; +use itertools::Itertools; use maplit::btreeset; use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_events::{ @@ -14,7 +15,7 @@ use ruma_events::{ EventType, }; use ruma_identifiers::{EventId, RoomId, RoomVersionId}; -use tracing::{debug, info, warn}; +use tracing::{debug, info, trace, warn}; mod error; pub mod event_auth; @@ -45,9 +46,6 @@ impl StateResolution { /// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork /// in the state of a room. /// - /// * `auth_events` - The full recursive set of `auth_events` for each event in the - /// `state_sets`. - /// /// * `fetch_event` - Any event not found in the `event_map` will defer to this closure to find /// the event. /// @@ -60,7 +58,6 @@ impl StateResolution { room_id: &RoomId, room_version: &RoomVersionId, state_sets: &[StateMap], - auth_events: Vec>, fetch_event: F, ) -> Result> where @@ -72,22 +69,43 @@ impl StateResolution { // Split non-conflicting and conflicting state let (clean, conflicting) = StateResolution::separate(state_sets); - info!("non conflicting {:?}", clean.len()); + info!("non conflicting events: {}", clean.len()); + trace!("{:?}", clean); if conflicting.is_empty() { info!("no conflicting state found"); return Ok(clean); } - info!("{} conflicting events", conflicting.len()); + info!("conflicting events: {}", conflicting.len()); + debug!("{:?}", conflicting); + + let mut iter = conflicting.values(); + let mut conflicting_state_sets = iter + .next() + .expect("we made sure conflicting is not empty") + .iter() + .map(|o| if let Some(e) = o { btreeset![e.clone()] } else { BTreeSet::new() }) + .collect::>(); + + for events in iter { + for i in 0..events.len() { + // This is okay because all vecs have the same length = number of states + if let Some(e) = &events[i] { + conflicting_state_sets[i].insert(e.clone()); + } + } + } // The set of auth events that are not common across server forks - let mut auth_diff = StateResolution::get_auth_chain_diff(room_id, &auth_events)?; - - debug!("auth diff size {:?}", auth_diff); + let mut auth_diff = + StateResolution::get_auth_chain_diff(room_id, &conflicting_state_sets, &fetch_event)?; // Add the auth_diff to conflicting now we have a full set of conflicting events - auth_diff.extend(conflicting.values().cloned().flatten()); + auth_diff.extend(conflicting.values().cloned().flatten().filter_map(|o| o)); + + debug!("auth diff: {}", auth_diff.len()); + trace!("{:?}", auth_diff); // `all_conflicted` contains unique items // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` @@ -97,7 +115,8 @@ impl StateResolution { let all_conflicted = auth_diff.into_iter().filter(|id| fetch_event(id).is_some()).collect::>(); - info!("full conflicted set is {} events", all_conflicted.len()); + info!("full conflicted set: {}", all_conflicted.len()); + debug!("{:?}", all_conflicted); // We used to check that all events are events from the correct room // this is now a check the caller of `resolve` must make. @@ -116,7 +135,8 @@ impl StateResolution { &fetch_event, ); - debug!("SRTD {:?}", sorted_control_levels); + debug!("sorted control events: {}", sorted_control_levels.len()); + trace!("{:?}", sorted_control_levels); let room_version = RoomVersion::new(room_version)?; // Sequentially auth check each control event. @@ -127,7 +147,8 @@ impl StateResolution { &fetch_event, )?; - debug!("AUTHED {:?}", resolved_control.iter().collect::>()); + debug!("resolved control events: {}", resolved_control.len()); + trace!("{:?}", resolved_control); // At this point the control_events have been resolved we now have to // sort the remaining events using the mainline of the resolved power level. @@ -141,17 +162,18 @@ impl StateResolution { .cloned() .collect::>(); - debug!("LEFT {:?}", events_to_resolve.iter().collect::>()); + debug!("events left to resolve: {}", events_to_resolve.len()); + trace!("{:?}", events_to_resolve); // This "epochs" power level event let power_event = resolved_control.get(&(EventType::RoomPowerLevels, "".into())); - debug!("PL {:?}", power_event); + debug!("power event: {:?}", power_event); let sorted_left_events = StateResolution::mainline_sort(&events_to_resolve, power_event, &fetch_event); - debug!("SORTED LEFT {:?}", sorted_left_events.iter().collect::>()); + trace!("events left, sorted: {:?}", sorted_left_events.iter().collect::>()); let mut resolved_state = StateResolution::iterative_auth_check( &room_version, @@ -174,9 +196,7 @@ impl StateResolution { /// that none of the other have this is a conflicting event. pub fn separate( state_sets: &[StateMap], - ) -> (StateMap, StateMap>) { - use itertools::Itertools; - + ) -> (StateMap, StateMap>>) { info!("separating {} sets of events into conflicted/unconflicted", state_sets.len()); let mut unconflicted_state = StateMap::new(); @@ -184,16 +204,14 @@ impl StateResolution { for key in state_sets.iter().flat_map(|map| map.keys()).unique() { let mut event_ids = - state_sets.iter().map(|state_set| state_set.get(key)).unique().collect::>(); + state_sets.iter().map(|state_set| state_set.get(key)).collect::>(); - if event_ids.len() == 1 { + if event_ids.iter().all_equal() { let id = event_ids.remove(0).expect("unconflicting `EventId` is not None"); unconflicted_state.insert(key.clone(), id.clone()); } else { - conflicted_state.insert( - key.clone(), - event_ids.into_iter().flatten().cloned().collect::>(), - ); + conflicted_state + .insert(key.clone(), event_ids.into_iter().map(|o| o.cloned()).collect()); } } @@ -201,22 +219,50 @@ impl StateResolution { } /// Returns a Vec of deduped EventIds that appear in some chains but not others. - pub fn get_auth_chain_diff( + pub fn get_auth_chain_diff( _room_id: &RoomId, - auth_event_ids: &[Vec], - ) -> Result> { + conflicting_state_sets: &[BTreeSet], + fetch_event: F, + ) -> Result> + where + E: Event, + F: Fn(&EventId) -> Option>, + { let mut chains = vec![]; - for ids in auth_event_ids { + // Conflicted state sets are just some top level state events. Now we fetch the complete + // auth chain of those events + for ids in conflicting_state_sets { // TODO state store `auth_event_ids` returns self in the event ids list // when an event returns `auth_event_ids` self is not contained - let chain = ids.iter().cloned().collect::>(); - chains.push(chain); + let mut todo = ids.iter().map(|e| e.clone()).collect::>(); + let mut auth_chain_ids = ids.clone(); // we also return the events we started with + + while let Some(event_id) = todo.iter().next().cloned() { + if let Some(pdu) = fetch_event(&event_id) { + todo.extend( + pdu.auth_events() + .clone() + .into_iter() + .collect::>() + .difference(&auth_chain_ids) + .cloned(), + ); + auth_chain_ids.extend(pdu.auth_events().into_iter()); + } else { + warn!("Could not find pdu mentioned in auth events."); + } + + todo.remove(&event_id); + } + + chains.push(auth_chain_ids); } - if let Some(chain) = chains.first().cloned() { - let rest = chains.iter().skip(1).flatten().cloned().collect(); - let common = chain.intersection(&rest).collect::>(); + if let Some(first) = chains.first().cloned() { + let common = chains + .iter() + .fold(first, |a, b| a.intersection(&b).cloned().collect::>()); Ok(chains.into_iter().flatten().filter(|id| !common.contains(&id)).collect()) } else { diff --git a/crates/ruma-state-res/tests/res_with_auth_ids.rs b/crates/ruma-state-res/tests/res_with_auth_ids.rs index 10be20ea..b1500eac 100644 --- a/crates/ruma-state-res/tests/res_with_auth_ids.rs +++ b/crates/ruma-state-res/tests/res_with_auth_ids.rs @@ -69,14 +69,6 @@ fn ban_with_auth_chains2() { &room_id(), &RoomVersionId::Version6, &state_sets, - state_sets - .iter() - .map(|map| { - store - .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) - .unwrap() - }) - .collect(), |id| ev_map.get(id).map(Arc::clone), ) { Ok(state) => state, diff --git a/crates/ruma-state-res/tests/state_res.rs b/crates/ruma-state-res/tests/state_res.rs index fdfb13dc..b7463b59 100644 --- a/crates/ruma-state-res/tests/state_res.rs +++ b/crates/ruma-state-res/tests/state_res.rs @@ -258,14 +258,6 @@ fn test_event_map_none() { &room_id(), &RoomVersionId::Version2, &state_sets, - state_sets - .iter() - .map(|map| { - store - .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) - .unwrap() - }) - .collect(), |id| ev_map.get(id).map(Arc::clone), ) { Ok(state) => state, diff --git a/crates/ruma-state-res/tests/utils.rs b/crates/ruma-state-res/tests/utils.rs index 22074958..a31f4a4d 100644 --- a/crates/ruma-state-res/tests/utils.rs +++ b/crates/ruma-state-res/tests/utils.rs @@ -109,20 +109,10 @@ pub fn do_check( .collect::>() ); - let resolved = StateResolution::resolve( - &room_id(), - &RoomVersionId::Version6, - &state_sets, - state_sets - .iter() - .map(|map| { - store - .auth_event_ids(&room_id(), &map.values().cloned().collect::>()) - .unwrap() - }) - .collect(), - |id| event_map.get(id).map(Arc::clone), - ); + let resolved = + StateResolution::resolve(&room_id(), &RoomVersionId::Version6, &state_sets, |id| { + event_map.get(id).map(Arc::clone) + }); match resolved { Ok(state) => state, Err(e) => panic!("resolution for {} failed: {}", node, e),