async state-res

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-09-21 00:49:02 +00:00
parent 1d0b06b581
commit e7db44989d
4 changed files with 262 additions and 166 deletions

View File

@ -18,9 +18,10 @@ all-features = true
unstable-exhaustive-types = []
[dependencies]
futures-util = "0.3"
itertools = "0.12.1"
js_int = { workspace = true }
ruma-common = { workspace = true }
ruma-common = { workspace = true, features = ["api"] }
ruma-events = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@ -34,6 +35,7 @@ criterion = { workspace = true, optional = true }
maplit = { workspace = true }
rand = { workspace = true }
ruma-events = { workspace = true, features = ["unstable-pdu"] }
tokio = { version = "1", features = ["rt", "macros"] }
tracing-subscriber = "0.3.16"
[[bench]]

View File

@ -1,5 +1,6 @@
use std::{borrow::Borrow, collections::BTreeSet};
use futures_util::Future;
use js_int::{int, Int};
use ruma_common::{
serde::{Base64, Raw},
@ -121,12 +122,18 @@ pub fn auth_types_for_event(
///
/// The `fetch_state` closure should gather state from a state snapshot. We need to know if the
/// event passes auth against some state not a recursive collection of auth_events fields.
pub fn auth_check<E: Event>(
pub async fn auth_check<F, Fut, Fetched, Incoming>(
room_version: &RoomVersion,
incoming_event: impl Event,
current_third_party_invite: Option<impl Event>,
fetch_state: impl Fn(&StateEventType, &str) -> Option<E>,
) -> Result<bool> {
incoming_event: &Incoming,
current_third_party_invite: Option<&Incoming>,
fetch_state: F,
) -> Result<bool>
where
F: Fn(&'static StateEventType, &str) -> Fut,
Fut: Future<Output = Option<Fetched>> + Send,
Fetched: Event + Send,
Incoming: Event + Send,
{
debug!(
"auth_check beginning for {} ({})",
incoming_event.event_id(),
@ -216,7 +223,7 @@ pub fn auth_check<E: Event>(
}
*/
let room_create_event = match fetch_state(&StateEventType::RoomCreate, "") {
let room_create_event = match fetch_state(&StateEventType::RoomCreate, "").await {
None => {
warn!("no m.room.create event in auth chain");
return Ok(false);
@ -265,8 +272,8 @@ pub fn auth_check<E: Event>(
}
// If type is m.room.member
let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, "");
let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str());
let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, "").await;
let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str()).await;
if *incoming_event.event_type() == TimelineEventType::RoomMember {
debug!("starting m.room.member check");
@ -290,9 +297,13 @@ pub fn auth_check<E: Event>(
let user_for_join_auth =
content.join_authorised_via_users_server.as_ref().and_then(|u| u.deserialize().ok());
let user_for_join_auth_membership = user_for_join_auth
.as_ref()
.and_then(|auth_user| fetch_state(&StateEventType::RoomMember, auth_user.as_str()))
let user_for_join_auth_event = if let Some(auth_user) = user_for_join_auth.as_ref() {
fetch_state(&StateEventType::RoomMember, auth_user.as_str()).await
} else {
None
};
let user_for_join_auth_membership = user_for_join_auth_event
.and_then(|mem| from_json_str::<GetMembership>(mem.content().get()).ok())
.map(|mem| mem.membership)
.unwrap_or(MembershipState::Leave);
@ -300,13 +311,13 @@ pub fn auth_check<E: Event>(
if !valid_membership_change(
room_version,
target_user,
fetch_state(&StateEventType::RoomMember, target_user.as_str()).as_ref(),
fetch_state(&StateEventType::RoomMember, target_user.as_str()).await.as_ref(),
sender,
sender_member_event.as_ref(),
&incoming_event,
current_third_party_invite,
power_levels_event.as_ref(),
fetch_state(&StateEventType::RoomJoinRules, "").as_ref(),
fetch_state(&StateEventType::RoomJoinRules, "").await.as_ref(),
user_for_join_auth.as_deref(),
&user_for_join_auth_membership,
room_create_event,

View File

@ -5,6 +5,7 @@ use std::{
hash::Hash,
};
use futures_util::{future, stream, Future, StreamExt};
use itertools::Itertools;
use js_int::{int, Int};
use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId};
@ -52,16 +53,22 @@ pub type StateMap<T> = HashMap<(StateEventType, String), T>;
///
/// The caller of `resolve` must ensure that all the events are from the same room. Although this
/// function takes a `RoomId` it does not check that each event is part of the same room.
pub fn resolve<'a, E, SetIter>(
pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>(
room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SetIter>,
auth_chain_sets: Vec<HashSet<E::Id>>,
fetch_event: impl Fn(&EventId) -> Option<E>,
state_sets: impl IntoIterator<IntoIter = SetIter> + Send,
auth_chain_sets: &'a Vec<HashSet<E::Id>>,
event_fetch: &Fetch,
event_exists: &Exists,
) -> Result<StateMap<E::Id>>
where
E: Event + Clone,
E::Id: 'a,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
Fetch: Fn(E::Id) -> FetchFut + Sync,
FetchFut: Future<Output = Option<E>> + Send,
Exists: Fn(E::Id) -> ExistsFut,
ExistsFut: Future<Output = bool> + Send,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send + Sync,
for<'b> &'b E: Send,
{
debug!("State resolution starting");
@ -79,13 +86,16 @@ where
debug!("conflicting events: {}", conflicting.len());
debug!("{conflicting:?}");
let auth_chain_diff =
get_auth_chain_diff(&auth_chain_sets).chain(conflicting.into_values().flatten());
// `all_conflicted` contains unique items
// synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}`
let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets)
.chain(conflicting.into_values().flatten())
let all_conflicted: HashSet<E::Id> = stream::iter(auth_chain_diff)
// Don't honor events we cannot "verify"
.filter(|id| fetch_event(id.borrow()).is_some())
.collect();
.filter(|id| event_exists(id.clone()))
.collect()
.await;
debug!("full conflicted set: {}", all_conflicted.len());
debug!("{all_conflicted:?}");
@ -94,15 +104,15 @@ where
// this is now a check the caller of `resolve` must make.
// Get only the control events with a state_key: "" or ban/kick event (sender != state_key)
let control_events = all_conflicted
.iter()
.filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
.cloned()
.collect::<Vec<_>>();
let control_events = stream::iter(all_conflicted.iter())
.filter(|&id| is_power_event_id(id, &event_fetch))
.map(Clone::clone)
.collect::<Vec<_>>()
.await;
// Sort the control events based on power_level/clock/event_id and outgoing/incoming edges
let sorted_control_levels =
reverse_topological_power_sort(control_events, &all_conflicted, &fetch_event)?;
reverse_topological_power_sort(control_events, &all_conflicted, &event_fetch).await?;
debug!("sorted control events: {}", sorted_control_levels.len());
trace!("{sorted_control_levels:?}");
@ -110,7 +120,8 @@ where
let room_version = RoomVersion::new(room_version)?;
// Sequentially auth check each control event.
let resolved_control =
iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &fetch_event)?;
iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &event_fetch)
.await?;
debug!("resolved control events: {}", resolved_control.len());
trace!("{resolved_control:?}");
@ -135,7 +146,8 @@ where
debug!("power event: {power_event:?}");
let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?;
let sorted_left_events =
mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?;
trace!("events left, sorted: {sorted_left_events:?}");
@ -143,8 +155,9 @@ where
&room_version,
&sorted_left_events,
resolved_control, // The control events are added to the final resolved state
&fetch_event,
)?;
&event_fetch,
)
.await?;
// Add unconflicted state to the resolved state
// We priorities the unconflicting state
@ -188,15 +201,14 @@ where
}
/// Returns a Vec of deduped EventIds that appear in some chains but not others.
fn get_auth_chain_diff<Id>(auth_chain_sets: Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
fn get_auth_chain_diff<Id>(auth_chain_sets: &Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
where
Id: Eq + Hash,
Id: Clone + Eq + Hash,
{
let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<Id, usize> = HashMap::new();
for id in auth_chain_sets.into_iter().flatten() {
*id_counts.entry(id).or_default() += 1;
*id_counts.entry(id.clone()).or_default() += 1;
}
id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
@ -209,16 +221,22 @@ where
///
/// The power level is negative because a higher power level is equated to an earlier (further back
/// in time) origin server timestamp.
fn reverse_topological_power_sort<E: Event>(
async fn reverse_topological_power_sort<E, F, Fut>(
events_to_sort: Vec<E::Id>,
auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<E::Id>> {
fetch_event: &F,
) -> Result<Vec<E::Id>>
where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send + Sync,
{
debug!("reverse topological sort of power events");
let mut graph = HashMap::new();
for event_id in events_to_sort {
add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event);
add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, fetch_event).await;
// TODO: if these functions are ever made async here
// is a good place to yield every once in a while so other
@ -228,7 +246,7 @@ fn reverse_topological_power_sort<E: Event>(
// This is used in the `key_fn` passed to the lexico_topo_sort fn
let mut event_to_pl = HashMap::new();
for event_id in graph.keys() {
let pl = get_power_level_for_sender(event_id.borrow(), &fetch_event)?;
let pl = get_power_level_for_sender(event_id, fetch_event).await?;
debug!("{event_id} power level {pl}");
event_to_pl.insert(event_id.clone(), pl);
@ -238,26 +256,30 @@ fn reverse_topological_power_sort<E: Event>(
// tasks can make progress
}
lexicographical_topological_sort(&graph, |event_id| {
let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?;
let pl = *event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?;
let event_to_pl = &event_to_pl;
let fetcher = |event_id: E::Id| async move {
let pl = *event_to_pl.get(event_id.borrow()).ok_or_else(|| Error::NotFound("".into()))?;
let ev = fetch_event(event_id).await.ok_or_else(|| Error::NotFound("".into()))?;
Ok((pl, ev.origin_server_ts()))
})
};
lexicographical_topological_sort(&graph, &fetcher).await
}
/// Sorts the event graph based on number of outgoing/incoming edges.
///
/// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together
/// with the event ID).
pub fn lexicographical_topological_sort<Id, F>(
pub async fn lexicographical_topological_sort<Id, F, Fut>(
graph: &HashMap<Id, HashSet<Id>>,
key_fn: F,
key_fn: &F,
) -> Result<Vec<Id>>
where
F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>,
Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
F: Fn(Id) -> Fut,
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send,
{
#[derive(PartialEq, Eq, PartialOrd, Ord)]
#[derive(Eq, Ord, PartialEq, PartialOrd)]
struct TieBreaker<'a, Id> {
inv_power_level: Int,
age: MilliSecondsSinceUnixEpoch,
@ -285,7 +307,7 @@ where
for (node, edges) in graph {
if edges.is_empty() {
let (power_level, age) = key_fn(node.borrow())?;
let (power_level, age) = key_fn(node.clone()).await?;
// The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need
// smallest -> largest
zero_outdegree.push(Reverse(TieBreaker {
@ -318,7 +340,7 @@ where
// Only push on the heap once older events have been cleared
out.remove(node.borrow());
if out.is_empty() {
let (power_level, age) = key_fn(node.borrow())?;
let (power_level, age) = key_fn(node.clone()).await?;
heap.push(Reverse(TieBreaker {
inv_power_level: -power_level,
age,
@ -339,17 +361,23 @@ where
/// Do NOT use this any where but topological sort, we find the power level for the eventId
/// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level
/// event).
fn get_power_level_for_sender<E: Event>(
event_id: &EventId,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> serde_json::Result<Int> {
async fn get_power_level_for_sender<E, F, Fut>(
event_id: &E::Id,
fetch_event: &F,
) -> serde_json::Result<Int>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
{
debug!("fetch event ({event_id}) senders power level");
let event = fetch_event(event_id);
let event = fetch_event(event_id.clone()).await;
let mut pl = None;
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
if let Some(aev) = fetch_event(aid.borrow()) {
if let Some(aev) = fetch_event(aid.clone()).await {
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
pl = Some(aev);
break;
@ -381,12 +409,19 @@ fn get_power_level_for_sender<E: Event>(
///
/// For each `events_to_check` event we gather the events needed to auth it from the the
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
fn iterative_auth_check<E: Event + Clone>(
async fn iterative_auth_check<E, F, Fut>(
room_version: &RoomVersion,
events_to_check: &[E::Id],
unconflicted_state: StateMap<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<E::Id>> {
fetch_event: &F,
) -> Result<StateMap<E::Id>>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
for<'a> &'a E: Send,
{
debug!("starting iterative auth check");
debug!("performing auth checks on {events_to_check:?}");
@ -394,7 +429,8 @@ fn iterative_auth_check<E: Event + Clone>(
let mut resolved_state = unconflicted_state;
for event_id in events_to_check {
let event = fetch_event(event_id.borrow())
let event = fetch_event(event_id.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?;
let state_key = event
.state_key()
@ -402,7 +438,7 @@ fn iterative_auth_check<E: Event + Clone>(
let mut auth_events = StateMap::new();
for aid in event.auth_events() {
if let Some(ev) = fetch_event(aid.borrow()) {
if let Some(ev) = fetch_event(aid.clone()).await {
// TODO synapse check "rejected_reason" which is most likely
// related to soft-failing
auth_events.insert(
@ -423,7 +459,7 @@ fn iterative_auth_check<E: Event + Clone>(
event.content(),
)? {
if let Some(ev_id) = resolved_state.get(&key) {
if let Some(event) = fetch_event(ev_id.borrow()) {
if let Some(event) = fetch_event(ev_id.clone()).await {
// TODO synapse checks `rejected_reason` is None here
auth_events.insert(key.to_owned(), event);
}
@ -438,9 +474,11 @@ fn iterative_auth_check<E: Event + Clone>(
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
});
if auth_check(room_version, &event, current_third_party, |ty, key| {
auth_events.get(&ty.with_state_key(key))
})? {
let fetch_state = |ty: &StateEventType, key: &str| {
future::ready(auth_events.get(&ty.with_state_key(key)))
};
if auth_check(room_version, &event, current_third_party, fetch_state).await? {
// add event to resolved state map
resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone());
} else {
@ -462,11 +500,17 @@ fn iterative_auth_check<E: Event + Clone>(
/// power_level event. If there have been two power events the after the most recent are depth 0,
/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
/// "older" than depth 0.
fn mainline_sort<E: Event>(
async fn mainline_sort<E, F, Fut>(
to_sort: &[E::Id],
resolved_power_level: Option<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<E::Id>> {
fetch_event: &F,
) -> Result<Vec<E::Id>>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
{
debug!("mainline sort of events");
// There are no EventId's to sort, bail.
@ -479,11 +523,13 @@ fn mainline_sort<E: Event>(
while let Some(p) = pl {
mainline.push(p.clone());
let event = fetch_event(p.borrow())
let event = fetch_event(p.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
pl = None;
for aid in event.auth_events() {
let ev = fetch_event(aid.borrow())
let ev = fetch_event(aid.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
pl = Some(aid.to_owned());
@ -504,11 +550,15 @@ fn mainline_sort<E: Event>(
let mut order_map = HashMap::new();
for ev_id in to_sort.iter() {
if let Some(event) = fetch_event(ev_id.borrow()) {
if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) {
if let Some(event) = fetch_event(ev_id.clone()).await {
if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, fetch_event).await {
order_map.insert(
ev_id,
(depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id),
(
depth,
fetch_event(ev_id.clone()).await.map(|ev| ev.origin_server_ts()),
ev_id,
),
);
}
}
@ -528,11 +578,17 @@ fn mainline_sort<E: Event>(
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
/// associated mainline depth.
fn get_mainline_depth<E: Event>(
async fn get_mainline_depth<E, F, Fut>(
mut event: Option<E>,
mainline_map: &HashMap<E::Id, usize>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<usize> {
fetch_event: &F,
) -> Result<usize>
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
{
while let Some(sort_ev) = event {
debug!("mainline event_id {}", sort_ev.event_id());
let id = sort_ev.event_id();
@ -542,7 +598,8 @@ fn get_mainline_depth<E: Event>(
event = None;
for aid in sort_ev.auth_events() {
let aev = fetch_event(aid.borrow())
let aev = fetch_event(aid.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
event = Some(aev);
@ -554,18 +611,23 @@ fn get_mainline_depth<E: Event>(
Ok(0)
}
fn add_event_and_auth_chain_to_graph<E: Event>(
async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
graph: &mut HashMap<E::Id, HashSet<E::Id>>,
event_id: E::Id,
auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) {
fetch_event: &F,
) where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
{
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default();
// Prefer the store to event as the store filters dedups the events
for aid in
fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
fetch_event(eid.clone()).await.as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
{
if auth_diff.contains(aid.borrow()) {
if !graph.contains_key(aid.borrow()) {
@ -579,8 +641,14 @@ fn add_event_and_auth_chain_to_graph<E: Event>(
}
}
fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
match fetch(event_id).as_ref() {
async fn is_power_event_id<E, F, Fut>(event_id: &E::Id, fetch: &F) -> bool
where
F: Fn(E::Id) -> Fut,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
{
match fetch(event_id.clone()).await.as_ref() {
Some(state) => is_power_event(state),
_ => false,
}
@ -609,7 +677,7 @@ fn is_power_event(event: impl Event) -> bool {
}
/// Convenience trait for adding event type plus state key to state maps.
trait EventTypeExt {
pub trait EventTypeExt {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
}
@ -662,7 +730,9 @@ mod tests {
Event, EventTypeExt, StateMap,
};
fn test_event_sort() {
async fn test_event_sort() {
use futures_util::future::ready;
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = INITIAL_EVENTS();
@ -680,18 +750,19 @@ mod tests {
.map(|pdu| pdu.event_id.clone())
.collect::<Vec<_>>();
let fetcher = |id| ready(events.get(&id).cloned());
let sorted_power_events =
crate::reverse_topological_power_sort(power_events, &auth_chain, |id| {
events.get(id).cloned()
})
.unwrap();
crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher)
.await
.unwrap();
let resolved_power = crate::iterative_auth_check(
&RoomVersion::V6,
&sorted_power_events,
HashMap::new(), // unconflicted events
|id| events.get(id).cloned(),
&fetcher,
)
.await
.expect("iterative auth check failed on resolved events");
// don't remove any events so we know it sorts them all correctly
@ -703,8 +774,7 @@ mod tests {
resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned();
let sorted_event_ids =
crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).cloned())
.unwrap();
crate::mainline_sort(&events_to_sort, power_level, &fetcher).await.unwrap();
assert_eq!(
vec![
@ -721,17 +791,17 @@ mod tests {
);
}
#[test]
fn test_sort() {
#[tokio::test]
async fn test_sort() {
for _ in 0..20 {
// since we shuffle the eventIds before we sort them introducing randomness
// seems like we should test this a few times
test_event_sort();
test_event_sort().await;
}
}
#[test]
fn ban_vs_power_level() {
#[tokio::test]
async fn ban_vs_power_level() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -774,11 +844,11 @@ mod tests {
let expected_state_ids =
vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
do_check(events, edges, expected_state_ids).await;
}
#[test]
fn topic_basic() {
#[tokio::test]
async fn topic_basic() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -835,11 +905,11 @@ mod tests {
let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
do_check(events, edges, expected_state_ids).await;
}
#[test]
fn topic_reset() {
#[tokio::test]
async fn topic_reset() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -882,11 +952,11 @@ mod tests {
let expected_state_ids =
vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
do_check(events, edges, expected_state_ids).await;
}
#[test]
fn join_rule_evasion() {
#[tokio::test]
async fn join_rule_evasion() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -914,11 +984,11 @@ mod tests {
let expected_state_ids = vec![event_id("JR")];
do_check(events, edges, expected_state_ids);
do_check(events, edges, expected_state_ids).await;
}
#[test]
fn offtopic_power_level() {
#[tokio::test]
async fn offtopic_power_level() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -955,11 +1025,11 @@ mod tests {
let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
do_check(events, edges, expected_state_ids).await;
}
#[test]
fn topic_setting() {
#[tokio::test]
async fn topic_setting() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -1032,11 +1102,13 @@ mod tests {
let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
do_check(events, edges, expected_state_ids).await;
}
#[test]
fn test_event_map_none() {
#[tokio::test]
async fn test_event_map_none() {
use futures_util::future::ready;
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -1046,27 +1118,29 @@ mod tests {
let (state_at_bob, state_at_charlie, expected) = store.set_up();
let ev_map = store.0.clone();
let fetcher = |id| ready(ev_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&*id).is_some());
let state_sets = [state_at_bob, state_at_charlie];
let resolved = match crate::resolve(
&RoomVersionId::V2,
&state_sets,
state_sets
.iter()
.map(|map| {
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
})
.collect(),
|id| ev_map.get(id).cloned(),
) {
Ok(state) => state,
Err(e) => panic!("{e}"),
};
let auth_chain = state_sets
.iter()
.map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap())
.collect();
let resolved =
match crate::resolve(&RoomVersionId::V2, &state_sets, &auth_chain, &fetcher, &exists)
.await
{
Ok(state) => state,
Err(e) => panic!("{e}"),
};
assert_eq!(expected, resolved);
}
#[test]
fn test_lexicographical_sort() {
#[tokio::test]
async fn test_lexicographical_sort() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
@ -1078,9 +1152,10 @@ mod tests {
event_id("p") => hashset![event_id("o")],
};
let res = crate::lexicographical_topological_sort(&graph, |_id| {
let res = crate::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
.unwrap();
assert_eq!(
@ -1092,8 +1167,8 @@ mod tests {
);
}
#[test]
fn ban_with_auth_chains() {
#[tokio::test]
async fn ban_with_auth_chains() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let ban = BAN_STATE_SET();
@ -1105,11 +1180,13 @@ mod tests {
let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
}
#[test]
fn ban_with_auth_chains2() {
#[tokio::test]
async fn ban_with_auth_chains2() {
use futures_util::future::ready;
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let init = INITIAL_EVENTS();
@ -1147,20 +1224,20 @@ mod tests {
let ev_map = &store.0;
let state_sets = [state_set_a, state_set_b];
let resolved = match crate::resolve(
&RoomVersionId::V6,
&state_sets,
state_sets
.iter()
.map(|map| {
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
})
.collect(),
|id| ev_map.get(id).cloned(),
) {
Ok(state) => state,
Err(e) => panic!("{e}"),
};
let auth_chain = state_sets
.iter()
.map(|map| store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap())
.collect();
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
let resolved =
match crate::resolve(&RoomVersionId::V6, &state_sets, &auth_chain, &fetcher, &exists)
.await
{
Ok(state) => state,
Err(e) => panic!("{e}"),
};
debug!(
"{:#?}",
@ -1180,8 +1257,8 @@ mod tests {
assert_eq!(expected.len(), resolved.len());
}
#[test]
fn join_rule_with_auth_chain() {
#[tokio::test]
async fn join_rule_with_auth_chain() {
let join_rule = JOIN_RULE();
let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
@ -1191,7 +1268,7 @@ mod tests {
let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
}
#[allow(non_snake_case)]

View File

@ -7,6 +7,7 @@ use std::{
},
};
use futures_util::future::ready;
use js_int::{int, uint};
use ruma_common::{
event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
@ -31,7 +32,7 @@ use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap};
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
pub(crate) fn do_check(
pub(crate) async fn do_check(
events: &[Arc<PduEvent>],
edges: Vec<Vec<OwnedEventId>>,
expected_state_ids: Vec<OwnedEventId>,
@ -81,9 +82,10 @@ pub(crate) fn do_check(
// Resolve the current state and add it to the state_at_event map then continue
// on in "time"
for node in crate::lexicographical_topological_sort(&graph, |_id| {
for node in crate::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
.unwrap()
{
let fake_event = fake_event_map.get(&node).unwrap();
@ -117,9 +119,13 @@ pub(crate) fn do_check(
})
.collect();
let resolved = crate::resolve(&RoomVersionId::V6, state_sets, auth_chain_sets, |id| {
event_map.get(id).cloned()
});
let event_map = &event_map;
let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
let resolved =
crate::resolve(&RoomVersionId::V6, state_sets, &auth_chain_sets, &fetch, &exists)
.await;
match resolved {
Ok(state) => state,
Err(e) => panic!("resolution for {node} failed: {e}"),
@ -614,7 +620,7 @@ pub(crate) mod event {
}
}
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + Send + '_> {
match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()),
@ -623,7 +629,7 @@ pub(crate) mod event {
}
}
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + Send + '_> {
match &self.rest {
Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)),
Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()),