optimize IO for iterative_auth_check and get_power_level_for_sender

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-12-24 13:36:45 +00:00
parent 9bdc048cdb
commit 307186ebdc

View File

@ -2,6 +2,7 @@ use std::{
borrow::Borrow, borrow::Borrow,
cmp::{Ordering, Reverse}, cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet}, collections::{BinaryHeap, HashMap, HashSet},
fmt::Debug,
hash::Hash, hash::Hash,
}; };
@ -104,6 +105,7 @@ where
.buffer_unordered(PARALLEL_FETCHES) .buffer_unordered(PARALLEL_FETCHES)
.filter_map(|(id, exists)| future::ready(exists.then_some(id.clone()))) .filter_map(|(id, exists)| future::ready(exists.then_some(id.clone())))
.collect() .collect()
.boxed()
.await; .await;
debug!(count = all_conflicted.len(), "full conflicted set"); debug!(count = all_conflicted.len(), "full conflicted set");
@ -123,16 +125,22 @@ where
// Sort the control events based on power_level/clock/event_id and outgoing/incoming edges // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges
let sorted_control_levels = let sorted_control_levels =
reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch).await?; reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch)
.boxed()
.await?;
debug!(count = sorted_control_levels.len(), "power events"); debug!(count = sorted_control_levels.len(), "power events");
trace!(list = ?sorted_control_levels, "sorted power events"); trace!(list = ?sorted_control_levels, "sorted power events");
let room_version = RoomVersion::new(room_version)?; let room_version = RoomVersion::new(room_version)?;
// Sequentially auth check each control event. // Sequentially auth check each control event.
let resolved_control = let resolved_control = iterative_auth_check(
iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &event_fetch) &room_version,
.await?; sorted_control_levels.iter(),
clean.clone(),
&event_fetch,
)
.await?;
debug!(count = resolved_control.len(), "resolved power events"); debug!(count = resolved_control.len(), "resolved power events");
trace!(map = ?resolved_control, "resolved power events"); trace!(map = ?resolved_control, "resolved power events");
@ -158,13 +166,13 @@ where
debug!(event_id = ?power_event, "power event"); debug!(event_id = ?power_event, "power event");
let sorted_left_events = let sorted_left_events =
mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?; mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).boxed().await?;
trace!(list = ?sorted_left_events, "events left, sorted"); trace!(list = ?sorted_left_events, "events left, sorted");
let mut resolved_state = iterative_auth_check( let mut resolved_state = iterative_auth_check(
&room_version, &room_version,
&sorted_left_events, sorted_left_events.iter(),
resolved_control, // The control events are added to the final resolved state resolved_control, // The control events are added to the final resolved state
&event_fetch, &event_fetch,
) )
@ -424,16 +432,18 @@ where
debug!("fetch event ({event_id}) senders power level"); debug!("fetch event ({event_id}) senders power level");
let event = fetch_event(event_id.clone()).await; let event = fetch_event(event_id.clone()).await;
let mut pl = None;
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() { let auth_events = event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten();
if let Some(aev) = fetch_event(aid.clone()).await {
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { let pl = stream::iter(auth_events)
pl = Some(aev); .map(|aid| fetch_event(aid.clone()))
break; .buffer_unordered(PARALLEL_FETCHES.min(5))
} .filter_map(|aev| future::ready(aev))
} .collect::<Vec<_>>()
} .boxed()
.await
.into_iter()
.find(|aev| is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, ""));
let content: PowerLevelsContentFields = match pl { let content: PowerLevelsContentFields = match pl {
None => return Ok(int!(0)), None => return Ok(int!(0)),
@ -459,49 +469,60 @@ where
/// ///
/// For each `events_to_check` event we gather the events needed to auth it from the the /// For each `events_to_check` event we gather the events needed to auth it from the the
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
async fn iterative_auth_check<E, F, Fut>( async fn iterative_auth_check<'a, E, F, Fut, I>(
room_version: &RoomVersion, room_version: &RoomVersion,
events_to_check: &[E::Id], events_to_check: I,
unconflicted_state: StateMap<E::Id>, unconflicted_state: StateMap<E::Id>,
fetch_event: &F, fetch_event: &F,
) -> Result<StateMap<E::Id>> ) -> Result<StateMap<E::Id>>
where where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send, Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync, E::Id: Borrow<EventId> + Clone + Eq + Ord + Send + Sync + 'a,
E::Id: Borrow<EventId> + Clone + Send, I: Iterator<Item = &'a E::Id> + Debug + Send + 'a,
for<'a> &'a E: Send, E: Event + Clone + Send + Sync,
{ {
debug!("starting iterative auth check"); debug!("starting iterative auth check");
trace!(
list = ?events_to_check,
"events to check"
);
trace!(list = ?events_to_check, "events to check"); let events_to_check: Vec<_> = stream::iter(events_to_check)
.map(Result::Ok)
.map_ok(|event_id| {
fetch_event(event_id.clone()).map(move |result| {
result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))
})
})
.try_buffer_unordered(PARALLEL_FETCHES)
.try_collect()
.boxed()
.await?;
let auth_event_ids: HashSet<E::Id> = events_to_check
.iter()
.map(|event: &E| event.auth_events().map(Clone::clone))
.flatten()
.collect();
let auth_events: HashMap<E::Id, E> = stream::iter(auth_event_ids.into_iter())
.map(|event_id| fetch_event(event_id))
.buffer_unordered(PARALLEL_FETCHES)
.filter_map(|result| future::ready(result))
.map(|auth_event| (auth_event.event_id().clone(), auth_event))
.collect()
.boxed()
.await;
let auth_events = &auth_events;
let mut resolved_state = unconflicted_state; let mut resolved_state = unconflicted_state;
for event in events_to_check.iter() {
for event_id in events_to_check { let event_id = event.event_id();
let event = fetch_event(event_id.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?;
let state_key = event let state_key = event
.state_key() .state_key()
.ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?;
let mut auth_events = StateMap::new();
for aid in event.auth_events() {
if let Some(ev) = fetch_event(aid.clone()).await {
//TODO: synapse checks "rejected_reason" which is most likely related to
// soft-failing
auth_events.insert(
ev.event_type().with_state_key(ev.state_key().ok_or_else(|| {
Error::InvalidPdu("State event had no state key".to_owned())
})?),
ev,
);
} else {
warn!(event_id = aid.borrow().as_str(), "missing auth event");
}
}
let auth_types = auth_types_for_event( let auth_types = auth_types_for_event(
event.event_type(), event.event_type(),
event.sender(), event.sender(),
@ -509,33 +530,51 @@ where
event.content(), event.content(),
)?; )?;
let auth_types = let mut auth_state = StateMap::new();
auth_types.iter().filter_map(|key| Some((key, resolved_state.get(key)?))).into_iter(); for aid in event.auth_events() {
if let Some(&ref ev) = auth_events.get(aid.borrow()) {
//TODO: synapse checks "rejected_reason" which is most likely related to
// soft-failing
auth_state.insert(
ev.event_type().with_state_key(ev.state_key().ok_or_else(|| {
Error::InvalidPdu("State event had no state key".to_owned())
})?),
ev.clone(),
);
} else {
warn!(event_id = aid.borrow().as_str(), "missing auth event");
}
}
stream::iter(auth_types) stream::iter(
.filter_map(|(key, ev_id)| { auth_types.iter().filter_map(|key| Some((key, resolved_state.get(key)?))).into_iter(),
fetch_event(ev_id.clone()).map(move |event| event.map(|event| (key, event))) )
}) .filter_map(|(key, ev_id)| async move {
.for_each(|(key, event)| { if let Some(event) = auth_events.get(ev_id.borrow()) {
//TODO: synapse checks "rejected_reason" is None here Some((key, event.clone()))
auth_events.insert(key.to_owned(), event); } else {
future::ready(()) Some((key, fetch_event(ev_id.clone()).await?.clone()))
}) }
.await; })
.for_each(|(key, event)| {
//TODO: synapse checks "rejected_reason" is None here
auth_state.insert(key.to_owned(), event);
future::ready(())
})
.await;
debug!("event to check {:?}", event.event_id()); debug!("event to check {:?}", event.event_id());
// The key for this is (eventType + a state_key of the signed token not sender) so // The key for this is (eventType + a state_key of the signed token not sender) so
// search for it // search for it
let current_third_party = auth_events.iter().find_map(|(_, pdu)| { let current_third_party = auth_state.iter().find_map(|(_, pdu)| {
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu) (*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
}); });
let fetch_state = |ty: &StateEventType, key: &str| { let fetch_state =
future::ready(auth_events.get(&ty.with_state_key(key))) |ty: &StateEventType, key: &str| future::ready(auth_state.get(&ty.with_state_key(key)));
};
if auth_check(room_version, &event, current_third_party, fetch_state).await? { if auth_check(room_version, &event, current_third_party.as_ref(), fetch_state).await? {
// add event to resolved state map // add event to resolved state map
resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone());
} else { } else {
@ -808,7 +847,7 @@ mod tests {
let resolved_power = crate::iterative_auth_check( let resolved_power = crate::iterative_auth_check(
&RoomVersion::V6, &RoomVersion::V6,
&sorted_power_events, sorted_power_events.iter(),
HashMap::new(), // unconflicted events HashMap::new(), // unconflicted events
&fetcher, &fetcher,
) )