diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 245179eb..f226c5e7 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -35,11 +35,6 @@ pub type StateMap = HashMap; pub type StateMapItem = (TypeStateKey, T); pub type TypeStateKey = (StateEventType, String); -/// Limit the number of asynchronous fetch requests in-flight for any given operation. This is a -/// local maximum which could be multiplied over several macro-operations, therefor the total number -/// of requests demanded from the callbacks could be far greater. -const PARALLEL_FETCHES: usize = 16; - /// Resolve sets of state events as they come in. /// /// Internally `StateResolution` builds a graph and an auth chain to allow for state conflict @@ -56,6 +51,9 @@ const PARALLEL_FETCHES: usize = 16; /// * `event_fetch` - Any event not found in the `event_map` will defer to this closure to find the /// event. /// +/// * `parallel_fetches` - The number of asynchronous fetch requests in-flight for any given +/// operation. +/// /// ## Invariants /// /// The caller of `resolve` must ensure that all the events are from the same room. Although this @@ -67,6 +65,7 @@ pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>( auth_chain_sets: &'a Vec>, event_fetch: &Fetch, event_exists: &Exists, + parallel_fetches: usize, ) -> Result> where Fetch: Fn(E::Id) -> FetchFut + Sync, @@ -102,7 +101,7 @@ where let all_conflicted: HashSet<_> = stream::iter(auth_chain_diff) // Don't honor events we cannot "verify" .map(|id| event_exists(id.clone()).map(move |exists| (id, exists))) - .buffer_unordered(PARALLEL_FETCHES) + .buffer_unordered(parallel_fetches) .filter_map(|(id, exists)| future::ready(exists.then_some(id.clone()))) .collect() .boxed() @@ -117,17 +116,21 @@ where // Get only the control events with a state_key: "" or ban/kick event (sender != state_key) let control_events: Vec<_> = stream::iter(all_conflicted.iter()) .map(|id| is_power_event_id(id, &event_fetch).map(move |is| (id, is))) - .buffer_unordered(PARALLEL_FETCHES) + .buffer_unordered(parallel_fetches) .filter_map(|(id, is)| future::ready(is.then_some(id.clone()))) .collect() .boxed() .await; // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges - let sorted_control_levels = - reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch) - .boxed() - .await?; + let sorted_control_levels = reverse_topological_power_sort( + control_events, + &all_conflicted, + &event_fetch, + parallel_fetches, + ) + .boxed() + .await?; debug!(count = sorted_control_levels.len(), "power events"); trace!(list = ?sorted_control_levels, "sorted power events"); @@ -139,7 +142,9 @@ where sorted_control_levels.iter(), clean.clone(), &event_fetch, + parallel_fetches, ) + .boxed() .await?; debug!(count = resolved_control.len(), "resolved power events"); @@ -166,7 +171,9 @@ where debug!(event_id = ?power_event, "power event"); let sorted_left_events = - mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).boxed().await?; + mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch, parallel_fetches) + .boxed() + .await?; trace!(list = ?sorted_left_events, "events left, sorted"); @@ -175,7 +182,9 @@ where sorted_left_events.iter(), resolved_control, // The control events are added to the final resolved state &event_fetch, + parallel_fetches, ) + .boxed() .await?; // Add unconflicted state to the resolved state @@ -253,6 +262,7 @@ async fn reverse_topological_power_sort( events_to_sort: Vec, auth_diff: &HashSet, fetch_event: &F, + parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, @@ -270,10 +280,10 @@ where // This is used in the `key_fn` passed to the lexico_topo_sort fn let event_to_pl = stream::iter(graph.keys()) .map(|event_id| { - get_power_level_for_sender(event_id.clone(), fetch_event) + get_power_level_for_sender(event_id.clone(), fetch_event, parallel_fetches) .map(move |res| res.map(|pl| (event_id, pl))) }) - .buffer_unordered(PARALLEL_FETCHES) + .buffer_unordered(parallel_fetches) .try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| { debug!( event_id = event_id.borrow().as_str(), @@ -422,6 +432,7 @@ where async fn get_power_level_for_sender( event_id: E::Id, fetch_event: &F, + parallel_fetches: usize, ) -> serde_json::Result where F: Fn(E::Id) -> Fut + Sync, @@ -437,7 +448,7 @@ where let pl = stream::iter(auth_events) .map(|aid| fetch_event(aid.clone())) - .buffer_unordered(PARALLEL_FETCHES.min(5)) + .buffer_unordered(parallel_fetches.min(5)) .filter_map(|aev| future::ready(aev)) .collect::>() .boxed() @@ -474,6 +485,7 @@ async fn iterative_auth_check<'a, E, F, Fut, I>( events_to_check: I, unconflicted_state: StateMap, fetch_event: &F, + parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, @@ -495,7 +507,7 @@ where result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}"))) }) }) - .try_buffer_unordered(PARALLEL_FETCHES) + .try_buffer_unordered(parallel_fetches) .try_collect() .boxed() .await?; @@ -508,7 +520,7 @@ where let auth_events: HashMap = stream::iter(auth_event_ids.into_iter()) .map(|event_id| fetch_event(event_id)) - .buffer_unordered(PARALLEL_FETCHES) + .buffer_unordered(parallel_fetches) .filter_map(|result| future::ready(result)) .map(|auth_event| (auth_event.event_id().clone(), auth_event)) .collect() @@ -597,6 +609,7 @@ async fn mainline_sort( to_sort: &[E::Id], resolved_power_level: Option, fetch_event: &F, + parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, @@ -640,14 +653,14 @@ where let order_map = stream::iter(to_sort.into_iter()) .map(|ev_id| fetch_event(ev_id.clone()).map(move |event| event.map(|event| (event, ev_id)))) - .buffer_unordered(PARALLEL_FETCHES) + .buffer_unordered(parallel_fetches) .filter_map(|result| future::ready(result)) .map(|(event, ev_id)| { get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event) .map_ok(move |depth| (depth, event, ev_id)) .map(Result::ok) }) - .buffer_unordered(PARALLEL_FETCHES) + .buffer_unordered(parallel_fetches) .filter_map(|result| future::ready(result)) .fold(HashMap::new(), |mut order_map, (depth, event, ev_id)| { order_map.insert(ev_id, (depth, event.origin_server_ts(), ev_id)); @@ -841,7 +854,7 @@ mod tests { let fetcher = |id| ready(events.get(&id).cloned()); let sorted_power_events = - crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher) + crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1) .await .unwrap(); @@ -850,6 +863,7 @@ mod tests { sorted_power_events.iter(), HashMap::new(), // unconflicted events &fetcher, + 1, ) .await .expect("iterative auth check failed on resolved events"); @@ -863,7 +877,7 @@ mod tests { resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned(); let sorted_event_ids = - crate::mainline_sort(&events_to_sort, power_level, &fetcher).await.unwrap(); + crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1).await.unwrap(); assert_eq!( vec![ @@ -1217,13 +1231,19 @@ mod tests { .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()) .collect(); - let resolved = - match crate::resolve(&RoomVersionId::V2, &state_sets, &auth_chain, &fetcher, &exists) - .await - { - Ok(state) => state, - Err(e) => panic!("{e}"), - }; + let resolved = match crate::resolve( + &RoomVersionId::V2, + &state_sets, + &auth_chain, + &fetcher, + &exists, + 1, + ) + .await + { + Ok(state) => state, + Err(e) => panic!("{e}"), + }; assert_eq!(expected, resolved); } @@ -1320,13 +1340,19 @@ mod tests { let fetcher = |id: ::Id| ready(ev_map.get(&id).cloned()); let exists = |id: ::Id| ready(ev_map.get(&id).is_some()); - let resolved = - match crate::resolve(&RoomVersionId::V6, &state_sets, &auth_chain, &fetcher, &exists) - .await - { - Ok(state) => state, - Err(e) => panic!("{e}"), - }; + let resolved = match crate::resolve( + &RoomVersionId::V6, + &state_sets, + &auth_chain, + &fetcher, + &exists, + 1, + ) + .await + { + Ok(state) => state, + Err(e) => panic!("{e}"), + }; debug!( resolved = ?resolved diff --git a/crates/ruma-state-res/src/test_utils.rs b/crates/ruma-state-res/src/test_utils.rs index 8a8e00a0..5ce7b6cc 100644 --- a/crates/ruma-state-res/src/test_utils.rs +++ b/crates/ruma-state-res/src/test_utils.rs @@ -122,9 +122,15 @@ pub(crate) async fn do_check( let event_map = &event_map; let fetch = |id: ::Id| ready(event_map.get(&id).cloned()); let exists = |id: ::Id| ready(event_map.get(&id).is_some()); - let resolved = - crate::resolve(&RoomVersionId::V6, state_sets, &auth_chain_sets, &fetch, &exists) - .await; + let resolved = crate::resolve( + &RoomVersionId::V6, + state_sets, + &auth_chain_sets, + &fetch, + &exists, + 1, + ) + .await; match resolved { Ok(state) => state,