state-res: Make the API generic over the event ID storage

This commit is contained in:
Jonas Platte 2021-11-27 18:47:17 +01:00
parent a9c12f0909
commit 16f031fabb
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
5 changed files with 138 additions and 119 deletions

View File

@ -8,6 +8,7 @@
#![allow(clippy::exhaustive_structs)] #![allow(clippy::exhaustive_structs)]
use std::{ use std::{
borrow::Borrow,
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::TryInto, convert::TryInto,
sync::{ sync::{
@ -182,11 +183,7 @@ impl<E: Event> TestStore<E> {
} }
/// Returns a Vec of the related auth events to the given `event`. /// Returns a Vec of the related auth events to the given `event`.
fn auth_event_ids( fn auth_event_ids(&self, room_id: &RoomId, event_ids: Vec<E::Id>) -> Result<HashSet<E::Id>> {
&self,
room_id: &RoomId,
event_ids: Vec<Box<EventId>>,
) -> Result<HashSet<Box<EventId>>> {
let mut result = HashSet::new(); let mut result = HashSet::new();
let mut stack = event_ids; let mut stack = event_ids;
@ -199,7 +196,7 @@ impl<E: Event> TestStore<E> {
result.insert(ev_id.clone()); result.insert(ev_id.clone());
let event = self.get_event(room_id, &ev_id)?; let event = self.get_event(room_id, ev_id.borrow())?;
stack.extend(event.auth_events().map(ToOwned::to_owned)); stack.extend(event.auth_events().map(ToOwned::to_owned));
} }
@ -207,13 +204,8 @@ impl<E: Event> TestStore<E> {
Ok(result) Ok(result)
} }
/// Returns a Vec<Box<EventId>> representing the difference in auth chains of the given /// Returns a vector representing the difference in auth chains of the given `events`.
/// `events`. fn auth_chain_diff(&self, room_id: &RoomId, event_ids: Vec<Vec<E::Id>>) -> Result<Vec<E::Id>> {
fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<Box<EventId>>>,
) -> Result<Vec<Box<EventId>>> {
let mut auth_chain_sets = vec![]; let mut auth_chain_sets = vec![];
for ids in event_ids { for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list // TODO state store `auth_event_ids` returns self in the event ids list
@ -226,9 +218,13 @@ impl<E: Event> TestStore<E> {
let common = auth_chain_sets let common = auth_chain_sets
.iter() .iter()
.skip(1) .skip(1)
.fold(first, |a, b| a.intersection(b).cloned().collect::<HashSet<Box<EventId>>>()); .fold(first, |a, b| a.intersection(b).cloned().collect::<HashSet<_>>());
Ok(auth_chain_sets.into_iter().flatten().filter(|id| !common.contains(id)).collect()) Ok(auth_chain_sets
.into_iter()
.flatten()
.filter(|id| !common.contains(id.borrow()))
.collect())
} else { } else {
Ok(vec![]) Ok(vec![])
} }
@ -546,7 +542,9 @@ mod event {
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
impl Event for StateEvent { impl Event for StateEvent {
fn event_id(&self) -> &EventId { type Id = Box<EventId>;
fn event_id(&self) -> &Self::Id {
&self.event_id &self.event_id
} }
@ -604,28 +602,28 @@ mod event {
} }
} }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| &**id)), Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter().map(|id| &**id)), Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()),
#[cfg(not(feature = "unstable-exhaustive-types"))] #[cfg(not(feature = "unstable-exhaustive-types"))]
_ => unreachable!("new PDU version"), _ => unreachable!("new PDU version"),
} }
} }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| &**id)), Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter().map(|id| &**id)), Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()),
#[cfg(not(feature = "unstable-exhaustive-types"))] #[cfg(not(feature = "unstable-exhaustive-types"))]
_ => unreachable!("new PDU version"), _ => unreachable!("new PDU version"),
} }
} }
fn redacts(&self) -> Option<&EventId> { fn redacts(&self) -> Option<&Self::Id> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => ev.redacts.as_deref(), Pdu::RoomV1Pdu(ev) => ev.redacts.as_ref(),
Pdu::RoomV3Pdu(ev) => ev.redacts.as_deref(), Pdu::RoomV3Pdu(ev) => ev.redacts.as_ref(),
#[cfg(not(feature = "unstable-exhaustive-types"))] #[cfg(not(feature = "unstable-exhaustive-types"))]
_ => unreachable!("new PDU version"), _ => unreachable!("new PDU version"),
} }

