From 29d86ebf3cdc0cb92df502e927bbc133dfe9fe71 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Fri, 24 Jul 2020 23:14:30 -0400 Subject: [PATCH] Fix separate ignoring missing ids and auth_check details --- Cargo.toml | 5 + benches/state_bench.rs | 15 +++ src/event_auth.rs | 162 ++++++++++++++++++++--------- src/lib.rs | 97 +++++++----------- src/state_event.rs | 27 +++-- tests/state_res.rs | 224 +++++++++++++++++++++++++++++------------ 6 files changed, 355 insertions(+), 175 deletions(-) create mode 100644 benches/state_bench.rs diff --git a/Cargo.toml b/Cargo.toml index cb14f3cc..8da3bb0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,8 @@ features = ["client-api", "federation-api", "appservice-api"] [dev-dependencies] lazy_static = "1.4.0" +criterion = "0.3.3" + +[[bench]] +name = "state_bench" +harness = false \ No newline at end of file diff --git a/benches/state_bench.rs b/benches/state_bench.rs new file mode 100644 index 00000000..37f7ff88 --- /dev/null +++ b/benches/state_bench.rs @@ -0,0 +1,15 @@ +// `cargo bench` works, but if you use `cargo bench -- --save-baseline ` +// or pass any other args to it, it fails with the error +// `cargo bench unknown option --save-baseline`. +// To pass args to criterion, use this form +// `cargo bench --bench -- --save-baseline `. + +use criterion::{criterion_group, criterion_main, Criterion}; + +fn state_res(c: &mut Criterion) { + c.bench_function("resolve state of 10 events", |b| b.iter(|| {})); +} + +criterion_group!(benches, state_res,); + +criterion_main!(benches); diff --git a/src/event_auth.rs b/src/event_auth.rs index 9459d351..976608e8 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -7,6 +7,7 @@ use ruma::{ }, identifiers::{RoomVersionId, UserId}, }; +use serde_json::json; use crate::{room_version::RoomVersion, state_event::StateEvent, StateMap}; @@ -64,17 +65,39 @@ pub fn auth_check( room_version: &RoomVersionId, event: &StateEvent, auth_events: StateMap, + do_sig_check: bool, ) -> Option { tracing::info!("auth_check beginning"); // don't let power from other rooms be used for auth_event in auth_events.values() { if auth_event.room_id() != event.room_id() { + tracing::info!("found auth event that did not match event's room_id"); return Some(false); } } - // TODO do_sig_check, do_size_check is false when called by `iterative_auth_check` + if do_sig_check { + let sender_domain = event.sender().server_name(); + + let is_invite_via_3pid = if event.kind() == EventType::RoomMember { + event + .deserialize_content::() + .map(|c| c.membership == MembershipState::Invite && c.third_party_invite.is_some()) + .unwrap_or_default() + } else { + false + }; + + if !event.signatures().get(sender_domain).is_some() && !is_invite_via_3pid { + tracing::info!("event not signed by sender's server"); + return Some(false); + } + } + + // TODO do_size_check is false when called by `iterative_auth_check` + // do_size_check is also mostly accomplished by ruma with the exception of checking event_type, + // state_key, and json are below a certain size (255 and 65536 respectivly) // Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules // @@ -148,9 +171,8 @@ pub fn auth_check( if event.kind() == EventType::RoomMember { tracing::info!("starting m.room.member check"); - if is_membership_change_allowed(event, &auth_events)? { - tracing::info!("m.room.member membership change was allowed"); - return Some(true); + if !is_membership_change_allowed(event, &auth_events)? { + return Some(false); } tracing::info!("m.room.member event was allowed"); @@ -244,10 +266,11 @@ fn is_membership_change_allowed( let target_user_id = UserId::try_from(event.state_key().unwrap()).ok().unwrap(); // if the server_names are different and federation is NOT allowed - if event.room_id().unwrap().server_name() != target_user_id.server_name() { - if !can_federate(auth_events) { - return Some(false); - } + if event.room_id().unwrap().server_name() != target_user_id.server_name() + && !can_federate(auth_events) + { + tracing::info!("server cannot federate"); + return Some(false); } let key = (EventType::RoomMember, event.sender().to_string()); @@ -264,16 +287,18 @@ fn is_membership_change_allowed( let key = (EventType::RoomJoinRules, "".to_string()); let join_rules_event = auth_events.get(&key); + let mut join_rule = JoinRule::Invite; if let Some(jr) = join_rules_event { join_rule = jr .deserialize_content::() - .ok()? // TODO these are errors? and should be treated as a DB failure? + .ok() + .unwrap() // TODO these are errors? and should be treated as a DB failure? .join_rule; } let user_level = get_user_power_level(event.sender(), auth_events); - let target_level = get_user_power_level(event.sender(), auth_events); + let target_level = get_user_power_level(&target_user_id, auth_events); // synapse has a not "what to do for default here 50" let ban_level = get_named_level(auth_events, "ban", 50); @@ -281,7 +306,7 @@ fn is_membership_change_allowed( // TODO clean this up tracing::debug!( "_is_membership_change_allowed: {}", - serde_json::json!({ + serde_json::to_string_pretty(&json!({ "caller_in_room": caller_in_room, "caller_invited": caller_invited, "target_banned": target_banned, @@ -290,12 +315,34 @@ fn is_membership_change_allowed( "join_rule": join_rule, "target_user_id": target_user_id, "event.user_id": event.sender(), - }), + })) + .unwrap(), ); if membership == MembershipState::Invite && content.third_party_invite.is_some() { - // TODO impl this - unimplemented!("third party invite") + // TODO this is unimpled + if !verify_third_party_invite(event, auth_events) { + tracing::info!( + "{} was not invited to this room", + event + .event_id() + .map(ToString::to_string) + .unwrap_or("Unknow".into()) + ); + return Some(false); + } + if target_banned { + tracing::info!( + "{} is banned", + event + .event_id() + .map(ToString::to_string) + .unwrap_or("Unknow".into()) + ); + return Some(false); + } + tracing::info!("invite succeded"); + return Some(true); } if membership != MembershipState::Join { @@ -303,15 +350,27 @@ fn is_membership_change_allowed( && membership == MembershipState::Leave && &target_user_id == event.sender() { + tracing::info!("join event succeded"); return Some(true); } + + if !caller_in_room { + tracing::info!( + "{} is not in this room {:?}", + event.sender(), + event.room_id() + ); + return Some(false); // caller is not joined + } } if membership == MembershipState::Invite { if target_banned { + tracing::info!("target has been banned"); return Some(false); } else if target_in_room { - return Some(false); + tracing::info!("already in room"); + return Some(false); // already in room } else { let invite_level = get_named_level(auth_events, "invite", 0); if user_level < invite_level { @@ -320,39 +379,57 @@ fn is_membership_change_allowed( } } else if membership == MembershipState::Join { if event.sender() != &target_user_id { + tracing::info!("cannot force another user to join"); return Some(false); // cannot force another user to join } else if target_banned { + tracing::info!("cannot join when banned"); return Some(false); // cannot joined when banned } else if join_rule == JoinRule::Public { - // pass + tracing::info!("join rule public") + // pass } else if join_rule == JoinRule::Invite { if !caller_in_room && !caller_invited { + tracing::info!("user has not been invited to this room"); return Some(false); // you are not invited to this room } } else { + tracing::info!("the join rule is Private or yet to be spec'ed by Matrix"); // synapse has 2 TODO's may_join list and private rooms + + // the join_rule is Private or Knock which means it is not yet spec'ed return Some(false); } } else if membership == MembershipState::Leave { if target_banned && user_level < ban_level { + tracing::info!("not enough power to unban"); return Some(false); // you cannot unban this user } else if &target_user_id != event.sender() { let kick_level = get_named_level(auth_events, "kick", 50); if user_level < kick_level || user_level <= target_level { + tracing::info!("not enough power to kick user"); return Some(false); // you do not have the power to kick user } } } else if membership == MembershipState::Ban { + tracing::debug!( + "{} < {} || {} <= {}", + user_level, + ban_level, + user_level, + target_level + ); if user_level < ban_level || user_level <= target_level { + tracing::info!("not enough power to ban"); return Some(false); } } else { + tracing::warn!("unknown membership status"); // Unknown membership status return Some(false); } - Some(false) + Some(true) } /// Is the event's sender in the room that they sent the event to. @@ -391,10 +468,8 @@ fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Opt } if let Some(sk) = event.state_key() { - if sk.starts_with("@") { - if sk != event.sender().to_string() { - return Some(false); // permission required to post in this room - } + if sk.starts_with("@") && sk != event.sender().to_string() { + return Some(false); // permission required to post in this room } } Some(true) @@ -467,16 +542,12 @@ fn check_power_levels( for user in user_levels_to_check { let old_level = old_state.users.get(user); let new_level = new_state.users.get(user); - if old_level.is_some() && new_level.is_some() { - if old_level == new_level { - continue; - } + if old_level.is_some() && new_level.is_some() && old_level == new_level { + continue; } - if user != power_event.sender() { - if old_level.map(|int| (*int).into()) == Some(user_level) { - tracing::info!("m.room.power_level cannot remove ops == to own"); - return Some(false); // cannot remove ops level == to own - } + if user != power_event.sender() && old_level.map(|int| (*int).into()) == Some(user_level) { + tracing::info!("m.room.power_level cannot remove ops == to own"); + return Some(false); // cannot remove ops level == to own } let old_level_too_big = old_level.map(|int| (*int).into()) > Some(user_level); @@ -491,10 +562,8 @@ fn check_power_levels( for ev_type in event_levels_to_check { let old_level = old_state.events.get(ev_type); let new_level = new_state.events.get(ev_type); - if old_level.is_some() && new_level.is_some() { - if old_level == new_level { - continue; - } + if old_level.is_some() && new_level.is_some() && old_level == new_level { + continue; } let old_level_too_big = old_level.map(|int| (*int).into()) > Some(user_level); @@ -539,7 +608,7 @@ fn check_redaction( fn check_membership(member_event: Option<&StateEvent>, state: MembershipState) -> bool { if let Some(event) = member_event { if let Ok(content) = - serde_json::from_value::(event.content()) + serde_json::from_value::(event.content().clone()) { content.membership == state } else { @@ -602,9 +671,9 @@ fn get_send_level( ) -> i64 { tracing::debug!("{:?} {:?}", e_type, state_key); if let Some(ple) = power_lvl { - if let Ok(content) = - serde_json::from_value::(ple.content()) - { + if let Ok(content) = serde_json::from_value::( + ple.content().clone(), + ) { let mut lvl: i64 = content .events .get(&e_type) @@ -613,19 +682,18 @@ fn get_send_level( .into(); let state_def: i64 = content.state_default.into(); let event_def: i64 = content.events_default.into(); - if state_key.is_some() { - if state_def > lvl { - lvl = event_def; - } - } - if event_def > lvl { + if (state_key.is_some() && state_def > lvl) || event_def > lvl { lvl = event_def; } - return lvl; + lvl } else { - return 50; + 50 } } else { - return 0; + 0 } } + +fn verify_third_party_invite(_event: &StateEvent, _auth_events: &StateMap) -> bool { + unimplemented!("impl third party invites") +} diff --git a/src/lib.rs b/src/lib.rs index aefd24d6..04f4cef7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,7 +77,10 @@ impl StateResolution { // split non-conflicting and conflicting state let (clean, conflicting) = self.separate(&state_sets); + tracing::info!("non conflicting {:?}", clean.len()); + if conflicting.is_empty() { + tracing::warn!("no conflicting state found"); return Ok(ResolutionResult::Resolved(clean)); } @@ -86,6 +89,8 @@ impl StateResolution { // the set of auth events that are not common across server forks let mut auth_diff = self.get_auth_chain_diff(room_id, &state_sets, &event_map, store)?; + tracing::debug!("auth diff size {}", auth_diff.len()); + // add the auth_diff to conflicting now we have a full set of conflicting events auth_diff.extend(conflicting.values().cloned().flatten()); let mut all_conflicted = auth_diff @@ -94,14 +99,6 @@ impl StateResolution { .into_iter() .collect::>(); - tracing::debug!( - "FULL CONF {:?}", - all_conflicted - .iter() - .map(ToString::to_string) - .collect::>() - ); - tracing::info!("full conflicted set is {} events", all_conflicted.len()); // gather missing events for the event_map @@ -123,6 +120,8 @@ impl StateResolution { .flat_map(|ev| Some((ev.event_id()?.clone(), ev))), ); + tracing::debug!("event map size: {}", event_map.len()); + for event in event_map.values() { if event.room_id() != Some(room_id) { return Err(format!( @@ -139,16 +138,10 @@ impl StateResolution { // TODO make sure each conflicting event is in event_map?? // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}` + // + // don't honor events we cannot "verify" all_conflicted.retain(|id| event_map.contains_key(id)); - tracing::debug!( - "ALL {:?}", - all_conflicted - .iter() - .map(ToString::to_string) - .collect::>() - ); - // get only the power events with a state_key: "" or ban/kick event (sender != state_key) let power_events = all_conflicted .iter() @@ -156,14 +149,6 @@ impl StateResolution { .cloned() .collect::>(); - tracing::debug!( - "POWER {:?}", - power_events - .iter() - .map(ToString::to_string) - .collect::>() - ); - // sort the power events based on power_level/clock/event_id and outgoing/incoming edges let mut sorted_power_levels = self.reverse_topological_power_sort( room_id, @@ -204,7 +189,7 @@ impl StateResolution { sorted_power_levels.dedup(); let deduped_power_ev = sorted_power_levels; - // we have resolved the power events so remove them, I'm sure theres other reasons to do so + // we have resolved the power events so remove them, I'm sure there are other reasons to do so let events_to_resolve = all_conflicted .iter() .filter(|id| !deduped_power_ev.contains(id)) @@ -267,18 +252,28 @@ impl StateResolution { let mut unconflicted_state = StateMap::new(); let mut conflicted_state = StateMap::new(); - for key in state_sets.iter().flat_map(|map| map.keys()) { + for key in state_sets + .iter() + .flat_map(|map| map.keys()) + .collect::>() + { let mut event_ids = state_sets .iter() - .flat_map(|map| map.get(key).cloned()) + .map(|state_set| state_set.get(key)) .dedup() - .collect::>(); + .collect::>(); if event_ids.len() == 1 { - // unwrap is ok since we know the len is 1 - unconflicted_state.insert(key.clone(), event_ids.pop().unwrap()); + if let Some(Some(id)) = event_ids.pop() { + unconflicted_state.insert(key.clone(), id.clone()); + } else { + panic!() + } } else { - conflicted_state.insert(key.clone(), event_ids); + conflicted_state.insert( + key.clone(), + event_ids.into_iter().flatten().cloned().collect::>(), + ); } } @@ -348,15 +343,6 @@ impl StateResolution { let ev = event_map.get(event_id).unwrap(); let pl = event_to_pl.get(event_id).unwrap(); - tracing::debug!( - "{:?}", - ( - -*pl, - ev.origin_server_ts().clone(), - ev.event_id().unwrap().to_string() - ) - ); - // This return value is the key used for sorting events, // events are then sorted by power level, time, // and lexically by event_id. @@ -531,15 +517,17 @@ impl StateResolution { } tracing::debug!("event to check {:?}", event.event_id().unwrap().to_string()); - if !event_auth::auth_check(room_version, &event, auth_events) - .ok_or("Auth check failed due to deserialization most likely".to_string()) - .unwrap() + if event_auth::auth_check(room_version, &event, auth_events, false) + .ok_or("Auth check failed due to deserialization most likely".to_string())? { - // TODO synapse passes here on AuthError ?? - tracing::warn!("event {} failed the authentication", event_id.to_string()); - } else { // add event to resolved state map resolved_state.insert((event.kind(), event.state_key().unwrap()), event_id.clone()); + } else { + // TODO synapse passes here on AuthError ?? + tracing::warn!( + "event {} failed the authentication check", + event_id.to_string() + ); } // We yield occasionally when we're working with large data sets to @@ -576,13 +564,15 @@ impl StateResolution { let mut idx = 0; while let Some(p) = pl { mainline.push(p.clone()); + // We don't need the actual pl_ev here since we delegate to the store - let auth_events = store.get_event(&p).unwrap().auth_event_ids(); + let event = store.get_event(&p).unwrap(); + let auth_events = event.auth_event_ids(); pl = None; for aid in auth_events { let ev = store.get_event(&aid).unwrap(); if ev.is_type_and_key(EventType::RoomPowerLevels, "") { - pl = Some(aid); + pl = Some(aid.clone()); break; } } @@ -646,16 +636,7 @@ impl StateResolution { } let auth_events = sort_ev.auth_event_ids(); - tracing::debug!( - "mainline AUTH EV {:?}", - auth_events - .iter() - .map(ToString::to_string) - .collect::>() - ); - event = None; - for aid in auth_events { let aev = store.get_event(&aid).unwrap(); if aev.is_type_and_key(EventType::RoomPowerLevels, "") { @@ -690,7 +671,7 @@ impl StateResolution { } // we just inserted this at the start of the while loop - graph.get_mut(&eid).unwrap().push(aid); + graph.get_mut(&eid).unwrap().push(aid.clone()); } } } diff --git a/src/state_event.rs b/src/state_event.rs index f15a72cc..897834f6 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use ruma::{ events::{ from_raw_json_value, @@ -5,7 +7,7 @@ use ruma::{ room::member::{MemberEventContent, MembershipState}, EventDeHelper, EventType, }, - identifiers::{EventId, RoomId, UserId}, + identifiers::{EventId, RoomId, ServerName, UserId}, }; use serde::{de, Serialize}; use serde_json::value::RawValue as RawJsonValue; @@ -198,15 +200,28 @@ impl StateEvent { } } - pub fn content(&self) -> serde_json::Value { + pub fn content(&self) -> &serde_json::Value { match self { Self::Full(ev) => match ev { - Pdu::RoomV1Pdu(ev) => ev.content.clone(), - Pdu::RoomV3Pdu(ev) => ev.content.clone(), + Pdu::RoomV1Pdu(ev) => &ev.content, + Pdu::RoomV3Pdu(ev) => &ev.content, }, Self::Sync(ev) => match ev { - PduStub::RoomV1PduStub(ev) => ev.content.clone(), - PduStub::RoomV3PduStub(ev) => ev.content.clone(), + PduStub::RoomV1PduStub(ev) => &ev.content, + PduStub::RoomV3PduStub(ev) => &ev.content, + }, + } + } + + pub fn signatures(&self) -> BTreeMap, BTreeMap> { + match self { + Self::Full(ev) => match ev { + Pdu::RoomV1Pdu(_) => maplit::btreemap! {}, + Pdu::RoomV3Pdu(ev) => ev.signatures.clone(), + }, + Self::Sync(ev) => match ev { + PduStub::RoomV1PduStub(ev) => ev.signatures.clone(), + PduStub::RoomV3PduStub(ev) => ev.signatures.clone(), }, } } diff --git a/tests/state_res.rs b/tests/state_res.rs index 221839d7..ac5541f9 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -23,9 +23,16 @@ use serde_json::{from_value as from_json_value, json, Value as JsonValue}; use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; use tracing_subscriber as tracer; +use std::sync::Once; + +static LOGGER: Once = Once::new(); + static mut SERVER_TIMESTAMP: i32 = 0; -fn id(id: &str) -> EventId { +fn event_id(id: &str) -> EventId { + if id.contains("$") { + return EventId::try_from(id).unwrap(); + } EventId::try_from(format!("${}:foo", id)).unwrap() } @@ -38,6 +45,9 @@ fn bob() -> UserId { fn charlie() -> UserId { UserId::try_from("@charlie:foo").unwrap() } +fn ella() -> UserId { + UserId::try_from("@ella:foo").unwrap() +} fn zera() -> UserId { UserId::try_from("@zera:foo").unwrap() } @@ -277,19 +287,19 @@ fn INITIAL_EDGES() -> Vec { "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", ] .into_iter() - .map(|s| format!("${}:foo", s)) - .map(EventId::try_from) - .collect::, _>>() - .unwrap() + .map(event_id) + .collect::>() } fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: Vec) { use itertools::Itertools; // to activate logging use `RUST_LOG=debug cargo t one_test_only` - // tracer::fmt() - // .with_env_filter(tracer::EnvFilter::from_default_env()) - // .init(); + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); let mut resolver = StateResolution::default(); @@ -354,17 +364,6 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .cloned() .collect::>(); - tracing::debug!( - "RESOLVING {:?}", - state_sets - .iter() - .map(|map| map - .iter() - .map(|((t, s), id)| (t, s, id.to_string())) - .collect::>()) - .collect::>() - ); - let resolved = resolver.resolve( &room_id(), &RoomVersionId::version_1(), @@ -389,6 +388,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: }; let mut state_after = state_before.clone(); + if fake_event.state_key().is_some() { let ty = fake_event.kind().clone(); // we know there is a state_key unwrap OK @@ -414,12 +414,16 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: e.sender().clone(), e.kind(), e.state_key().as_deref(), - e.content(), + e.content().clone(), &auth_events, prev_events, ); - // we have to update our store, an actual user of this lib would do this - // with the result of the resolution> + // we have to update our store, an actual user of this lib would + // be giving us state from a DB. + // + // TODO + // TODO we need to convert the `StateResolution::resolve` to use the event_map + // because the user of this crate cannot update their DB's state. *store.0.borrow_mut().get_mut(ev_id).unwrap() = event.clone(); state_at_event.insert(node, state_after); @@ -442,12 +446,10 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: expected_state.insert(key, node); } - let start_state = state_at_event - .get(&EventId::try_from("$START:foo").unwrap()) - .unwrap(); + let start_state = state_at_event.get(&event_id("$START:foo")).unwrap(); let end_state = state_at_event - .get(&EventId::try_from("$END:foo").unwrap()) + .get(&event_id("$END:foo")) .unwrap() .iter() .filter(|(k, v)| expected_state.contains_key(k) || start_state.get(k) != Some(*v)) @@ -495,21 +497,13 @@ fn ban_vs_power_level() { vec!["END", "PB", "PA"], ] .into_iter() - .map(|list| { - list.into_iter() - .map(|s| format!("${}:foo", s)) - .map(EventId::try_from) - .collect::, _>>() - .unwrap() - }) + .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["PA", "MA", "MB"] .into_iter() - .map(|s| format!("${}:foo", s)) - .map(EventId::try_from) - .collect::, _>>() - .unwrap(); + .map(event_id) + .collect::>(); do_check(events, edges, expected_state_ids) } @@ -548,21 +542,13 @@ fn topic_basic() { vec!["END", "T3", "PB", "PA1"], ] .into_iter() - .map(|list| { - list.into_iter() - .map(|s| format!("${}:foo", s)) - .map(EventId::try_from) - .collect::, _>>() - .unwrap() - }) + .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["PA2", "T2"] .into_iter() - .map(|s| format!("${}:foo", s)) - .map(EventId::try_from) - .collect::, _>>() - .unwrap(); + .map(event_id) + .collect::>(); do_check(events, edges, expected_state_ids) } @@ -593,21 +579,125 @@ fn topic_reset() { vec!["END", "T1"], ] .into_iter() - .map(|list| { - list.into_iter() - .map(|s| format!("${}:foo", s)) - .map(EventId::try_from) - .collect::, _>>() - .unwrap() - }) + .map(|list| list.into_iter().map(event_id).collect::>()) .collect::>(); let expected_state_ids = vec!["T1", "MB", "PA"] .into_iter() - .map(|s| format!("${}:foo", s)) - .map(EventId::try_from) - .collect::, _>>() - .unwrap(); + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn join_rule_evasion() { + let events = &[ + to_init_pdu_event( + "JR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Private }), + ), + to_init_pdu_event( + "ME", + ella(), + EventType::RoomMember, + Some(ella().to_string().as_str()), + member_content_join(), + ), + ]; + + let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec![event_id("JR")]; + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn offtopic_power_level() { + let events = &[ + to_init_pdu_event( + "PA", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event( + "PB", + bob(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50, charlie(): 50}}), + ), + to_init_pdu_event( + "PC", + charlie(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50, charlie(): 0}}), + ), + ]; + + let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::>(); + + do_check(events, edges, expected_state_ids) +} + +#[test] +fn topic_setting() { + let events = &[ + to_init_pdu_event("T1", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA1", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event("T2", alice(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event( + "PA2", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 0}}), + ), + to_init_pdu_event( + "PB", + bob(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice(): 100, bob(): 50}}), + ), + to_init_pdu_event("T3", bob(), EventType::RoomTopic, Some(""), json!({})), + to_init_pdu_event("MZ1", zera(), EventType::RoomMessage, None, json!({})), + to_init_pdu_event("T4", alice(), EventType::RoomTopic, Some(""), json!({})), + ]; + + let edges = vec![ + vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], + vec!["END", "MZ1", "T3", "PB", "PA1"], + ] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["T4", "PA2"] + .into_iter() + .map(event_id) + .collect::>(); do_check(events, edges, expected_state_ids) } @@ -641,11 +731,11 @@ fn test_lexicographical_sort() { let mut resolver = StateResolution::default(); let graph = btreemap! { - id("l") => vec![id("o")], - id("m") => vec![id("n"), id("o")], - id("n") => vec![id("o")], - id("o") => vec![], // "o" has zero outgoing edges but 4 incoming edges - id("p") => vec![id("o")], + event_id("l") => vec![event_id("o")], + event_id("m") => vec![event_id("n"), event_id("o")], + event_id("n") => vec![event_id("o")], + event_id("o") => vec![], // "o" has zero outgoing edges but 4 incoming edges + event_id("p") => vec![event_id("o")], }; let res = @@ -750,6 +840,12 @@ impl StateStore for TestStore { impl TestStore { pub fn set_up(&self) -> (StateMap, StateMap, StateMap) { + // to activate logging use `RUST_LOG=debug cargo t one_test_only` + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); let create_event = to_pdu_event::( "CREATE", alice(),