From 1a550585bf025cce48ef8b734339245092bc986e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 3 Dec 2024 23:03:23 +0000 Subject: [PATCH] state-res: parallelize fetches within some loops Signed-off-by: Jason Volk --- crates/ruma-state-res/src/lib.rs | 170 ++++++++++++++++--------------- 1 file changed, 90 insertions(+), 80 deletions(-) diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 99a8dac3..92ad02c3 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -5,7 +5,7 @@ use std::{ hash::Hash, }; -use futures_util::{future, stream, Future, StreamExt}; +use futures_util::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use js_int::{int, Int}; use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId}; use ruma_events::{ @@ -30,7 +30,14 @@ pub use room_version::RoomVersion; pub use state_event::Event; /// A mapping of event type and state_key to some value `T`, usually an `EventId`. -pub type StateMap = HashMap<(StateEventType, String), T>; +pub type StateMap = HashMap; +pub type StateMapItem = (TypeStateKey, T); +pub type TypeStateKey = (StateEventType, String); + +/// Limit the number of asynchronous fetch requests in-flight for any given operation. This is a +/// local maximum which could be multiplied over several macro-operations, therefor the total number +/// of requests demanded from the callbacks could be far greater. +const PARALLEL_FETCHES: usize = 16; /// Resolve sets of state events as they come in. /// @@ -63,10 +70,10 @@ pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>( where Fetch: Fn(E::Id) -> FetchFut + Sync, FetchFut: Future> + Send, - Exists: Fn(E::Id) -> ExistsFut, + Exists: Fn(E::Id) -> ExistsFut + Sync, ExistsFut: Future + Send, SetIter: Iterator> + Clone + Send, - E: Event + Send, + E: Event + Clone + Send + Sync, E::Id: Borrow + Send + Sync, for<'b> &'b E: Send, { @@ -91,9 +98,11 @@ where // `all_conflicted` contains unique items // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` - let all_conflicted: HashSet = stream::iter(auth_chain_diff) + let all_conflicted: HashSet<_> = stream::iter(auth_chain_diff) // Don't honor events we cannot "verify" - .filter(|id| event_exists(id.clone())) + .map(|id| event_exists(id.clone()).map(move |exists| (id, exists))) + .buffer_unordered(PARALLEL_FETCHES) + .filter_map(|(id, exists)| future::ready(exists.then_some(id.clone()))) .collect() .await; @@ -104,10 +113,12 @@ where // this is now a check the caller of `resolve` must make. // Get only the control events with a state_key: "" or ban/kick event (sender != state_key) - let control_events = stream::iter(all_conflicted.iter()) - .filter(|&id| is_power_event_id(id, &event_fetch)) - .map(Clone::clone) - .collect::>() + let control_events: Vec<_> = stream::iter(all_conflicted.iter()) + .map(|id| is_power_event_id(id, &event_fetch).map(move |is| (id, is))) + .buffer_unordered(PARALLEL_FETCHES) + .filter_map(|(id, is)| future::ready(is.then_some(id.clone()))) + .collect() + .boxed() .await; // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges @@ -209,9 +220,9 @@ where } /// Returns a Vec of deduped EventIds that appear in some chains but not others. -fn get_auth_chain_diff(auth_chain_sets: &Vec>) -> impl Iterator +fn get_auth_chain_diff(auth_chain_sets: &Vec>) -> impl Iterator + Send where - Id: Clone + Eq + Hash, + Id: Clone + Eq + Hash + Send, { let num_sets = auth_chain_sets.len(); let mut id_counts: HashMap = HashMap::new(); @@ -246,28 +257,27 @@ where let mut graph = HashMap::new(); for event_id in events_to_sort { add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, fetch_event).await; - - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress } // This is used in the `key_fn` passed to the lexico_topo_sort fn - let mut event_to_pl = HashMap::new(); - for event_id in graph.keys() { - let pl = get_power_level_for_sender(event_id, fetch_event).await?; - debug!( - event_id = event_id.borrow().as_str(), - power_level = i64::from(pl), - "found the power level of an event's sender", - ); + let event_to_pl = stream::iter(graph.keys()) + .map(|event_id| { + get_power_level_for_sender(event_id.clone(), fetch_event) + .map(move |res| res.map(|pl| (event_id, pl))) + }) + .buffer_unordered(PARALLEL_FETCHES) + .try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| { + debug!( + event_id = event_id.borrow().as_str(), + power_level = i64::from(pl), + "found the power level of an event's sender", + ); - event_to_pl.insert(event_id.clone(), pl); - - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress - } + event_to_pl.insert(event_id.clone(), pl); + future::ok(event_to_pl) + }) + .boxed() + .await?; let event_to_pl = &event_to_pl; let fetcher = |event_id: E::Id| async move { @@ -289,7 +299,7 @@ pub async fn lexicographical_topological_sort( key_fn: &F, ) -> Result> where - F: Fn(Id) -> Fut, + F: Fn(Id) -> Fut + Sync, Fut: Future> + Send, Id: Borrow + Clone + Eq + Hash + Ord + Send, { @@ -402,11 +412,11 @@ where /// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level /// event). async fn get_power_level_for_sender( - event_id: &E::Id, + event_id: E::Id, fetch_event: &F, ) -> serde_json::Result where - F: Fn(E::Id) -> Fut, + F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Send, @@ -456,9 +466,9 @@ async fn iterative_auth_check( fetch_event: &F, ) -> Result> where - F: Fn(E::Id) -> Fut, + F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, - E: Event + Send, + E: Event + Send + Sync, E::Id: Borrow + Clone + Send, for<'a> &'a E: Send, { @@ -479,8 +489,8 @@ where let mut auth_events = StateMap::new(); for aid in event.auth_events() { if let Some(ev) = fetch_event(aid.clone()).await { - // TODO synapse check "rejected_reason" which is most likely - // related to soft-failing + //TODO: synapse checks "rejected_reason" which is most likely related to + // soft-failing auth_events.insert( ev.event_type().with_state_key(ev.state_key().ok_or_else(|| { Error::InvalidPdu("State event had no state key".to_owned()) @@ -492,19 +502,26 @@ where } } - for key in auth_types_for_event( + let auth_types = auth_types_for_event( event.event_type(), event.sender(), Some(state_key), event.content(), - )? { - if let Some(ev_id) = resolved_state.get(&key) { - if let Some(event) = fetch_event(ev_id.clone()).await { - // TODO synapse checks `rejected_reason` is None here - auth_events.insert(key.to_owned(), event); - } - } - } + )?; + + let auth_types = + auth_types.iter().filter_map(|key| Some((key, resolved_state.get(key)?))).into_iter(); + + stream::iter(auth_types) + .filter_map(|(key, ev_id)| { + fetch_event(ev_id.clone()).map(move |event| event.map(|event| (key, event))) + }) + .for_each(|(key, event)| { + //TODO: synapse checks "rejected_reason" is None here + auth_events.insert(key.to_owned(), event); + future::ready(()) + }) + .await; debug!("event to check {:?}", event.event_id()); @@ -525,11 +542,8 @@ where // synapse passes here on AuthError. We do not add this event to resolved_state. warn!("event {event_id} failed the authentication check"); } - - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress } + Ok(resolved_state) } @@ -546,10 +560,10 @@ async fn mainline_sort( fetch_event: &F, ) -> Result> where - F: Fn(E::Id) -> Fut, + F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, - E: Event + Send, - E::Id: Borrow + Clone + Send, + E: Event + Clone + Send + Sync, + E::Id: Borrow + Clone + Send + Sync, { debug!("mainline sort of events"); @@ -576,9 +590,6 @@ where break; } } - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress } let mainline_map = mainline @@ -588,25 +599,23 @@ where .map(|(idx, eid)| ((*eid).clone(), idx)) .collect::>(); - let mut order_map = HashMap::new(); - for ev_id in to_sort.iter() { - if let Some(event) = fetch_event(ev_id.clone()).await { - if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, fetch_event).await { - order_map.insert( - ev_id, - ( - depth, - fetch_event(ev_id.clone()).await.map(|ev| ev.origin_server_ts()), - ev_id, - ), - ); - } - } - - // TODO: if these functions are ever made async here - // is a good place to yield every once in a while so other - // tasks can make progress - } + let order_map = stream::iter(to_sort.into_iter()) + .map(|ev_id| fetch_event(ev_id.clone()).map(move |event| event.map(|event| (event, ev_id)))) + .buffer_unordered(PARALLEL_FETCHES) + .filter_map(|result| future::ready(result)) + .map(|(event, ev_id)| { + get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event) + .map_ok(move |depth| (depth, event, ev_id)) + .map(Result::ok) + }) + .buffer_unordered(PARALLEL_FETCHES) + .filter_map(|result| future::ready(result)) + .fold(HashMap::new(), |mut order_map, (depth, event, ev_id)| { + order_map.insert(ev_id, (depth, event.origin_server_ts(), ev_id)); + future::ready(order_map) + }) + .boxed() + .await; // Sort the event_ids by their depth, timestamp and EventId // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) @@ -624,7 +633,7 @@ async fn get_mainline_depth( fetch_event: &F, ) -> Result where - F: Fn(E::Id) -> Fut, + F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Send, @@ -665,10 +674,11 @@ async fn add_event_and_auth_chain_to_graph( let mut state = vec![event_id]; while let Some(eid) = state.pop() { graph.entry(eid.clone()).or_default(); + let event = fetch_event(eid.clone()).await; + let auth_events = event.as_ref().map(|ev| ev.auth_events()).into_iter().flatten(); + // Prefer the store to event as the store filters dedups the events - for aid in - fetch_event(eid.clone()).await.as_ref().map(|ev| ev.auth_events()).into_iter().flatten() - { + for aid in auth_events { if auth_diff.contains(aid.borrow()) { if !graph.contains_key(aid.borrow()) { state.push(aid.to_owned()); @@ -683,7 +693,7 @@ async fn add_event_and_auth_chain_to_graph( async fn is_power_event_id(event_id: &E::Id, fetch: &F) -> bool where - F: Fn(E::Id) -> Fut, + F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, E::Id: Borrow + Send,