diff --git a/Cargo.toml b/Cargo.toml index 8da3bb0b..53d0c68e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0.114", features = ["derive"] } serde_json = "1.0.56" tracing = "0.1.16" maplit = "1.0.2" +thiserror = "1.0.20" tracing-subscriber = "0.2.8" [dependencies.ruma] diff --git a/README.md b/README.md index b1365819..f751df51 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ -Would it be possible to abstract state res into a `ruma-state-res` crate? I've been thinking about something along the lines of +### Matrix state resolution in rust! + ```rust /// StateMap is just a wrapper/deserialize target for a PDU. struct StateEvent { @@ -41,3 +42,9 @@ trait StateStore { } ``` + + + +The `StateStore` trait is an abstraction around what ever database your server (or maybe even client) uses to store __P__[]()ersistant __D__[]()ata __U__[]()nits. + +We use `ruma`s types when deserializing any PDU or it's contents which helps avoid a lot of type checking logic [synapse](https://github.com/matrix-org/synapse) must do while authenticating event chains. \ No newline at end of file diff --git a/benches/state_bench.rs b/benches/state_bench.rs index df152bfa..c4fa0e40 100644 --- a/benches/state_bench.rs +++ b/benches/state_bench.rs @@ -70,10 +70,83 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) { }); } -criterion_group!(benches, lexico_topo_sort, resolution_shallow_auth_chain); +fn resolve_deeper_event_set(c: &mut Criterion) { + c.bench_function("resolve state of 10 events 3 conflicting", |b| { + let mut resolver = StateResolution::default(); + + let init = INITIAL_EVENTS(); + let ban = BAN_STATE_SET(); + + let mut inner = init; + inner.extend(ban); + let store = TestStore(RefCell::new(inner.clone())); + + let state_set_a = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("MB")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| { + ( + (ev.kind(), ev.state_key().unwrap()), + ev.event_id().unwrap().clone(), + ) + }) + .collect::>(); + + let state_set_b = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("IME")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| { + ( + (ev.kind(), ev.state_key().unwrap()), + ev.event_id().unwrap().clone(), + ) + }) + .collect::>(); + + b.iter(|| { + let _resolved = match resolver.resolve( + &room_id(), + &RoomVersionId::version_2(), + &[state_set_a.clone(), state_set_b.clone()], + Some(inner.clone()), + &store, + ) { + Ok(ResolutionResult::Resolved(state)) => state, + Err(_) => panic!("resolution failed during benchmarking"), + _ => panic!("resolution failed during benchmarking"), + }; + }) + }); +} + +criterion_group!( + benches, + lexico_topo_sort, + resolution_shallow_auth_chain, + resolve_deeper_event_set +); criterion_main!(benches); +//*///////////////////////////////////////////////////////////////////// +// +// IMPLEMENTATION DETAILS AHEAD +// +/////////////////////////////////////////////////////////////////////*/ pub struct TestStore(RefCell>); #[allow(unused)] @@ -115,7 +188,7 @@ impl StateStore for TestStore { result.push(ev_id.clone()); let event = self.get_event(&ev_id).unwrap(); - stack.extend(event.auth_event_ids()); + stack.extend(event.auth_events()); } Ok(result) @@ -220,7 +293,7 @@ impl TestStore { EventType::RoomMember, Some(charlie().to_string().as_str()), member_content_join(), - &[cre.clone(), join_rules.event_id().unwrap().clone()], + &[cre, join_rules.event_id().unwrap().clone()], &[join_rules.event_id().unwrap().clone()], ); self.0 @@ -231,7 +304,7 @@ impl TestStore { .iter() .map(|e| { ( - (e.kind(), e.state_key().unwrap().clone()), + (e.kind(), e.state_key().unwrap()), e.event_id().unwrap().clone(), ) }) @@ -241,7 +314,7 @@ impl TestStore { .iter() .map(|e| { ( - (e.kind(), e.state_key().unwrap().clone()), + (e.kind(), e.state_key().unwrap()), e.event_id().unwrap().clone(), ) }) @@ -257,7 +330,7 @@ impl TestStore { .iter() .map(|e| { ( - (e.kind(), e.state_key().unwrap().clone()), + (e.kind(), e.state_key().unwrap()), e.event_id().unwrap().clone(), ) }) @@ -268,7 +341,7 @@ impl TestStore { } fn event_id(id: &str) -> EventId { - if id.contains("$") { + if id.contains('$') { return EventId::try_from(id).unwrap(); } EventId::try_from(format!("${}:foo", id)).unwrap() @@ -283,11 +356,25 @@ fn bob() -> UserId { fn charlie() -> UserId { UserId::try_from("@charlie:foo").unwrap() } +fn ella() -> UserId { + UserId::try_from("@ella:foo").unwrap() +} fn room_id() -> RoomId { RoomId::try_from("!test:foo").unwrap() } +fn member_content_ban() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Ban, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + fn member_content_join() -> JsonValue { serde_json::to_value(MemberEventContent { membership: MembershipState::Join, @@ -317,7 +404,7 @@ where SERVER_TIMESTAMP += 1; ts }; - let id = if id.contains("$") { + let id = if id.contains('$') { id.to_string() } else { format!("${}:foo", id) @@ -325,33 +412,13 @@ where let auth_events = auth_events .iter() .map(AsRef::as_ref) - .map(|s| { - EventId::try_from( - if s.contains("$") { - s.to_owned() - } else { - format!("${}:foo", s) - } - .as_str(), - ) - }) - .collect::, _>>() - .unwrap(); + .map(event_id) + .collect::>(); let prev_events = prev_events .iter() .map(AsRef::as_ref) - .map(|s| { - EventId::try_from( - if s.contains("$") { - s.to_owned() - } else { - format!("${}:foo", s) - } - .as_str(), - ) - }) - .collect::, _>>() - .unwrap(); + .map(event_id) + .collect::>(); let json = if let Some(state_key) = state_key { json!({ @@ -387,3 +454,131 @@ where }; serde_json::from_value(json).unwrap() } + +// all graphs start with these input events +#[allow(non_snake_case)] +fn INITIAL_EVENTS() -> BTreeMap { + vec![ + to_pdu_event::( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + &[], + &[], + ), + to_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ), + to_pdu_event( + "IPOWER", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice().to_string(): 100}}), + &["CREATE", "IMA"], + &["IMA"], + ), + to_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ), + to_pdu_event( + "IMB", + bob(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IJR"], + ), + to_pdu_event( + "IMC", + charlie(), + EventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IMB"], + ), + to_pdu_event::( + "START", + charlie(), + EventType::RoomMessage, + None, + json!({}), + &[], + &[], + ), + to_pdu_event::( + "END", + charlie(), + EventType::RoomMessage, + None, + json!({}), + &[], + &[], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .collect() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn BAN_STATE_SET() -> BTreeMap { + vec![ + to_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], // auth_events + &["START"], // prev_events + ), + to_pdu_event( + "PB", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], + &["END"], + ), + to_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_ban(), + &["CREATE", "IMA", "PB"], + &["PA"], + ), + to_pdu_event( + "IME", + ella(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_join(), + &["CREATE", "IJR", "PA"], + &["MB"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .collect() +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..79f91750 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,23 @@ +use std::num::ParseIntError; + +use serde_json::Error as JsonError; +use thiserror::Error; + +/// Result type for state resolution. +pub type Result = std::result::Result; + +/// Represents the various errors that arise when resolving state. +#[derive(Error, Debug)] +pub enum Error { + /// A deserialization error. + #[error(transparent)] + SerdeJson(#[from] JsonError), + + /// An error that occurs when converting from JSON numbers to rust. + #[error(transparent)] + IntParseError(#[from] ParseIntError), + + // TODO remove once the correct errors are used + #[error("an error occured {0}")] + TempString(String), +} diff --git a/src/event_auth.rs b/src/event_auth.rs index 976608e8..8fad3fdd 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -1,5 +1,6 @@ use std::convert::TryFrom; +use maplit::btreeset; use ruma::{ events::{ room::{self, join_rules::JoinRule, member::MembershipState}, @@ -89,7 +90,7 @@ pub fn auth_check( false }; - if !event.signatures().get(sender_domain).is_some() && !is_invite_via_3pid { + if event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid { tracing::info!("event not signed by sender's server"); return Some(false); } @@ -107,6 +108,7 @@ pub fn auth_check( // domain of room_id must match domain of sender. if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) { + tracing::info!("creation events server does not match sender"); return Some(false); // creation events room id does not match senders } @@ -117,7 +119,8 @@ pub fn auth_check( .content() .get("room_version") .cloned() - .unwrap_or(serde_json::json!({})), + // synapse defaults to version 1 + .unwrap_or(serde_json::json!("1")), ) .is_err() { @@ -231,7 +234,7 @@ fn can_federate(auth_events: &StateMap) -> bool { let creation_event = auth_events.get(&(EventType::RoomCreate, "".into())); if let Some(ev) = creation_event { if let Some(fed) = ev.content().get("m.federate") { - fed.to_string() == "true" + fed == "true" } else { false } @@ -468,7 +471,7 @@ fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Opt } if let Some(sk) = event.state_key() { - if sk.starts_with("@") && sk != event.sender().to_string() { + if sk.starts_with('@') && sk != event.sender().as_str() { return Some(false); // permission required to post in this room } } @@ -484,7 +487,13 @@ fn check_power_levels( use itertools::Itertools; let key = (power_event.kind(), power_event.state_key().unwrap()); - let current_state = auth_events.get(&key)?; + + let current_state = if let Some(current_state) = auth_events.get(&key) { + current_state + } else { + // TODO synapse returns here, shouldn't this be an error ?? + return Some(true); + }; let user_content = power_event .deserialize_content::() @@ -493,25 +502,27 @@ fn check_power_levels( .deserialize_content::() .unwrap(); - tracing::info!("validation of power event finished"); // validation of users is done in Ruma, synapse for loops validating user_ids and integers here + tracing::info!("validation of power event finished"); let user_level = get_user_power_level(power_event.sender(), auth_events); - let mut user_levels_to_check = vec![]; + let mut user_levels_to_check = btreeset![]; let old_list = ¤t_content.users; let user_list = &user_content.users; for user in old_list.keys().chain(user_list.keys()).dedup() { let user: &UserId = user; - user_levels_to_check.push(user); + user_levels_to_check.insert(user); } - let mut event_levels_to_check = vec![]; + tracing::debug!("users to check {:?}", user_levels_to_check); + + let mut event_levels_to_check = btreeset![]; let old_list = ¤t_content.events; let new_list = &user_content.events; for ev_id in old_list.keys().chain(new_list.keys()).dedup() { let ev_id: &EventType = ev_id; - event_levels_to_check.push(ev_id); + event_levels_to_check.insert(ev_id); } tracing::debug!("events to check {:?}", event_levels_to_check); @@ -574,9 +585,43 @@ fn check_power_levels( } } + let levels = [ + "users_default", + "events_default", + "state_default", + "ban", + "redact", + "kick", + "invite", + ]; + let old_state = serde_json::to_value(old_state).unwrap(); + let new_state = serde_json::to_value(new_state).unwrap(); + for lvl_name in &levels { + if let Some((old_lvl, new_lvl)) = get_deserialize_levels(&old_state, &new_state, lvl_name) { + let old_level_too_big = old_lvl > user_level; + let new_level_too_big = new_lvl > user_level; + + if old_level_too_big || new_level_too_big { + tracing::info!("cannot add ops > than own"); + return Some(false); + } + } + } + Some(true) } +fn get_deserialize_levels( + old: &serde_json::Value, + new: &serde_json::Value, + name: &str, +) -> Option<(i64, i64)> { + Some(( + serde_json::from_value(old.get(name)?.clone()).ok()?, + serde_json::from_value(new.get(name)?.clone()).ok()?, + )) +} + /// Does the event redacting come from a user with enough power to redact the given event. fn check_redaction( room_version: &RoomVersionId, diff --git a/src/lib.rs b/src/lib.rs index 04f4cef7..e5f9da55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![allow(clippy::or_fun_call)] + use std::{ cmp::Reverse, collections::{BTreeMap, BTreeSet, BinaryHeap}, @@ -11,11 +13,13 @@ use ruma::{ }; use serde::{Deserialize, Serialize}; +mod error; mod event_auth; mod room_version; mod state_event; mod state_store; +pub use error::{Error, Result}; pub use event_auth::{auth_check, auth_types_for_event}; pub use state_event::StateEvent; pub use state_store::StateStore; @@ -66,7 +70,7 @@ impl StateResolution { event_map: Option>, store: &dyn StateStore, // TODO actual error handling (`thiserror`??) - ) -> Result { + ) -> Result { tracing::info!("State resolution starting"); let mut event_map = if let Some(ev_map) = event_map { @@ -76,7 +80,26 @@ impl StateResolution { }; // split non-conflicting and conflicting state let (clean, conflicting) = self.separate(&state_sets); + tracing::debug!( + "CLEAN {:#?}", + clean + .iter() + .map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id)) + .collect::>() + ); + tracing::debug!( + "CONFLICT {:#?}", + conflicting + .iter() + .map(|((ty, key), ids)| format!( + "(({} `{}`), {:?})", + ty, + key, + ids.iter().map(ToString::to_string).collect::>() + )) + .collect::>() + ); tracing::info!("non conflicting {:?}", clean.len()); if conflicting.is_empty() { @@ -124,7 +147,7 @@ impl StateResolution { for event in event_map.values() { if event.room_id() != Some(room_id) { - return Err(format!( + return Err(Error::TempString(format!( "resolving event {} in room {}, when correct room is {}", event .event_id() @@ -132,7 +155,7 @@ impl StateResolution { .unwrap_or("`unknown`"), event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"), room_id.as_str() - )); + ))); } } @@ -153,7 +176,7 @@ impl StateResolution { let mut sorted_power_levels = self.reverse_topological_power_sort( room_id, &power_events, - &mut event_map, + &event_map, // TODO use event_map store, &all_conflicted, ); @@ -172,7 +195,7 @@ impl StateResolution { room_version, &sorted_power_levels, &clean, - &mut event_map, + &event_map, store, )?; @@ -224,7 +247,7 @@ impl StateResolution { room_version, &sorted_left_events, &resolved, - &mut event_map, + &event_map, store, )?; @@ -255,7 +278,8 @@ impl StateResolution { for key in state_sets .iter() .flat_map(|map| map.keys()) - .collect::>() + .dedup() + .collect::>() { let mut event_ids = state_sets .iter() @@ -263,6 +287,14 @@ impl StateResolution { .dedup() .collect::>(); + tracing::debug!( + "SEP {:?}", + event_ids + .iter() + .map(|i| i.map(ToString::to_string).unwrap_or("None".into())) + .collect::>() + ); + if event_ids.len() == 1 { if let Some(Some(id)) = event_ids.pop() { unconflicted_state.insert(key.clone(), id.clone()); @@ -270,6 +302,7 @@ impl StateResolution { panic!() } } else { + tracing::warn!("{:?}", key); conflicted_state.insert( key.clone(), event_ids.into_iter().flatten().cloned().collect::>(), @@ -287,19 +320,21 @@ impl StateResolution { state_sets: &[StateMap], _event_map: &EventMap, store: &dyn StateStore, - ) -> Result, String> { + ) -> Result> { use itertools::Itertools; tracing::debug!("calculating auth chain difference"); - store.auth_chain_diff( - room_id, - state_sets - .iter() - .map(|map| map.values().cloned().collect()) - .dedup() - .collect::>(), - ) + store + .auth_chain_diff( + room_id, + state_sets + .iter() + .map(|map| map.values().cloned().collect()) + .dedup() + .collect::>(), + ) + .map_err(Error::TempString) } pub fn reverse_topological_power_sort( @@ -338,15 +373,20 @@ impl StateResolution { } } - self.lexicographical_topological_sort(&mut graph, |event_id| { + self.lexicographical_topological_sort(&graph, |event_id| { // tracing::debug!("{:?}", event_map.get(event_id).unwrap().origin_server_ts()); let ev = event_map.get(event_id).unwrap(); let pl = event_to_pl.get(event_id).unwrap(); + tracing::warn!( + "{:?}", + (-*pl, *ev.origin_server_ts(), ev.event_id().cloned()) + ); + // This return value is the key used for sorting events, // events are then sorted by power level, time, // and lexically by event_id. - (-*pl, ev.origin_server_ts().clone(), ev.event_id().cloned()) + (-*pl, *ev.origin_server_ts(), ev.event_id().cloned()) }) } @@ -371,8 +411,8 @@ impl StateResolution { // TODO make the BTreeSet conversion cleaner ?? let mut outdegree_map: BTreeMap> = graph - .into_iter() - .map(|(k, v)| (k.clone(), v.into_iter().cloned().collect())) + .iter() + .map(|(k, v)| (k.clone(), v.iter().cloned().collect())) .collect(); let mut reverse_graph = BTreeMap::new(); @@ -432,7 +472,7 @@ impl StateResolution { let mut pl = None; // TODO store.auth_event_ids returns "self" with the event ids is this ok // event.auth_event_ids does not include its own event id ? - for aid in store.get_event(event_id).unwrap().auth_event_ids() { + for aid in store.get_event(event_id).unwrap().auth_events() { if let Ok(aev) = store.get_event(&aid) { if aev.is_type_and_key(EventType::RoomPowerLevels, "") { pl = Some(aev); @@ -442,7 +482,7 @@ impl StateResolution { } if pl.is_none() { - for aid in store.get_event(event_id).unwrap().auth_event_ids() { + for aid in store.get_event(event_id).unwrap().auth_events() { if let Ok(aev) = store.get_event(&aid) { if aev.is_type_and_key(EventType::RoomCreate, "") { if let Ok(content) = aev @@ -487,16 +527,25 @@ impl StateResolution { unconflicted_state: &StateMap, _event_map: &EventMap, // TODO use event_map over store ?? store: &dyn StateStore, - ) -> Result, String> { + ) -> Result> { tracing::info!("starting iterative auth check"); + tracing::debug!( + "{:?}", + power_events + .iter() + .map(ToString::to_string) + .collect::>() + ); + let mut resolved_state = unconflicted_state.clone(); for (idx, event_id) in power_events.iter().enumerate() { + tracing::warn!("POWER EVENTS {}", event_id.as_str()); let event = store.get_event(event_id).unwrap(); let mut auth_events = BTreeMap::new(); - for aid in event.auth_event_ids() { + for aid in event.auth_events() { if let Ok(ev) = store.get_event(&aid) { // TODO what to do when no state_key is found ?? // TODO check "rejected_reason", I'm guessing this is redacted_because for ruma ?? @@ -508,9 +557,8 @@ impl StateResolution { for key in event_auth::auth_types_for_event(&event) { if let Some(ev_id) = resolved_state.get(&key) { - // TODO synapse gets the event from the store then checks its not None - // then pulls the same `ev_id` event from the event_map?? if let Ok(event) = store.get_event(ev_id) { + // TODO synapse checks `rejected_reason` is None here auth_events.insert(key.clone(), event); } } @@ -518,7 +566,8 @@ impl StateResolution { tracing::debug!("event to check {:?}", event.event_id().unwrap().to_string()); if event_auth::auth_check(room_version, &event, auth_events, false) - .ok_or("Auth check failed due to deserialization most likely".to_string())? + .ok_or("Auth check failed due to deserialization most likely".to_string()) + .map_err(Error::TempString)? { // add event to resolved state map resolved_state.insert((event.kind(), event.state_key().unwrap()), event_id.clone()); @@ -567,7 +616,7 @@ impl StateResolution { // We don't need the actual pl_ev here since we delegate to the store let event = store.get_event(&p).unwrap(); - let auth_events = event.auth_event_ids(); + let auth_events = event.auth_events(); pl = None; for aid in auth_events { let ev = store.get_event(&aid).unwrap(); @@ -635,7 +684,7 @@ impl StateResolution { } } - let auth_events = sort_ev.auth_event_ids(); + let auth_events = sort_ev.auth_events(); event = None; for aid in auth_events { let aev = store.get_event(&aid).unwrap(); @@ -664,7 +713,7 @@ impl StateResolution { graph.entry(eid.clone()).or_insert(vec![]); // prefer the store to event as the store filters dedups the events // otherwise it seems we can loop forever - for aid in store.get_event(&eid).unwrap().auth_event_ids() { + for aid in store.get_event(&eid).unwrap().auth_events() { if auth_diff.contains(&aid) { if !graph.contains_key(&aid) { state.push(aid.clone()); diff --git a/src/state_event.rs b/src/state_event.rs index 897834f6..f2f8cd95 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -173,29 +173,29 @@ impl StateEvent { pub fn prev_event_ids(&self) -> Vec { match self { Self::Full(ev) => match ev { - Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().cloned().collect(), + Pdu::RoomV1Pdu(ev) => ev.prev_events.to_vec(), Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(), }, Self::Sync(ev) => match ev { PduStub::RoomV1PduStub(ev) => { ev.prev_events.iter().map(|(id, _)| id).cloned().collect() } - PduStub::RoomV3PduStub(ev) => ev.prev_events.clone(), + PduStub::RoomV3PduStub(ev) => ev.prev_events.to_vec(), }, } } - pub fn auth_event_ids(&self) -> Vec { + pub fn auth_events(&self) -> Vec { match self { Self::Full(ev) => match ev { - Pdu::RoomV1Pdu(ev) => ev.auth_events.iter().cloned().collect(), - Pdu::RoomV3Pdu(ev) => ev.auth_events.clone(), + Pdu::RoomV1Pdu(ev) => ev.auth_events.to_vec(), + Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(), }, Self::Sync(ev) => match ev { PduStub::RoomV1PduStub(ev) => { ev.auth_events.iter().map(|(id, _)| id).cloned().collect() } - PduStub::RoomV3PduStub(ev) => ev.auth_events.clone(), + PduStub::RoomV3PduStub(ev) => ev.auth_events.to_vec(), }, } } diff --git a/tests/auth_ids.rs b/tests/auth_ids.rs new file mode 100644 index 00000000..45bbd1da --- /dev/null +++ b/tests/auth_ids.rs @@ -0,0 +1,741 @@ +#![allow(clippy::or_fun_call, clippy::expect_fun_call)] + +use std::{ + cell::RefCell, + collections::{BTreeMap, BTreeSet}, + convert::TryFrom, + sync::Once, + time::UNIX_EPOCH, +}; + +use ruma::{ + events::{ + room::{ + join_rules::JoinRule, + member::{MemberEventContent, MembershipState}, + }, + EventType, + }, + identifiers::{EventId, RoomId, RoomVersionId, UserId}, +}; +use serde_json::{json, Value as JsonValue}; +use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; +use tracing_subscriber as tracer; + +static LOGGER: Once = Once::new(); + +static mut SERVER_TIMESTAMP: i32 = 0; + +fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: Vec) { + // to activate logging use `RUST_LOG=debug cargo t one_test_only` + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); + + let mut resolver = StateResolution::default(); + + let store = TestStore(RefCell::new( + INITIAL_EVENTS() + .values() + .chain(events) + .map(|ev| (ev.event_id().unwrap().clone(), ev.clone())) + .collect(), + )); + + // This will be lexi_topo_sorted for resolution + let mut graph = BTreeMap::new(); + // this is the same as in `resolve` event_id -> StateEvent + let mut fake_event_map = BTreeMap::new(); + + // create the DB of events that led up to this point + // TODO maybe clean up some of these clones it is just tests but... + for ev in INITIAL_EVENTS().values().chain(events) { + graph.insert(ev.event_id().unwrap().clone(), vec![]); + fake_event_map.insert(ev.event_id().unwrap().clone(), ev.clone()); + } + + for pair in INITIAL_EDGES().windows(2) { + if let [a, b] = &pair { + graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); + } + } + + for edge_list in edges { + for pair in edge_list.windows(2) { + if let [a, b] = &pair { + graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); + } + } + } + + // event_id -> StateEvent + let mut event_map: BTreeMap = BTreeMap::new(); + // event_id -> StateMap + let mut state_at_event: BTreeMap> = BTreeMap::new(); + + // resolve the current state and add it to the state_at_event map then continue + // on in "time" + for node in + resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))) + { + let fake_event = fake_event_map.get(&node).unwrap(); + let event_id = fake_event.event_id().unwrap(); + + let prev_events = graph.get(&node).unwrap(); + + let state_before: StateMap = if prev_events.is_empty() { + BTreeMap::new() + } else if prev_events.len() == 1 { + state_at_event.get(&prev_events[0]).unwrap().clone() + } else { + let state_sets = prev_events + .iter() + .filter_map(|k| state_at_event.get(k)) + .cloned() + .collect::>(); + + tracing::info!( + "{:#?}", + state_sets + .iter() + .map(|map| map + .iter() + .map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id)) + .collect::>()) + .collect::>() + ); + + let resolved = resolver.resolve( + &room_id(), + &RoomVersionId::version_1(), + &state_sets, + Some(event_map.clone()), + &store, + ); + match resolved { + Ok(ResolutionResult::Resolved(state)) => state, + Ok(ResolutionResult::Conflicted(state)) => panic!( + "conflicted: {:?}", + state + .iter() + .map(|map| map + .iter() + .map(|(key, id)| (key, id.to_string())) + .collect::>()) + .collect::>() + ), + Err(e) => panic!("resolution for {} failed: {}", node, e), + } + }; + + let mut state_after = state_before.clone(); + + if fake_event.state_key().is_some() { + let ty = fake_event.kind().clone(); + // we know there is a state_key unwrap OK + let key = fake_event.state_key().unwrap().clone(); + state_after.insert((ty, key), event_id.clone()); + } + + let auth_types = state_res::auth_types_for_event(fake_event); + + let mut auth_events = vec![]; + for key in auth_types { + if state_before.contains_key(&key) { + auth_events.push(state_before[&key].clone()) + } + } + + // TODO The event is just remade, adding the auth_events and prev_events here + // UPDATE: the `to_pdu_event` was split into `init` and the fn below, could be better + let e = fake_event; + let ev_id = e.event_id().unwrap(); + let event = to_pdu_event( + &e.event_id().unwrap().to_string(), + e.sender().clone(), + e.kind(), + e.state_key().as_deref(), + e.content().clone(), + &auth_events, + prev_events, + ); + // we have to update our store, an actual user of this lib would + // be giving us state from a DB. + // + // TODO + // TODO we need to convert the `StateResolution::resolve` to use the event_map + // because the user of this crate cannot update their DB's state. + *store.0.borrow_mut().get_mut(ev_id).unwrap() = event.clone(); + + state_at_event.insert(node, state_after); + event_map.insert(event_id.clone(), event); + } + + let mut expected_state = BTreeMap::new(); + for node in expected_state_ids { + let ev = event_map.get(&node).expect(&format!( + "{} not found in {:?}", + node.to_string(), + event_map + .keys() + .map(ToString::to_string) + .collect::>(), + )); + + let key = (ev.kind(), ev.state_key().unwrap()); + + expected_state.insert(key, node); + } + + let start_state = state_at_event.get(&event_id("$START:foo")).unwrap(); + + let end_state = state_at_event + .get(&event_id("$END:foo")) + .unwrap() + .iter() + .filter(|(k, v)| expected_state.contains_key(k) || start_state.get(k) != Some(*v)) + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + assert_eq!(expected_state, end_state); +} +pub struct TestStore(RefCell>); + +#[allow(unused)] +impl StateStore for TestStore { + fn get_events(&self, events: &[EventId]) -> Result, String> { + Ok(self + .0 + .borrow() + .iter() + .filter(|e| events.contains(e.0)) + .map(|(_, s)| s) + .cloned() + .collect()) + } + + fn get_event(&self, event_id: &EventId) -> Result { + self.0 + .borrow() + .get(event_id) + .cloned() + .ok_or(format!("{} not found", event_id.to_string())) + } + + fn auth_event_ids( + &self, + room_id: &RoomId, + event_ids: &[EventId], + ) -> Result, String> { + let mut result = vec![]; + let mut stack = event_ids.to_vec(); + + // DFS for auth event chain + while !stack.is_empty() { + let ev_id = stack.pop().unwrap(); + if result.contains(&ev_id) { + continue; + } + + result.push(ev_id.clone()); + + let event = self.get_event(&ev_id).unwrap(); + stack.extend(event.auth_events()); + } + + Ok(result) + } + + fn auth_chain_diff( + &self, + room_id: &RoomId, + event_ids: Vec>, + ) -> Result, String> { + use itertools::Itertools; + + let mut chains = vec![]; + for ids in event_ids { + // TODO state store `auth_event_ids` returns self in the event ids list + // when an event returns `auth_event_ids` self is not contained + let chain = self + .auth_event_ids(room_id, &ids)? + .into_iter() + .collect::>(); + chains.push(chain); + } + + if let Some(chain) = chains.first() { + let rest = chains.iter().skip(1).flatten().cloned().collect(); + let common = chain.intersection(&rest).collect::>(); + + Ok(chains + .iter() + .flatten() + .filter(|id| !common.contains(&id)) + .cloned() + .collect::>() + .into_iter() + .collect()) + } else { + Ok(vec![]) + } + } +} + +fn event_id(id: &str) -> EventId { + if id.contains('$') { + return EventId::try_from(id).unwrap(); + } + EventId::try_from(format!("${}:foo", id)).unwrap() +} + +fn alice() -> UserId { + UserId::try_from("@alice:foo").unwrap() +} +fn bob() -> UserId { + UserId::try_from("@bob:foo").unwrap() +} +fn charlie() -> UserId { + UserId::try_from("@charlie:foo").unwrap() +} +fn ella() -> UserId { + UserId::try_from("@ella:foo").unwrap() +} +fn zara() -> UserId { + UserId::try_from("@zara:foo").unwrap() +} + +fn room_id() -> RoomId { + RoomId::try_from("!test:foo").unwrap() +} + +fn member_content_ban() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Ban, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +fn member_content_join() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +fn to_pdu_event( + id: &str, + sender: UserId, + ev_type: EventType, + state_key: Option<&str>, + content: JsonValue, + auth_events: &[S], + prev_events: &[S], +) -> StateEvent +where + S: AsRef, +{ + let ts = unsafe { + let ts = SERVER_TIMESTAMP; + // increment the "origin_server_ts" value + SERVER_TIMESTAMP += 1; + ts + }; + let id = if id.contains('$') { + id.to_string() + } else { + format!("${}:foo", id) + }; + let auth_events = auth_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + let prev_events = prev_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + + let json = if let Some(state_key) = state_key { + json!({ + "auth_events": auth_events, + "prev_events": prev_events, + "event_id": id, + "sender": sender, + "type": ev_type, + "state_key": state_key, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + } else { + json!({ + "auth_events": auth_events, + "prev_events": prev_events, + "event_id": id, + "sender": sender, + "type": ev_type, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + }; + serde_json::from_value(json).unwrap() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn INITIAL_EVENTS() -> BTreeMap { + // this is always called so we can init the logger here + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); + + vec![ + to_pdu_event::( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + &[], + &[], + ), + to_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ), + to_pdu_event( + "IPOWER", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice().to_string(): 100}}), + &["CREATE", "IMA"], + &["IMA"], + ), + to_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ), + to_pdu_event( + "IMB", + bob(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IJR"], + ), + to_pdu_event( + "IMC", + charlie(), + EventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IMB"], + ), + to_pdu_event::( + "START", + charlie(), + EventType::RoomMessage, + None, + json!({}), + &[], + &[], + ), + to_pdu_event::( + "END", + charlie(), + EventType::RoomMessage, + None, + json!({}), + &[], + &[], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .collect() +} + +#[allow(non_snake_case)] +fn INITIAL_EDGES() -> Vec { + vec!["START", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] + .into_iter() + .map(event_id) + .collect::>() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn BAN_STATE_SET() -> BTreeMap { + vec![ + to_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], // auth_events + &["START"], // prev_events + ), + to_pdu_event( + "PB", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + &["CREATE", "IMA", "IPOWER"], + &["END"], + ), + to_pdu_event( + "MB", + alice(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_ban(), + &["CREATE", "IMA", "PB"], + &["PA"], + ), + to_pdu_event( + "IME", + ella(), + EventType::RoomMember, + Some(ella().as_str()), + member_content_join(), + &["CREATE", "IJR", "PA"], + &["MB"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .collect() +} + +#[test] +fn ban_with_auth_chains() { + let ban = BAN_STATE_SET(); + + let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PA", "MB"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check( + &ban.values().cloned().collect::>(), + edges, + expected_state_ids, + ); +} + +#[test] +fn base_with_auth_chains() { + let mut resolver = StateResolution::default(); + + let store = TestStore(RefCell::new(INITIAL_EVENTS())); + + let resolved: BTreeMap<_, EventId> = + match resolver.resolve(&room_id(), &RoomVersionId::version_2(), &[], None, &store) { + Ok(ResolutionResult::Resolved(state)) => state, + Err(e) => panic!("{}", e), + _ => panic!("conflicted state left"), + }; + + let resolved = resolved + .values() + .cloned() + .chain( + INITIAL_EVENTS() + .values() + .map(|e| e.event_id().unwrap().clone()), + ) + .collect::>(); + + let expected = vec![ + "$CREATE:foo", + "$IJR:foo", + "$IPOWER:foo", + "$IMA:foo", + "$IMB:foo", + "$IMC:foo", + "START", + "END", + ]; + for id in expected.iter().map(|i| event_id(i)) { + // make sure our resolved events are equall to the expected list + assert!(resolved.iter().any(|eid| eid == &id), "{}", id) + } + assert_eq!(expected.len(), resolved.len()) +} + +#[test] +fn ban_with_auth_chains2() { + let mut resolver = StateResolution::default(); + + let init = INITIAL_EVENTS(); + let ban = BAN_STATE_SET(); + + let mut inner = init.clone(); + inner.extend(ban); + let store = TestStore(RefCell::new(inner.clone())); + + let state_set_a = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("MB")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| { + ( + (ev.kind(), ev.state_key().unwrap()), + ev.event_id().unwrap().clone(), + ) + }) + .collect::>(); + + let state_set_b = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("IME")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| { + ( + (ev.kind(), ev.state_key().unwrap()), + ev.event_id().unwrap().clone(), + ) + }) + .collect::>(); + + let resolved: BTreeMap<_, EventId> = match resolver.resolve( + &room_id(), + &RoomVersionId::version_2(), + &[state_set_a, state_set_b], + None, + &store, + ) { + Ok(ResolutionResult::Resolved(state)) => state, + Err(e) => panic!("{}", e), + _ => panic!("conflicted state left"), + }; + + tracing::debug!( + "{:#?}", + resolved + .iter() + .map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id)) + .collect::>() + ); + + let expected = vec![ + "$CREATE:foo", + "$IJR:foo", + "$PA:foo", + "$IMA:foo", + "$IMB:foo", + "$IMC:foo", + "$MB:foo", + ]; + + for id in expected.iter().map(|i| event_id(i)) { + // make sure our resolved events are equall to the expected list + assert!( + resolved.values().any(|eid| eid == &id) || init.contains_key(&id), + "{}", + id + ) + } + assert_eq!(expected.len(), resolved.len()) +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn JOIN_RULE() -> BTreeMap { + vec![ + to_pdu_event( + "JR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({"join_rule": "invite"}), + &["CREATE", "IMA", "IPOWER"], + &["START"], + ), + to_pdu_event( + "IMZ", + zara(), + EventType::RoomPowerLevels, + Some(zara().as_str()), + member_content_join(), + &["CREATE", "JR", "IPOWER"], + &["START"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .collect() +} + +#[test] +fn join_rule_with_auth_chain() { + let join_rule = JOIN_RULE(); + + let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::>(); + + do_check( + &join_rule.values().cloned().collect::>(), + edges, + expected_state_ids, + ); +} diff --git a/tests/state_res.rs b/tests/state_res.rs index e36f30ac..06bfbf5c 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -1,3 +1,5 @@ +#![allow(clippy::or_fun_call, clippy::expect_fun_call)] + use std::{ cell::RefCell, collections::{BTreeMap, BTreeSet}, @@ -27,7 +29,7 @@ static LOGGER: Once = Once::new(); static mut SERVER_TIMESTAMP: i32 = 0; fn event_id(id: &str) -> EventId { - if id.contains("$") { + if id.contains('$') { return EventId::try_from(id).unwrap(); } EventId::try_from(format!("${}:foo", id)).unwrap() @@ -92,7 +94,7 @@ where SERVER_TIMESTAMP += 1; ts }; - let id = if id.contains("$") { + let id = if id.contains('$') { id.to_string() } else { format!("${}:foo", id) @@ -100,33 +102,13 @@ where let auth_events = auth_events .iter() .map(AsRef::as_ref) - .map(|s| { - EventId::try_from( - if s.contains("$") { - s.to_owned() - } else { - format!("${}:foo", s) - } - .as_str(), - ) - }) - .collect::, _>>() - .unwrap(); + .map(event_id) + .collect::>(); let prev_events = prev_events .iter() .map(AsRef::as_ref) - .map(|s| { - EventId::try_from( - if s.contains("$") { - s.to_owned() - } else { - format!("${}:foo", s) - } - .as_str(), - ) - }) - .collect::, _>>() - .unwrap(); + .map(event_id) + .collect::>(); let json = if let Some(state_key) = state_key { json!({ @@ -176,7 +158,7 @@ fn to_init_pdu_event( SERVER_TIMESTAMP += 1; ts }; - let id = if id.contains("$") { + let id = if id.contains('$') { id.to_string() } else { format!("${}:foo", id) @@ -319,14 +301,14 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: } for pair in INITIAL_EDGES().windows(2) { - if let &[a, b] = &pair { + if let [a, b] = &pair { graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); } } for edge_list in edges { for pair in edge_list.windows(2) { - if let &[a, b] = &pair { + if let [a, b] = &pair { graph.entry(a.clone()).or_insert(vec![]).push(b.clone()); } } @@ -338,10 +320,9 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: let mut state_at_event: BTreeMap> = BTreeMap::new(); // resolve the current state and add it to the state_at_event map then continue - // on in "time"? - for node in resolver - // TODO is this `key_fn` return correct ?? - .lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))) + // on in "time" + for node in + resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))) { let fake_event = fake_event_map.get(&node).unwrap(); let event_id = fake_event.event_id().unwrap(); @@ -359,6 +340,17 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .cloned() .collect::>(); + tracing::warn!( + "{:#?}", + state_sets + .iter() + .map(|map| map + .iter() + .map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id)) + .collect::>()) + .collect::>() + ); + let resolved = resolver.resolve( &room_id(), &RoomVersionId::version_1(), @@ -791,7 +783,8 @@ impl StateStore for TestStore { result.push(ev_id.clone()); let event = self.get_event(&ev_id).unwrap(); - stack.extend(event.auth_event_ids()); + + stack.extend(event.auth_events()); } Ok(result) @@ -902,7 +895,7 @@ impl TestStore { EventType::RoomMember, Some(charlie().to_string().as_str()), member_content_join(), - &[cre.clone(), join_rules.event_id().unwrap().clone()], + &[cre, join_rules.event_id().unwrap().clone()], &[join_rules.event_id().unwrap().clone()], ); self.0 @@ -913,7 +906,7 @@ impl TestStore { .iter() .map(|e| { ( - (e.kind(), e.state_key().unwrap().clone()), + (e.kind(), e.state_key().unwrap()), e.event_id().unwrap().clone(), ) }) @@ -923,7 +916,7 @@ impl TestStore { .iter() .map(|e| { ( - (e.kind(), e.state_key().unwrap().clone()), + (e.kind(), e.state_key().unwrap()), e.event_id().unwrap().clone(), ) }) @@ -939,7 +932,7 @@ impl TestStore { .iter() .map(|e| { ( - (e.kind(), e.state_key().unwrap().clone()), + (e.kind(), e.state_key().unwrap()), e.event_id().unwrap().clone(), ) })