From e7db44989d68406393270d3a91815597385d3acb Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 21 Sep 2024 00:49:02 +0000 Subject: [PATCH] async state-res Signed-off-by: Jason Volk --- crates/ruma-state-res/Cargo.toml | 4 +- crates/ruma-state-res/src/event_auth.rs | 37 ++- crates/ruma-state-res/src/lib.rs | 367 ++++++++++++++---------- crates/ruma-state-res/src/test_utils.rs | 20 +- 4 files changed, 262 insertions(+), 166 deletions(-) diff --git a/crates/ruma-state-res/Cargo.toml b/crates/ruma-state-res/Cargo.toml index 0b6b7744..f24d57c9 100644 --- a/crates/ruma-state-res/Cargo.toml +++ b/crates/ruma-state-res/Cargo.toml @@ -18,9 +18,10 @@ all-features = true unstable-exhaustive-types = [] [dependencies] +futures-util = "0.3" itertools = "0.12.1" js_int = { workspace = true } -ruma-common = { workspace = true } +ruma-common = { workspace = true, features = ["api"] } ruma-events = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -34,6 +35,7 @@ criterion = { workspace = true, optional = true } maplit = { workspace = true } rand = { workspace = true } ruma-events = { workspace = true, features = ["unstable-pdu"] } +tokio = { version = "1", features = ["rt", "macros"] } tracing-subscriber = "0.3.16" [[bench]] diff --git a/crates/ruma-state-res/src/event_auth.rs b/crates/ruma-state-res/src/event_auth.rs index 91a386c7..4132c421 100644 --- a/crates/ruma-state-res/src/event_auth.rs +++ b/crates/ruma-state-res/src/event_auth.rs @@ -1,5 +1,6 @@ use std::{borrow::Borrow, collections::BTreeSet}; +use futures_util::Future; use js_int::{int, Int}; use ruma_common::{ serde::{Base64, Raw}, @@ -121,12 +122,18 @@ pub fn auth_types_for_event( /// /// The `fetch_state` closure should gather state from a state snapshot. We need to know if the /// event passes auth against some state not a recursive collection of auth_events fields. -pub fn auth_check( +pub async fn auth_check( room_version: &RoomVersion, - incoming_event: impl Event, - current_third_party_invite: Option, - fetch_state: impl Fn(&StateEventType, &str) -> Option, -) -> Result { + incoming_event: &Incoming, + current_third_party_invite: Option<&Incoming>, + fetch_state: F, +) -> Result +where + F: Fn(&'static StateEventType, &str) -> Fut, + Fut: Future> + Send, + Fetched: Event + Send, + Incoming: Event + Send, +{ debug!( "auth_check beginning for {} ({})", incoming_event.event_id(), @@ -216,7 +223,7 @@ pub fn auth_check( } */ - let room_create_event = match fetch_state(&StateEventType::RoomCreate, "") { + let room_create_event = match fetch_state(&StateEventType::RoomCreate, "").await { None => { warn!("no m.room.create event in auth chain"); return Ok(false); @@ -265,8 +272,8 @@ pub fn auth_check( } // If type is m.room.member - let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, ""); - let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str()); + let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, "").await; + let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str()).await; if *incoming_event.event_type() == TimelineEventType::RoomMember { debug!("starting m.room.member check"); @@ -290,9 +297,13 @@ pub fn auth_check( let user_for_join_auth = content.join_authorised_via_users_server.as_ref().and_then(|u| u.deserialize().ok()); - let user_for_join_auth_membership = user_for_join_auth - .as_ref() - .and_then(|auth_user| fetch_state(&StateEventType::RoomMember, auth_user.as_str())) + let user_for_join_auth_event = if let Some(auth_user) = user_for_join_auth.as_ref() { + fetch_state(&StateEventType::RoomMember, auth_user.as_str()).await + } else { + None + }; + + let user_for_join_auth_membership = user_for_join_auth_event .and_then(|mem| from_json_str::(mem.content().get()).ok()) .map(|mem| mem.membership) .unwrap_or(MembershipState::Leave); @@ -300,13 +311,13 @@ pub fn auth_check( if !valid_membership_change( room_version, target_user, - fetch_state(&StateEventType::RoomMember, target_user.as_str()).as_ref(), + fetch_state(&StateEventType::RoomMember, target_user.as_str()).await.as_ref(), sender, sender_member_event.as_ref(), &incoming_event, current_third_party_invite, power_levels_event.as_ref(), - fetch_state(&StateEventType::RoomJoinRules, "").as_ref(), + fetch_state(&StateEventType::RoomJoinRules, "").await.as_ref(), user_for_join_auth.as_deref(), &user_for_join_auth_membership, room_create_event, diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 02e08a8f..36c06e3b 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -5,6 +5,7 @@ use std::{ hash::Hash, }; +use futures_util::{future, stream, Future, StreamExt}; use itertools::Itertools; use js_int::{int, Int}; use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId}; @@ -52,16 +53,22 @@ pub type StateMap = HashMap<(StateEventType, String), T>; /// /// The caller of `resolve` must ensure that all the events are from the same room. Although this /// function takes a `RoomId` it does not check that each event is part of the same room. -pub fn resolve<'a, E, SetIter>( +pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>( room_version: &RoomVersionId, - state_sets: impl IntoIterator, - auth_chain_sets: Vec>, - fetch_event: impl Fn(&EventId) -> Option, + state_sets: impl IntoIterator + Send, + auth_chain_sets: &'a Vec>, + event_fetch: &Fetch, + event_exists: &Exists, ) -> Result> where - E: Event + Clone, - E::Id: 'a, - SetIter: Iterator> + Clone, + Fetch: Fn(E::Id) -> FetchFut + Sync, + FetchFut: Future> + Send, + Exists: Fn(E::Id) -> ExistsFut, + ExistsFut: Future + Send, + SetIter: Iterator> + Clone + Send, + E: Event + Send, + E::Id: Borrow + Send + Sync, + for<'b> &'b E: Send, { debug!("State resolution starting"); @@ -79,13 +86,16 @@ where debug!("conflicting events: {}", conflicting.len()); debug!("{conflicting:?}"); + let auth_chain_diff = + get_auth_chain_diff(&auth_chain_sets).chain(conflicting.into_values().flatten()); + // `all_conflicted` contains unique items // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` - let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets) - .chain(conflicting.into_values().flatten()) + let all_conflicted: HashSet = stream::iter(auth_chain_diff) // Don't honor events we cannot "verify" - .filter(|id| fetch_event(id.borrow()).is_some()) - .collect(); + .filter(|id| event_exists(id.clone())) + .collect() + .await; debug!("full conflicted set: {}", all_conflicted.len()); debug!("{all_conflicted:?}"); @@ -94,15 +104,15 @@ where // this is now a check the caller of `resolve` must make. // Get only the control events with a state_key: "" or ban/kick event (sender != state_key) - let control_events = all_conflicted - .iter() - .filter(|&id| is_power_event_id(id.borrow(), &fetch_event)) - .cloned() - .collect::>(); + let control_events = stream::iter(all_conflicted.iter()) + .filter(|&id| is_power_event_id(id, &event_fetch)) + .map(Clone::clone) + .collect::>() + .await; // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges let sorted_control_levels = - reverse_topological_power_sort(control_events, &all_conflicted, &fetch_event)?; + reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch).await?; debug!("sorted control events: {}", sorted_control_levels.len()); trace!("{sorted_control_levels:?}"); @@ -110,7 +120,8 @@ where let room_version = RoomVersion::new(room_version)?; // Sequentially auth check each control event. let resolved_control = - iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &fetch_event)?; + iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &event_fetch) + .await?; debug!("resolved control events: {}", resolved_control.len()); trace!("{resolved_control:?}"); @@ -135,7 +146,8 @@ where debug!("power event: {power_event:?}"); - let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?; + let sorted_left_events = + mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?; trace!("events left, sorted: {sorted_left_events:?}"); @@ -143,8 +155,9 @@ where &room_version, &sorted_left_events, resolved_control, // The control events are added to the final resolved state - &fetch_event, - )?; + &event_fetch, + ) + .await?; // Add unconflicted state to the resolved state // We priorities the unconflicting state @@ -188,15 +201,14 @@ where } /// Returns a Vec of deduped EventIds that appear in some chains but not others. -fn get_auth_chain_diff(auth_chain_sets: Vec>) -> impl Iterator +fn get_auth_chain_diff(auth_chain_sets: &Vec>) -> impl Iterator where - Id: Eq + Hash, + Id: Clone + Eq + Hash, { let num_sets = auth_chain_sets.len(); - let mut id_counts: HashMap = HashMap::new(); for id in auth_chain_sets.into_iter().flatten() { - *id_counts.entry(id).or_default() += 1; + *id_counts.entry(id.clone()).or_default() += 1; } id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id)) @@ -209,16 +221,22 @@ where /// /// The power level is negative because a higher power level is equated to an earlier (further back /// in time) origin server timestamp. -fn reverse_topological_power_sort( +async fn reverse_topological_power_sort( events_to_sort: Vec, auth_diff: &HashSet, - fetch_event: impl Fn(&EventId) -> Option, -) -> Result> { + fetch_event: &F, +) -> Result> +where + F: Fn(E::Id) -> Fut + Sync, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send + Sync, +{ debug!("reverse topological sort of power events"); let mut graph = HashMap::new(); for event_id in events_to_sort { - add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event); + add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, fetch_event).await; // TODO: if these functions are ever made async here // is a good place to yield every once in a while so other @@ -228,7 +246,7 @@ fn reverse_topological_power_sort( // This is used in the `key_fn` passed to the lexico_topo_sort fn let mut event_to_pl = HashMap::new(); for event_id in graph.keys() { - let pl = get_power_level_for_sender(event_id.borrow(), &fetch_event)?; + let pl = get_power_level_for_sender(event_id, fetch_event).await?; debug!("{event_id} power level {pl}"); event_to_pl.insert(event_id.clone(), pl); @@ -238,26 +256,30 @@ fn reverse_topological_power_sort( // tasks can make progress } - lexicographical_topological_sort(&graph, |event_id| { - let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?; - let pl = *event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?; + let event_to_pl = &event_to_pl; + let fetcher = |event_id: E::Id| async move { + let pl = *event_to_pl.get(event_id.borrow()).ok_or_else(|| Error::NotFound("".into()))?; + let ev = fetch_event(event_id).await.ok_or_else(|| Error::NotFound("".into()))?; Ok((pl, ev.origin_server_ts())) - }) + }; + + lexicographical_topological_sort(&graph, &fetcher).await } /// Sorts the event graph based on number of outgoing/incoming edges. /// /// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together /// with the event ID). -pub fn lexicographical_topological_sort( +pub async fn lexicographical_topological_sort( graph: &HashMap>, - key_fn: F, + key_fn: &F, ) -> Result> where - F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>, - Id: Clone + Eq + Ord + Hash + Borrow, + F: Fn(Id) -> Fut, + Fut: Future> + Send, + Id: Borrow + Clone + Eq + Hash + Ord + Send, { - #[derive(PartialEq, Eq, PartialOrd, Ord)] + #[derive(Eq, Ord, PartialEq, PartialOrd)] struct TieBreaker<'a, Id> { inv_power_level: Int, age: MilliSecondsSinceUnixEpoch, @@ -285,7 +307,7 @@ where for (node, edges) in graph { if edges.is_empty() { - let (power_level, age) = key_fn(node.borrow())?; + let (power_level, age) = key_fn(node.clone()).await?; // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need // smallest -> largest zero_outdegree.push(Reverse(TieBreaker { @@ -318,7 +340,7 @@ where // Only push on the heap once older events have been cleared out.remove(node.borrow()); if out.is_empty() { - let (power_level, age) = key_fn(node.borrow())?; + let (power_level, age) = key_fn(node.clone()).await?; heap.push(Reverse(TieBreaker { inv_power_level: -power_level, age, @@ -339,17 +361,23 @@ where /// Do NOT use this any where but topological sort, we find the power level for the eventId /// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level /// event). -fn get_power_level_for_sender( - event_id: &EventId, - fetch_event: impl Fn(&EventId) -> Option, -) -> serde_json::Result { +async fn get_power_level_for_sender( + event_id: &E::Id, + fetch_event: &F, +) -> serde_json::Result +where + F: Fn(E::Id) -> Fut, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send, +{ debug!("fetch event ({event_id}) senders power level"); - let event = fetch_event(event_id); + let event = fetch_event(event_id.clone()).await; let mut pl = None; for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() { - if let Some(aev) = fetch_event(aid.borrow()) { + if let Some(aev) = fetch_event(aid.clone()).await { if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { pl = Some(aev); break; @@ -381,12 +409,19 @@ fn get_power_level_for_sender( /// /// For each `events_to_check` event we gather the events needed to auth it from the the /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. -fn iterative_auth_check( +async fn iterative_auth_check( room_version: &RoomVersion, events_to_check: &[E::Id], unconflicted_state: StateMap, - fetch_event: impl Fn(&EventId) -> Option, -) -> Result> { + fetch_event: &F, +) -> Result> +where + F: Fn(E::Id) -> Fut, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Clone + Send, + for<'a> &'a E: Send, +{ debug!("starting iterative auth check"); debug!("performing auth checks on {events_to_check:?}"); @@ -394,7 +429,8 @@ fn iterative_auth_check( let mut resolved_state = unconflicted_state; for event_id in events_to_check { - let event = fetch_event(event_id.borrow()) + let event = fetch_event(event_id.clone()) + .await .ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?; let state_key = event .state_key() @@ -402,7 +438,7 @@ fn iterative_auth_check( let mut auth_events = StateMap::new(); for aid in event.auth_events() { - if let Some(ev) = fetch_event(aid.borrow()) { + if let Some(ev) = fetch_event(aid.clone()).await { // TODO synapse check "rejected_reason" which is most likely // related to soft-failing auth_events.insert( @@ -423,7 +459,7 @@ fn iterative_auth_check( event.content(), )? { if let Some(ev_id) = resolved_state.get(&key) { - if let Some(event) = fetch_event(ev_id.borrow()) { + if let Some(event) = fetch_event(ev_id.clone()).await { // TODO synapse checks `rejected_reason` is None here auth_events.insert(key.to_owned(), event); } @@ -438,9 +474,11 @@ fn iterative_auth_check( (*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu) }); - if auth_check(room_version, &event, current_third_party, |ty, key| { - auth_events.get(&ty.with_state_key(key)) - })? { + let fetch_state = |ty: &StateEventType, key: &str| { + future::ready(auth_events.get(&ty.with_state_key(key))) + }; + + if auth_check(room_version, &event, current_third_party, fetch_state).await? { // add event to resolved state map resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); } else { @@ -462,11 +500,17 @@ fn iterative_auth_check( /// power_level event. If there have been two power events the after the most recent are depth 0, /// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is /// "older" than depth 0. -fn mainline_sort( +async fn mainline_sort( to_sort: &[E::Id], resolved_power_level: Option, - fetch_event: impl Fn(&EventId) -> Option, -) -> Result> { + fetch_event: &F, +) -> Result> +where + F: Fn(E::Id) -> Fut, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Clone + Send, +{ debug!("mainline sort of events"); // There are no EventId's to sort, bail. @@ -479,11 +523,13 @@ fn mainline_sort( while let Some(p) = pl { mainline.push(p.clone()); - let event = fetch_event(p.borrow()) + let event = fetch_event(p.clone()) + .await .ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?; pl = None; for aid in event.auth_events() { - let ev = fetch_event(aid.borrow()) + let ev = fetch_event(aid.clone()) + .await .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") { pl = Some(aid.to_owned()); @@ -504,11 +550,15 @@ fn mainline_sort( let mut order_map = HashMap::new(); for ev_id in to_sort.iter() { - if let Some(event) = fetch_event(ev_id.borrow()) { - if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) { + if let Some(event) = fetch_event(ev_id.clone()).await { + if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, fetch_event).await { order_map.insert( ev_id, - (depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id), + ( + depth, + fetch_event(ev_id.clone()).await.map(|ev| ev.origin_server_ts()), + ev_id, + ), ); } } @@ -528,11 +578,17 @@ fn mainline_sort( /// Get the mainline depth from the `mainline_map` or finds a power_level event that has an /// associated mainline depth. -fn get_mainline_depth( +async fn get_mainline_depth( mut event: Option, mainline_map: &HashMap, - fetch_event: impl Fn(&EventId) -> Option, -) -> Result { + fetch_event: &F, +) -> Result +where + F: Fn(E::Id) -> Fut, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send, +{ while let Some(sort_ev) = event { debug!("mainline event_id {}", sort_ev.event_id()); let id = sort_ev.event_id(); @@ -542,7 +598,8 @@ fn get_mainline_depth( event = None; for aid in sort_ev.auth_events() { - let aev = fetch_event(aid.borrow()) + let aev = fetch_event(aid.clone()) + .await .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { event = Some(aev); @@ -554,18 +611,23 @@ fn get_mainline_depth( Ok(0) } -fn add_event_and_auth_chain_to_graph( +async fn add_event_and_auth_chain_to_graph( graph: &mut HashMap>, event_id: E::Id, auth_diff: &HashSet, - fetch_event: impl Fn(&EventId) -> Option, -) { + fetch_event: &F, +) where + F: Fn(E::Id) -> Fut, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Clone + Send, +{ let mut state = vec![event_id]; while let Some(eid) = state.pop() { graph.entry(eid.clone()).or_default(); // Prefer the store to event as the store filters dedups the events for aid in - fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten() + fetch_event(eid.clone()).await.as_ref().map(|ev| ev.auth_events()).into_iter().flatten() { if auth_diff.contains(aid.borrow()) { if !graph.contains_key(aid.borrow()) { @@ -579,8 +641,14 @@ fn add_event_and_auth_chain_to_graph( } } -fn is_power_event_id(event_id: &EventId, fetch: impl Fn(&EventId) -> Option) -> bool { - match fetch(event_id).as_ref() { +async fn is_power_event_id(event_id: &E::Id, fetch: &F) -> bool +where + F: Fn(E::Id) -> Fut, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send, +{ + match fetch(event_id.clone()).await.as_ref() { Some(state) => is_power_event(state), _ => false, } @@ -609,7 +677,7 @@ fn is_power_event(event: impl Event) -> bool { } /// Convenience trait for adding event type plus state key to state maps. -trait EventTypeExt { +pub trait EventTypeExt { fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); } @@ -662,7 +730,9 @@ mod tests { Event, EventTypeExt, StateMap, }; - fn test_event_sort() { + async fn test_event_sort() { + use futures_util::future::ready; + let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let events = INITIAL_EVENTS(); @@ -680,18 +750,19 @@ mod tests { .map(|pdu| pdu.event_id.clone()) .collect::>(); + let fetcher = |id| ready(events.get(&id).cloned()); let sorted_power_events = - crate::reverse_topological_power_sort(power_events, &auth_chain, |id| { - events.get(id).cloned() - }) - .unwrap(); + crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher) + .await + .unwrap(); let resolved_power = crate::iterative_auth_check( &RoomVersion::V6, &sorted_power_events, HashMap::new(), // unconflicted events - |id| events.get(id).cloned(), + &fetcher, ) + .await .expect("iterative auth check failed on resolved events"); // don't remove any events so we know it sorts them all correctly @@ -703,8 +774,7 @@ mod tests { resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned(); let sorted_event_ids = - crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).cloned()) - .unwrap(); + crate::mainline_sort(&events_to_sort, power_level, &fetcher).await.unwrap(); assert_eq!( vec![ @@ -721,17 +791,17 @@ mod tests { ); } - #[test] - fn test_sort() { + #[tokio::test] + async fn test_sort() { for _ in 0..20 { // since we shuffle the eventIds before we sort them introducing randomness // seems like we should test this a few times - test_event_sort(); + test_event_sort().await; } } - #[test] - fn ban_vs_power_level() { + #[tokio::test] + async fn ban_vs_power_level() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -774,11 +844,11 @@ mod tests { let expected_state_ids = vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::>(); - do_check(events, edges, expected_state_ids); + do_check(events, edges, expected_state_ids).await; } - #[test] - fn topic_basic() { + #[tokio::test] + async fn topic_basic() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -835,11 +905,11 @@ mod tests { let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::>(); - do_check(events, edges, expected_state_ids); + do_check(events, edges, expected_state_ids).await; } - #[test] - fn topic_reset() { + #[tokio::test] + async fn topic_reset() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -882,11 +952,11 @@ mod tests { let expected_state_ids = vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::>(); - do_check(events, edges, expected_state_ids); + do_check(events, edges, expected_state_ids).await; } - #[test] - fn join_rule_evasion() { + #[tokio::test] + async fn join_rule_evasion() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -914,11 +984,11 @@ mod tests { let expected_state_ids = vec![event_id("JR")]; - do_check(events, edges, expected_state_ids); + do_check(events, edges, expected_state_ids).await; } - #[test] - fn offtopic_power_level() { + #[tokio::test] + async fn offtopic_power_level() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -955,11 +1025,11 @@ mod tests { let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::>(); - do_check(events, edges, expected_state_ids); + do_check(events, edges, expected_state_ids).await; } - #[test] - fn topic_setting() { + #[tokio::test] + async fn topic_setting() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -1032,11 +1102,13 @@ mod tests { let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::>(); - do_check(events, edges, expected_state_ids); + do_check(events, edges, expected_state_ids).await; } - #[test] - fn test_event_map_none() { + #[tokio::test] + async fn test_event_map_none() { + use futures_util::future::ready; + let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -1046,27 +1118,29 @@ mod tests { let (state_at_bob, state_at_charlie, expected) = store.set_up(); let ev_map = store.0.clone(); + let fetcher = |id| ready(ev_map.get(&id).cloned()); + + let exists = |id: ::Id| ready(ev_map.get(&*id).is_some()); + let state_sets = [state_at_bob, state_at_charlie]; - let resolved = match crate::resolve( - &RoomVersionId::V2, - &state_sets, - state_sets - .iter() - .map(|map| { - store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap() - }) - .collect(), - |id| ev_map.get(id).cloned(), - ) { - Ok(state) => state, - Err(e) => panic!("{e}"), - }; + let auth_chain = state_sets + .iter() + .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()) + .collect(); + + let resolved = + match crate::resolve(&RoomVersionId::V2, &state_sets, &auth_chain, &fetcher, &exists) + .await + { + Ok(state) => state, + Err(e) => panic!("{e}"), + }; assert_eq!(expected, resolved); } - #[test] - fn test_lexicographical_sort() { + #[tokio::test] + async fn test_lexicographical_sort() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); @@ -1078,9 +1152,10 @@ mod tests { event_id("p") => hashset![event_id("o")], }; - let res = crate::lexicographical_topological_sort(&graph, |_id| { + let res = crate::lexicographical_topological_sort(&graph, &|_id| async { Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) }) + .await .unwrap(); assert_eq!( @@ -1092,8 +1167,8 @@ mod tests { ); } - #[test] - fn ban_with_auth_chains() { + #[tokio::test] + async fn ban_with_auth_chains() { let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let ban = BAN_STATE_SET(); @@ -1105,11 +1180,13 @@ mod tests { let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::>(); - do_check(&ban.values().cloned().collect::>(), edges, expected_state_ids); + do_check(&ban.values().cloned().collect::>(), edges, expected_state_ids).await; } - #[test] - fn ban_with_auth_chains2() { + #[tokio::test] + async fn ban_with_auth_chains2() { + use futures_util::future::ready; + let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); let init = INITIAL_EVENTS(); @@ -1147,20 +1224,20 @@ mod tests { let ev_map = &store.0; let state_sets = [state_set_a, state_set_b]; - let resolved = match crate::resolve( - &RoomVersionId::V6, - &state_sets, - state_sets - .iter() - .map(|map| { - store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap() - }) - .collect(), - |id| ev_map.get(id).cloned(), - ) { - Ok(state) => state, - Err(e) => panic!("{e}"), - }; + let auth_chain = state_sets + .iter() + .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()) + .collect(); + + let fetcher = |id: ::Id| ready(ev_map.get(&id).cloned()); + let exists = |id: ::Id| ready(ev_map.get(&id).is_some()); + let resolved = + match crate::resolve(&RoomVersionId::V6, &state_sets, &auth_chain, &fetcher, &exists) + .await + { + Ok(state) => state, + Err(e) => panic!("{e}"), + }; debug!( "{:#?}", @@ -1180,8 +1257,8 @@ mod tests { assert_eq!(expected.len(), resolved.len()); } - #[test] - fn join_rule_with_auth_chain() { + #[tokio::test] + async fn join_rule_with_auth_chain() { let join_rule = JOIN_RULE(); let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]] @@ -1191,7 +1268,7 @@ mod tests { let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::>(); - do_check(&join_rule.values().cloned().collect::>(), edges, expected_state_ids); + do_check(&join_rule.values().cloned().collect::>(), edges, expected_state_ids).await; } #[allow(non_snake_case)] diff --git a/crates/ruma-state-res/src/test_utils.rs b/crates/ruma-state-res/src/test_utils.rs index 6b06dc37..4ae23580 100644 --- a/crates/ruma-state-res/src/test_utils.rs +++ b/crates/ruma-state-res/src/test_utils.rs @@ -7,6 +7,7 @@ use std::{ }, }; +use futures_util::future::ready; use js_int::{int, uint}; use ruma_common::{ event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, @@ -31,7 +32,7 @@ use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap}; static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); -pub(crate) fn do_check( +pub(crate) async fn do_check( events: &[Arc], edges: Vec>, expected_state_ids: Vec, @@ -81,9 +82,10 @@ pub(crate) fn do_check( // Resolve the current state and add it to the state_at_event map then continue // on in "time" - for node in crate::lexicographical_topological_sort(&graph, |_id| { + for node in crate::lexicographical_topological_sort(&graph, &|_id| async { Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) }) + .await .unwrap() { let fake_event = fake_event_map.get(&node).unwrap(); @@ -117,9 +119,13 @@ pub(crate) fn do_check( }) .collect(); - let resolved = crate::resolve(&RoomVersionId::V6, state_sets, auth_chain_sets, |id| { - event_map.get(id).cloned() - }); + let event_map = &event_map; + let fetch = |id: ::Id| ready(event_map.get(&id).cloned()); + let exists = |id: ::Id| ready(event_map.get(&id).is_some()); + let resolved = + crate::resolve(&RoomVersionId::V6, state_sets, &auth_chain_sets, &fetch, &exists) + .await; + match resolved { Ok(state) => state, Err(e) => panic!("resolution for {node} failed: {e}"), @@ -614,7 +620,7 @@ pub(crate) mod event { } } - fn prev_events(&self) -> Box + '_> { + fn prev_events(&self) -> Box + Send + '_> { match &self.rest { Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)), Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()), @@ -623,7 +629,7 @@ pub(crate) mod event { } } - fn auth_events(&self) -> Box + '_> { + fn auth_events(&self) -> Box + Send + '_> { match &self.rest { Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)), Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()),