diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 0c80292f..66e051d8 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -79,7 +79,7 @@ where debug!("{:?}", conflicting); // The set of auth events that are not common across server forks - let mut auth_diff = get_auth_chain_diff(auth_chain_sets); + let mut auth_diff: HashSet<_> = get_auth_chain_diff(auth_chain_sets).collect(); // Add the auth_diff to conflicting now we have a full set of conflicting events auth_diff.extend(conflicting.values().cloned().flatten().flatten()); @@ -194,17 +194,17 @@ pub fn separate( } /// Returns a Vec of deduped EventIds that appear in some chains but not others. -pub fn get_auth_chain_diff(auth_chain_sets: Vec>) -> HashSet { - if let Some(first) = auth_chain_sets.first().cloned() { - let common = auth_chain_sets - .iter() - .skip(1) - .fold(first, |a, b| a.intersection(b).cloned().collect::>()); +pub fn get_auth_chain_diff( + auth_chain_sets: Vec>, +) -> impl Iterator { + let num_sets = auth_chain_sets.len(); - auth_chain_sets.into_iter().flatten().filter(|id| !common.contains(id)).collect() - } else { - HashSet::new() + let mut id_counts: HashMap = HashMap::new(); + for id in auth_chain_sets.into_iter().flatten() { + *id_counts.entry(id).or_default() += 1; } + + id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then(move || id)) } /// Events are sorted from "earliest" to "latest".