async state-res

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-09-21 00:49:02 +00:00
parent 1d0b06b581
commit e7db44989d
4 changed files with 262 additions and 166 deletions

View File

@ -18,9 +18,10 @@ all-features = true
unstable-exhaustive-types = [] unstable-exhaustive-types = []
[dependencies] [dependencies]
futures-util = "0.3"
itertools = "0.12.1" itertools = "0.12.1"
js_int = { workspace = true } js_int = { workspace = true }
ruma-common = { workspace = true } ruma-common = { workspace = true, features = ["api"] }
ruma-events = { workspace = true } ruma-events = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
@ -34,6 +35,7 @@ criterion = { workspace = true, optional = true }
maplit = { workspace = true } maplit = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
ruma-events = { workspace = true, features = ["unstable-pdu"] } ruma-events = { workspace = true, features = ["unstable-pdu"] }
tokio = { version = "1", features = ["rt", "macros"] }
tracing-subscriber = "0.3.16" tracing-subscriber = "0.3.16"
[[bench]] [[bench]]

View File

@ -1,5 +1,6 @@
use std::{borrow::Borrow, collections::BTreeSet}; use std::{borrow::Borrow, collections::BTreeSet};
use futures_util::Future;
use js_int::{int, Int}; use js_int::{int, Int};
use ruma_common::{ use ruma_common::{
serde::{Base64, Raw}, serde::{Base64, Raw},
@ -121,12 +122,18 @@ pub fn auth_types_for_event(
/// ///
/// The `fetch_state` closure should gather state from a state snapshot. We need to know if the /// The `fetch_state` closure should gather state from a state snapshot. We need to know if the
/// event passes auth against some state not a recursive collection of auth_events fields. /// event passes auth against some state not a recursive collection of auth_events fields.
pub fn auth_check<E: Event>( pub async fn auth_check<F, Fut, Fetched, Incoming>(
room_version: &RoomVersion, room_version: &RoomVersion,
incoming_event: impl Event, incoming_event: &Incoming,
current_third_party_invite: Option<impl Event>, current_third_party_invite: Option<&Incoming>,
fetch_state: impl Fn(&StateEventType, &str) -> Option<E>, fetch_state: F,
) -> Result<bool> { ) -> Result<bool>
where
F: Fn(&'static StateEventType, &str) -> Fut,
Fut: Future<Output = Option<Fetched>> + Send,
Fetched: Event + Send,
Incoming: Event + Send,
{
debug!( debug!(
"auth_check beginning for {} ({})", "auth_check beginning for {} ({})",
incoming_event.event_id(), incoming_event.event_id(),
@ -216,7 +223,7 @@ pub fn auth_check<E: Event>(
} }
*/ */
let room_create_event = match fetch_state(&StateEventType::RoomCreate, "") { let room_create_event = match fetch_state(&StateEventType::RoomCreate, "").await {
None => { None => {
warn!("no m.room.create event in auth chain"); warn!("no m.room.create event in auth chain");
return Ok(false); return Ok(false);
@ -265,8 +272,8 @@ pub fn auth_check<E: Event>(
} }
// If type is m.room.member // If type is m.room.member
let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, ""); let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, "").await;
let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str()); let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str()).await;
if *incoming_event.event_type() == TimelineEventType::RoomMember { if *incoming_event.event_type() == TimelineEventType::RoomMember {
debug!("starting m.room.member check"); debug!("starting m.room.member check");
@ -290,9 +297,13 @@ pub fn auth_check<E: Event>(
let user_for_join_auth = let user_for_join_auth =
content.join_authorised_via_users_server.as_ref().and_then(|u| u.deserialize().ok()); content.join_authorised_via_users_server.as_ref().and_then(|u| u.deserialize().ok());
let user_for_join_auth_membership = user_for_join_auth let user_for_join_auth_event = if let Some(auth_user) = user_for_join_auth.as_ref() {
.as_ref() fetch_state(&StateEventType::RoomMember, auth_user.as_str()).await
.and_then(|auth_user| fetch_state(&StateEventType::RoomMember, auth_user.as_str())) } else {
None
};
let user_for_join_auth_membership = user_for_join_auth_event
.and_then(|mem| from_json_str::<GetMembership>(mem.content().get()).ok()) .and_then(|mem| from_json_str::<GetMembership>(mem.content().get()).ok())
.map(|mem| mem.membership) .map(|mem| mem.membership)
.unwrap_or(MembershipState::Leave); .unwrap_or(MembershipState::Leave);
@ -300,13 +311,13 @@ pub fn auth_check<E: Event>(
if !valid_membership_change( if !valid_membership_change(
room_version, room_version,
target_user, target_user,
fetch_state(&StateEventType::RoomMember, target_user.as_str()).as_ref(), fetch_state(&StateEventType::RoomMember, target_user.as_str()).await.as_ref(),
sender, sender,
sender_member_event.as_ref(), sender_member_event.as_ref(),
&incoming_event, &incoming_event,
current_third_party_invite, current_third_party_invite,
power_levels_event.as_ref(), power_levels_event.as_ref(),
fetch_state(&StateEventType::RoomJoinRules, "").as_ref(), fetch_state(&StateEventType::RoomJoinRules, "").await.as_ref(),
user_for_join_auth.as_deref(), user_for_join_auth.as_deref(),
&user_for_join_auth_membership, &user_for_join_auth_membership,
room_create_event, room_create_event,

View File

@ -5,6 +5,7 @@ use std::{
hash::Hash, hash::Hash,
}; };
use futures_util::{future, stream, Future, StreamExt};
use itertools::Itertools; use itertools::Itertools;
use js_int::{int, Int}; use js_int::{int, Int};
use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId}; use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId};
@ -52,16 +53,22 @@ pub type StateMap<T> = HashMap<(StateEventType, String), T>;
/// ///
/// The caller of `resolve` must ensure that all the events are from the same room. Although this /// The caller of `resolve` must ensure that all the events are from the same room. Although this
/// function takes a `RoomId` it does not check that each event is part of the same room. /// function takes a `RoomId` it does not check that each event is part of the same room.
pub fn resolve<'a, E, SetIter>( pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>(
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SetIter>, state_sets: impl IntoIterator<IntoIter = SetIter> + Send,
auth_chain_sets: Vec<HashSet<E::Id>>, auth_chain_sets: &'a Vec<HashSet<E::Id>>,
fetch_event: impl Fn(&EventId) -> Option<E>, event_fetch: &Fetch,
event_exists: &Exists,
) -> Result<StateMap<E::Id>> ) -> Result<StateMap<E::Id>>
where where
E: Event + Clone, Fetch: Fn(E::Id) -> FetchFut + Sync,
E::Id: 'a, FetchFut: Future<Output = Option<E>> + Send,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone, Exists: Fn(E::Id) -> ExistsFut,
ExistsFut: Future<Output = bool> + Send,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send + Sync,
for<'b> &'b E: Send,
{ {
debug!("State resolution starting"); debug!("State resolution starting");
@ -79,13 +86,16 @@ where
debug!("conflicting events: {}", conflicting.len()); debug!("conflicting events: {}", conflicting.len());
debug!("{conflicting:?}"); debug!("{conflicting:?}");
let auth_chain_diff =
get_auth_chain_diff(&auth_chain_sets).chain(conflicting.into_values().flatten());
// `all_conflicted` contains unique items // `all_conflicted` contains unique items
// synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}`
let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets) let all_conflicted: HashSet<E::Id> = stream::iter(auth_chain_diff)
.chain(conflicting.into_values().flatten())
// Don't honor events we cannot "verify" // Don't honor events we cannot "verify"
.filter(|id| fetch_event(id.borrow()).is_some()) .filter(|id| event_exists(id.clone()))
.collect(); .collect()
.await;
debug!("full conflicted set: {}", all_conflicted.len()); debug!("full conflicted set: {}", all_conflicted.len());
debug!("{all_conflicted:?}"); debug!("{all_conflicted:?}");
@ -94,15 +104,15 @@ where
// this is now a check the caller of `resolve` must make. // this is now a check the caller of `resolve` must make.
// Get only the control events with a state_key: "" or ban/kick event (sender != state_key) // Get only the control events with a state_key: "" or ban/kick event (sender != state_key)
let control_events = all_conflicted let control_events = stream::iter(all_conflicted.iter())
.iter() .filter(|&id| is_power_event_id(id, &event_fetch))
.filter(|&id| is_power_event_id(id.borrow(), &fetch_event)) .map(Clone::clone)
.cloned() .collect::<Vec<_>>()
.collect::<Vec<_>>(); .await;
// Sort the control events based on power_level/clock/event_id and outgoing/incoming edges // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges
let sorted_control_levels = let sorted_control_levels =
reverse_topological_power_sort(control_events, &all_conflicted, &fetch_event)?; reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch).await?;
debug!("sorted control events: {}", sorted_control_levels.len()); debug!("sorted control events: {}", sorted_control_levels.len());
trace!("{sorted_control_levels:?}"); trace!("{sorted_control_levels:?}");
@ -110,7 +120,8 @@ where
let room_version = RoomVersion::new(room_version)?; let room_version = RoomVersion::new(room_version)?;
// Sequentially auth check each control event. // Sequentially auth check each control event.
let resolved_control = let resolved_control =
iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &fetch_event)?; iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &event_fetch)
.await?;
debug!("resolved control events: {}", resolved_control.len()); debug!("resolved control events: {}", resolved_control.len());
trace!("{resolved_control:?}"); trace!("{resolved_control:?}");
@ -135,7 +146,8 @@ where
debug!("power event: {power_event:?}"); debug!("power event: {power_event:?}");
let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?; let sorted_left_events =
mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?;
trace!("events left, sorted: {sorted_left_events:?}"); trace!("events left, sorted: {sorted_left_events:?}");
@ -143,8 +155,9 @@ where
&room_version, &room_version,
&sorted_left_events, &sorted_left_events,
resolved_control, // The control events are added to the final resolved state resolved_control, // The control events are added to the final resolved state
&fetch_event, &event_fetch,
)?; )
.await?;
// Add unconflicted state to the resolved state // Add unconflicted state to the resolved state
// We priorities the unconflicting state // We priorities the unconflicting state
@ -188,15 +201,14 @@ where
} }
/// Returns a Vec of deduped EventIds that appear in some chains but not others. /// Returns a Vec of deduped EventIds that appear in some chains but not others.
fn get_auth_chain_diff<Id>(auth_chain_sets: Vec<HashSet<Id>>) -> impl Iterator<Item = Id> fn get_auth_chain_diff<Id>(auth_chain_sets: &Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
where where
Id: Eq + Hash, Id: Clone + Eq + Hash,
{ {
let num_sets = auth_chain_sets.len(); let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<Id, usize> = HashMap::new(); let mut id_counts: HashMap<Id, usize> = HashMap::new();
for id in auth_chain_sets.into_iter().flatten() { for id in auth_chain_sets.into_iter().flatten() {
*id_counts.entry(id).or_default() += 1; *id_counts.entry(id.clone()).or_default() += 1;
} }
id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id)) id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
@ -209,16 +221,22 @@ where
/// ///
/// The power level is negative because a higher power level is equated to an earlier (further back /// The power level is negative because a higher power level is equated to an earlier (further back
/// in time) origin server timestamp. /// in time) origin server timestamp.
fn reverse_topological_power_sort<E: Event>( async fn reverse_topological_power_sort<E, F, Fut>(
events_to_sort: Vec<E::Id>, events_to_sort: Vec<E::Id>,
auth_diff: &HashSet<E::Id>, auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: &F,
) -> Result<Vec<E::Id>> { ) -> Result<Vec<E::Id>>
where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send + Sync,
{
debug!("reverse topological sort of power events"); debug!("reverse topological sort of power events");
let mut graph = HashMap::new(); let mut graph = HashMap::new();
for event_id in events_to_sort { for event_id in events_to_sort {
add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event); add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, fetch_event).await;
// TODO: if these functions are ever made async here // TODO: if these functions are ever made async here
// is a good place to yield every once in a while so other // is a good place to yield every once in a while so other
@ -228,7 +246,7 @@ fn reverse_topological_power_sort<E: Event>(
// This is used in the `key_fn` passed to the lexico_topo_sort fn // This is used in the `key_fn` passed to the lexico_topo_sort fn
let mut event_to_pl = HashMap::new(); let mut event_to_pl = HashMap::new();
for event_id in graph.keys() { for event_id in graph.keys() {
let pl = get_power_level_for_sender(event_id.borrow(), &fetch_event)?; let pl = get_power_level_for_sender(event_id, fetch_event).await?;
debug!("{event_id} power level {pl}"); debug!("{event_id} power level {pl}");
event_to_pl.insert(event_id.clone(), pl); event_to_pl.insert(event_id.clone(), pl);
@ -238,26 +256,30 @@ fn reverse_topological_power_sort<E: Event>(
// tasks can make progress // tasks can make progress
} }
lexicographical_topological_sort(&graph, |event_id| { let event_to_pl = &event_to_pl;
let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?; let fetcher = |event_id: E::Id| async move {
let pl = *event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?; let pl = *event_to_pl.get(event_id.borrow()).ok_or_else(|| Error::NotFound("".into()))?;
let ev = fetch_event(event_id).await.ok_or_else(|| Error::NotFound("".into()))?;
Ok((pl, ev.origin_server_ts())) Ok((pl, ev.origin_server_ts()))
}) };
lexicographical_topological_sort(&graph, &fetcher).await
} }
/// Sorts the event graph based on number of outgoing/incoming edges. /// Sorts the event graph based on number of outgoing/incoming edges.
/// ///
/// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together /// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together
/// with the event ID). /// with the event ID).
pub fn lexicographical_topological_sort<Id, F>( pub async fn lexicographical_topological_sort<Id, F, Fut>(
graph: &HashMap<Id, HashSet<Id>>, graph: &HashMap<Id, HashSet<Id>>,
key_fn: F, key_fn: &F,
) -> Result<Vec<Id>> ) -> Result<Vec<Id>>
where where
F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>, F: Fn(Id) -> Fut,
Id: Clone + Eq + Ord + Hash + Borrow<EventId>, Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send,
{ {
#[derive(PartialEq, Eq, PartialOrd, Ord)] #[derive(Eq, Ord, PartialEq, PartialOrd)]
struct TieBreaker<'a, Id> { struct TieBreaker<'a, Id> {
inv_power_level: Int, inv_power_level: Int,
age: MilliSecondsSinceUnixEpoch, age: MilliSecondsSinceUnixEpoch,
@ -285,7 +307,7 @@ where
for (node, edges) in graph { for (node, edges) in graph {
if edges.is_empty() { if edges.is_empty() {
let (power_level, age) = key_fn(node.borrow())?; let (power_level, age) = key_fn(node.clone()).await?;
// The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need
// smallest -> largest // smallest -> largest
zero_outdegree.push(Reverse(TieBreaker { zero_outdegree.push(Reverse(TieBreaker {
@ -318,7 +340,7 @@ where
// Only push on the heap once older events have been cleared // Only push on the heap once older events have been cleared
out.remove(node.borrow()); out.remove(node.borrow());
if out.is_empty() { if out.is_empty() {
let (power_level, age) = key_fn(node.borrow())?; let (power_level, age) = key_fn(node.clone()).await?;
heap.push(Reverse(TieBreaker { heap.push(Reverse(TieBreaker {
inv_power_level: -power_level, inv_power_level: -power_level,
age, age,
@ -339,17 +361,23 @@ where
/// Do NOT use this any where but topological sort, we find the power level for the eventId /// Do NOT use this any where but topological sort, we find the power level for the eventId
/// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level /// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level
/// event). /// event).
fn get_power_level_for_sender<E: Event>( async fn get_power_level_for_sender<E, F, Fut>(
event_id: &EventId, event_id: &E::Id,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: &F,
) -> serde_json::Result<Int> { ) -> serde_json::Result<Int>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
{
debug!("fetch event ({event_id}) senders power level"); debug!("fetch event ({event_id}) senders power level");
let event = fetch_event(event_id); let event = fetch_event(event_id.clone()).await;
let mut pl = None; let mut pl = None;
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() { for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
if let Some(aev) = fetch_event(aid.borrow()) { if let Some(aev) = fetch_event(aid.clone()).await {
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
pl = Some(aev); pl = Some(aev);
break; break;
@ -381,12 +409,19 @@ fn get_power_level_for_sender<E: Event>(
/// ///
/// For each `events_to_check` event we gather the events needed to auth it from the the /// For each `events_to_check` event we gather the events needed to auth it from the the
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
fn iterative_auth_check<E: Event + Clone>( async fn iterative_auth_check<E, F, Fut>(
room_version: &RoomVersion, room_version: &RoomVersion,
events_to_check: &[E::Id], events_to_check: &[E::Id],
unconflicted_state: StateMap<E::Id>, unconflicted_state: StateMap<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: &F,
) -> Result<StateMap<E::Id>> { ) -> Result<StateMap<E::Id>>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
for<'a> &'a E: Send,
{
debug!("starting iterative auth check"); debug!("starting iterative auth check");
debug!("performing auth checks on {events_to_check:?}"); debug!("performing auth checks on {events_to_check:?}");
@ -394,7 +429,8 @@ fn iterative_auth_check<E: Event + Clone>(
let mut resolved_state = unconflicted_state; let mut resolved_state = unconflicted_state;
for event_id in events_to_check { for event_id in events_to_check {
let event = fetch_event(event_id.borrow()) let event = fetch_event(event_id.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?;
let state_key = event let state_key = event
.state_key() .state_key()
@ -402,7 +438,7 @@ fn iterative_auth_check<E: Event + Clone>(
let mut auth_events = StateMap::new(); let mut auth_events = StateMap::new();
for aid in event.auth_events() { for aid in event.auth_events() {
if let Some(ev) = fetch_event(aid.borrow()) { if let Some(ev) = fetch_event(aid.clone()).await {
// TODO synapse check "rejected_reason" which is most likely // TODO synapse check "rejected_reason" which is most likely
// related to soft-failing // related to soft-failing
auth_events.insert( auth_events.insert(
@ -423,7 +459,7 @@ fn iterative_auth_check<E: Event + Clone>(
event.content(), event.content(),
)? { )? {
if let Some(ev_id) = resolved_state.get(&key) { if let Some(ev_id) = resolved_state.get(&key) {
if let Some(event) = fetch_event(ev_id.borrow()) { if let Some(event) = fetch_event(ev_id.clone()).await {
// TODO synapse checks `rejected_reason` is None here // TODO synapse checks `rejected_reason` is None here
auth_events.insert(key.to_owned(), event); auth_events.insert(key.to_owned(), event);
} }
@ -438,9 +474,11 @@ fn iterative_auth_check<E: Event + Clone>(
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu) (*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
}); });
if auth_check(room_version, &event, current_third_party, |ty, key| { let fetch_state = |ty: &StateEventType, key: &str| {
auth_events.get(&ty.with_state_key(key)) future::ready(auth_events.get(&ty.with_state_key(key)))
})? { };
if auth_check(room_version, &event, current_third_party, fetch_state).await? {
// add event to resolved state map // add event to resolved state map
resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone());
} else { } else {
@ -462,11 +500,17 @@ fn iterative_auth_check<E: Event + Clone>(
/// power_level event. If there have been two power events the after the most recent are depth 0, /// power_level event. If there have been two power events the after the most recent are depth 0,
/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is /// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
/// "older" than depth 0. /// "older" than depth 0.
fn mainline_sort<E: Event>( async fn mainline_sort<E, F, Fut>(
to_sort: &[E::Id], to_sort: &[E::Id],
resolved_power_level: Option<E::Id>, resolved_power_level: Option<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: &F,
) -> Result<Vec<E::Id>> { ) -> Result<Vec<E::Id>>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
{
debug!("mainline sort of events"); debug!("mainline sort of events");
// There are no EventId's to sort, bail. // There are no EventId's to sort, bail.
@ -479,11 +523,13 @@ fn mainline_sort<E: Event>(
while let Some(p) = pl { while let Some(p) = pl {
mainline.push(p.clone()); mainline.push(p.clone());
let event = fetch_event(p.borrow()) let event = fetch_event(p.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
pl = None; pl = None;
for aid in event.auth_events() { for aid in event.auth_events() {
let ev = fetch_event(aid.borrow()) let ev = fetch_event(aid.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") { if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
pl = Some(aid.to_owned()); pl = Some(aid.to_owned());
@ -504,11 +550,15 @@ fn mainline_sort<E: Event>(
let mut order_map = HashMap::new(); let mut order_map = HashMap::new();
for ev_id in to_sort.iter() { for ev_id in to_sort.iter() {
if let Some(event) = fetch_event(ev_id.borrow()) { if let Some(event) = fetch_event(ev_id.clone()).await {
if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) { if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, fetch_event).await {
order_map.insert( order_map.insert(
ev_id, ev_id,
(depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id), (
depth,
fetch_event(ev_id.clone()).await.map(|ev| ev.origin_server_ts()),
ev_id,
),
); );
} }
} }
@ -528,11 +578,17 @@ fn mainline_sort<E: Event>(
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an /// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
/// associated mainline depth. /// associated mainline depth.
fn get_mainline_depth<E: Event>( async fn get_mainline_depth<E, F, Fut>(
mut event: Option<E>, mut event: Option<E>,
mainline_map: &HashMap<E::Id, usize>, mainline_map: &HashMap<E::Id, usize>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: &F,
) -> Result<usize> { ) -> Result<usize>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
{
while let Some(sort_ev) = event { while let Some(sort_ev) = event {
debug!("mainline event_id {}", sort_ev.event_id()); debug!("mainline event_id {}", sort_ev.event_id());
let id = sort_ev.event_id(); let id = sort_ev.event_id();
@ -542,7 +598,8 @@ fn get_mainline_depth<E: Event>(
event = None; event = None;
for aid in sort_ev.auth_events() { for aid in sort_ev.auth_events() {
let aev = fetch_event(aid.borrow()) let aev = fetch_event(aid.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
event = Some(aev); event = Some(aev);
@ -554,18 +611,23 @@ fn get_mainline_depth<E: Event>(
Ok(0) Ok(0)
} }
fn add_event_and_auth_chain_to_graph<E: Event>( async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
graph: &mut HashMap<E::Id, HashSet<E::Id>>, graph: &mut HashMap<E::Id, HashSet<E::Id>>,
event_id: E::Id, event_id: E::Id,
auth_diff: &HashSet<E::Id>, auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>, fetch_event: &F,
) { ) where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
{
let mut state = vec![event_id]; let mut state = vec![event_id];
while let Some(eid) = state.pop() { while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default(); graph.entry(eid.clone()).or_default();
// Prefer the store to event as the store filters dedups the events // Prefer the store to event as the store filters dedups the events
for aid in for aid in
fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten() fetch_event(eid.clone()).await.as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
{ {
if auth_diff.contains(aid.borrow()) { if auth_diff.contains(aid.borrow()) {
if !graph.contains_key(aid.borrow()) { if !graph.contains_key(aid.borrow()) {
@ -579,8 +641,14 @@ fn add_event_and_auth_chain_to_graph<E: Event>(
} }
} }
fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool { async fn is_power_event_id<E, F, Fut>(event_id: &E::Id, fetch: &F) -> bool
match fetch(event_id).as_ref() { where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
{
match fetch(event_id.clone()).await.as_ref() {
Some(state) => is_power_event(state), Some(state) => is_power_event(state),
_ => false, _ => false,
} }
@ -609,7 +677,7 @@ fn is_power_event(event: impl Event) -> bool {
} }
/// Convenience trait for adding event type plus state key to state maps. /// Convenience trait for adding event type plus state key to state maps.
trait EventTypeExt { pub trait EventTypeExt {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String); fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
} }
@ -662,7 +730,9 @@ mod tests {
Event, EventTypeExt, StateMap, Event, EventTypeExt, StateMap,
}; };
fn test_event_sort() { async fn test_event_sort() {
use futures_util::future::ready;
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = INITIAL_EVENTS(); let events = INITIAL_EVENTS();
@ -680,18 +750,19 @@ mod tests {
.map(|pdu| pdu.event_id.clone()) .map(|pdu| pdu.event_id.clone())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let fetcher = |id| ready(events.get(&id).cloned());
let sorted_power_events = let sorted_power_events =
crate::reverse_topological_power_sort(power_events, &auth_chain, |id| { crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher)
events.get(id).cloned() .await
}) .unwrap();
.unwrap();
let resolved_power = crate::iterative_auth_check( let resolved_power = crate::iterative_auth_check(
&RoomVersion::V6, &RoomVersion::V6,
&sorted_power_events, &sorted_power_events,
HashMap::new(), // unconflicted events HashMap::new(), // unconflicted events
|id| events.get(id).cloned(), &fetcher,
) )
.await
.expect("iterative auth check failed on resolved events"); .expect("iterative auth check failed on resolved events");
// don't remove any events so we know it sorts them all correctly // don't remove any events so we know it sorts them all correctly
@ -703,8 +774,7 @@ mod tests {
resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned(); resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned();
let sorted_event_ids = let sorted_event_ids =
crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).cloned()) crate::mainline_sort(&events_to_sort, power_level, &fetcher).await.unwrap();
.unwrap();
assert_eq!( assert_eq!(
vec![ vec![
@ -721,17 +791,17 @@ mod tests {
); );
} }
#[test] #[tokio::test]
fn test_sort() { async fn test_sort() {
for _ in 0..20 { for _ in 0..20 {
// since we shuffle the eventIds before we sort them introducing randomness // since we shuffle the eventIds before we sort them introducing randomness
// seems like we should test this a few times // seems like we should test this a few times
test_event_sort(); test_event_sort().await;
} }
} }
#[test] #[tokio::test]
fn ban_vs_power_level() { async fn ban_vs_power_level() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -774,11 +844,11 @@ mod tests {
let expected_state_ids = let expected_state_ids =
vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::<Vec<_>>(); vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids); do_check(events, edges, expected_state_ids).await;
} }
#[test] #[tokio::test]
fn topic_basic() { async fn topic_basic() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -835,11 +905,11 @@ mod tests {
let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::<Vec<_>>(); let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids); do_check(events, edges, expected_state_ids).await;
} }
#[test] #[tokio::test]
fn topic_reset() { async fn topic_reset() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -882,11 +952,11 @@ mod tests {
let expected_state_ids = let expected_state_ids =
vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::<Vec<_>>(); vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids); do_check(events, edges, expected_state_ids).await;
} }
#[test] #[tokio::test]
fn join_rule_evasion() { async fn join_rule_evasion() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -914,11 +984,11 @@ mod tests {
let expected_state_ids = vec![event_id("JR")]; let expected_state_ids = vec![event_id("JR")];
do_check(events, edges, expected_state_ids); do_check(events, edges, expected_state_ids).await;
} }
#[test] #[tokio::test]
fn offtopic_power_level() { async fn offtopic_power_level() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -955,11 +1025,11 @@ mod tests {
let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>(); let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids); do_check(events, edges, expected_state_ids).await;
} }
#[test] #[tokio::test]
fn topic_setting() { async fn topic_setting() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -1032,11 +1102,13 @@ mod tests {
let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::<Vec<_>>(); let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids); do_check(events, edges, expected_state_ids).await;
} }
#[test] #[tokio::test]
fn test_event_map_none() { async fn test_event_map_none() {
use futures_util::future::ready;
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -1046,27 +1118,29 @@ mod tests {
let (state_at_bob, state_at_charlie, expected) = store.set_up(); let (state_at_bob, state_at_charlie, expected) = store.set_up();
let ev_map = store.0.clone(); let ev_map = store.0.clone();
let fetcher = |id| ready(ev_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&*id).is_some());
let state_sets = [state_at_bob, state_at_charlie]; let state_sets = [state_at_bob, state_at_charlie];
let resolved = match crate::resolve( let auth_chain = state_sets
&RoomVersionId::V2, .iter()
&state_sets, .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap())
state_sets .collect();
.iter()
.map(|map| { let resolved =
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap() match crate::resolve(&RoomVersionId::V2, &state_sets, &auth_chain, &fetcher, &exists)
}) .await
.collect(), {
|id| ev_map.get(id).cloned(), Ok(state) => state,
) { Err(e) => panic!("{e}"),
Ok(state) => state, };
Err(e) => panic!("{e}"),
};
assert_eq!(expected, resolved); assert_eq!(expected, resolved);
} }
#[test] #[tokio::test]
fn test_lexicographical_sort() { async fn test_lexicographical_sort() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -1078,9 +1152,10 @@ mod tests {
event_id("p") => hashset![event_id("o")], event_id("p") => hashset![event_id("o")],
}; };
let res = crate::lexicographical_topological_sort(&graph, |_id| { let res = crate::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
}) })
.await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
@ -1092,8 +1167,8 @@ mod tests {
); );
} }
#[test] #[tokio::test]
fn ban_with_auth_chains() { async fn ban_with_auth_chains() {
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let ban = BAN_STATE_SET(); let ban = BAN_STATE_SET();
@ -1105,11 +1180,13 @@ mod tests {
let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::<Vec<_>>(); let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids); do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
} }
#[test] #[tokio::test]
fn ban_with_auth_chains2() { async fn ban_with_auth_chains2() {
use futures_util::future::ready;
let _ = let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let init = INITIAL_EVENTS(); let init = INITIAL_EVENTS();
@ -1147,20 +1224,20 @@ mod tests {
let ev_map = &store.0; let ev_map = &store.0;
let state_sets = [state_set_a, state_set_b]; let state_sets = [state_set_a, state_set_b];
let resolved = match crate::resolve( let auth_chain = state_sets
&RoomVersionId::V6, .iter()
&state_sets, .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap())
state_sets .collect();
.iter()
.map(|map| { let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap() let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
}) let resolved =
.collect(), match crate::resolve(&RoomVersionId::V6, &state_sets, &auth_chain, &fetcher, &exists)
|id| ev_map.get(id).cloned(), .await
) { {
Ok(state) => state, Ok(state) => state,
Err(e) => panic!("{e}"), Err(e) => panic!("{e}"),
}; };
debug!( debug!(
"{:#?}", "{:#?}",
@ -1180,8 +1257,8 @@ mod tests {
assert_eq!(expected.len(), resolved.len()); assert_eq!(expected.len(), resolved.len());
} }
#[test] #[tokio::test]
fn join_rule_with_auth_chain() { async fn join_rule_with_auth_chain() {
let join_rule = JOIN_RULE(); let join_rule = JOIN_RULE();
let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]] let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
@ -1191,7 +1268,7 @@ mod tests {
let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>(); let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids); do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]

View File

@ -7,6 +7,7 @@ use std::{
}, },
}; };
use futures_util::future::ready;
use js_int::{int, uint}; use js_int::{int, uint};
use ruma_common::{ use ruma_common::{
event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
@ -31,7 +32,7 @@ use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap};
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
pub(crate) fn do_check( pub(crate) async fn do_check(
events: &[Arc<PduEvent>], events: &[Arc<PduEvent>],
edges: Vec<Vec<OwnedEventId>>, edges: Vec<Vec<OwnedEventId>>,
expected_state_ids: Vec<OwnedEventId>, expected_state_ids: Vec<OwnedEventId>,
@ -81,9 +82,10 @@ pub(crate) fn do_check(
// Resolve the current state and add it to the state_at_event map then continue // Resolve the current state and add it to the state_at_event map then continue
// on in "time" // on in "time"
for node in crate::lexicographical_topological_sort(&graph, |_id| { for node in crate::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
}) })
.await
.unwrap() .unwrap()
{ {
let fake_event = fake_event_map.get(&node).unwrap(); let fake_event = fake_event_map.get(&node).unwrap();
@ -117,9 +119,13 @@ pub(crate) fn do_check(
}) })
.collect(); .collect();
let resolved = crate::resolve(&RoomVersionId::V6, state_sets, auth_chain_sets, |id| { let event_map = &event_map;
event_map.get(id).cloned() let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
}); let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
let resolved =
crate::resolve(&RoomVersionId::V6, state_sets, &auth_chain_sets, &fetch, &exists)
.await;
match resolved { match resolved {
Ok(state) => state, Ok(state) => state,
Err(e) => panic!("resolution for {node} failed: {e}"), Err(e) => panic!("resolution for {node} failed: {e}"),
@ -614,7 +620,7 @@ pub(crate) mod event {
} }
} }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + Send + '_> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)), Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()), Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()),
@ -623,7 +629,7 @@ pub(crate) mod event {
} }
} }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + Send + '_> {
match &self.rest { match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)), Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()), Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()),