async state-res
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
1d0b06b581
commit
e7db44989d
@ -18,9 +18,10 @@ all-features = true
|
||||
unstable-exhaustive-types = []
|
||||
|
||||
[dependencies]
|
||||
futures-util = "0.3"
|
||||
itertools = "0.12.1"
|
||||
js_int = { workspace = true }
|
||||
ruma-common = { workspace = true }
|
||||
ruma-common = { workspace = true, features = ["api"] }
|
||||
ruma-events = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -34,6 +35,7 @@ criterion = { workspace = true, optional = true }
|
||||
maplit = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
ruma-events = { workspace = true, features = ["unstable-pdu"] }
|
||||
tokio = { version = "1", features = ["rt", "macros"] }
|
||||
tracing-subscriber = "0.3.16"
|
||||
|
||||
[[bench]]
|
||||
|
@ -1,5 +1,6 @@
|
||||
use std::{borrow::Borrow, collections::BTreeSet};
|
||||
|
||||
use futures_util::Future;
|
||||
use js_int::{int, Int};
|
||||
use ruma_common::{
|
||||
serde::{Base64, Raw},
|
||||
@ -121,12 +122,18 @@ pub fn auth_types_for_event(
|
||||
///
|
||||
/// The `fetch_state` closure should gather state from a state snapshot. We need to know if the
|
||||
/// event passes auth against some state not a recursive collection of auth_events fields.
|
||||
pub fn auth_check<E: Event>(
|
||||
pub async fn auth_check<F, Fut, Fetched, Incoming>(
|
||||
room_version: &RoomVersion,
|
||||
incoming_event: impl Event,
|
||||
current_third_party_invite: Option<impl Event>,
|
||||
fetch_state: impl Fn(&StateEventType, &str) -> Option<E>,
|
||||
) -> Result<bool> {
|
||||
incoming_event: &Incoming,
|
||||
current_third_party_invite: Option<&Incoming>,
|
||||
fetch_state: F,
|
||||
) -> Result<bool>
|
||||
where
|
||||
F: Fn(&'static StateEventType, &str) -> Fut,
|
||||
Fut: Future<Output = Option<Fetched>> + Send,
|
||||
Fetched: Event + Send,
|
||||
Incoming: Event + Send,
|
||||
{
|
||||
debug!(
|
||||
"auth_check beginning for {} ({})",
|
||||
incoming_event.event_id(),
|
||||
@ -216,7 +223,7 @@ pub fn auth_check<E: Event>(
|
||||
}
|
||||
*/
|
||||
|
||||
let room_create_event = match fetch_state(&StateEventType::RoomCreate, "") {
|
||||
let room_create_event = match fetch_state(&StateEventType::RoomCreate, "").await {
|
||||
None => {
|
||||
warn!("no m.room.create event in auth chain");
|
||||
return Ok(false);
|
||||
@ -265,8 +272,8 @@ pub fn auth_check<E: Event>(
|
||||
}
|
||||
|
||||
// If type is m.room.member
|
||||
let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, "");
|
||||
let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str());
|
||||
let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, "").await;
|
||||
let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str()).await;
|
||||
|
||||
if *incoming_event.event_type() == TimelineEventType::RoomMember {
|
||||
debug!("starting m.room.member check");
|
||||
@ -290,9 +297,13 @@ pub fn auth_check<E: Event>(
|
||||
let user_for_join_auth =
|
||||
content.join_authorised_via_users_server.as_ref().and_then(|u| u.deserialize().ok());
|
||||
|
||||
let user_for_join_auth_membership = user_for_join_auth
|
||||
.as_ref()
|
||||
.and_then(|auth_user| fetch_state(&StateEventType::RoomMember, auth_user.as_str()))
|
||||
let user_for_join_auth_event = if let Some(auth_user) = user_for_join_auth.as_ref() {
|
||||
fetch_state(&StateEventType::RoomMember, auth_user.as_str()).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let user_for_join_auth_membership = user_for_join_auth_event
|
||||
.and_then(|mem| from_json_str::<GetMembership>(mem.content().get()).ok())
|
||||
.map(|mem| mem.membership)
|
||||
.unwrap_or(MembershipState::Leave);
|
||||
@ -300,13 +311,13 @@ pub fn auth_check<E: Event>(
|
||||
if !valid_membership_change(
|
||||
room_version,
|
||||
target_user,
|
||||
fetch_state(&StateEventType::RoomMember, target_user.as_str()).as_ref(),
|
||||
fetch_state(&StateEventType::RoomMember, target_user.as_str()).await.as_ref(),
|
||||
sender,
|
||||
sender_member_event.as_ref(),
|
||||
&incoming_event,
|
||||
current_third_party_invite,
|
||||
power_levels_event.as_ref(),
|
||||
fetch_state(&StateEventType::RoomJoinRules, "").as_ref(),
|
||||
fetch_state(&StateEventType::RoomJoinRules, "").await.as_ref(),
|
||||
user_for_join_auth.as_deref(),
|
||||
&user_for_join_auth_membership,
|
||||
room_create_event,
|
||||
|
@ -5,6 +5,7 @@ use std::{
|
||||
hash::Hash,
|
||||
};
|
||||
|
||||
use futures_util::{future, stream, Future, StreamExt};
|
||||
use itertools::Itertools;
|
||||
use js_int::{int, Int};
|
||||
use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId};
|
||||
@ -52,16 +53,22 @@ pub type StateMap<T> = HashMap<(StateEventType, String), T>;
|
||||
///
|
||||
/// The caller of `resolve` must ensure that all the events are from the same room. Although this
|
||||
/// function takes a `RoomId` it does not check that each event is part of the same room.
|
||||
pub fn resolve<'a, E, SetIter>(
|
||||
pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>(
|
||||
room_version: &RoomVersionId,
|
||||
state_sets: impl IntoIterator<IntoIter = SetIter>,
|
||||
auth_chain_sets: Vec<HashSet<E::Id>>,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
state_sets: impl IntoIterator<IntoIter = SetIter> + Send,
|
||||
auth_chain_sets: &'a Vec<HashSet<E::Id>>,
|
||||
event_fetch: &Fetch,
|
||||
event_exists: &Exists,
|
||||
) -> Result<StateMap<E::Id>>
|
||||
where
|
||||
E: Event + Clone,
|
||||
E::Id: 'a,
|
||||
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
|
||||
Fetch: Fn(E::Id) -> FetchFut + Sync,
|
||||
FetchFut: Future<Output = Option<E>> + Send,
|
||||
Exists: Fn(E::Id) -> ExistsFut,
|
||||
ExistsFut: Future<Output = bool> + Send,
|
||||
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Send + Sync,
|
||||
for<'b> &'b E: Send,
|
||||
{
|
||||
debug!("State resolution starting");
|
||||
|
||||
@ -79,13 +86,16 @@ where
|
||||
debug!("conflicting events: {}", conflicting.len());
|
||||
debug!("{conflicting:?}");
|
||||
|
||||
let auth_chain_diff =
|
||||
get_auth_chain_diff(&auth_chain_sets).chain(conflicting.into_values().flatten());
|
||||
|
||||
// `all_conflicted` contains unique items
|
||||
// synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}`
|
||||
let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets)
|
||||
.chain(conflicting.into_values().flatten())
|
||||
let all_conflicted: HashSet<E::Id> = stream::iter(auth_chain_diff)
|
||||
// Don't honor events we cannot "verify"
|
||||
.filter(|id| fetch_event(id.borrow()).is_some())
|
||||
.collect();
|
||||
.filter(|id| event_exists(id.clone()))
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
debug!("full conflicted set: {}", all_conflicted.len());
|
||||
debug!("{all_conflicted:?}");
|
||||
@ -94,15 +104,15 @@ where
|
||||
// this is now a check the caller of `resolve` must make.
|
||||
|
||||
// Get only the control events with a state_key: "" or ban/kick event (sender != state_key)
|
||||
let control_events = all_conflicted
|
||||
.iter()
|
||||
.filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let control_events = stream::iter(all_conflicted.iter())
|
||||
.filter(|&id| is_power_event_id(id, &event_fetch))
|
||||
.map(Clone::clone)
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
// Sort the control events based on power_level/clock/event_id and outgoing/incoming edges
|
||||
let sorted_control_levels =
|
||||
reverse_topological_power_sort(control_events, &all_conflicted, &fetch_event)?;
|
||||
reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch).await?;
|
||||
|
||||
debug!("sorted control events: {}", sorted_control_levels.len());
|
||||
trace!("{sorted_control_levels:?}");
|
||||
@ -110,7 +120,8 @@ where
|
||||
let room_version = RoomVersion::new(room_version)?;
|
||||
// Sequentially auth check each control event.
|
||||
let resolved_control =
|
||||
iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &fetch_event)?;
|
||||
iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &event_fetch)
|
||||
.await?;
|
||||
|
||||
debug!("resolved control events: {}", resolved_control.len());
|
||||
trace!("{resolved_control:?}");
|
||||
@ -135,7 +146,8 @@ where
|
||||
|
||||
debug!("power event: {power_event:?}");
|
||||
|
||||
let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?;
|
||||
let sorted_left_events =
|
||||
mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?;
|
||||
|
||||
trace!("events left, sorted: {sorted_left_events:?}");
|
||||
|
||||
@ -143,8 +155,9 @@ where
|
||||
&room_version,
|
||||
&sorted_left_events,
|
||||
resolved_control, // The control events are added to the final resolved state
|
||||
&fetch_event,
|
||||
)?;
|
||||
&event_fetch,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Add unconflicted state to the resolved state
|
||||
// We priorities the unconflicting state
|
||||
@ -188,15 +201,14 @@ where
|
||||
}
|
||||
|
||||
/// Returns a Vec of deduped EventIds that appear in some chains but not others.
|
||||
fn get_auth_chain_diff<Id>(auth_chain_sets: Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
|
||||
fn get_auth_chain_diff<Id>(auth_chain_sets: &Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
|
||||
where
|
||||
Id: Eq + Hash,
|
||||
Id: Clone + Eq + Hash,
|
||||
{
|
||||
let num_sets = auth_chain_sets.len();
|
||||
|
||||
let mut id_counts: HashMap<Id, usize> = HashMap::new();
|
||||
for id in auth_chain_sets.into_iter().flatten() {
|
||||
*id_counts.entry(id).or_default() += 1;
|
||||
*id_counts.entry(id.clone()).or_default() += 1;
|
||||
}
|
||||
|
||||
id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
|
||||
@ -209,16 +221,22 @@ where
|
||||
///
|
||||
/// The power level is negative because a higher power level is equated to an earlier (further back
|
||||
/// in time) origin server timestamp.
|
||||
fn reverse_topological_power_sort<E: Event>(
|
||||
async fn reverse_topological_power_sort<E, F, Fut>(
|
||||
events_to_sort: Vec<E::Id>,
|
||||
auth_diff: &HashSet<E::Id>,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<Vec<E::Id>> {
|
||||
fetch_event: &F,
|
||||
) -> Result<Vec<E::Id>>
|
||||
where
|
||||
F: Fn(E::Id) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Send + Sync,
|
||||
{
|
||||
debug!("reverse topological sort of power events");
|
||||
|
||||
let mut graph = HashMap::new();
|
||||
for event_id in events_to_sort {
|
||||
add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event);
|
||||
add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, fetch_event).await;
|
||||
|
||||
// TODO: if these functions are ever made async here
|
||||
// is a good place to yield every once in a while so other
|
||||
@ -228,7 +246,7 @@ fn reverse_topological_power_sort<E: Event>(
|
||||
// This is used in the `key_fn` passed to the lexico_topo_sort fn
|
||||
let mut event_to_pl = HashMap::new();
|
||||
for event_id in graph.keys() {
|
||||
let pl = get_power_level_for_sender(event_id.borrow(), &fetch_event)?;
|
||||
let pl = get_power_level_for_sender(event_id, fetch_event).await?;
|
||||
debug!("{event_id} power level {pl}");
|
||||
|
||||
event_to_pl.insert(event_id.clone(), pl);
|
||||
@ -238,26 +256,30 @@ fn reverse_topological_power_sort<E: Event>(
|
||||
// tasks can make progress
|
||||
}
|
||||
|
||||
lexicographical_topological_sort(&graph, |event_id| {
|
||||
let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?;
|
||||
let pl = *event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?;
|
||||
let event_to_pl = &event_to_pl;
|
||||
let fetcher = |event_id: E::Id| async move {
|
||||
let pl = *event_to_pl.get(event_id.borrow()).ok_or_else(|| Error::NotFound("".into()))?;
|
||||
let ev = fetch_event(event_id).await.ok_or_else(|| Error::NotFound("".into()))?;
|
||||
Ok((pl, ev.origin_server_ts()))
|
||||
})
|
||||
};
|
||||
|
||||
lexicographical_topological_sort(&graph, &fetcher).await
|
||||
}
|
||||
|
||||
/// Sorts the event graph based on number of outgoing/incoming edges.
|
||||
///
|
||||
/// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together
|
||||
/// with the event ID).
|
||||
pub fn lexicographical_topological_sort<Id, F>(
|
||||
pub async fn lexicographical_topological_sort<Id, F, Fut>(
|
||||
graph: &HashMap<Id, HashSet<Id>>,
|
||||
key_fn: F,
|
||||
key_fn: &F,
|
||||
) -> Result<Vec<Id>>
|
||||
where
|
||||
F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>,
|
||||
Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
|
||||
F: Fn(Id) -> Fut,
|
||||
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
|
||||
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send,
|
||||
{
|
||||
#[derive(PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(Eq, Ord, PartialEq, PartialOrd)]
|
||||
struct TieBreaker<'a, Id> {
|
||||
inv_power_level: Int,
|
||||
age: MilliSecondsSinceUnixEpoch,
|
||||
@ -285,7 +307,7 @@ where
|
||||
|
||||
for (node, edges) in graph {
|
||||
if edges.is_empty() {
|
||||
let (power_level, age) = key_fn(node.borrow())?;
|
||||
let (power_level, age) = key_fn(node.clone()).await?;
|
||||
// The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need
|
||||
// smallest -> largest
|
||||
zero_outdegree.push(Reverse(TieBreaker {
|
||||
@ -318,7 +340,7 @@ where
|
||||
// Only push on the heap once older events have been cleared
|
||||
out.remove(node.borrow());
|
||||
if out.is_empty() {
|
||||
let (power_level, age) = key_fn(node.borrow())?;
|
||||
let (power_level, age) = key_fn(node.clone()).await?;
|
||||
heap.push(Reverse(TieBreaker {
|
||||
inv_power_level: -power_level,
|
||||
age,
|
||||
@ -339,17 +361,23 @@ where
|
||||
/// Do NOT use this any where but topological sort, we find the power level for the eventId
|
||||
/// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level
|
||||
/// event).
|
||||
fn get_power_level_for_sender<E: Event>(
|
||||
event_id: &EventId,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> serde_json::Result<Int> {
|
||||
async fn get_power_level_for_sender<E, F, Fut>(
|
||||
event_id: &E::Id,
|
||||
fetch_event: &F,
|
||||
) -> serde_json::Result<Int>
|
||||
where
|
||||
F: Fn(E::Id) -> Fut,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Send,
|
||||
{
|
||||
debug!("fetch event ({event_id}) senders power level");
|
||||
|
||||
let event = fetch_event(event_id);
|
||||
let event = fetch_event(event_id.clone()).await;
|
||||
let mut pl = None;
|
||||
|
||||
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
|
||||
if let Some(aev) = fetch_event(aid.borrow()) {
|
||||
if let Some(aev) = fetch_event(aid.clone()).await {
|
||||
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
|
||||
pl = Some(aev);
|
||||
break;
|
||||
@ -381,12 +409,19 @@ fn get_power_level_for_sender<E: Event>(
|
||||
///
|
||||
/// For each `events_to_check` event we gather the events needed to auth it from the the
|
||||
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
|
||||
fn iterative_auth_check<E: Event + Clone>(
|
||||
async fn iterative_auth_check<E, F, Fut>(
|
||||
room_version: &RoomVersion,
|
||||
events_to_check: &[E::Id],
|
||||
unconflicted_state: StateMap<E::Id>,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<StateMap<E::Id>> {
|
||||
fetch_event: &F,
|
||||
) -> Result<StateMap<E::Id>>
|
||||
where
|
||||
F: Fn(E::Id) -> Fut,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Clone + Send,
|
||||
for<'a> &'a E: Send,
|
||||
{
|
||||
debug!("starting iterative auth check");
|
||||
|
||||
debug!("performing auth checks on {events_to_check:?}");
|
||||
@ -394,7 +429,8 @@ fn iterative_auth_check<E: Event + Clone>(
|
||||
let mut resolved_state = unconflicted_state;
|
||||
|
||||
for event_id in events_to_check {
|
||||
let event = fetch_event(event_id.borrow())
|
||||
let event = fetch_event(event_id.clone())
|
||||
.await
|
||||
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?;
|
||||
let state_key = event
|
||||
.state_key()
|
||||
@ -402,7 +438,7 @@ fn iterative_auth_check<E: Event + Clone>(
|
||||
|
||||
let mut auth_events = StateMap::new();
|
||||
for aid in event.auth_events() {
|
||||
if let Some(ev) = fetch_event(aid.borrow()) {
|
||||
if let Some(ev) = fetch_event(aid.clone()).await {
|
||||
// TODO synapse check "rejected_reason" which is most likely
|
||||
// related to soft-failing
|
||||
auth_events.insert(
|
||||
@ -423,7 +459,7 @@ fn iterative_auth_check<E: Event + Clone>(
|
||||
event.content(),
|
||||
)? {
|
||||
if let Some(ev_id) = resolved_state.get(&key) {
|
||||
if let Some(event) = fetch_event(ev_id.borrow()) {
|
||||
if let Some(event) = fetch_event(ev_id.clone()).await {
|
||||
// TODO synapse checks `rejected_reason` is None here
|
||||
auth_events.insert(key.to_owned(), event);
|
||||
}
|
||||
@ -438,9 +474,11 @@ fn iterative_auth_check<E: Event + Clone>(
|
||||
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
|
||||
});
|
||||
|
||||
if auth_check(room_version, &event, current_third_party, |ty, key| {
|
||||
auth_events.get(&ty.with_state_key(key))
|
||||
})? {
|
||||
let fetch_state = |ty: &StateEventType, key: &str| {
|
||||
future::ready(auth_events.get(&ty.with_state_key(key)))
|
||||
};
|
||||
|
||||
if auth_check(room_version, &event, current_third_party, fetch_state).await? {
|
||||
// add event to resolved state map
|
||||
resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone());
|
||||
} else {
|
||||
@ -462,11 +500,17 @@ fn iterative_auth_check<E: Event + Clone>(
|
||||
/// power_level event. If there have been two power events the after the most recent are depth 0,
|
||||
/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
|
||||
/// "older" than depth 0.
|
||||
fn mainline_sort<E: Event>(
|
||||
async fn mainline_sort<E, F, Fut>(
|
||||
to_sort: &[E::Id],
|
||||
resolved_power_level: Option<E::Id>,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<Vec<E::Id>> {
|
||||
fetch_event: &F,
|
||||
) -> Result<Vec<E::Id>>
|
||||
where
|
||||
F: Fn(E::Id) -> Fut,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Clone + Send,
|
||||
{
|
||||
debug!("mainline sort of events");
|
||||
|
||||
// There are no EventId's to sort, bail.
|
||||
@ -479,11 +523,13 @@ fn mainline_sort<E: Event>(
|
||||
while let Some(p) = pl {
|
||||
mainline.push(p.clone());
|
||||
|
||||
let event = fetch_event(p.borrow())
|
||||
let event = fetch_event(p.clone())
|
||||
.await
|
||||
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
|
||||
pl = None;
|
||||
for aid in event.auth_events() {
|
||||
let ev = fetch_event(aid.borrow())
|
||||
let ev = fetch_event(aid.clone())
|
||||
.await
|
||||
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
|
||||
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
|
||||
pl = Some(aid.to_owned());
|
||||
@ -504,11 +550,15 @@ fn mainline_sort<E: Event>(
|
||||
|
||||
let mut order_map = HashMap::new();
|
||||
for ev_id in to_sort.iter() {
|
||||
if let Some(event) = fetch_event(ev_id.borrow()) {
|
||||
if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) {
|
||||
if let Some(event) = fetch_event(ev_id.clone()).await {
|
||||
if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, fetch_event).await {
|
||||
order_map.insert(
|
||||
ev_id,
|
||||
(depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id),
|
||||
(
|
||||
depth,
|
||||
fetch_event(ev_id.clone()).await.map(|ev| ev.origin_server_ts()),
|
||||
ev_id,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -528,11 +578,17 @@ fn mainline_sort<E: Event>(
|
||||
|
||||
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
|
||||
/// associated mainline depth.
|
||||
fn get_mainline_depth<E: Event>(
|
||||
async fn get_mainline_depth<E, F, Fut>(
|
||||
mut event: Option<E>,
|
||||
mainline_map: &HashMap<E::Id, usize>,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<usize> {
|
||||
fetch_event: &F,
|
||||
) -> Result<usize>
|
||||
where
|
||||
F: Fn(E::Id) -> Fut,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Send,
|
||||
{
|
||||
while let Some(sort_ev) = event {
|
||||
debug!("mainline event_id {}", sort_ev.event_id());
|
||||
let id = sort_ev.event_id();
|
||||
@ -542,7 +598,8 @@ fn get_mainline_depth<E: Event>(
|
||||
|
||||
event = None;
|
||||
for aid in sort_ev.auth_events() {
|
||||
let aev = fetch_event(aid.borrow())
|
||||
let aev = fetch_event(aid.clone())
|
||||
.await
|
||||
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
|
||||
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
|
||||
event = Some(aev);
|
||||
@ -554,18 +611,23 @@ fn get_mainline_depth<E: Event>(
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn add_event_and_auth_chain_to_graph<E: Event>(
|
||||
async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
|
||||
graph: &mut HashMap<E::Id, HashSet<E::Id>>,
|
||||
event_id: E::Id,
|
||||
auth_diff: &HashSet<E::Id>,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) {
|
||||
fetch_event: &F,
|
||||
) where
|
||||
F: Fn(E::Id) -> Fut,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Clone + Send,
|
||||
{
|
||||
let mut state = vec![event_id];
|
||||
while let Some(eid) = state.pop() {
|
||||
graph.entry(eid.clone()).or_default();
|
||||
// Prefer the store to event as the store filters dedups the events
|
||||
for aid in
|
||||
fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
|
||||
fetch_event(eid.clone()).await.as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
|
||||
{
|
||||
if auth_diff.contains(aid.borrow()) {
|
||||
if !graph.contains_key(aid.borrow()) {
|
||||
@ -579,8 +641,14 @@ fn add_event_and_auth_chain_to_graph<E: Event>(
|
||||
}
|
||||
}
|
||||
|
||||
fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
|
||||
match fetch(event_id).as_ref() {
|
||||
async fn is_power_event_id<E, F, Fut>(event_id: &E::Id, fetch: &F) -> bool
|
||||
where
|
||||
F: Fn(E::Id) -> Fut,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
E::Id: Borrow<EventId> + Send,
|
||||
{
|
||||
match fetch(event_id.clone()).await.as_ref() {
|
||||
Some(state) => is_power_event(state),
|
||||
_ => false,
|
||||
}
|
||||
@ -609,7 +677,7 @@ fn is_power_event(event: impl Event) -> bool {
|
||||
}
|
||||
|
||||
/// Convenience trait for adding event type plus state key to state maps.
|
||||
trait EventTypeExt {
|
||||
pub trait EventTypeExt {
|
||||
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
|
||||
}
|
||||
|
||||
@ -662,7 +730,9 @@ mod tests {
|
||||
Event, EventTypeExt, StateMap,
|
||||
};
|
||||
|
||||
fn test_event_sort() {
|
||||
async fn test_event_sort() {
|
||||
use futures_util::future::ready;
|
||||
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
let events = INITIAL_EVENTS();
|
||||
@ -680,18 +750,19 @@ mod tests {
|
||||
.map(|pdu| pdu.event_id.clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let fetcher = |id| ready(events.get(&id).cloned());
|
||||
let sorted_power_events =
|
||||
crate::reverse_topological_power_sort(power_events, &auth_chain, |id| {
|
||||
events.get(id).cloned()
|
||||
})
|
||||
crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resolved_power = crate::iterative_auth_check(
|
||||
&RoomVersion::V6,
|
||||
&sorted_power_events,
|
||||
HashMap::new(), // unconflicted events
|
||||
|id| events.get(id).cloned(),
|
||||
&fetcher,
|
||||
)
|
||||
.await
|
||||
.expect("iterative auth check failed on resolved events");
|
||||
|
||||
// don't remove any events so we know it sorts them all correctly
|
||||
@ -703,8 +774,7 @@ mod tests {
|
||||
resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned();
|
||||
|
||||
let sorted_event_ids =
|
||||
crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).cloned())
|
||||
.unwrap();
|
||||
crate::mainline_sort(&events_to_sort, power_level, &fetcher).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
vec![
|
||||
@ -721,17 +791,17 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort() {
|
||||
#[tokio::test]
|
||||
async fn test_sort() {
|
||||
for _ in 0..20 {
|
||||
// since we shuffle the eventIds before we sort them introducing randomness
|
||||
// seems like we should test this a few times
|
||||
test_event_sort();
|
||||
test_event_sort().await;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ban_vs_power_level() {
|
||||
#[tokio::test]
|
||||
async fn ban_vs_power_level() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -774,11 +844,11 @@ mod tests {
|
||||
let expected_state_ids =
|
||||
vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids);
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn topic_basic() {
|
||||
#[tokio::test]
|
||||
async fn topic_basic() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -835,11 +905,11 @@ mod tests {
|
||||
|
||||
let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids);
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn topic_reset() {
|
||||
#[tokio::test]
|
||||
async fn topic_reset() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -882,11 +952,11 @@ mod tests {
|
||||
let expected_state_ids =
|
||||
vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids);
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn join_rule_evasion() {
|
||||
#[tokio::test]
|
||||
async fn join_rule_evasion() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -914,11 +984,11 @@ mod tests {
|
||||
|
||||
let expected_state_ids = vec![event_id("JR")];
|
||||
|
||||
do_check(events, edges, expected_state_ids);
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn offtopic_power_level() {
|
||||
#[tokio::test]
|
||||
async fn offtopic_power_level() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -955,11 +1025,11 @@ mod tests {
|
||||
|
||||
let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids);
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn topic_setting() {
|
||||
#[tokio::test]
|
||||
async fn topic_setting() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -1032,11 +1102,13 @@ mod tests {
|
||||
|
||||
let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids);
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_event_map_none() {
|
||||
#[tokio::test]
|
||||
async fn test_event_map_none() {
|
||||
use futures_util::future::ready;
|
||||
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -1046,18 +1118,20 @@ mod tests {
|
||||
let (state_at_bob, state_at_charlie, expected) = store.set_up();
|
||||
|
||||
let ev_map = store.0.clone();
|
||||
let fetcher = |id| ready(ev_map.get(&id).cloned());
|
||||
|
||||
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&*id).is_some());
|
||||
|
||||
let state_sets = [state_at_bob, state_at_charlie];
|
||||
let resolved = match crate::resolve(
|
||||
&RoomVersionId::V2,
|
||||
&state_sets,
|
||||
state_sets
|
||||
let auth_chain = state_sets
|
||||
.iter()
|
||||
.map(|map| {
|
||||
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
|
||||
})
|
||||
.collect(),
|
||||
|id| ev_map.get(id).cloned(),
|
||||
) {
|
||||
.map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap())
|
||||
.collect();
|
||||
|
||||
let resolved =
|
||||
match crate::resolve(&RoomVersionId::V2, &state_sets, &auth_chain, &fetcher, &exists)
|
||||
.await
|
||||
{
|
||||
Ok(state) => state,
|
||||
Err(e) => panic!("{e}"),
|
||||
};
|
||||
@ -1065,8 +1139,8 @@ mod tests {
|
||||
assert_eq!(expected, resolved);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lexicographical_sort() {
|
||||
#[tokio::test]
|
||||
async fn test_lexicographical_sort() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
|
||||
@ -1078,9 +1152,10 @@ mod tests {
|
||||
event_id("p") => hashset![event_id("o")],
|
||||
};
|
||||
|
||||
let res = crate::lexicographical_topological_sort(&graph, |_id| {
|
||||
let res = crate::lexicographical_topological_sort(&graph, &|_id| async {
|
||||
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
@ -1092,8 +1167,8 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ban_with_auth_chains() {
|
||||
#[tokio::test]
|
||||
async fn ban_with_auth_chains() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
let ban = BAN_STATE_SET();
|
||||
@ -1105,11 +1180,13 @@ mod tests {
|
||||
|
||||
let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
|
||||
|
||||
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
|
||||
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ban_with_auth_chains2() {
|
||||
#[tokio::test]
|
||||
async fn ban_with_auth_chains2() {
|
||||
use futures_util::future::ready;
|
||||
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
let init = INITIAL_EVENTS();
|
||||
@ -1147,17 +1224,17 @@ mod tests {
|
||||
|
||||
let ev_map = &store.0;
|
||||
let state_sets = [state_set_a, state_set_b];
|
||||
let resolved = match crate::resolve(
|
||||
&RoomVersionId::V6,
|
||||
&state_sets,
|
||||
state_sets
|
||||
let auth_chain = state_sets
|
||||
.iter()
|
||||
.map(|map| {
|
||||
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
|
||||
})
|
||||
.collect(),
|
||||
|id| ev_map.get(id).cloned(),
|
||||
) {
|
||||
.map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap())
|
||||
.collect();
|
||||
|
||||
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
|
||||
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
|
||||
let resolved =
|
||||
match crate::resolve(&RoomVersionId::V6, &state_sets, &auth_chain, &fetcher, &exists)
|
||||
.await
|
||||
{
|
||||
Ok(state) => state,
|
||||
Err(e) => panic!("{e}"),
|
||||
};
|
||||
@ -1180,8 +1257,8 @@ mod tests {
|
||||
assert_eq!(expected.len(), resolved.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn join_rule_with_auth_chain() {
|
||||
#[tokio::test]
|
||||
async fn join_rule_with_auth_chain() {
|
||||
let join_rule = JOIN_RULE();
|
||||
|
||||
let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
|
||||
@ -1191,7 +1268,7 @@ mod tests {
|
||||
|
||||
let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
|
||||
|
||||
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
|
||||
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
|
@ -7,6 +7,7 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use futures_util::future::ready;
|
||||
use js_int::{int, uint};
|
||||
use ruma_common::{
|
||||
event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
|
||||
@ -31,7 +32,7 @@ use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap};
|
||||
|
||||
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
pub(crate) fn do_check(
|
||||
pub(crate) async fn do_check(
|
||||
events: &[Arc<PduEvent>],
|
||||
edges: Vec<Vec<OwnedEventId>>,
|
||||
expected_state_ids: Vec<OwnedEventId>,
|
||||
@ -81,9 +82,10 @@ pub(crate) fn do_check(
|
||||
|
||||
// Resolve the current state and add it to the state_at_event map then continue
|
||||
// on in "time"
|
||||
for node in crate::lexicographical_topological_sort(&graph, |_id| {
|
||||
for node in crate::lexicographical_topological_sort(&graph, &|_id| async {
|
||||
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
let fake_event = fake_event_map.get(&node).unwrap();
|
||||
@ -117,9 +119,13 @@ pub(crate) fn do_check(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resolved = crate::resolve(&RoomVersionId::V6, state_sets, auth_chain_sets, |id| {
|
||||
event_map.get(id).cloned()
|
||||
});
|
||||
let event_map = &event_map;
|
||||
let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
|
||||
let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
|
||||
let resolved =
|
||||
crate::resolve(&RoomVersionId::V6, state_sets, &auth_chain_sets, &fetch, &exists)
|
||||
.await;
|
||||
|
||||
match resolved {
|
||||
Ok(state) => state,
|
||||
Err(e) => panic!("resolution for {node} failed: {e}"),
|
||||
@ -614,7 +620,7 @@ pub(crate) mod event {
|
||||
}
|
||||
}
|
||||
|
||||
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
||||
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + Send + '_> {
|
||||
match &self.rest {
|
||||
Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)),
|
||||
Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()),
|
||||
@ -623,7 +629,7 @@ pub(crate) mod event {
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
||||
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + Send + '_> {
|
||||
match &self.rest {
|
||||
Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)),
|
||||
Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()),
|
||||
|
Loading…
x
Reference in New Issue
Block a user