Add benchmark for longer auth chain and Error type

This required that the code being run in the benchmark be tested to
verify it works correctly. Now work can begin cleaning up and optimizing
state-res.
This commit is contained in:
Devin Ragotzy 2020-07-27 00:09:21 -04:00
parent ea0b6ad530
commit d8fb5ca112
9 changed files with 1171 additions and 117 deletions

View File

@ -13,6 +13,7 @@ serde = { version = "1.0.114", features = ["derive"] }
serde_json = "1.0.56" serde_json = "1.0.56"
tracing = "0.1.16" tracing = "0.1.16"
maplit = "1.0.2" maplit = "1.0.2"
thiserror = "1.0.20"
tracing-subscriber = "0.2.8" tracing-subscriber = "0.2.8"
[dependencies.ruma] [dependencies.ruma]

View File

@ -1,4 +1,5 @@
Would it be possible to abstract state res into a `ruma-state-res` crate? I've been thinking about something along the lines of ### Matrix state resolution in rust!
```rust ```rust
/// StateMap is just a wrapper/deserialize target for a PDU. /// StateMap is just a wrapper/deserialize target for a PDU.
struct StateEvent { struct StateEvent {
@ -41,3 +42,9 @@ trait StateStore {
} }
``` ```
The `StateStore` trait is an abstraction around what ever database your server (or maybe even client) uses to store __P__[]()ersistant __D__[]()ata __U__[]()nits.
We use `ruma`s types when deserializing any PDU or it's contents which helps avoid a lot of type checking logic [synapse](https://github.com/matrix-org/synapse) must do while authenticating event chains.

View File

@ -70,10 +70,83 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) {
}); });
} }
criterion_group!(benches, lexico_topo_sort, resolution_shallow_auth_chain); fn resolve_deeper_event_set(c: &mut Criterion) {
c.bench_function("resolve state of 10 events 3 conflicting", |b| {
let mut resolver = StateResolution::default();
let init = INITIAL_EVENTS();
let ban = BAN_STATE_SET();
let mut inner = init;
inner.extend(ban);
let store = TestStore(RefCell::new(inner.clone()));
let state_set_a = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("MB")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.kind(), ev.state_key().unwrap()),
ev.event_id().unwrap().clone(),
)
})
.collect::<BTreeMap<_, _>>();
let state_set_b = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("IME")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.kind(), ev.state_key().unwrap()),
ev.event_id().unwrap().clone(),
)
})
.collect::<BTreeMap<_, _>>();
b.iter(|| {
let _resolved = match resolver.resolve(
&room_id(),
&RoomVersionId::version_2(),
&[state_set_a.clone(), state_set_b.clone()],
Some(inner.clone()),
&store,
) {
Ok(ResolutionResult::Resolved(state)) => state,
Err(_) => panic!("resolution failed during benchmarking"),
_ => panic!("resolution failed during benchmarking"),
};
})
});
}
criterion_group!(
benches,
lexico_topo_sort,
resolution_shallow_auth_chain,
resolve_deeper_event_set
);
criterion_main!(benches); criterion_main!(benches);
//*/////////////////////////////////////////////////////////////////////
//
// IMPLEMENTATION DETAILS AHEAD
//
/////////////////////////////////////////////////////////////////////*/
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>); pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
#[allow(unused)] #[allow(unused)]
@ -115,7 +188,7 @@ impl StateStore for TestStore {
result.push(ev_id.clone()); result.push(ev_id.clone());
let event = self.get_event(&ev_id).unwrap(); let event = self.get_event(&ev_id).unwrap();
stack.extend(event.auth_event_ids()); stack.extend(event.auth_events());
} }
Ok(result) Ok(result)
@ -220,7 +293,7 @@ impl TestStore {
EventType::RoomMember, EventType::RoomMember,
Some(charlie().to_string().as_str()), Some(charlie().to_string().as_str()),
member_content_join(), member_content_join(),
&[cre.clone(), join_rules.event_id().unwrap().clone()], &[cre, join_rules.event_id().unwrap().clone()],
&[join_rules.event_id().unwrap().clone()], &[join_rules.event_id().unwrap().clone()],
); );
self.0 self.0
@ -231,7 +304,7 @@ impl TestStore {
.iter() .iter()
.map(|e| { .map(|e| {
( (
(e.kind(), e.state_key().unwrap().clone()), (e.kind(), e.state_key().unwrap()),
e.event_id().unwrap().clone(), e.event_id().unwrap().clone(),
) )
}) })
@ -241,7 +314,7 @@ impl TestStore {
.iter() .iter()
.map(|e| { .map(|e| {
( (
(e.kind(), e.state_key().unwrap().clone()), (e.kind(), e.state_key().unwrap()),
e.event_id().unwrap().clone(), e.event_id().unwrap().clone(),
) )
}) })
@ -257,7 +330,7 @@ impl TestStore {
.iter() .iter()
.map(|e| { .map(|e| {
( (
(e.kind(), e.state_key().unwrap().clone()), (e.kind(), e.state_key().unwrap()),
e.event_id().unwrap().clone(), e.event_id().unwrap().clone(),
) )
}) })
@ -268,7 +341,7 @@ impl TestStore {
} }
fn event_id(id: &str) -> EventId { fn event_id(id: &str) -> EventId {
if id.contains("$") { if id.contains('$') {
return EventId::try_from(id).unwrap(); return EventId::try_from(id).unwrap();
} }
EventId::try_from(format!("${}:foo", id)).unwrap() EventId::try_from(format!("${}:foo", id)).unwrap()
@ -283,11 +356,25 @@ fn bob() -> UserId {
fn charlie() -> UserId { fn charlie() -> UserId {
UserId::try_from("@charlie:foo").unwrap() UserId::try_from("@charlie:foo").unwrap()
} }
fn ella() -> UserId {
UserId::try_from("@ella:foo").unwrap()
}
fn room_id() -> RoomId { fn room_id() -> RoomId {
RoomId::try_from("!test:foo").unwrap() RoomId::try_from("!test:foo").unwrap()
} }
fn member_content_ban() -> JsonValue {
serde_json::to_value(MemberEventContent {
membership: MembershipState::Ban,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
})
.unwrap()
}
fn member_content_join() -> JsonValue { fn member_content_join() -> JsonValue {
serde_json::to_value(MemberEventContent { serde_json::to_value(MemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
@ -317,7 +404,7 @@ where
SERVER_TIMESTAMP += 1; SERVER_TIMESTAMP += 1;
ts ts
}; };
let id = if id.contains("$") { let id = if id.contains('$') {
id.to_string() id.to_string()
} else { } else {
format!("${}:foo", id) format!("${}:foo", id)
@ -325,33 +412,13 @@ where
let auth_events = auth_events let auth_events = auth_events
.iter() .iter()
.map(AsRef::as_ref) .map(AsRef::as_ref)
.map(|s| { .map(event_id)
EventId::try_from( .collect::<Vec<_>>();
if s.contains("$") {
s.to_owned()
} else {
format!("${}:foo", s)
}
.as_str(),
)
})
.collect::<Result<Vec<_>, _>>()
.unwrap();
let prev_events = prev_events let prev_events = prev_events
.iter() .iter()
.map(AsRef::as_ref) .map(AsRef::as_ref)
.map(|s| { .map(event_id)
EventId::try_from( .collect::<Vec<_>>();
if s.contains("$") {
s.to_owned()
} else {
format!("${}:foo", s)
}
.as_str(),
)
})
.collect::<Result<Vec<_>, _>>()
.unwrap();
let json = if let Some(state_key) = state_key { let json = if let Some(state_key) = state_key {
json!({ json!({
@ -387,3 +454,131 @@ where
}; };
serde_json::from_value(json).unwrap() serde_json::from_value(json).unwrap()
} }
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
vec![
to_pdu_event::<EventId>(
"CREATE",
alice(),
EventType::RoomCreate,
Some(""),
json!({ "creator": alice() }),
&[],
&[],
),
to_pdu_event(
"IMA",
alice(),
EventType::RoomMember,
Some(alice().to_string().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
),
to_pdu_event(
"IPOWER",
alice(),
EventType::RoomPowerLevels,
Some(""),
json!({"users": {alice().to_string(): 100}}),
&["CREATE", "IMA"],
&["IMA"],
),
to_pdu_event(
"IJR",
alice(),
EventType::RoomJoinRules,
Some(""),
json!({ "join_rule": JoinRule::Public }),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
),
to_pdu_event(
"IMB",
bob(),
EventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IJR"],
),
to_pdu_event(
"IMC",
charlie(),
EventType::RoomMember,
Some(charlie().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IMB"],
),
to_pdu_event::<EventId>(
"START",
charlie(),
EventType::RoomMessage,
None,
json!({}),
&[],
&[],
),
to_pdu_event::<EventId>(
"END",
charlie(),
EventType::RoomMessage,
None,
json!({}),
&[],
&[],
),
]
.into_iter()
.map(|ev| (ev.event_id().unwrap().clone(), ev))
.collect()
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn BAN_STATE_SET() -> BTreeMap<EventId, StateEvent> {
vec![
to_pdu_event(
"PA",
alice(),
EventType::RoomPowerLevels,
Some(""),
json!({"users": {alice(): 100, bob(): 50}}),
&["CREATE", "IMA", "IPOWER"], // auth_events
&["START"], // prev_events
),
to_pdu_event(
"PB",
alice(),
EventType::RoomPowerLevels,
Some(""),
json!({"users": {alice(): 100, bob(): 50}}),
&["CREATE", "IMA", "IPOWER"],
&["END"],
),
to_pdu_event(
"MB",
alice(),
EventType::RoomMember,
Some(ella().as_str()),
member_content_ban(),
&["CREATE", "IMA", "PB"],
&["PA"],
),
to_pdu_event(
"IME",
ella(),
EventType::RoomMember,
Some(ella().as_str()),
member_content_join(),
&["CREATE", "IJR", "PA"],
&["MB"],
),
]
.into_iter()
.map(|ev| (ev.event_id().unwrap().clone(), ev))
.collect()
}

23
src/error.rs Normal file
View File

@ -0,0 +1,23 @@
use std::num::ParseIntError;
use serde_json::Error as JsonError;
use thiserror::Error;
/// Result type for state resolution.
pub type Result<T> = std::result::Result<T, Error>;
/// Represents the various errors that arise when resolving state.
#[derive(Error, Debug)]
pub enum Error {
/// A deserialization error.
#[error(transparent)]
SerdeJson(#[from] JsonError),
/// An error that occurs when converting from JSON numbers to rust.
#[error(transparent)]
IntParseError(#[from] ParseIntError),
// TODO remove once the correct errors are used
#[error("an error occured {0}")]
TempString(String),
}

View File

@ -1,5 +1,6 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use maplit::btreeset;
use ruma::{ use ruma::{
events::{ events::{
room::{self, join_rules::JoinRule, member::MembershipState}, room::{self, join_rules::JoinRule, member::MembershipState},
@ -89,7 +90,7 @@ pub fn auth_check(
false false
}; };
if !event.signatures().get(sender_domain).is_some() && !is_invite_via_3pid { if event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid {
tracing::info!("event not signed by sender's server"); tracing::info!("event not signed by sender's server");
return Some(false); return Some(false);
} }
@ -107,6 +108,7 @@ pub fn auth_check(
// domain of room_id must match domain of sender. // domain of room_id must match domain of sender.
if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) { if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) {
tracing::info!("creation events server does not match sender");
return Some(false); // creation events room id does not match senders return Some(false); // creation events room id does not match senders
} }
@ -117,7 +119,8 @@ pub fn auth_check(
.content() .content()
.get("room_version") .get("room_version")
.cloned() .cloned()
.unwrap_or(serde_json::json!({})), // synapse defaults to version 1
.unwrap_or(serde_json::json!("1")),
) )
.is_err() .is_err()
{ {
@ -231,7 +234,7 @@ fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
let creation_event = auth_events.get(&(EventType::RoomCreate, "".into())); let creation_event = auth_events.get(&(EventType::RoomCreate, "".into()));
if let Some(ev) = creation_event { if let Some(ev) = creation_event {
if let Some(fed) = ev.content().get("m.federate") { if let Some(fed) = ev.content().get("m.federate") {
fed.to_string() == "true" fed == "true"
} else { } else {
false false
} }
@ -468,7 +471,7 @@ fn can_send_event(event: &StateEvent, auth_events: &StateMap<StateEvent>) -> Opt
} }
if let Some(sk) = event.state_key() { if let Some(sk) = event.state_key() {
if sk.starts_with("@") && sk != event.sender().to_string() { if sk.starts_with('@') && sk != event.sender().as_str() {
return Some(false); // permission required to post in this room return Some(false); // permission required to post in this room
} }
} }
@ -484,7 +487,13 @@ fn check_power_levels(
use itertools::Itertools; use itertools::Itertools;
let key = (power_event.kind(), power_event.state_key().unwrap()); let key = (power_event.kind(), power_event.state_key().unwrap());
let current_state = auth_events.get(&key)?;
let current_state = if let Some(current_state) = auth_events.get(&key) {
current_state
} else {
// TODO synapse returns here, shouldn't this be an error ??
return Some(true);
};
let user_content = power_event let user_content = power_event
.deserialize_content::<room::power_levels::PowerLevelsEventContent>() .deserialize_content::<room::power_levels::PowerLevelsEventContent>()
@ -493,25 +502,27 @@ fn check_power_levels(
.deserialize_content::<room::power_levels::PowerLevelsEventContent>() .deserialize_content::<room::power_levels::PowerLevelsEventContent>()
.unwrap(); .unwrap();
tracing::info!("validation of power event finished");
// validation of users is done in Ruma, synapse for loops validating user_ids and integers here // validation of users is done in Ruma, synapse for loops validating user_ids and integers here
tracing::info!("validation of power event finished");
let user_level = get_user_power_level(power_event.sender(), auth_events); let user_level = get_user_power_level(power_event.sender(), auth_events);
let mut user_levels_to_check = vec![]; let mut user_levels_to_check = btreeset![];
let old_list = &current_content.users; let old_list = &current_content.users;
let user_list = &user_content.users; let user_list = &user_content.users;
for user in old_list.keys().chain(user_list.keys()).dedup() { for user in old_list.keys().chain(user_list.keys()).dedup() {
let user: &UserId = user; let user: &UserId = user;
user_levels_to_check.push(user); user_levels_to_check.insert(user);
} }
let mut event_levels_to_check = vec![]; tracing::debug!("users to check {:?}", user_levels_to_check);
let mut event_levels_to_check = btreeset![];
let old_list = &current_content.events; let old_list = &current_content.events;
let new_list = &user_content.events; let new_list = &user_content.events;
for ev_id in old_list.keys().chain(new_list.keys()).dedup() { for ev_id in old_list.keys().chain(new_list.keys()).dedup() {
let ev_id: &EventType = ev_id; let ev_id: &EventType = ev_id;
event_levels_to_check.push(ev_id); event_levels_to_check.insert(ev_id);
} }
tracing::debug!("events to check {:?}", event_levels_to_check); tracing::debug!("events to check {:?}", event_levels_to_check);
@ -574,9 +585,43 @@ fn check_power_levels(
} }
} }
let levels = [
"users_default",
"events_default",
"state_default",
"ban",
"redact",
"kick",
"invite",
];
let old_state = serde_json::to_value(old_state).unwrap();
let new_state = serde_json::to_value(new_state).unwrap();
for lvl_name in &levels {
if let Some((old_lvl, new_lvl)) = get_deserialize_levels(&old_state, &new_state, lvl_name) {
let old_level_too_big = old_lvl > user_level;
let new_level_too_big = new_lvl > user_level;
if old_level_too_big || new_level_too_big {
tracing::info!("cannot add ops > than own");
return Some(false);
}
}
}
Some(true) Some(true)
} }
fn get_deserialize_levels(
old: &serde_json::Value,
new: &serde_json::Value,
name: &str,
) -> Option<(i64, i64)> {
Some((
serde_json::from_value(old.get(name)?.clone()).ok()?,
serde_json::from_value(new.get(name)?.clone()).ok()?,
))
}
/// Does the event redacting come from a user with enough power to redact the given event. /// Does the event redacting come from a user with enough power to redact the given event.
fn check_redaction( fn check_redaction(
room_version: &RoomVersionId, room_version: &RoomVersionId,

View File

@ -1,3 +1,5 @@
#![allow(clippy::or_fun_call)]
use std::{ use std::{
cmp::Reverse, cmp::Reverse,
collections::{BTreeMap, BTreeSet, BinaryHeap}, collections::{BTreeMap, BTreeSet, BinaryHeap},
@ -11,11 +13,13 @@ use ruma::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
mod error;
mod event_auth; mod event_auth;
mod room_version; mod room_version;
mod state_event; mod state_event;
mod state_store; mod state_store;
pub use error::{Error, Result};
pub use event_auth::{auth_check, auth_types_for_event}; pub use event_auth::{auth_check, auth_types_for_event};
pub use state_event::StateEvent; pub use state_event::StateEvent;
pub use state_store::StateStore; pub use state_store::StateStore;
@ -66,7 +70,7 @@ impl StateResolution {
event_map: Option<EventMap<StateEvent>>, event_map: Option<EventMap<StateEvent>>,
store: &dyn StateStore, store: &dyn StateStore,
// TODO actual error handling (`thiserror`??) // TODO actual error handling (`thiserror`??)
) -> Result<ResolutionResult, String> { ) -> Result<ResolutionResult> {
tracing::info!("State resolution starting"); tracing::info!("State resolution starting");
let mut event_map = if let Some(ev_map) = event_map { let mut event_map = if let Some(ev_map) = event_map {
@ -76,7 +80,26 @@ impl StateResolution {
}; };
// split non-conflicting and conflicting state // split non-conflicting and conflicting state
let (clean, conflicting) = self.separate(&state_sets); let (clean, conflicting) = self.separate(&state_sets);
tracing::debug!(
"CLEAN {:#?}",
clean
.iter()
.map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id))
.collect::<Vec<_>>()
);
tracing::debug!(
"CONFLICT {:#?}",
conflicting
.iter()
.map(|((ty, key), ids)| format!(
"(({} `{}`), {:?})",
ty,
key,
ids.iter().map(ToString::to_string).collect::<Vec<_>>()
))
.collect::<Vec<_>>()
);
tracing::info!("non conflicting {:?}", clean.len()); tracing::info!("non conflicting {:?}", clean.len());
if conflicting.is_empty() { if conflicting.is_empty() {
@ -124,7 +147,7 @@ impl StateResolution {
for event in event_map.values() { for event in event_map.values() {
if event.room_id() != Some(room_id) { if event.room_id() != Some(room_id) {
return Err(format!( return Err(Error::TempString(format!(
"resolving event {} in room {}, when correct room is {}", "resolving event {} in room {}, when correct room is {}",
event event
.event_id() .event_id()
@ -132,7 +155,7 @@ impl StateResolution {
.unwrap_or("`unknown`"), .unwrap_or("`unknown`"),
event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"), event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"),
room_id.as_str() room_id.as_str()
)); )));
} }
} }
@ -153,7 +176,7 @@ impl StateResolution {
let mut sorted_power_levels = self.reverse_topological_power_sort( let mut sorted_power_levels = self.reverse_topological_power_sort(
room_id, room_id,
&power_events, &power_events,
&mut event_map, &event_map, // TODO use event_map
store, store,
&all_conflicted, &all_conflicted,
); );
@ -172,7 +195,7 @@ impl StateResolution {
room_version, room_version,
&sorted_power_levels, &sorted_power_levels,
&clean, &clean,
&mut event_map, &event_map,
store, store,
)?; )?;
@ -224,7 +247,7 @@ impl StateResolution {
room_version, room_version,
&sorted_left_events, &sorted_left_events,
&resolved, &resolved,
&mut event_map, &event_map,
store, store,
)?; )?;
@ -255,7 +278,8 @@ impl StateResolution {
for key in state_sets for key in state_sets
.iter() .iter()
.flat_map(|map| map.keys()) .flat_map(|map| map.keys())
.collect::<BTreeSet<_>>() .dedup()
.collect::<Vec<_>>()
{ {
let mut event_ids = state_sets let mut event_ids = state_sets
.iter() .iter()
@ -263,6 +287,14 @@ impl StateResolution {
.dedup() .dedup()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
tracing::debug!(
"SEP {:?}",
event_ids
.iter()
.map(|i| i.map(ToString::to_string).unwrap_or("None".into()))
.collect::<Vec<_>>()
);
if event_ids.len() == 1 { if event_ids.len() == 1 {
if let Some(Some(id)) = event_ids.pop() { if let Some(Some(id)) = event_ids.pop() {
unconflicted_state.insert(key.clone(), id.clone()); unconflicted_state.insert(key.clone(), id.clone());
@ -270,6 +302,7 @@ impl StateResolution {
panic!() panic!()
} }
} else { } else {
tracing::warn!("{:?}", key);
conflicted_state.insert( conflicted_state.insert(
key.clone(), key.clone(),
event_ids.into_iter().flatten().cloned().collect::<Vec<_>>(), event_ids.into_iter().flatten().cloned().collect::<Vec<_>>(),
@ -287,19 +320,21 @@ impl StateResolution {
state_sets: &[StateMap<EventId>], state_sets: &[StateMap<EventId>],
_event_map: &EventMap<StateEvent>, _event_map: &EventMap<StateEvent>,
store: &dyn StateStore, store: &dyn StateStore,
) -> Result<Vec<EventId>, String> { ) -> Result<Vec<EventId>> {
use itertools::Itertools; use itertools::Itertools;
tracing::debug!("calculating auth chain difference"); tracing::debug!("calculating auth chain difference");
store.auth_chain_diff( store
room_id, .auth_chain_diff(
state_sets room_id,
.iter() state_sets
.map(|map| map.values().cloned().collect()) .iter()
.dedup() .map(|map| map.values().cloned().collect())
.collect::<Vec<_>>(), .dedup()
) .collect::<Vec<_>>(),
)
.map_err(Error::TempString)
} }
pub fn reverse_topological_power_sort( pub fn reverse_topological_power_sort(
@ -338,15 +373,20 @@ impl StateResolution {
} }
} }
self.lexicographical_topological_sort(&mut graph, |event_id| { self.lexicographical_topological_sort(&graph, |event_id| {
// tracing::debug!("{:?}", event_map.get(event_id).unwrap().origin_server_ts()); // tracing::debug!("{:?}", event_map.get(event_id).unwrap().origin_server_ts());
let ev = event_map.get(event_id).unwrap(); let ev = event_map.get(event_id).unwrap();
let pl = event_to_pl.get(event_id).unwrap(); let pl = event_to_pl.get(event_id).unwrap();
tracing::warn!(
"{:?}",
(-*pl, *ev.origin_server_ts(), ev.event_id().cloned())
);
// This return value is the key used for sorting events, // This return value is the key used for sorting events,
// events are then sorted by power level, time, // events are then sorted by power level, time,
// and lexically by event_id. // and lexically by event_id.
(-*pl, ev.origin_server_ts().clone(), ev.event_id().cloned()) (-*pl, *ev.origin_server_ts(), ev.event_id().cloned())
}) })
} }
@ -371,8 +411,8 @@ impl StateResolution {
// TODO make the BTreeSet conversion cleaner ?? // TODO make the BTreeSet conversion cleaner ??
let mut outdegree_map: BTreeMap<EventId, BTreeSet<EventId>> = graph let mut outdegree_map: BTreeMap<EventId, BTreeSet<EventId>> = graph
.into_iter() .iter()
.map(|(k, v)| (k.clone(), v.into_iter().cloned().collect())) .map(|(k, v)| (k.clone(), v.iter().cloned().collect()))
.collect(); .collect();
let mut reverse_graph = BTreeMap::new(); let mut reverse_graph = BTreeMap::new();
@ -432,7 +472,7 @@ impl StateResolution {
let mut pl = None; let mut pl = None;
// TODO store.auth_event_ids returns "self" with the event ids is this ok // TODO store.auth_event_ids returns "self" with the event ids is this ok
// event.auth_event_ids does not include its own event id ? // event.auth_event_ids does not include its own event id ?
for aid in store.get_event(event_id).unwrap().auth_event_ids() { for aid in store.get_event(event_id).unwrap().auth_events() {
if let Ok(aev) = store.get_event(&aid) { if let Ok(aev) = store.get_event(&aid) {
if aev.is_type_and_key(EventType::RoomPowerLevels, "") { if aev.is_type_and_key(EventType::RoomPowerLevels, "") {
pl = Some(aev); pl = Some(aev);
@ -442,7 +482,7 @@ impl StateResolution {
} }
if pl.is_none() { if pl.is_none() {
for aid in store.get_event(event_id).unwrap().auth_event_ids() { for aid in store.get_event(event_id).unwrap().auth_events() {
if let Ok(aev) = store.get_event(&aid) { if let Ok(aev) = store.get_event(&aid) {
if aev.is_type_and_key(EventType::RoomCreate, "") { if aev.is_type_and_key(EventType::RoomCreate, "") {
if let Ok(content) = aev if let Ok(content) = aev
@ -487,16 +527,25 @@ impl StateResolution {
unconflicted_state: &StateMap<EventId>, unconflicted_state: &StateMap<EventId>,
_event_map: &EventMap<StateEvent>, // TODO use event_map over store ?? _event_map: &EventMap<StateEvent>, // TODO use event_map over store ??
store: &dyn StateStore, store: &dyn StateStore,
) -> Result<StateMap<EventId>, String> { ) -> Result<StateMap<EventId>> {
tracing::info!("starting iterative auth check"); tracing::info!("starting iterative auth check");
tracing::debug!(
"{:?}",
power_events
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
);
let mut resolved_state = unconflicted_state.clone(); let mut resolved_state = unconflicted_state.clone();
for (idx, event_id) in power_events.iter().enumerate() { for (idx, event_id) in power_events.iter().enumerate() {
tracing::warn!("POWER EVENTS {}", event_id.as_str());
let event = store.get_event(event_id).unwrap(); let event = store.get_event(event_id).unwrap();
let mut auth_events = BTreeMap::new(); let mut auth_events = BTreeMap::new();
for aid in event.auth_event_ids() { for aid in event.auth_events() {
if let Ok(ev) = store.get_event(&aid) { if let Ok(ev) = store.get_event(&aid) {
// TODO what to do when no state_key is found ?? // TODO what to do when no state_key is found ??
// TODO check "rejected_reason", I'm guessing this is redacted_because for ruma ?? // TODO check "rejected_reason", I'm guessing this is redacted_because for ruma ??
@ -508,9 +557,8 @@ impl StateResolution {
for key in event_auth::auth_types_for_event(&event) { for key in event_auth::auth_types_for_event(&event) {
if let Some(ev_id) = resolved_state.get(&key) { if let Some(ev_id) = resolved_state.get(&key) {
// TODO synapse gets the event from the store then checks its not None
// then pulls the same `ev_id` event from the event_map??
if let Ok(event) = store.get_event(ev_id) { if let Ok(event) = store.get_event(ev_id) {
// TODO synapse checks `rejected_reason` is None here
auth_events.insert(key.clone(), event); auth_events.insert(key.clone(), event);
} }
} }
@ -518,7 +566,8 @@ impl StateResolution {
tracing::debug!("event to check {:?}", event.event_id().unwrap().to_string()); tracing::debug!("event to check {:?}", event.event_id().unwrap().to_string());
if event_auth::auth_check(room_version, &event, auth_events, false) if event_auth::auth_check(room_version, &event, auth_events, false)
.ok_or("Auth check failed due to deserialization most likely".to_string())? .ok_or("Auth check failed due to deserialization most likely".to_string())
.map_err(Error::TempString)?
{ {
// add event to resolved state map // add event to resolved state map
resolved_state.insert((event.kind(), event.state_key().unwrap()), event_id.clone()); resolved_state.insert((event.kind(), event.state_key().unwrap()), event_id.clone());
@ -567,7 +616,7 @@ impl StateResolution {
// We don't need the actual pl_ev here since we delegate to the store // We don't need the actual pl_ev here since we delegate to the store
let event = store.get_event(&p).unwrap(); let event = store.get_event(&p).unwrap();
let auth_events = event.auth_event_ids(); let auth_events = event.auth_events();
pl = None; pl = None;
for aid in auth_events { for aid in auth_events {
let ev = store.get_event(&aid).unwrap(); let ev = store.get_event(&aid).unwrap();
@ -635,7 +684,7 @@ impl StateResolution {
} }
} }
let auth_events = sort_ev.auth_event_ids(); let auth_events = sort_ev.auth_events();
event = None; event = None;
for aid in auth_events { for aid in auth_events {
let aev = store.get_event(&aid).unwrap(); let aev = store.get_event(&aid).unwrap();
@ -664,7 +713,7 @@ impl StateResolution {
graph.entry(eid.clone()).or_insert(vec![]); graph.entry(eid.clone()).or_insert(vec![]);
// prefer the store to event as the store filters dedups the events // prefer the store to event as the store filters dedups the events
// otherwise it seems we can loop forever // otherwise it seems we can loop forever
for aid in store.get_event(&eid).unwrap().auth_event_ids() { for aid in store.get_event(&eid).unwrap().auth_events() {
if auth_diff.contains(&aid) { if auth_diff.contains(&aid) {
if !graph.contains_key(&aid) { if !graph.contains_key(&aid) {
state.push(aid.clone()); state.push(aid.clone());

View File

@ -173,29 +173,29 @@ impl StateEvent {
pub fn prev_event_ids(&self) -> Vec<EventId> { pub fn prev_event_ids(&self) -> Vec<EventId> {
match self { match self {
Self::Full(ev) => match ev { Self::Full(ev) => match ev {
Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().cloned().collect(), Pdu::RoomV1Pdu(ev) => ev.prev_events.to_vec(),
Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(), Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(),
}, },
Self::Sync(ev) => match ev { Self::Sync(ev) => match ev {
PduStub::RoomV1PduStub(ev) => { PduStub::RoomV1PduStub(ev) => {
ev.prev_events.iter().map(|(id, _)| id).cloned().collect() ev.prev_events.iter().map(|(id, _)| id).cloned().collect()
} }
PduStub::RoomV3PduStub(ev) => ev.prev_events.clone(), PduStub::RoomV3PduStub(ev) => ev.prev_events.to_vec(),
}, },
} }
} }
pub fn auth_event_ids(&self) -> Vec<EventId> { pub fn auth_events(&self) -> Vec<EventId> {
match self { match self {
Self::Full(ev) => match ev { Self::Full(ev) => match ev {
Pdu::RoomV1Pdu(ev) => ev.auth_events.iter().cloned().collect(), Pdu::RoomV1Pdu(ev) => ev.auth_events.to_vec(),
Pdu::RoomV3Pdu(ev) => ev.auth_events.clone(), Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(),
}, },
Self::Sync(ev) => match ev { Self::Sync(ev) => match ev {
PduStub::RoomV1PduStub(ev) => { PduStub::RoomV1PduStub(ev) => {
ev.auth_events.iter().map(|(id, _)| id).cloned().collect() ev.auth_events.iter().map(|(id, _)| id).cloned().collect()
} }
PduStub::RoomV3PduStub(ev) => ev.auth_events.clone(), PduStub::RoomV3PduStub(ev) => ev.auth_events.to_vec(),
}, },
} }
} }

741
tests/auth_ids.rs Normal file
View File

@ -0,0 +1,741 @@
#![allow(clippy::or_fun_call, clippy::expect_fun_call)]
use std::{
cell::RefCell,
collections::{BTreeMap, BTreeSet},
convert::TryFrom,
sync::Once,
time::UNIX_EPOCH,
};
use ruma::{
events::{
room::{
join_rules::JoinRule,
member::{MemberEventContent, MembershipState},
},
EventType,
},
identifiers::{EventId, RoomId, RoomVersionId, UserId},
};
use serde_json::{json, Value as JsonValue};
use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore};
use tracing_subscriber as tracer;
static LOGGER: Once = Once::new();
static mut SERVER_TIMESTAMP: i32 = 0;
fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
// to activate logging use `RUST_LOG=debug cargo t one_test_only`
let _ = LOGGER.call_once(|| {
tracer::fmt()
.with_env_filter(tracer::EnvFilter::from_default_env())
.init()
});
let mut resolver = StateResolution::default();
let store = TestStore(RefCell::new(
INITIAL_EVENTS()
.values()
.chain(events)
.map(|ev| (ev.event_id().unwrap().clone(), ev.clone()))
.collect(),
));
// This will be lexi_topo_sorted for resolution
let mut graph = BTreeMap::new();
// this is the same as in `resolve` event_id -> StateEvent
let mut fake_event_map = BTreeMap::new();
// create the DB of events that led up to this point
// TODO maybe clean up some of these clones it is just tests but...
for ev in INITIAL_EVENTS().values().chain(events) {
graph.insert(ev.event_id().unwrap().clone(), vec![]);
fake_event_map.insert(ev.event_id().unwrap().clone(), ev.clone());
}
for pair in INITIAL_EDGES().windows(2) {
if let [a, b] = &pair {
graph.entry(a.clone()).or_insert(vec![]).push(b.clone());
}
}
for edge_list in edges {
for pair in edge_list.windows(2) {
if let [a, b] = &pair {
graph.entry(a.clone()).or_insert(vec![]).push(b.clone());
}
}
}
// event_id -> StateEvent
let mut event_map: BTreeMap<EventId, StateEvent> = BTreeMap::new();
// event_id -> StateMap<EventId>
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new();
// resolve the current state and add it to the state_at_event map then continue
// on in "time"
for node in
resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone())))
{
let fake_event = fake_event_map.get(&node).unwrap();
let event_id = fake_event.event_id().unwrap();
let prev_events = graph.get(&node).unwrap();
let state_before: StateMap<EventId> = if prev_events.is_empty() {
BTreeMap::new()
} else if prev_events.len() == 1 {
state_at_event.get(&prev_events[0]).unwrap().clone()
} else {
let state_sets = prev_events
.iter()
.filter_map(|k| state_at_event.get(k))
.cloned()
.collect::<Vec<_>>();
tracing::info!(
"{:#?}",
state_sets
.iter()
.map(|map| map
.iter()
.map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id))
.collect::<Vec<_>>())
.collect::<Vec<_>>()
);
let resolved = resolver.resolve(
&room_id(),
&RoomVersionId::version_1(),
&state_sets,
Some(event_map.clone()),
&store,
);
match resolved {
Ok(ResolutionResult::Resolved(state)) => state,
Ok(ResolutionResult::Conflicted(state)) => panic!(
"conflicted: {:?}",
state
.iter()
.map(|map| map
.iter()
.map(|(key, id)| (key, id.to_string()))
.collect::<Vec<_>>())
.collect::<Vec<_>>()
),
Err(e) => panic!("resolution for {} failed: {}", node, e),
}
};
let mut state_after = state_before.clone();
if fake_event.state_key().is_some() {
let ty = fake_event.kind().clone();
// we know there is a state_key unwrap OK
let key = fake_event.state_key().unwrap().clone();
state_after.insert((ty, key), event_id.clone());
}
let auth_types = state_res::auth_types_for_event(fake_event);
let mut auth_events = vec![];
for key in auth_types {
if state_before.contains_key(&key) {
auth_events.push(state_before[&key].clone())
}
}
// TODO The event is just remade, adding the auth_events and prev_events here
// UPDATE: the `to_pdu_event` was split into `init` and the fn below, could be better
let e = fake_event;
let ev_id = e.event_id().unwrap();
let event = to_pdu_event(
&e.event_id().unwrap().to_string(),
e.sender().clone(),
e.kind(),
e.state_key().as_deref(),
e.content().clone(),
&auth_events,
prev_events,
);
// we have to update our store, an actual user of this lib would
// be giving us state from a DB.
//
// TODO
// TODO we need to convert the `StateResolution::resolve` to use the event_map
// because the user of this crate cannot update their DB's state.
*store.0.borrow_mut().get_mut(ev_id).unwrap() = event.clone();
state_at_event.insert(node, state_after);
event_map.insert(event_id.clone(), event);
}
let mut expected_state = BTreeMap::new();
for node in expected_state_ids {
let ev = event_map.get(&node).expect(&format!(
"{} not found in {:?}",
node.to_string(),
event_map
.keys()
.map(ToString::to_string)
.collect::<Vec<_>>(),
));
let key = (ev.kind(), ev.state_key().unwrap());
expected_state.insert(key, node);
}
let start_state = state_at_event.get(&event_id("$START:foo")).unwrap();
let end_state = state_at_event
.get(&event_id("$END:foo"))
.unwrap()
.iter()
.filter(|(k, v)| expected_state.contains_key(k) || start_state.get(k) != Some(*v))
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<StateMap<EventId>>();
assert_eq!(expected_state, end_state);
}
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
#[allow(unused)]
impl StateStore for TestStore {
fn get_events(&self, events: &[EventId]) -> Result<Vec<StateEvent>, String> {
Ok(self
.0
.borrow()
.iter()
.filter(|e| events.contains(e.0))
.map(|(_, s)| s)
.cloned()
.collect())
}
fn get_event(&self, event_id: &EventId) -> Result<StateEvent, String> {
self.0
.borrow()
.get(event_id)
.cloned()
.ok_or(format!("{} not found", event_id.to_string()))
}
fn auth_event_ids(
&self,
room_id: &RoomId,
event_ids: &[EventId],
) -> Result<Vec<EventId>, String> {
let mut result = vec![];
let mut stack = event_ids.to_vec();
// DFS for auth event chain
while !stack.is_empty() {
let ev_id = stack.pop().unwrap();
if result.contains(&ev_id) {
continue;
}
result.push(ev_id.clone());
let event = self.get_event(&ev_id).unwrap();
stack.extend(event.auth_events());
}
Ok(result)
}
fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<EventId>>,
) -> Result<Vec<EventId>, String> {
use itertools::Itertools;
let mut chains = vec![];
for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = self
.auth_event_ids(room_id, &ids)?
.into_iter()
.collect::<BTreeSet<_>>();
chains.push(chain);
}
if let Some(chain) = chains.first() {
let rest = chains.iter().skip(1).flatten().cloned().collect();
let common = chain.intersection(&rest).collect::<Vec<_>>();
Ok(chains
.iter()
.flatten()
.filter(|id| !common.contains(&id))
.cloned()
.collect::<BTreeSet<_>>()
.into_iter()
.collect())
} else {
Ok(vec![])
}
}
}
fn event_id(id: &str) -> EventId {
if id.contains('$') {
return EventId::try_from(id).unwrap();
}
EventId::try_from(format!("${}:foo", id)).unwrap()
}
fn alice() -> UserId {
UserId::try_from("@alice:foo").unwrap()
}
fn bob() -> UserId {
UserId::try_from("@bob:foo").unwrap()
}
fn charlie() -> UserId {
UserId::try_from("@charlie:foo").unwrap()
}
fn ella() -> UserId {
UserId::try_from("@ella:foo").unwrap()
}
fn zara() -> UserId {
UserId::try_from("@zara:foo").unwrap()
}
fn room_id() -> RoomId {
RoomId::try_from("!test:foo").unwrap()
}
fn member_content_ban() -> JsonValue {
serde_json::to_value(MemberEventContent {
membership: MembershipState::Ban,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
})
.unwrap()
}
fn member_content_join() -> JsonValue {
serde_json::to_value(MemberEventContent {
membership: MembershipState::Join,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
})
.unwrap()
}
fn to_pdu_event<S>(
id: &str,
sender: UserId,
ev_type: EventType,
state_key: Option<&str>,
content: JsonValue,
auth_events: &[S],
prev_events: &[S],
) -> StateEvent
where
S: AsRef<str>,
{
let ts = unsafe {
let ts = SERVER_TIMESTAMP;
// increment the "origin_server_ts" value
SERVER_TIMESTAMP += 1;
ts
};
let id = if id.contains('$') {
id.to_string()
} else {
format!("${}:foo", id)
};
let auth_events = auth_events
.iter()
.map(AsRef::as_ref)
.map(event_id)
.collect::<Vec<_>>();
let prev_events = prev_events
.iter()
.map(AsRef::as_ref)
.map(event_id)
.collect::<Vec<_>>();
let json = if let Some(state_key) = state_key {
json!({
"auth_events": auth_events,
"prev_events": prev_events,
"event_id": id,
"sender": sender,
"type": ev_type,
"state_key": state_key,
"content": content,
"origin_server_ts": ts,
"room_id": room_id(),
"origin": "foo",
"depth": 0,
"hashes": { "sha256": "hello" },
"signatures": {},
})
} else {
json!({
"auth_events": auth_events,
"prev_events": prev_events,
"event_id": id,
"sender": sender,
"type": ev_type,
"content": content,
"origin_server_ts": ts,
"room_id": room_id(),
"origin": "foo",
"depth": 0,
"hashes": { "sha256": "hello" },
"signatures": {},
})
};
serde_json::from_value(json).unwrap()
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
// this is always called so we can init the logger here
let _ = LOGGER.call_once(|| {
tracer::fmt()
.with_env_filter(tracer::EnvFilter::from_default_env())
.init()
});
vec![
to_pdu_event::<EventId>(
"CREATE",
alice(),
EventType::RoomCreate,
Some(""),
json!({ "creator": alice() }),
&[],
&[],
),
to_pdu_event(
"IMA",
alice(),
EventType::RoomMember,
Some(alice().to_string().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
),
to_pdu_event(
"IPOWER",
alice(),
EventType::RoomPowerLevels,
Some(""),
json!({"users": {alice().to_string(): 100}}),
&["CREATE", "IMA"],
&["IMA"],
),
to_pdu_event(
"IJR",
alice(),
EventType::RoomJoinRules,
Some(""),
json!({ "join_rule": JoinRule::Public }),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
),
to_pdu_event(
"IMB",
bob(),
EventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IJR"],
),
to_pdu_event(
"IMC",
charlie(),
EventType::RoomMember,
Some(charlie().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IMB"],
),
to_pdu_event::<EventId>(
"START",
charlie(),
EventType::RoomMessage,
None,
json!({}),
&[],
&[],
),
to_pdu_event::<EventId>(
"END",
charlie(),
EventType::RoomMessage,
None,
json!({}),
&[],
&[],
),
]
.into_iter()
.map(|ev| (ev.event_id().unwrap().clone(), ev))
.collect()
}
#[allow(non_snake_case)]
fn INITIAL_EDGES() -> Vec<EventId> {
vec!["START", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>()
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn BAN_STATE_SET() -> BTreeMap<EventId, StateEvent> {
vec![
to_pdu_event(
"PA",
alice(),
EventType::RoomPowerLevels,
Some(""),
json!({"users": {alice(): 100, bob(): 50}}),
&["CREATE", "IMA", "IPOWER"], // auth_events
&["START"], // prev_events
),
to_pdu_event(
"PB",
alice(),
EventType::RoomPowerLevels,
Some(""),
json!({"users": {alice(): 100, bob(): 50}}),
&["CREATE", "IMA", "IPOWER"],
&["END"],
),
to_pdu_event(
"MB",
alice(),
EventType::RoomMember,
Some(ella().as_str()),
member_content_ban(),
&["CREATE", "IMA", "PB"],
&["PA"],
),
to_pdu_event(
"IME",
ella(),
EventType::RoomMember,
Some(ella().as_str()),
member_content_join(),
&["CREATE", "IJR", "PA"],
&["MB"],
),
]
.into_iter()
.map(|ev| (ev.event_id().unwrap().clone(), ev))
.collect()
}
#[test]
fn ban_with_auth_chains() {
let ban = BAN_STATE_SET();
let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PA", "MB"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(
&ban.values().cloned().collect::<Vec<_>>(),
edges,
expected_state_ids,
);
}
#[test]
fn base_with_auth_chains() {
let mut resolver = StateResolution::default();
let store = TestStore(RefCell::new(INITIAL_EVENTS()));
let resolved: BTreeMap<_, EventId> =
match resolver.resolve(&room_id(), &RoomVersionId::version_2(), &[], None, &store) {
Ok(ResolutionResult::Resolved(state)) => state,
Err(e) => panic!("{}", e),
_ => panic!("conflicted state left"),
};
let resolved = resolved
.values()
.cloned()
.chain(
INITIAL_EVENTS()
.values()
.map(|e| e.event_id().unwrap().clone()),
)
.collect::<Vec<_>>();
let expected = vec![
"$CREATE:foo",
"$IJR:foo",
"$IPOWER:foo",
"$IMA:foo",
"$IMB:foo",
"$IMC:foo",
"START",
"END",
];
for id in expected.iter().map(|i| event_id(i)) {
// make sure our resolved events are equall to the expected list
assert!(resolved.iter().any(|eid| eid == &id), "{}", id)
}
assert_eq!(expected.len(), resolved.len())
}
#[test]
fn ban_with_auth_chains2() {
let mut resolver = StateResolution::default();
let init = INITIAL_EVENTS();
let ban = BAN_STATE_SET();
let mut inner = init.clone();
inner.extend(ban);
let store = TestStore(RefCell::new(inner.clone()));
let state_set_a = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("MB")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.kind(), ev.state_key().unwrap()),
ev.event_id().unwrap().clone(),
)
})
.collect::<BTreeMap<_, _>>();
let state_set_b = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("IME")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.kind(), ev.state_key().unwrap()),
ev.event_id().unwrap().clone(),
)
})
.collect::<BTreeMap<_, _>>();
let resolved: BTreeMap<_, EventId> = match resolver.resolve(
&room_id(),
&RoomVersionId::version_2(),
&[state_set_a, state_set_b],
None,
&store,
) {
Ok(ResolutionResult::Resolved(state)) => state,
Err(e) => panic!("{}", e),
_ => panic!("conflicted state left"),
};
tracing::debug!(
"{:#?}",
resolved
.iter()
.map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id))
.collect::<Vec<_>>()
);
let expected = vec![
"$CREATE:foo",
"$IJR:foo",
"$PA:foo",
"$IMA:foo",
"$IMB:foo",
"$IMC:foo",
"$MB:foo",
];
for id in expected.iter().map(|i| event_id(i)) {
// make sure our resolved events are equall to the expected list
assert!(
resolved.values().any(|eid| eid == &id) || init.contains_key(&id),
"{}",
id
)
}
assert_eq!(expected.len(), resolved.len())
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn JOIN_RULE() -> BTreeMap<EventId, StateEvent> {
vec![
to_pdu_event(
"JR",
alice(),
EventType::RoomJoinRules,
Some(""),
json!({"join_rule": "invite"}),
&["CREATE", "IMA", "IPOWER"],
&["START"],
),
to_pdu_event(
"IMZ",
zara(),
EventType::RoomPowerLevels,
Some(zara().as_str()),
member_content_join(),
&["CREATE", "JR", "IPOWER"],
&["START"],
),
]
.into_iter()
.map(|ev| (ev.event_id().unwrap().clone(), ev))
.collect()
}
#[test]
fn join_rule_with_auth_chain() {
let join_rule = JOIN_RULE();
let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(
&join_rule.values().cloned().collect::<Vec<_>>(),
edges,
expected_state_ids,
);
}

View File

@ -1,3 +1,5 @@
#![allow(clippy::or_fun_call, clippy::expect_fun_call)]
use std::{ use std::{
cell::RefCell, cell::RefCell,
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
@ -27,7 +29,7 @@ static LOGGER: Once = Once::new();
static mut SERVER_TIMESTAMP: i32 = 0; static mut SERVER_TIMESTAMP: i32 = 0;
fn event_id(id: &str) -> EventId { fn event_id(id: &str) -> EventId {
if id.contains("$") { if id.contains('$') {
return EventId::try_from(id).unwrap(); return EventId::try_from(id).unwrap();
} }
EventId::try_from(format!("${}:foo", id)).unwrap() EventId::try_from(format!("${}:foo", id)).unwrap()
@ -92,7 +94,7 @@ where
SERVER_TIMESTAMP += 1; SERVER_TIMESTAMP += 1;
ts ts
}; };
let id = if id.contains("$") { let id = if id.contains('$') {
id.to_string() id.to_string()
} else { } else {
format!("${}:foo", id) format!("${}:foo", id)
@ -100,33 +102,13 @@ where
let auth_events = auth_events let auth_events = auth_events
.iter() .iter()
.map(AsRef::as_ref) .map(AsRef::as_ref)
.map(|s| { .map(event_id)
EventId::try_from( .collect::<Vec<_>>();
if s.contains("$") {
s.to_owned()
} else {
format!("${}:foo", s)
}
.as_str(),
)
})
.collect::<Result<Vec<_>, _>>()
.unwrap();
let prev_events = prev_events let prev_events = prev_events
.iter() .iter()
.map(AsRef::as_ref) .map(AsRef::as_ref)
.map(|s| { .map(event_id)
EventId::try_from( .collect::<Vec<_>>();
if s.contains("$") {
s.to_owned()
} else {
format!("${}:foo", s)
}
.as_str(),
)
})
.collect::<Result<Vec<_>, _>>()
.unwrap();
let json = if let Some(state_key) = state_key { let json = if let Some(state_key) = state_key {
json!({ json!({
@ -176,7 +158,7 @@ fn to_init_pdu_event(
SERVER_TIMESTAMP += 1; SERVER_TIMESTAMP += 1;
ts ts
}; };
let id = if id.contains("$") { let id = if id.contains('$') {
id.to_string() id.to_string()
} else { } else {
format!("${}:foo", id) format!("${}:foo", id)
@ -319,14 +301,14 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
} }
for pair in INITIAL_EDGES().windows(2) { for pair in INITIAL_EDGES().windows(2) {
if let &[a, b] = &pair { if let [a, b] = &pair {
graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); graph.entry(a.clone()).or_insert(vec![]).push(b.clone());
} }
} }
for edge_list in edges { for edge_list in edges {
for pair in edge_list.windows(2) { for pair in edge_list.windows(2) {
if let &[a, b] = &pair { if let [a, b] = &pair {
graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); graph.entry(a.clone()).or_insert(vec![]).push(b.clone());
} }
} }
@ -338,10 +320,9 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new(); let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new();
// resolve the current state and add it to the state_at_event map then continue // resolve the current state and add it to the state_at_event map then continue
// on in "time"? // on in "time"
for node in resolver for node in
// TODO is this `key_fn` return correct ?? resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone())))
.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone())))
{ {
let fake_event = fake_event_map.get(&node).unwrap(); let fake_event = fake_event_map.get(&node).unwrap();
let event_id = fake_event.event_id().unwrap(); let event_id = fake_event.event_id().unwrap();
@ -359,6 +340,17 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
tracing::warn!(
"{:#?}",
state_sets
.iter()
.map(|map| map
.iter()
.map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id))
.collect::<Vec<_>>())
.collect::<Vec<_>>()
);
let resolved = resolver.resolve( let resolved = resolver.resolve(
&room_id(), &room_id(),
&RoomVersionId::version_1(), &RoomVersionId::version_1(),
@ -791,7 +783,8 @@ impl StateStore for TestStore {
result.push(ev_id.clone()); result.push(ev_id.clone());
let event = self.get_event(&ev_id).unwrap(); let event = self.get_event(&ev_id).unwrap();
stack.extend(event.auth_event_ids());
stack.extend(event.auth_events());
} }
Ok(result) Ok(result)
@ -902,7 +895,7 @@ impl TestStore {
EventType::RoomMember, EventType::RoomMember,
Some(charlie().to_string().as_str()), Some(charlie().to_string().as_str()),
member_content_join(), member_content_join(),
&[cre.clone(), join_rules.event_id().unwrap().clone()], &[cre, join_rules.event_id().unwrap().clone()],
&[join_rules.event_id().unwrap().clone()], &[join_rules.event_id().unwrap().clone()],
); );
self.0 self.0
@ -913,7 +906,7 @@ impl TestStore {
.iter() .iter()
.map(|e| { .map(|e| {
( (
(e.kind(), e.state_key().unwrap().clone()), (e.kind(), e.state_key().unwrap()),
e.event_id().unwrap().clone(), e.event_id().unwrap().clone(),
) )
}) })
@ -923,7 +916,7 @@ impl TestStore {
.iter() .iter()
.map(|e| { .map(|e| {
( (
(e.kind(), e.state_key().unwrap().clone()), (e.kind(), e.state_key().unwrap()),
e.event_id().unwrap().clone(), e.event_id().unwrap().clone(),
) )
}) })
@ -939,7 +932,7 @@ impl TestStore {
.iter() .iter()
.map(|e| { .map(|e| {
( (
(e.kind(), e.state_key().unwrap().clone()), (e.kind(), e.state_key().unwrap()),
e.event_id().unwrap().clone(), e.event_id().unwrap().clone(),
) )
}) })