diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 92ad02c3..245179eb 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -2,6 +2,7 @@ use std::{ borrow::Borrow, cmp::{Ordering, Reverse}, collections::{BinaryHeap, HashMap, HashSet}, + fmt::Debug, hash::Hash, }; @@ -104,6 +105,7 @@ where .buffer_unordered(PARALLEL_FETCHES) .filter_map(|(id, exists)| future::ready(exists.then_some(id.clone()))) .collect() + .boxed() .await; debug!(count = all_conflicted.len(), "full conflicted set"); @@ -123,16 +125,22 @@ where // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges let sorted_control_levels = - reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch).await?; + reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch) + .boxed() + .await?; debug!(count = sorted_control_levels.len(), "power events"); trace!(list = ?sorted_control_levels, "sorted power events"); let room_version = RoomVersion::new(room_version)?; // Sequentially auth check each control event. - let resolved_control = - iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &event_fetch) - .await?; + let resolved_control = iterative_auth_check( + &room_version, + sorted_control_levels.iter(), + clean.clone(), + &event_fetch, + ) + .await?; debug!(count = resolved_control.len(), "resolved power events"); trace!(map = ?resolved_control, "resolved power events"); @@ -158,13 +166,13 @@ where debug!(event_id = ?power_event, "power event"); let sorted_left_events = - mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?; + mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).boxed().await?; trace!(list = ?sorted_left_events, "events left, sorted"); let mut resolved_state = iterative_auth_check( &room_version, - &sorted_left_events, + sorted_left_events.iter(), resolved_control, // The control events are added to the final resolved state &event_fetch, ) @@ -424,16 +432,18 @@ where debug!("fetch event ({event_id}) senders power level"); let event = fetch_event(event_id.clone()).await; - let mut pl = None; - for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() { - if let Some(aev) = fetch_event(aid.clone()).await { - if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { - pl = Some(aev); - break; - } - } - } + let auth_events = event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten(); + + let pl = stream::iter(auth_events) + .map(|aid| fetch_event(aid.clone())) + .buffer_unordered(PARALLEL_FETCHES.min(5)) + .filter_map(|aev| future::ready(aev)) + .collect::>() + .boxed() + .await + .into_iter() + .find(|aev| is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "")); let content: PowerLevelsContentFields = match pl { None => return Ok(int!(0)), @@ -459,49 +469,60 @@ where /// /// For each `events_to_check` event we gather the events needed to auth it from the the /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. -async fn iterative_auth_check( +async fn iterative_auth_check<'a, E, F, Fut, I>( room_version: &RoomVersion, - events_to_check: &[E::Id], + events_to_check: I, unconflicted_state: StateMap, fetch_event: &F, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, - E: Event + Send + Sync, - E::Id: Borrow + Clone + Send, - for<'a> &'a E: Send, + E::Id: Borrow + Clone + Eq + Ord + Send + Sync + 'a, + I: Iterator + Debug + Send + 'a, + E: Event + Clone + Send + Sync, { debug!("starting iterative auth check"); + trace!( + list = ?events_to_check, + "events to check" + ); - trace!(list = ?events_to_check, "events to check"); + let events_to_check: Vec<_> = stream::iter(events_to_check) + .map(Result::Ok) + .map_ok(|event_id| { + fetch_event(event_id.clone()).map(move |result| { + result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}"))) + }) + }) + .try_buffer_unordered(PARALLEL_FETCHES) + .try_collect() + .boxed() + .await?; + let auth_event_ids: HashSet = events_to_check + .iter() + .map(|event: &E| event.auth_events().map(Clone::clone)) + .flatten() + .collect(); + + let auth_events: HashMap = stream::iter(auth_event_ids.into_iter()) + .map(|event_id| fetch_event(event_id)) + .buffer_unordered(PARALLEL_FETCHES) + .filter_map(|result| future::ready(result)) + .map(|auth_event| (auth_event.event_id().clone(), auth_event)) + .collect() + .boxed() + .await; + + let auth_events = &auth_events; let mut resolved_state = unconflicted_state; - - for event_id in events_to_check { - let event = fetch_event(event_id.clone()) - .await - .ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?; + for event in events_to_check.iter() { + let event_id = event.event_id(); let state_key = event .state_key() .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; - let mut auth_events = StateMap::new(); - for aid in event.auth_events() { - if let Some(ev) = fetch_event(aid.clone()).await { - //TODO: synapse checks "rejected_reason" which is most likely related to - // soft-failing - auth_events.insert( - ev.event_type().with_state_key(ev.state_key().ok_or_else(|| { - Error::InvalidPdu("State event had no state key".to_owned()) - })?), - ev, - ); - } else { - warn!(event_id = aid.borrow().as_str(), "missing auth event"); - } - } - let auth_types = auth_types_for_event( event.event_type(), event.sender(), @@ -509,33 +530,51 @@ where event.content(), )?; - let auth_types = - auth_types.iter().filter_map(|key| Some((key, resolved_state.get(key)?))).into_iter(); + let mut auth_state = StateMap::new(); + for aid in event.auth_events() { + if let Some(&ref ev) = auth_events.get(aid.borrow()) { + //TODO: synapse checks "rejected_reason" which is most likely related to + // soft-failing + auth_state.insert( + ev.event_type().with_state_key(ev.state_key().ok_or_else(|| { + Error::InvalidPdu("State event had no state key".to_owned()) + })?), + ev.clone(), + ); + } else { + warn!(event_id = aid.borrow().as_str(), "missing auth event"); + } + } - stream::iter(auth_types) - .filter_map(|(key, ev_id)| { - fetch_event(ev_id.clone()).map(move |event| event.map(|event| (key, event))) - }) - .for_each(|(key, event)| { - //TODO: synapse checks "rejected_reason" is None here - auth_events.insert(key.to_owned(), event); - future::ready(()) - }) - .await; + stream::iter( + auth_types.iter().filter_map(|key| Some((key, resolved_state.get(key)?))).into_iter(), + ) + .filter_map(|(key, ev_id)| async move { + if let Some(event) = auth_events.get(ev_id.borrow()) { + Some((key, event.clone())) + } else { + Some((key, fetch_event(ev_id.clone()).await?.clone())) + } + }) + .for_each(|(key, event)| { + //TODO: synapse checks "rejected_reason" is None here + auth_state.insert(key.to_owned(), event); + future::ready(()) + }) + .await; debug!("event to check {:?}", event.event_id()); // The key for this is (eventType + a state_key of the signed token not sender) so // search for it - let current_third_party = auth_events.iter().find_map(|(_, pdu)| { + let current_third_party = auth_state.iter().find_map(|(_, pdu)| { (*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu) }); - let fetch_state = |ty: &StateEventType, key: &str| { - future::ready(auth_events.get(&ty.with_state_key(key))) - }; + let fetch_state = + |ty: &StateEventType, key: &str| future::ready(auth_state.get(&ty.with_state_key(key))); - if auth_check(room_version, &event, current_third_party, fetch_state).await? { + if auth_check(room_version, &event, current_third_party.as_ref(), fetch_state).await? { // add event to resolved state map resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); } else { @@ -808,7 +847,7 @@ mod tests { let resolved_power = crate::iterative_auth_check( &RoomVersion::V6, - &sorted_power_events, + sorted_power_events.iter(), HashMap::new(), // unconflicted events &fetcher, )