replace constant with function parameter for io-parallelism

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-12-24 13:48:25 +00:00
parent 307186ebdc
commit d3ed3194eb
2 changed files with 70 additions and 38 deletions

View File

@ -35,11 +35,6 @@ pub type StateMap<T> = HashMap<TypeStateKey, T>;
pub type StateMapItem<T> = (TypeStateKey, T); pub type StateMapItem<T> = (TypeStateKey, T);
pub type TypeStateKey = (StateEventType, String); pub type TypeStateKey = (StateEventType, String);
/// Limit the number of asynchronous fetch requests in-flight for any given operation. This is a
/// local maximum which could be multiplied over several macro-operations, therefor the total number
/// of requests demanded from the callbacks could be far greater.
const PARALLEL_FETCHES: usize = 16;
/// Resolve sets of state events as they come in. /// Resolve sets of state events as they come in.
/// ///
/// Internally `StateResolution` builds a graph and an auth chain to allow for state conflict /// Internally `StateResolution` builds a graph and an auth chain to allow for state conflict
@ -56,6 +51,9 @@ const PARALLEL_FETCHES: usize = 16;
/// * `event_fetch` - Any event not found in the `event_map` will defer to this closure to find the /// * `event_fetch` - Any event not found in the `event_map` will defer to this closure to find the
/// event. /// event.
/// ///
/// * `parallel_fetches` - The number of asynchronous fetch requests in-flight for any given
/// operation.
///
/// ## Invariants /// ## Invariants
/// ///
/// The caller of `resolve` must ensure that all the events are from the same room. Although this /// The caller of `resolve` must ensure that all the events are from the same room. Although this
@ -67,6 +65,7 @@ pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>(
auth_chain_sets: &'a Vec<HashSet<E::Id>>, auth_chain_sets: &'a Vec<HashSet<E::Id>>,
event_fetch: &Fetch, event_fetch: &Fetch,
event_exists: &Exists, event_exists: &Exists,
parallel_fetches: usize,
) -> Result<StateMap<E::Id>> ) -> Result<StateMap<E::Id>>
where where
Fetch: Fn(E::Id) -> FetchFut + Sync, Fetch: Fn(E::Id) -> FetchFut + Sync,
@ -102,7 +101,7 @@ where
let all_conflicted: HashSet<_> = stream::iter(auth_chain_diff) let all_conflicted: HashSet<_> = stream::iter(auth_chain_diff)
// Don't honor events we cannot "verify" // Don't honor events we cannot "verify"
.map(|id| event_exists(id.clone()).map(move |exists| (id, exists))) .map(|id| event_exists(id.clone()).map(move |exists| (id, exists)))
.buffer_unordered(PARALLEL_FETCHES) .buffer_unordered(parallel_fetches)
.filter_map(|(id, exists)| future::ready(exists.then_some(id.clone()))) .filter_map(|(id, exists)| future::ready(exists.then_some(id.clone())))
.collect() .collect()
.boxed() .boxed()
@ -117,17 +116,21 @@ where
// Get only the control events with a state_key: "" or ban/kick event (sender != state_key) // Get only the control events with a state_key: "" or ban/kick event (sender != state_key)
let control_events: Vec<_> = stream::iter(all_conflicted.iter()) let control_events: Vec<_> = stream::iter(all_conflicted.iter())
.map(|id| is_power_event_id(id, &event_fetch).map(move |is| (id, is))) .map(|id| is_power_event_id(id, &event_fetch).map(move |is| (id, is)))
.buffer_unordered(PARALLEL_FETCHES) .buffer_unordered(parallel_fetches)
.filter_map(|(id, is)| future::ready(is.then_some(id.clone()))) .filter_map(|(id, is)| future::ready(is.then_some(id.clone())))
.collect() .collect()
.boxed() .boxed()
.await; .await;
// Sort the control events based on power_level/clock/event_id and outgoing/incoming edges // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges
let sorted_control_levels = let sorted_control_levels = reverse_topological_power_sort(
reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch) control_events,
.boxed() &all_conflicted,
.await?; &event_fetch,
parallel_fetches,
)
.boxed()
.await?;
debug!(count = sorted_control_levels.len(), "power events"); debug!(count = sorted_control_levels.len(), "power events");
trace!(list = ?sorted_control_levels, "sorted power events"); trace!(list = ?sorted_control_levels, "sorted power events");
@ -139,7 +142,9 @@ where
sorted_control_levels.iter(), sorted_control_levels.iter(),
clean.clone(), clean.clone(),
&event_fetch, &event_fetch,
parallel_fetches,
) )
.boxed()
.await?; .await?;
debug!(count = resolved_control.len(), "resolved power events"); debug!(count = resolved_control.len(), "resolved power events");
@ -166,7 +171,9 @@ where
debug!(event_id = ?power_event, "power event"); debug!(event_id = ?power_event, "power event");
let sorted_left_events = let sorted_left_events =
mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).boxed().await?; mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch, parallel_fetches)
.boxed()
.await?;
trace!(list = ?sorted_left_events, "events left, sorted"); trace!(list = ?sorted_left_events, "events left, sorted");
@ -175,7 +182,9 @@ where
sorted_left_events.iter(), sorted_left_events.iter(),
resolved_control, // The control events are added to the final resolved state resolved_control, // The control events are added to the final resolved state
&event_fetch, &event_fetch,
parallel_fetches,
) )
.boxed()
.await?; .await?;
// Add unconflicted state to the resolved state // Add unconflicted state to the resolved state
@ -253,6 +262,7 @@ async fn reverse_topological_power_sort<E, F, Fut>(
events_to_sort: Vec<E::Id>, events_to_sort: Vec<E::Id>,
auth_diff: &HashSet<E::Id>, auth_diff: &HashSet<E::Id>,
fetch_event: &F, fetch_event: &F,
parallel_fetches: usize,
) -> Result<Vec<E::Id>> ) -> Result<Vec<E::Id>>
where where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
@ -270,10 +280,10 @@ where
// This is used in the `key_fn` passed to the lexico_topo_sort fn // This is used in the `key_fn` passed to the lexico_topo_sort fn
let event_to_pl = stream::iter(graph.keys()) let event_to_pl = stream::iter(graph.keys())
.map(|event_id| { .map(|event_id| {
get_power_level_for_sender(event_id.clone(), fetch_event) get_power_level_for_sender(event_id.clone(), fetch_event, parallel_fetches)
.map(move |res| res.map(|pl| (event_id, pl))) .map(move |res| res.map(|pl| (event_id, pl)))
}) })
.buffer_unordered(PARALLEL_FETCHES) .buffer_unordered(parallel_fetches)
.try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| { .try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| {
debug!( debug!(
event_id = event_id.borrow().as_str(), event_id = event_id.borrow().as_str(),
@ -422,6 +432,7 @@ where
async fn get_power_level_for_sender<E, F, Fut>( async fn get_power_level_for_sender<E, F, Fut>(
event_id: E::Id, event_id: E::Id,
fetch_event: &F, fetch_event: &F,
parallel_fetches: usize,
) -> serde_json::Result<Int> ) -> serde_json::Result<Int>
where where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
@ -437,7 +448,7 @@ where
let pl = stream::iter(auth_events) let pl = stream::iter(auth_events)
.map(|aid| fetch_event(aid.clone())) .map(|aid| fetch_event(aid.clone()))
.buffer_unordered(PARALLEL_FETCHES.min(5)) .buffer_unordered(parallel_fetches.min(5))
.filter_map(|aev| future::ready(aev)) .filter_map(|aev| future::ready(aev))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.boxed() .boxed()
@ -474,6 +485,7 @@ async fn iterative_auth_check<'a, E, F, Fut, I>(
events_to_check: I, events_to_check: I,
unconflicted_state: StateMap<E::Id>, unconflicted_state: StateMap<E::Id>,
fetch_event: &F, fetch_event: &F,
parallel_fetches: usize,
) -> Result<StateMap<E::Id>> ) -> Result<StateMap<E::Id>>
where where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
@ -495,7 +507,7 @@ where
result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}"))) result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))
}) })
}) })
.try_buffer_unordered(PARALLEL_FETCHES) .try_buffer_unordered(parallel_fetches)
.try_collect() .try_collect()
.boxed() .boxed()
.await?; .await?;
@ -508,7 +520,7 @@ where
let auth_events: HashMap<E::Id, E> = stream::iter(auth_event_ids.into_iter()) let auth_events: HashMap<E::Id, E> = stream::iter(auth_event_ids.into_iter())
.map(|event_id| fetch_event(event_id)) .map(|event_id| fetch_event(event_id))
.buffer_unordered(PARALLEL_FETCHES) .buffer_unordered(parallel_fetches)
.filter_map(|result| future::ready(result)) .filter_map(|result| future::ready(result))
.map(|auth_event| (auth_event.event_id().clone(), auth_event)) .map(|auth_event| (auth_event.event_id().clone(), auth_event))
.collect() .collect()
@ -597,6 +609,7 @@ async fn mainline_sort<E, F, Fut>(
to_sort: &[E::Id], to_sort: &[E::Id],
resolved_power_level: Option<E::Id>, resolved_power_level: Option<E::Id>,
fetch_event: &F, fetch_event: &F,
parallel_fetches: usize,
) -> Result<Vec<E::Id>> ) -> Result<Vec<E::Id>>
where where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
@ -640,14 +653,14 @@ where
let order_map = stream::iter(to_sort.into_iter()) let order_map = stream::iter(to_sort.into_iter())
.map(|ev_id| fetch_event(ev_id.clone()).map(move |event| event.map(|event| (event, ev_id)))) .map(|ev_id| fetch_event(ev_id.clone()).map(move |event| event.map(|event| (event, ev_id))))
.buffer_unordered(PARALLEL_FETCHES) .buffer_unordered(parallel_fetches)
.filter_map(|result| future::ready(result)) .filter_map(|result| future::ready(result))
.map(|(event, ev_id)| { .map(|(event, ev_id)| {
get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event) get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event)
.map_ok(move |depth| (depth, event, ev_id)) .map_ok(move |depth| (depth, event, ev_id))
.map(Result::ok) .map(Result::ok)
}) })
.buffer_unordered(PARALLEL_FETCHES) .buffer_unordered(parallel_fetches)
.filter_map(|result| future::ready(result)) .filter_map(|result| future::ready(result))
.fold(HashMap::new(), |mut order_map, (depth, event, ev_id)| { .fold(HashMap::new(), |mut order_map, (depth, event, ev_id)| {
order_map.insert(ev_id, (depth, event.origin_server_ts(), ev_id)); order_map.insert(ev_id, (depth, event.origin_server_ts(), ev_id));
@ -841,7 +854,7 @@ mod tests {
let fetcher = |id| ready(events.get(&id).cloned()); let fetcher = |id| ready(events.get(&id).cloned());
let sorted_power_events = let sorted_power_events =
crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher) crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
.await .await
.unwrap(); .unwrap();
@ -850,6 +863,7 @@ mod tests {
sorted_power_events.iter(), sorted_power_events.iter(),
HashMap::new(), // unconflicted events HashMap::new(), // unconflicted events
&fetcher, &fetcher,
1,
) )
.await .await
.expect("iterative auth check failed on resolved events"); .expect("iterative auth check failed on resolved events");
@ -863,7 +877,7 @@ mod tests {
resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned(); resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned();
let sorted_event_ids = let sorted_event_ids =
crate::mainline_sort(&events_to_sort, power_level, &fetcher).await.unwrap(); crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1).await.unwrap();
assert_eq!( assert_eq!(
vec![ vec![
@ -1217,13 +1231,19 @@ mod tests {
.map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()) .map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap())
.collect(); .collect();
let resolved = let resolved = match crate::resolve(
match crate::resolve(&RoomVersionId::V2, &state_sets, &auth_chain, &fetcher, &exists) &RoomVersionId::V2,
.await &state_sets,
{ &auth_chain,
Ok(state) => state, &fetcher,
Err(e) => panic!("{e}"), &exists,
}; 1,
)
.await
{
Ok(state) => state,
Err(e) => panic!("{e}"),
};
assert_eq!(expected, resolved); assert_eq!(expected, resolved);
} }
@ -1320,13 +1340,19 @@ mod tests {
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned()); let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some()); let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
let resolved = let resolved = match crate::resolve(
match crate::resolve(&RoomVersionId::V6, &state_sets, &auth_chain, &fetcher, &exists) &RoomVersionId::V6,
.await &state_sets,
{ &auth_chain,
Ok(state) => state, &fetcher,
Err(e) => panic!("{e}"), &exists,
}; 1,
)
.await
{
Ok(state) => state,
Err(e) => panic!("{e}"),
};
debug!( debug!(
resolved = ?resolved resolved = ?resolved

View File

@ -122,9 +122,15 @@ pub(crate) async fn do_check(
let event_map = &event_map; let event_map = &event_map;
let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned()); let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some()); let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
let resolved = let resolved = crate::resolve(
crate::resolve(&RoomVersionId::V6, state_sets, &auth_chain_sets, &fetch, &exists) &RoomVersionId::V6,
.await; state_sets,
&auth_chain_sets,
&fetch,
&exists,
1,
)
.await;
match resolved { match resolved {
Ok(state) => state, Ok(state) => state,