View File

@ -1,4 +1,4 @@
use std::{collections::BTreeSet, convert::TryFrom}; use std::{borrow::Borrow, collections::BTreeSet, convert::TryFrom};
use js_int::{int, Int}; use js_int::{int, Int};
use ruma_events::{ use ruma_events::{
@ -749,8 +749,8 @@ fn check_redaction(
// If the domain of the event_id of the event being redacted is the same as the // If the domain of the event_id of the event being redacted is the same as the
// domain of the event_id of the m.room.redaction, allow // domain of the event_id of the m.room.redaction, allow
if redaction_event.event_id().server_name() if redaction_event.event_id().borrow().server_name()
== redaction_event.redacts().as_ref().and_then(|id| id.server_name()) == redaction_event.redacts().as_ref().and_then(|&id| id.borrow().server_name())
{ {
info!("redaction event allowed via room version 1 rules"); info!("redaction event allowed via room version 1 rules");
return Ok(true); return Ok(true);

View File

@ -1,6 +1,8 @@
use std::{ use std::{
borrow::Borrow,
cmp::Reverse, cmp::Reverse,
collections::{BTreeMap, BinaryHeap, HashMap, HashSet}, collections::{BTreeMap, BinaryHeap, HashMap, HashSet},
hash::Hash,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -53,12 +55,13 @@ pub type StateMap<T> = HashMap<(EventType, String), T>;
pub fn resolve<'a, E, SetIter>( pub fn resolve<'a, E, SetIter>(
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SetIter>, state_sets: impl IntoIterator<IntoIter = SetIter>,
auth_chain_sets: Vec<HashSet<Box<EventId>>>, auth_chain_sets: Vec<HashSet<E::Id>>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<Box<EventId>>> ) -> Result<StateMap<E::Id>>
where where
E: Event + Clone, E: Event + Clone,
SetIter: Iterator<Item = &'a StateMap<Box<EventId>>> + Clone, E::Id: 'a,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
{ {
info!("State resolution starting"); info!("State resolution starting");
@ -82,7 +85,7 @@ where
// FIXME: Use into_values() once MSRV >= 1.54 // FIXME: Use into_values() once MSRV >= 1.54
.chain(conflicting.into_iter().flat_map(|(_k, v)| v)) .chain(conflicting.into_iter().flat_map(|(_k, v)| v))
// Don't honor events we cannot "verify" // Don't honor events we cannot "verify"
.filter(|id| fetch_event(id).is_some()) .filter(|id| fetch_event(id.borrow()).is_some())
.collect(); .collect();
info!("full conflicted set: {}", all_conflicted.len()); info!("full conflicted set: {}", all_conflicted.len());
@ -94,7 +97,7 @@ where
// Get only the control events with a state_key: "" or ban/kick event (sender != state_key) // Get only the control events with a state_key: "" or ban/kick event (sender != state_key)
let control_events = all_conflicted let control_events = all_conflicted
.iter() .iter()
.filter(|id| is_power_event_id(id, &fetch_event)) .filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -121,7 +124,7 @@ where
// auth // auth
let events_to_resolve = all_conflicted let events_to_resolve = all_conflicted
.iter() .iter()
.filter(|&id| !deduped_power_ev.contains(id)) .filter(|&id| !deduped_power_ev.contains(id.borrow()))
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -133,8 +136,7 @@ where
debug!("power event: {:?}", power_event); debug!("power event: {:?}", power_event);
let sorted_left_events = let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?;
mainline_sort(&events_to_resolve, power_event.map(|id| &**id), &fetch_event)?;
trace!("events left, sorted: {:?}", sorted_left_events); trace!("events left, sorted: {:?}", sorted_left_events);
@ -158,9 +160,12 @@ where
/// State is determined to be conflicting if for the given key (EventType, StateKey) there is not /// State is determined to be conflicting if for the given key (EventType, StateKey) there is not
/// exactly one eventId. This includes missing events, if one state_set includes an event that none /// exactly one eventId. This includes missing events, if one state_set includes an event that none
/// of the other have this is a conflicting event. /// of the other have this is a conflicting event.
fn separate<'a>( fn separate<'a, Id>(
state_sets_iter: impl Iterator<Item = &'a StateMap<Box<EventId>>> + Clone, state_sets_iter: impl Iterator<Item = &'a StateMap<Id>> + Clone,
) -> (StateMap<Box<EventId>>, StateMap<Vec<Box<EventId>>>) { ) -> (StateMap<Id>, StateMap<Vec<Id>>)
where
Id: Clone + Eq + 'a,
{
let mut unconflicted_state = StateMap::new(); let mut unconflicted_state = StateMap::new();
let mut conflicted_state = StateMap::new(); let mut conflicted_state = StateMap::new();
@ -184,12 +189,13 @@ fn separate<'a>(
} }
/// Returns a Vec of deduped EventIds that appear in some chains but not others. /// Returns a Vec of deduped EventIds that appear in some chains but not others.
fn get_auth_chain_diff( fn get_auth_chain_diff<Id>(auth_chain_sets: Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
auth_chain_sets: Vec<HashSet<Box<EventId>>>, where
) -> impl Iterator<Item = Box<EventId>> { Id: Eq + Hash,
{
let num_sets = auth_chain_sets.len(); let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<Box<EventId>, usize> = HashMap::new(); let mut id_counts: HashMap<Id, usize> = HashMap::new();
for id in auth_chain_sets.into_iter().flatten() { for id in auth_chain_sets.into_iter().flatten() {
*id_counts.entry(id).or_default() += 1; *id_counts.entry(id).or_default() += 1;
} }
@ -205,10 +211,10 @@ fn get_auth_chain_diff(
/// The power level is negative because a higher power level is equated to an earlier (further back /// The power level is negative because a higher power level is equated to an earlier (further back
/// in time) origin server timestamp. /// in time) origin server timestamp.
fn reverse_topological_power_sort<E: Event>( fn reverse_topological_power_sort<E: Event>(
events_to_sort: Vec<Box<EventId>>, events_to_sort: Vec<E::Id>,
auth_diff: &HashSet<Box<EventId>>, auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<Box<EventId>>> { ) -> Result<Vec<E::Id>> {
debug!("reverse topological sort of power events"); debug!("reverse topological sort of power events");
let mut graph = HashMap::new(); let mut graph = HashMap::new();
@ -223,7 +229,7 @@ fn reverse_topological_power_sort<E: Event>(
// This is used in the `key_fn` passed to the lexico_topo_sort fn // This is used in the `key_fn` passed to the lexico_topo_sort fn
let mut event_to_pl = HashMap::new(); let mut event_to_pl = HashMap::new();
for event_id in graph.keys() { for event_id in graph.keys() {
let pl = get_power_level_for_sender(event_id, &fetch_event)?; let pl = get_power_level_for_sender(event_id.borrow(), &fetch_event)?;
info!("{} power level {}", event_id, pl); info!("{} power level {}", event_id, pl);
event_to_pl.insert(event_id.clone(), pl); event_to_pl.insert(event_id.clone(), pl);
@ -244,18 +250,19 @@ fn reverse_topological_power_sort<E: Event>(
/// ///
/// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together /// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together
/// with the event ID). /// with the event ID).
pub fn lexicographical_topological_sort<F>( pub fn lexicographical_topological_sort<Id, F>(
graph: &HashMap<Box<EventId>, HashSet<Box<EventId>>>, graph: &HashMap<Id, HashSet<Id>>,
key_fn: F, key_fn: F,
) -> Result<Vec<Box<EventId>>> ) -> Result<Vec<Id>>
where where
F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>, F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>,
Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
{ {
#[derive(PartialEq, Eq, PartialOrd, Ord)] #[derive(PartialEq, Eq, PartialOrd, Ord)]
struct TieBreaker<'a> { struct TieBreaker<'a, Id> {
inv_power_level: Int, inv_power_level: Int,
age: MilliSecondsSinceUnixEpoch, age: MilliSecondsSinceUnixEpoch,
event_id: &'a EventId, event_id: &'a Id,
} }
info!("starting lexicographical topological sort"); info!("starting lexicographical topological sort");
@ -272,14 +279,14 @@ where
// The number of events that depend on the given event (the EventId key) // The number of events that depend on the given event (the EventId key)
// How many events reference this event in the DAG as a parent // How many events reference this event in the DAG as a parent
let mut reverse_graph: HashMap<&EventId, HashSet<&EventId>> = HashMap::new(); let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new();
// Vec of nodes that have zero out degree, least recent events. // Vec of nodes that have zero out degree, least recent events.
let mut zero_outdegree: Vec<Reverse<TieBreaker<'_>>> = vec![]; let mut zero_outdegree = Vec::new();
for (node, edges) in graph { for (node, edges) in graph {
if edges.is_empty() { if edges.is_empty() {
let (power_level, age) = key_fn(node)?; let (power_level, age) = key_fn(node.borrow())?;
// The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need
// smallest -> largest // smallest -> largest
zero_outdegree.push(Reverse(TieBreaker { zero_outdegree.push(Reverse(TieBreaker {
@ -306,13 +313,13 @@ where
for &parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") { for &parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") {
// The number of outgoing edges this node has // The number of outgoing edges this node has
let out = outdegree_map let out = outdegree_map
.get_mut(parent) .get_mut(parent.borrow())
.expect("outdegree_map knows of all referenced EventIds"); .expect("outdegree_map knows of all referenced EventIds");
// Only push on the heap once older events have been cleared // Only push on the heap once older events have been cleared
out.remove(node); out.remove(node.borrow());
if out.is_empty() { if out.is_empty() {
let (power_level, age) = key_fn(node)?; let (power_level, age) = key_fn(node.borrow())?;
heap.push(Reverse(TieBreaker { heap.push(Reverse(TieBreaker {
inv_power_level: -power_level, inv_power_level: -power_level,
age, age,
@ -322,7 +329,7 @@ where
} }
// synapse yields we push then return the vec // synapse yields we push then return the vec
sorted.push(node.to_owned()); sorted.push(node.clone());
} }
Ok(sorted) Ok(sorted)
@ -357,7 +364,7 @@ fn get_power_level_for_sender<E: Event>(
let mut pl = None; let mut pl = None;
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() { for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
if let Some(aev) = fetch_event(aid) { if let Some(aev) = fetch_event(aid.borrow()) {
if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") {
pl = Some(aev); pl = Some(aev);
break; break;
@ -385,16 +392,16 @@ fn get_power_level_for_sender<E: Event>(
/// ## Returns /// ## Returns
/// ///
/// The `unconflicted_state` combined with the newly auth'ed events. So any event that fails the /// The `unconflicted_state` combined with the newly auth'ed events. So any event that fails the
/// `event_auth::auth_check` will be excluded from the returned `StateMap<Box<EventId>>`. /// `event_auth::auth_check` will be excluded from the returned state map.
/// ///
/// For each `events_to_check` event we gather the events needed to auth it from the the /// For each `events_to_check` event we gather the events needed to auth it from the the
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
fn iterative_auth_check<E: Event + Clone>( fn iterative_auth_check<E: Event + Clone>(
room_version: &RoomVersion, room_version: &RoomVersion,
events_to_check: &[Box<EventId>], events_to_check: &[E::Id],
unconflicted_state: StateMap<Box<EventId>>, unconflicted_state: StateMap<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<Box<EventId>>> { ) -> Result<StateMap<E::Id>> {
info!("starting iterative auth check"); info!("starting iterative auth check");
debug!("performing auth checks on {:?}", events_to_check); debug!("performing auth checks on {:?}", events_to_check);
@ -402,7 +409,7 @@ fn iterative_auth_check<E: Event + Clone>(
let mut resolved_state = unconflicted_state; let mut resolved_state = unconflicted_state;
for event_id in events_to_check { for event_id in events_to_check {
let event = fetch_event(event_id) let event = fetch_event(event_id.borrow())
.ok_or_else(|| Error::NotFound(format!("Failed to find {}", event_id)))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {}", event_id)))?;
let state_key = event let state_key = event
.state_key() .state_key()
@ -410,7 +417,7 @@ fn iterative_auth_check<E: Event + Clone>(
let mut auth_events = HashMap::new(); let mut auth_events = HashMap::new();
for aid in event.auth_events() { for aid in event.auth_events() {
if let Some(ev) = fetch_event(aid) { if let Some(ev) = fetch_event(aid.borrow()) {
// TODO synapse check "rejected_reason" which is most likely // TODO synapse check "rejected_reason" which is most likely
// related to soft-failing // related to soft-failing
auth_events.insert( auth_events.insert(
@ -436,7 +443,7 @@ fn iterative_auth_check<E: Event + Clone>(
event.content(), event.content(),
)? { )? {
if let Some(ev_id) = resolved_state.get(&key) { if let Some(ev_id) = resolved_state.get(&key) {
if let Some(event) = fetch_event(ev_id) { if let Some(event) = fetch_event(ev_id.borrow()) {
// TODO synapse checks `rejected_reason` is None here // TODO synapse checks `rejected_reason` is None here
auth_events.insert(key.to_owned(), event); auth_events.insert(key.to_owned(), event);
} }
@ -447,7 +454,7 @@ fn iterative_auth_check<E: Event + Clone>(
#[allow(clippy::redundant_closure)] #[allow(clippy::redundant_closure)]
let most_recent_prev_event = let most_recent_prev_event =
event.prev_events().filter_map(|id| fetch_event(id)).next_back(); event.prev_events().filter_map(|id| fetch_event(id.borrow())).next_back();
// The key for this is (eventType + a state_key of the signed token not sender) so // The key for this is (eventType + a state_key of the signed token not sender) so
// search for it // search for it
@ -488,10 +495,10 @@ fn iterative_auth_check<E: Event + Clone>(
/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is /// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
/// "older" than depth 0. /// "older" than depth 0.
fn mainline_sort<E: Event>( fn mainline_sort<E: Event>(
to_sort: &[Box<EventId>], to_sort: &[E::Id],
resolved_power_level: Option<&EventId>, resolved_power_level: Option<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<Box<EventId>>> { ) -> Result<Vec<E::Id>> {
debug!("mainline sort of events"); debug!("mainline sort of events");
// There are no EventId's to sort, bail. // There are no EventId's to sort, bail.
@ -500,15 +507,15 @@ fn mainline_sort<E: Event>(
} }
let mut mainline = vec![]; let mut mainline = vec![];
let mut pl = resolved_power_level.map(ToOwned::to_owned); let mut pl = resolved_power_level;
while let Some(p) = pl { while let Some(p) = pl {
mainline.push(p.clone()); mainline.push(p.clone());
let event = let event = fetch_event(p.borrow())
fetch_event(&p).ok_or_else(|| Error::NotFound(format!("Failed to find {}", p)))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {}", p)))?;
pl = None; pl = None;
for aid in event.auth_events() { for aid in event.auth_events() {
let ev = fetch_event(aid) let ev = fetch_event(aid.borrow())
.ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?;
if is_type_and_key(&ev, &EventType::RoomPowerLevels, "") { if is_type_and_key(&ev, &EventType::RoomPowerLevels, "") {
pl = Some(aid.to_owned()); pl = Some(aid.to_owned());
@ -529,11 +536,11 @@ fn mainline_sort<E: Event>(
let mut order_map = HashMap::new(); let mut order_map = HashMap::new();
for ev_id in to_sort.iter() { for ev_id in to_sort.iter() {
if let Some(event) = fetch_event(ev_id) { if let Some(event) = fetch_event(ev_id.borrow()) {
if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) { if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) {
order_map.insert( order_map.insert(
ev_id, ev_id,
(depth, fetch_event(ev_id).map(|ev| ev.origin_server_ts()), ev_id), (depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id),
); );
} }
} }
@ -555,19 +562,19 @@ fn mainline_sort<E: Event>(
/// associated mainline depth. /// associated mainline depth.
fn get_mainline_depth<E: Event>( fn get_mainline_depth<E: Event>(
mut event: Option<E>, mut event: Option<E>,
mainline_map: &HashMap<Box<EventId>, usize>, mainline_map: &HashMap<E::Id, usize>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<usize> { ) -> Result<usize> {
while let Some(sort_ev) = event { while let Some(sort_ev) = event {
debug!("mainline event_id {}", sort_ev.event_id()); debug!("mainline event_id {}", sort_ev.event_id());
let id = sort_ev.event_id(); let id = sort_ev.event_id();
if let Some(depth) = mainline_map.get(id) { if let Some(depth) = mainline_map.get(id.borrow()) {
return Ok(*depth); return Ok(*depth);
} }
event = None; event = None;
for aid in sort_ev.auth_events() { for aid in sort_ev.auth_events() {
let aev = fetch_event(aid) let aev = fetch_event(aid.borrow())
.ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?;
if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") {
event = Some(aev); event = Some(aev);
@ -580,23 +587,25 @@ fn get_mainline_depth<E: Event>(
} }
fn add_event_and_auth_chain_to_graph<E: Event>( fn add_event_and_auth_chain_to_graph<E: Event>(
graph: &mut HashMap<Box<EventId>, HashSet<Box<EventId>>>, graph: &mut HashMap<E::Id, HashSet<E::Id>>,
event_id: Box<EventId>, event_id: E::Id,
auth_diff: &HashSet<Box<EventId>>, auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: impl Fn(&EventId) -> Option<E>,
) { ) {
let mut state = vec![event_id]; let mut state = vec![event_id];
while let Some(eid) = state.pop() { while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default(); graph.entry(eid.clone()).or_default();
// Prefer the store to event as the store filters dedups the events // Prefer the store to event as the store filters dedups the events
for aid in fetch_event(&eid).as_ref().map(|ev| ev.auth_events()).into_iter().flatten() { for aid in
if auth_diff.contains(aid) { fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
if !graph.contains_key(aid) { {
if auth_diff.contains(aid.borrow()) {
if !graph.contains_key(aid.borrow()) {
state.push(aid.to_owned()); state.push(aid.to_owned());
} }
// We just inserted this at the start of the while loop // We just inserted this at the start of the while loop
graph.get_mut(&eid).unwrap().insert(aid.to_owned()); graph.get_mut(eid.borrow()).unwrap().insert(aid.to_owned());
} }
} }
} }
@ -672,7 +681,7 @@ mod tests {
}) })
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let auth_chain = HashSet::new(); let auth_chain: HashSet<Box<EventId>> = HashSet::new();
let power_events = event_map let power_events = event_map
.values() .values()
@ -699,13 +708,11 @@ mod tests {
events_to_sort.shuffle(&mut rand::thread_rng()); events_to_sort.shuffle(&mut rand::thread_rng());
let power_level = resolved_power.get(&(EventType::RoomPowerLevels, "".to_owned())); let power_level = resolved_power.get(&(EventType::RoomPowerLevels, "".to_owned())).cloned();
let sorted_event_ids = let sorted_event_ids =
crate::mainline_sort(&events_to_sort, power_level.map(|id| &**id), |id| { crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).map(Arc::clone))
events.get(id).map(Arc::clone) .unwrap();
})
.unwrap();
assert_eq!( assert_eq!(
vec![ vec![

View File

@ -1,4 +1,9 @@
use std::sync::Arc; use std::{
borrow::Borrow,
fmt::{Debug, Display},
hash::Hash,
sync::Arc,
};
use ruma_common::MilliSecondsSinceUnixEpoch; use ruma_common::MilliSecondsSinceUnixEpoch;
use ruma_events::EventType; use ruma_events::EventType;
@ -7,8 +12,10 @@ use serde_json::value::RawValue as RawJsonValue;
/// Abstraction of a PDU so users can have their own PDU types. /// Abstraction of a PDU so users can have their own PDU types.
pub trait Event { pub trait Event {
type Id: Clone + Debug + Display + Eq + Ord + Hash + Borrow<EventId>;
/// The `EventId` of this event. /// The `EventId` of this event.
fn event_id(&self) -> &EventId; fn event_id(&self) -> &Self::Id;
/// The `RoomId` of this event. /// The `RoomId` of this event.
fn room_id(&self) -> &RoomId; fn room_id(&self) -> &RoomId;
@ -30,18 +37,20 @@ pub trait Event {
/// The events before this event. /// The events before this event.
// Requires GATs to avoid boxing (and TAIT for making it convenient). // Requires GATs to avoid boxing (and TAIT for making it convenient).
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_>; fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_>;
/// All the authenticating events for this event. /// All the authenticating events for this event.
// Requires GATs to avoid boxing (and TAIT for making it convenient). // Requires GATs to avoid boxing (and TAIT for making it convenient).
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_>; fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_>;
/// If this event is a redaction event this is the event it redacts. /// If this event is a redaction event this is the event it redacts.
fn redacts(&self) -> Option<&EventId>; fn redacts(&self) -> Option<&Self::Id>;
} }
impl<T: Event> Event for &T { impl<T: Event> Event for &T {
fn event_id(&self) -> &EventId { type Id = T::Id;
fn event_id(&self) -> &Self::Id {
(*self).event_id() (*self).event_id()
} }
@ -69,21 +78,23 @@ impl<T: Event> Event for &T {
(*self).state_key() (*self).state_key()
} }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
(*self).prev_events() (*self).prev_events()
} }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
(*self).auth_events() (*self).auth_events()
} }
fn redacts(&self) -> Option<&EventId> { fn redacts(&self) -> Option<&Self::Id> {
(*self).redacts() (*self).redacts()
} }
} }
impl<T: Event> Event for Arc<T> { impl<T: Event> Event for Arc<T> {
fn event_id(&self) -> &EventId { type Id = T::Id;
fn event_id(&self) -> &Self::Id {
(&**self).event_id() (&**self).event_id()
} }
@ -111,15 +122,15 @@ impl<T: Event> Event for Arc<T> {
(&**self).state_key() (&**self).state_key()
} }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
(&**self).prev_events() (&**self).prev_events()
} }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
(&**self).auth_events() (&**self).auth_events()
} }
fn redacts(&self) -> Option<&EventId> { fn redacts(&self) -> Option<&Self::Id> {
(&**self).redacts() (&**self).redacts()
} }
} }

View File

@ -1,4 +1,5 @@
use std::{ use std::{
borrow::Borrow,
collections::{BTreeMap, HashMap, HashSet}, collections::{BTreeMap, HashMap, HashSet},
convert::TryInto, convert::TryInto,
sync::{ sync::{
@ -218,8 +219,8 @@ impl<E: Event> TestStore<E> {
pub fn auth_event_ids( pub fn auth_event_ids(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_ids: Vec<Box<EventId>>, event_ids: Vec<E::Id>,
) -> Result<HashSet<Box<EventId>>> { ) -> Result<HashSet<E::Id>> {
let mut result = HashSet::new(); let mut result = HashSet::new();
let mut stack = event_ids; let mut stack = event_ids;
@ -231,7 +232,7 @@ impl<E: Event> TestStore<E> {
result.insert(ev_id.clone()); result.insert(ev_id.clone());
let event = self.get_event(room_id, &ev_id)?; let event = self.get_event(room_id, ev_id.borrow())?;
stack.extend(event.auth_events().map(ToOwned::to_owned)); stack.extend(event.auth_events().map(ToOwned::to_owned));
} }
@ -550,7 +551,9 @@ pub mod event {
use crate::Event; use crate::Event;
impl Event for StateEvent { impl Event for StateEvent {
fn event_id(&self) -> &EventId { type Id = Box<EventId>;
fn event_id(&self) -> &Self::Id {
&self.event_id &self.event_id
} }
@ -608,28 +611,28 @@ pub mod event {
} }
} }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| &**id)), Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter().map(|id| &**id)), Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()),
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
_ => unreachable!("new PDU version"), _ => unreachable!("new PDU version"),
} }
} }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &EventId> + '_> { fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| &**id)), Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter().map(|id| &**id)), Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()),
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
_ => unreachable!("new PDU version"), _ => unreachable!("new PDU version"),
} }
} }
fn redacts(&self) -> Option<&EventId> { fn redacts(&self) -> Option<&Self::Id> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => ev.redacts.as_deref(), Pdu::RoomV1Pdu(ev) => ev.redacts.as_ref(),
Pdu::RoomV3Pdu(ev) => ev.redacts.as_deref(), Pdu::RoomV3Pdu(ev) => ev.redacts.as_ref(),
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
_ => unreachable!("new PDU version"), _ => unreachable!("new PDU version"),
} }