diff --git a/Cargo.toml b/Cargo.toml index 779fec77..f110b0d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,15 +22,17 @@ maplit = "1.0.2" thiserror = "1.0.20" tracing-subscriber = "0.2.11" -# [dependencies.ruma] -# path = "../__forks__/ruma/ruma" -# features = ["client-api", "federation-api", "appservice-api"] - [dependencies.ruma] -git = "https://github.com/ruma/ruma" -rev = "aff914050eb297bd82b8aafb12158c88a9e480e1" +path = "../ruma/ruma" features = ["client-api", "federation-api", "appservice-api"] +#[dependencies.ruma] +#git = "https://github.com/ruma/ruma" +#rev = "aff914050eb297bd82b8aafb12158c88a9e480e1" +#features = ["client-api", "federation-api", "appservice-api"] + +[features] +unstable-pre-spec = ["ruma/unstable-pre-spec"] [dev-dependencies] criterion = "0.3.3" @@ -38,4 +40,5 @@ rand = "0.7.3" [[bench]] name = "state_res_bench" -harness = false \ No newline at end of file +harness = false + diff --git a/benches/state_res_bench.rs b/benches/state_res_bench.rs index eca7d013..1a7cd76b 100644 --- a/benches/state_res_bench.rs +++ b/benches/state_res_bench.rs @@ -3,7 +3,7 @@ // `cargo bench unknown option --save-baseline`. // To pass args to criterion, use this form // `cargo bench --bench -- --save-baseline `. -use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH}; +use std::{collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH, sync::Arc}; use criterion::{criterion_group, criterion_main, Criterion}; use maplit::btreemap; @@ -42,7 +42,7 @@ fn lexico_topo_sort(c: &mut Criterion) { fn resolution_shallow_auth_chain(c: &mut Criterion) { c.bench_function("resolve state of 5 events one fork", |b| { - let store = TestStore(RefCell::new(btreemap! {})); + let mut store = TestStore(btreemap! {}); // build up the DAG let (state_at_bob, state_at_charlie, _) = store.set_up(); @@ -69,7 +69,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { let mut inner = init; inner.extend(ban); - let store = TestStore(RefCell::new(inner.clone())); + let store = TestStore(inner.clone()); let state_set_a = [ inner.get(&event_id("CREATE")).unwrap(), @@ -126,22 +126,21 @@ criterion_main!(benches); // IMPLEMENTATION DETAILS AHEAD // /////////////////////////////////////////////////////////////////////*/ -pub struct TestStore(RefCell>); +pub struct TestStore(BTreeMap>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { self.0 - .borrow() .get(event_id) - .cloned() + .map(Arc::clone) .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } impl TestStore { - pub fn set_up(&self) -> (StateMap, StateMap, StateMap) { - let create_event = to_pdu_event::( + pub fn set_up(&mut self) -> (StateMap, StateMap, StateMap) { + let create_event = Arc::new(to_pdu_event::( "CREATE", alice(), EventType::RoomCreate, @@ -149,11 +148,10 @@ impl TestStore { json!({ "creator": alice() }), &[], &[], - ); + )); let cre = create_event.event_id(); self.0 - .borrow_mut() - .insert(cre.clone(), create_event.clone()); + .insert(cre.clone(), Arc::clone(&create_event)); let alice_mem = to_pdu_event( "IMA", @@ -165,8 +163,7 @@ impl TestStore { &[cre.clone()], ); self.0 - .borrow_mut() - .insert(alice_mem.event_id(), alice_mem.clone()); + .insert(alice_mem.event_id(), Arc::clone(&alice_mem)); let join_rules = to_pdu_event( "IJR", @@ -178,8 +175,7 @@ impl TestStore { &[alice_mem.event_id()], ); self.0 - .borrow_mut() - .insert(join_rules.event_id(), join_rules.clone()); + .insert(join_rules.event_id(), Arc::clone(&join_rules)); // Bob and Charlie join at the same time, so there is a fork // this will be represented in the state_sets when we resolve @@ -193,8 +189,7 @@ impl TestStore { &[join_rules.event_id()], ); self.0 - .borrow_mut() - .insert(bob_mem.event_id(), bob_mem.clone()); + .insert(bob_mem.event_id(), Arc::clone(&bob_mem)); let charlie_mem = to_pdu_event( "IMC", @@ -206,8 +201,7 @@ impl TestStore { &[join_rules.event_id()], ); self.0 - .borrow_mut() - .insert(charlie_mem.event_id(), charlie_mem.clone()); + .insert(charlie_mem.event_id(), Arc::clone(&charlie_mem)); let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() @@ -288,7 +282,7 @@ fn to_pdu_event( content: JsonValue, auth_events: &[S], prev_events: &[S], -) -> StateEvent +) -> Arc where S: AsRef, { @@ -362,12 +356,12 @@ where "signatures": {}, }) }; - serde_json::from_value(json).unwrap() + Arc::new(serde_json::from_value(json).unwrap()) } // all graphs start with these input events #[allow(non_snake_case)] -fn INITIAL_EVENTS() -> BTreeMap { +fn INITIAL_EVENTS() -> BTreeMap> { vec![ to_pdu_event::( "CREATE", @@ -449,7 +443,7 @@ fn INITIAL_EVENTS() -> BTreeMap { // all graphs start with these input events #[allow(non_snake_case)] -fn BAN_STATE_SET() -> BTreeMap { +fn BAN_STATE_SET() -> BTreeMap> { vec![ to_pdu_event( "PA", diff --git a/src/event_auth.rs b/src/event_auth.rs index 04acf141..462ec042 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, convert::TryFrom}; +use std::{collections::BTreeMap, convert::TryFrom, sync::Arc}; use maplit::btreeset; use ruma::{ @@ -72,10 +72,10 @@ pub fn auth_types_for_event( /// * then there are checks for specific event types pub fn auth_check( room_version: &RoomVersionId, - incoming_event: &StateEvent, - prev_event: Option<&StateEvent>, - auth_events: StateMap, - current_third_party_invite: Option<&StateEvent>, + incoming_event: &Arc, + prev_event: Option>, + auth_events: StateMap>, + current_third_party_invite: Option>, ) -> Result { tracing::info!("auth_check beginning for {}", incoming_event.kind()); @@ -206,8 +206,8 @@ pub fn auth_check( if !valid_membership_change( incoming_event.to_requester(), - current_third_party_invite, prev_event, + current_third_party_invite, &auth_events, )? { return Ok(false); @@ -241,7 +241,7 @@ pub fn auth_check( // If the event type's required power level is greater than the sender's power level, reject // If the event has a state_key that starts with an @ and does not match the sender, reject. - if !can_send_event(incoming_event, &auth_events)? { + if !can_send_event(&incoming_event, &auth_events)? { tracing::warn!("user cannot send event"); return Ok(false); } @@ -250,7 +250,7 @@ pub fn auth_check( tracing::info!("starting m.room.power_levels check"); if let Some(required_pwr_lvl) = - check_power_levels(room_version, incoming_event, &auth_events) + check_power_levels(room_version, &incoming_event, &auth_events) { if !required_pwr_lvl { tracing::warn!("power level was not allowed"); @@ -284,9 +284,9 @@ pub fn auth_check( /// the current State. pub fn valid_membership_change( user: Requester<'_>, - prev_event: Option<&StateEvent>, - current_third_party_invite: Option<&StateEvent>, - auth_events: &StateMap, + prev_event: Option>, + current_third_party_invite: Option>, + auth_events: &StateMap>, ) -> Result { let state_key = if let Some(s) = user.state_key.as_ref() { s @@ -377,7 +377,7 @@ pub fn valid_membership_change( .join_rule; } - if let Some(prev) = prev_event { + if let Some(prev) = dbg!(prev_event) { if prev.kind() == EventType::RoomCreate && prev.prev_event_ids().is_empty() { return Ok(true); } @@ -440,7 +440,7 @@ pub fn valid_membership_change( /// Is the event's sender in the room that they sent the event to. pub fn check_event_sender_in_room( sender: &UserId, - auth_events: &StateMap, + auth_events: &StateMap>, ) -> Option { let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?; Some( @@ -453,7 +453,7 @@ pub fn check_event_sender_in_room( /// Is the user allowed to send a specific event based on the rooms power levels. Does the event /// have the correct userId as it's state_key if it's not the "" state_key. -pub fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Result { +pub fn can_send_event(event: &Arc, auth_events: &StateMap>) -> Result { let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); let event_type_power_level = get_send_level(event.kind(), event.state_key(), ple); @@ -481,8 +481,8 @@ pub fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> /// Confirm that the event sender has the required power levels. pub fn check_power_levels( _: &RoomVersionId, - power_event: &StateEvent, - auth_events: &StateMap, + power_event: &Arc, + auth_events: &StateMap>, ) -> Option { let key = (power_event.kind(), power_event.state_key()); let current_state = if let Some(current_state) = auth_events.get(&key) { @@ -627,8 +627,8 @@ fn get_deserialize_levels( /// Does the event redacting come from a user with enough power to redact the given event. pub fn check_redaction( room_version: &RoomVersionId, - redaction_event: &StateEvent, - auth_events: &StateMap, + redaction_event: &Arc, + auth_events: &StateMap>, ) -> Result { let user_level = get_user_power_level(redaction_event.sender(), auth_events); let redact_level = get_named_level(auth_events, "redact", 50); @@ -662,7 +662,7 @@ pub fn check_redaction( /// Check that the member event matches `state`. /// /// This function returns false instead of failing when deserialization fails. -pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipState) -> bool { +pub fn check_membership(member_event: Option>, state: MembershipState) -> bool { if let Some(event) = member_event { if let Ok(content) = serde_json::from_value::(event.content().clone()) @@ -677,7 +677,7 @@ pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipStat } /// Can this room federate based on its m.room.create event. -pub fn can_federate(auth_events: &StateMap) -> bool { +pub fn can_federate(auth_events: &StateMap>) -> bool { let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); if let Some(ev) = creation_event { if let Some(fed) = ev.content().get("m.federate") { @@ -692,7 +692,7 @@ pub fn can_federate(auth_events: &StateMap) -> bool { /// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. /// or return `default` if no power level event is found or zero if no field matches `name`. -pub fn get_named_level(auth_events: &StateMap, name: &str, default: i64) -> i64 { +pub fn get_named_level(auth_events: &StateMap>, name: &str, default: i64) -> i64 { let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); if let Some(pl) = power_level_event { // TODO do this the right way and deserialize @@ -708,7 +708,7 @@ pub fn get_named_level(auth_events: &StateMap, name: &str, default: /// Helper function to fetch a users default power level from a "m.room.power_level" event's `users` /// object. -pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap) -> i64 { +pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap>) -> i64 { if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))) { if let Ok(content) = pl.deserialize_content::() { @@ -744,7 +744,7 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap pub fn get_send_level( e_type: EventType, state_key: Option, - power_lvl: Option<&StateEvent>, + power_lvl: Option<&Arc>, ) -> i64 { tracing::debug!("{:?} {:?}", e_type, state_key); if let Some(ple) = power_lvl { @@ -772,7 +772,7 @@ pub fn get_send_level( } /// Check user can send invite. -pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap) -> Result { +pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap>) -> Result { let user_level = get_user_power_level(event.sender, auth_events); let key = (EventType::RoomPowerLevels, Some("".into())); let invite_level = auth_events @@ -794,7 +794,7 @@ pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap pub fn verify_third_party_invite( event: &Requester<'_>, tp_id: &member::ThirdPartyInvite, - current_third_party_invite: Option<&StateEvent>, + current_third_party_invite: Option>, ) -> bool { // 1. check for user being banned happens before this is called // checking for mxid and token keys is done by ruma when deserializing diff --git a/src/lib.rs b/src/lib.rs index c782c41d..8f282139 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ use std::{ cmp::Reverse, collections::{BTreeMap, BTreeSet, BinaryHeap}, time::SystemTime, -}; +sync::Arc}; use maplit::btreeset; use ruma::{ @@ -58,7 +58,7 @@ impl StateResolution { room_id: &RoomId, room_version: &RoomVersionId, state_sets: &[StateMap], - event_map: Option>, + event_map: Option>>, store: &dyn StateStore, ) -> Result> { tracing::info!("State resolution starting"); @@ -292,7 +292,7 @@ impl StateResolution { pub fn reverse_topological_power_sort( room_id: &RoomId, events_to_sort: &[EventId], - event_map: &mut EventMap, + event_map: &mut EventMap>, store: &dyn StateStore, auth_diff: &[EventId], ) -> Vec { @@ -418,7 +418,7 @@ impl StateResolution { fn get_power_level_for_sender( room_id: &RoomId, event_id: &EventId, - event_map: &mut EventMap, + event_map: &mut EventMap>, store: &dyn StateStore, ) -> i64 { tracing::info!("fetch event ({}) senders power level", event_id.to_string()); @@ -476,7 +476,7 @@ impl StateResolution { room_version: &RoomVersionId, events_to_check: &[EventId], unconflicted_state: &StateMap, - event_map: &mut EventMap, + event_map: &mut EventMap>, store: &dyn StateStore, ) -> Result> { tracing::info!("starting iterative auth check"); @@ -526,8 +526,8 @@ impl StateResolution { tracing::debug!("event to check {:?}", event.event_id().as_str()); - let most_recent_prev_event = event - .prev_event_ids() + let most_recent_prev_event = dbg!(event + .prev_event_ids()) .iter() .filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store)) .next_back(); @@ -545,9 +545,9 @@ impl StateResolution { if event_auth::auth_check( room_version, &event, - most_recent_prev_event.as_ref(), + most_recent_prev_event, auth_events, - current_third_party.as_ref(), + current_third_party, )? { // add event to resolved state map resolved_state.insert((event.kind(), event.state_key()), event_id.clone()); @@ -579,7 +579,7 @@ impl StateResolution { room_id: &RoomId, to_sort: &[EventId], resolved_power_level: Option<&EventId>, - event_map: &mut EventMap, + event_map: &mut EventMap>, store: &dyn StateStore, ) -> Vec { tracing::debug!("mainline sort of events"); @@ -658,14 +658,13 @@ impl StateResolution { sort_event_ids } - // TODO make `event` not clone every loop /// Get the mainline depth from the `mainline_map` or finds a power_level event /// that has an associated mainline depth. fn get_mainline_depth( room_id: &RoomId, - mut event: Option, + mut event: Option>, mainline_map: &EventMap, - event_map: &mut EventMap, + event_map: &mut EventMap>, store: &dyn StateStore, ) -> usize { while let Some(sort_ev) = event { @@ -681,7 +680,7 @@ impl StateResolution { let aev = StateResolution::get_or_load_event(room_id, &aid, event_map, store).unwrap(); if aev.is_type_and_key(EventType::RoomPowerLevels, "") { - event = Some(aev.clone()); + event = Some(aev); break; } } @@ -694,7 +693,7 @@ impl StateResolution { room_id: &RoomId, graph: &mut BTreeMap>, event_id: &EventId, - event_map: &mut EventMap, + event_map: &mut EventMap>, store: &dyn StateStore, auth_diff: &[EventId], ) { @@ -730,21 +729,22 @@ impl StateResolution { fn get_or_load_event( room_id: &RoomId, ev_id: &EventId, - event_map: &mut EventMap, + event_map: &mut EventMap>, store: &dyn StateStore, - ) -> Option { - // TODO can we cut down on the clones? - if !event_map.contains_key(ev_id) { - let event = store.get_event(room_id, ev_id).ok()?; - event_map.insert(ev_id.clone(), event.clone()); - Some(event) - } else { - event_map.get(ev_id).cloned() + ) -> Option> { + if let Some(e) = event_map.get(ev_id) { + return Some(Arc::clone(e)); } + + if let Ok(e) = store.get_event(room_id, ev_id) { + return Some(Arc::clone(event_map.entry(ev_id.clone()).or_insert(e))) + } + + None } } -pub fn is_power_event(event_id: &EventId, event_map: &EventMap) -> bool { +pub fn is_power_event(event_id: &EventId, event_map: &EventMap>) -> bool { match event_map.get(event_id) { Some(state) => state.is_power_event(), _ => false, diff --git a/src/state_event.rs b/src/state_event.rs index e15d0531..96b2aae7 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -202,6 +202,8 @@ impl StateEvent { }, } } + + #[cfg(not(feature = "unstable-pre-spec"))] pub fn origin(&self) -> String { match self { Self::Full(ev) => match ev { diff --git a/src/state_store.rs b/src/state_store.rs index 777913ed..c1695fa4 100644 --- a/src/state_store.rs +++ b/src/state_store.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeSet; +use std::{collections::BTreeSet, sync::Arc}; use ruma::identifiers::{EventId, RoomId}; @@ -6,10 +6,10 @@ use crate::{Result, StateEvent}; pub trait StateStore { /// Return a single event based on the EventId. - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result; + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result>; /// Returns the events that correspond to the `event_ids` sorted in the same order. - fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result> { + fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result>> { let mut events = vec![]; for id in event_ids { events.push(self.get_event(room_id, id)?); diff --git a/tests/event_auth.rs b/tests/event_auth.rs index 47d04727..7ed7d44a 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom}; +use std::{collections::BTreeMap, convert::TryFrom, sync::Arc}; use ruma::{ events::{ @@ -71,15 +71,14 @@ fn member_content_join() -> JsonValue { .unwrap() } -pub struct TestStore(RefCell>); +pub struct TestStore(BTreeMap>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { self.0 - .borrow() .get(event_id) - .cloned() + .map(Arc::clone) .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } @@ -92,7 +91,7 @@ fn to_pdu_event( content: JsonValue, auth_events: &[S], prev_events: &[S], -) -> StateEvent +) -> Arc where S: AsRef, { @@ -166,12 +165,12 @@ where "signatures": {}, }) }; - serde_json::from_value(json).unwrap() + Arc::new(serde_json::from_value(json).unwrap()) } // all graphs start with these input events #[allow(non_snake_case)] -fn INITIAL_EVENTS() -> BTreeMap { +fn INITIAL_EVENTS() -> BTreeMap> { // this is always called so we can init the logger here let _ = LOGGER.call_once(|| { tracer::fmt() @@ -246,11 +245,12 @@ fn test_ban_pass() { let prev = events .values() - .find(|ev| ev.event_id().as_str().contains("IMC")); + .find(|ev| ev.event_id().as_str().contains("IMC")) + .map(Arc::clone); let auth_events = events .values() - .map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) + .map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev))) .collect::>(); let requester = Requester { @@ -270,11 +270,12 @@ fn test_ban_fail() { let prev = events .values() - .find(|ev| ev.event_id().as_str().contains("IMC")); + .find(|ev| ev.event_id().as_str().contains("IMC")) + .map(Arc::clone); let auth_events = events .values() - .map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) + .map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev))) .collect::>(); let requester = Requester { diff --git a/tests/event_sorting.rs b/tests/event_sorting.rs index c9d9248a..e4ac5b44 100644 --- a/tests/event_sorting.rs +++ b/tests/event_sorting.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom}; +use std::{collections::BTreeMap, convert::TryFrom, sync::Arc}; use ruma::{ events::{ @@ -53,15 +53,14 @@ fn member_content_join() -> JsonValue { .unwrap() } -pub struct TestStore(RefCell>); +pub struct TestStore(BTreeMap>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { self.0 - .borrow() .get(event_id) - .cloned() + .map(Arc::clone) .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } @@ -74,7 +73,7 @@ fn to_pdu_event( content: JsonValue, auth_events: &[S], prev_events: &[S], -) -> StateEvent +) -> Arc where S: AsRef, { @@ -148,12 +147,12 @@ where "signatures": {}, }) }; - serde_json::from_value(json).unwrap() + Arc::new(serde_json::from_value(json).unwrap()) } // all graphs start with these input events #[allow(non_snake_case)] -fn INITIAL_EVENTS() -> BTreeMap { +fn INITIAL_EVENTS() -> BTreeMap> { // this is always called so we can init the logger here let _ = LOGGER.call_once(|| { tracer::fmt() @@ -243,8 +242,7 @@ fn shuffle(list: &mut [EventId]) { fn test_event_sort() { let mut events = INITIAL_EVENTS(); - - let store = TestStore(RefCell::new(events.clone())); + let store = TestStore(events.clone()); let event_map = events .values() diff --git a/tests/res_with_auth_ids.rs b/tests/res_with_auth_ids.rs index c00d5ffa..0787eb7a 100644 --- a/tests/res_with_auth_ids.rs +++ b/tests/res_with_auth_ids.rs @@ -1,6 +1,6 @@ #![allow(clippy::or_fun_call, clippy::expect_fun_call)] -use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, sync::Once, time::UNIX_EPOCH}; +use std::{collections::BTreeMap, convert::TryFrom, sync::Once, time::UNIX_EPOCH, sync::Arc}; use ruma::{ events::{ @@ -21,7 +21,7 @@ static LOGGER: Once = Once::new(); static mut SERVER_TIMESTAMP: i32 = 0; -fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: Vec) { +fn do_check(events: &[Arc], edges: Vec>, expected_state_ids: Vec) { // to activate logging use `RUST_LOG=debug cargo t` let _ = LOGGER.call_once(|| { tracer::fmt() @@ -29,13 +29,13 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .init() }); - let store = TestStore(RefCell::new( + let mut store = TestStore( INITIAL_EVENTS() .values() .chain(events) .map(|ev| (ev.event_id(), ev.clone())) .collect(), - )); + ); // This will be lexi_topo_sorted for resolution let mut graph = BTreeMap::new(); @@ -64,7 +64,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: } // event_id -> StateEvent - let mut event_map: BTreeMap = BTreeMap::new(); + let mut event_map: BTreeMap> = BTreeMap::new(); // event_id -> StateMap let mut state_at_event: BTreeMap> = BTreeMap::new(); @@ -152,10 +152,10 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // we have to update our store, an actual user of this lib would // be giving us state from a DB. - *store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone(); + store.0.insert(ev_id.clone(), event.clone()); state_at_event.insert(node, state_after); - event_map.insert(event_id.clone(), event); + event_map.insert(event_id.clone(), Arc::clone(store.0.get(&ev_id).unwrap())); } let mut expected_state = StateMap::new(); @@ -186,15 +186,14 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: assert_eq!(expected_state, end_state); } -pub struct TestStore(RefCell>); +pub struct TestStore(BTreeMap>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { self.0 - .borrow() .get(event_id) - .cloned() + .map(Arc::clone) .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } @@ -256,7 +255,7 @@ fn to_pdu_event( content: JsonValue, auth_events: &[S], prev_events: &[S], -) -> StateEvent +) -> Arc where S: AsRef, { @@ -330,12 +329,12 @@ where "signatures": {}, }) }; - serde_json::from_value(json).unwrap() + Arc::new(serde_json::from_value(json).unwrap()) } // all graphs start with these input events #[allow(non_snake_case)] -fn INITIAL_EVENTS() -> BTreeMap { +fn INITIAL_EVENTS() -> BTreeMap> { // this is always called so we can init the logger here let _ = LOGGER.call_once(|| { tracer::fmt() @@ -432,7 +431,7 @@ fn INITIAL_EDGES() -> Vec { // all graphs start with these input events #[allow(non_snake_case)] -fn BAN_STATE_SET() -> BTreeMap { +fn BAN_STATE_SET() -> BTreeMap> { vec![ to_pdu_event( "PA", @@ -499,7 +498,7 @@ fn ban_with_auth_chains() { #[test] fn base_with_auth_chains() { - let store = TestStore(RefCell::new(INITIAL_EVENTS())); + let store = TestStore(INITIAL_EVENTS()); let resolved: BTreeMap<_, EventId> = match StateResolution::resolve(&room_id(), &RoomVersionId::Version2, &[], None, &store) { @@ -537,7 +536,7 @@ fn ban_with_auth_chains2() { let mut inner = init.clone(); inner.extend(ban); - let store = TestStore(RefCell::new(inner.clone())); + let store = TestStore(inner.clone()); let state_set_a = [ inner.get(&event_id("CREATE")).unwrap(), @@ -607,7 +606,7 @@ fn ban_with_auth_chains2() { // all graphs start with these input events #[allow(non_snake_case)] -fn JOIN_RULE() -> BTreeMap { +fn JOIN_RULE() -> BTreeMap> { vec![ to_pdu_event( "JR", diff --git a/tests/state_res.rs b/tests/state_res.rs index 7dbe30f8..6d72617e 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH}; +use std::{collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH, sync::Arc}; use maplit::btreemap; use ruma::{ @@ -78,7 +78,7 @@ fn to_pdu_event( content: JsonValue, auth_events: &[S], prev_events: &[S], -) -> StateEvent +) -> Arc where S: AsRef, { @@ -152,7 +152,7 @@ where "signatures": {}, }) }; - serde_json::from_value(json).unwrap() + Arc::new(serde_json::from_value(json).unwrap()) } fn to_init_pdu_event( @@ -161,7 +161,7 @@ fn to_init_pdu_event( ev_type: EventType, state_key: Option<&str>, content: JsonValue, -) -> StateEvent { +) -> Arc { let ts = unsafe { let ts = SERVER_TIMESTAMP; // increment the "origin_server_ts" value @@ -206,12 +206,12 @@ fn to_init_pdu_event( "signatures": {}, }) }; - serde_json::from_value(json).unwrap() + Arc::new(serde_json::from_value(json).unwrap()) } // all graphs start with these input events #[allow(non_snake_case)] -fn INITIAL_EVENTS() -> BTreeMap { +fn INITIAL_EVENTS() -> BTreeMap> { vec![ to_init_pdu_event( "CREATE", @@ -280,7 +280,7 @@ fn INITIAL_EDGES() -> Vec { .collect::>() } -fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: Vec) { +fn do_check(events: &[Arc], edges: Vec>, expected_state_ids: Vec) { // to activate logging use `RUST_LOG=debug cargo t one_test_only` let _ = LOGGER.call_once(|| { tracer::fmt() @@ -288,13 +288,13 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .init() }); - let store = TestStore(RefCell::new( + let mut store = TestStore( INITIAL_EVENTS() .values() .chain(events) .map(|ev| (ev.event_id(), ev.clone())) .collect(), - )); + ); // This will be lexi_topo_sorted for resolution let mut graph = BTreeMap::new(); @@ -329,7 +329,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: } // event_id -> StateEvent - let mut event_map: BTreeMap = BTreeMap::new(); + let mut event_map: BTreeMap> = BTreeMap::new(); // event_id -> StateMap let mut state_at_event: BTreeMap> = BTreeMap::new(); @@ -420,7 +420,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // TODO // TODO we need to convert the `StateResolution::resolve` to use the event_map // because the user of this crate cannot update their DB's state. - *store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone(); + store.0.insert(ev_id.clone(), Arc::clone(&event)); state_at_event.insert(node, state_after); event_map.insert(event_id.clone(), event); @@ -702,7 +702,7 @@ fn topic_setting() { #[test] fn test_event_map_none() { - let store = TestStore(RefCell::new(btreemap! {})); + let mut store = TestStore(btreemap! {}); // build up the DAG let (state_at_bob, state_at_charlie, expected) = store.set_up(); @@ -748,21 +748,20 @@ fn test_lexicographical_sort() { // /// The test state store. -pub struct TestStore(RefCell>); +pub struct TestStore(BTreeMap>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { self.0 - .borrow() .get(event_id) - .cloned() + .map(Arc::clone) .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } impl TestStore { - pub fn set_up(&self) -> (StateMap, StateMap, StateMap) { + pub fn set_up(&mut self) -> (StateMap, StateMap, StateMap) { // to activate logging use `RUST_LOG=debug cargo t one_test_only` let _ = LOGGER.call_once(|| { tracer::fmt() @@ -780,8 +779,7 @@ impl TestStore { ); let cre = create_event.event_id(); self.0 - .borrow_mut() - .insert(cre.clone(), create_event.clone()); + .insert(cre.clone(), Arc::clone(&create_event)); let alice_mem = to_pdu_event( "IMA", @@ -793,8 +791,7 @@ impl TestStore { &[cre.clone()], ); self.0 - .borrow_mut() - .insert(alice_mem.event_id(), alice_mem.clone()); + .insert(alice_mem.event_id(), Arc::clone(&alice_mem)); let join_rules = to_pdu_event( "IJR", @@ -806,7 +803,6 @@ impl TestStore { &[alice_mem.event_id()], ); self.0 - .borrow_mut() .insert(join_rules.event_id(), join_rules.clone()); // Bob and Charlie join at the same time, so there is a fork @@ -821,7 +817,6 @@ impl TestStore { &[join_rules.event_id()], ); self.0 - .borrow_mut() .insert(bob_mem.event_id(), bob_mem.clone()); let charlie_mem = to_pdu_event( @@ -834,7 +829,6 @@ impl TestStore { &[join_rules.event_id()], ); self.0 - .borrow_mut() .insert(charlie_mem.event_id(), charlie_mem.clone()); let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]