This commit is contained in:
Timo Kösters 2020-09-11 14:36:14 +02:00 committed by Devin Ragotzy
parent acd829336e
commit f587b88a60
10 changed files with 139 additions and 148 deletions

View File

@ -22,15 +22,17 @@ maplit = "1.0.2"
thiserror = "1.0.20"
tracing-subscriber = "0.2.11"
# [dependencies.ruma]
# path = "../__forks__/ruma/ruma"
# features = ["client-api", "federation-api", "appservice-api"]
[dependencies.ruma]
git = "https://github.com/ruma/ruma"
rev = "aff914050eb297bd82b8aafb12158c88a9e480e1"
path = "../ruma/ruma"
features = ["client-api", "federation-api", "appservice-api"]
#[dependencies.ruma]
#git = "https://github.com/ruma/ruma"
#rev = "aff914050eb297bd82b8aafb12158c88a9e480e1"
#features = ["client-api", "federation-api", "appservice-api"]
[features]
unstable-pre-spec = ["ruma/unstable-pre-spec"]
[dev-dependencies]
criterion = "0.3.3"
@ -39,3 +41,4 @@ rand = "0.7.3"
[[bench]]
name = "state_res_bench"
harness = false

View File

@ -3,7 +3,7 @@
// `cargo bench unknown option --save-baseline`.
// To pass args to criterion, use this form
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
use std::{collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH, sync::Arc};
use criterion::{criterion_group, criterion_main, Criterion};
use maplit::btreemap;
@ -42,7 +42,7 @@ fn lexico_topo_sort(c: &mut Criterion) {
fn resolution_shallow_auth_chain(c: &mut Criterion) {
c.bench_function("resolve state of 5 events one fork", |b| {
let store = TestStore(RefCell::new(btreemap! {}));
let mut store = TestStore(btreemap! {});
// build up the DAG
let (state_at_bob, state_at_charlie, _) = store.set_up();
@ -69,7 +69,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
let mut inner = init;
inner.extend(ban);
let store = TestStore(RefCell::new(inner.clone()));
let store = TestStore(inner.clone());
let state_set_a = [
inner.get(&event_id("CREATE")).unwrap(),
@ -126,22 +126,21 @@ criterion_main!(benches);
// IMPLEMENTATION DETAILS AHEAD
//
/////////////////////////////////////////////////////////////////////*/
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
#[allow(unused)]
impl StateStore for TestStore {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
self.0
.borrow()
.get(event_id)
.cloned()
.map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
}
}
impl TestStore {
pub fn set_up(&self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
let create_event = to_pdu_event::<EventId>(
pub fn set_up(&mut self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
let create_event = Arc::new(to_pdu_event::<EventId>(
"CREATE",
alice(),
EventType::RoomCreate,
@ -149,11 +148,10 @@ impl TestStore {
json!({ "creator": alice() }),
&[],
&[],
);
));
let cre = create_event.event_id();
self.0
.borrow_mut()
.insert(cre.clone(), create_event.clone());
.insert(cre.clone(), Arc::clone(&create_event));
let alice_mem = to_pdu_event(
"IMA",
@ -165,8 +163,7 @@ impl TestStore {
&[cre.clone()],
);
self.0
.borrow_mut()
.insert(alice_mem.event_id(), alice_mem.clone());
.insert(alice_mem.event_id(), Arc::clone(&alice_mem));
let join_rules = to_pdu_event(
"IJR",
@ -178,8 +175,7 @@ impl TestStore {
&[alice_mem.event_id()],
);
self.0
.borrow_mut()
.insert(join_rules.event_id(), join_rules.clone());
.insert(join_rules.event_id(), Arc::clone(&join_rules));
// Bob and Charlie join at the same time, so there is a fork
// this will be represented in the state_sets when we resolve
@ -193,8 +189,7 @@ impl TestStore {
&[join_rules.event_id()],
);
self.0
.borrow_mut()
.insert(bob_mem.event_id(), bob_mem.clone());
.insert(bob_mem.event_id(), Arc::clone(&bob_mem));
let charlie_mem = to_pdu_event(
"IMC",
@ -206,8 +201,7 @@ impl TestStore {
&[join_rules.event_id()],
);
self.0
.borrow_mut()
.insert(charlie_mem.event_id(), charlie_mem.clone());
.insert(charlie_mem.event_id(), Arc::clone(&charlie_mem));
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
.iter()
@ -288,7 +282,7 @@ fn to_pdu_event<S>(
content: JsonValue,
auth_events: &[S],
prev_events: &[S],
) -> StateEvent
) -> Arc<StateEvent>
where
S: AsRef<str>,
{
@ -362,12 +356,12 @@ where
"signatures": {},
})
};
serde_json::from_value(json).unwrap()
Arc::new(serde_json::from_value(json).unwrap())
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
vec![
to_pdu_event::<EventId>(
"CREATE",
@ -449,7 +443,7 @@ fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
// all graphs start with these input events
#[allow(non_snake_case)]
fn BAN_STATE_SET() -> BTreeMap<EventId, StateEvent> {
fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> {
vec![
to_pdu_event(
"PA",

View File

@ -1,4 +1,4 @@
use std::{collections::BTreeMap, convert::TryFrom};
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
use maplit::btreeset;
use ruma::{
@ -72,10 +72,10 @@ pub fn auth_types_for_event(
/// * then there are checks for specific event types
pub fn auth_check(
room_version: &RoomVersionId,
incoming_event: &StateEvent,
prev_event: Option<&StateEvent>,
auth_events: StateMap<StateEvent>,
current_third_party_invite: Option<&StateEvent>,
incoming_event: &Arc<StateEvent>,
prev_event: Option<Arc<StateEvent>>,
auth_events: StateMap<Arc<StateEvent>>,
current_third_party_invite: Option<Arc<StateEvent>>,
) -> Result<bool> {
tracing::info!("auth_check beginning for {}", incoming_event.kind());
@ -206,8 +206,8 @@ pub fn auth_check(
if !valid_membership_change(
incoming_event.to_requester(),
current_third_party_invite,
prev_event,
current_third_party_invite,
&auth_events,
)? {
return Ok(false);
@ -241,7 +241,7 @@ pub fn auth_check(
// If the event type's required power level is greater than the sender's power level, reject
// If the event has a state_key that starts with an @ and does not match the sender, reject.
if !can_send_event(incoming_event, &auth_events)? {
if !can_send_event(&incoming_event, &auth_events)? {
tracing::warn!("user cannot send event");
return Ok(false);
}
@ -250,7 +250,7 @@ pub fn auth_check(
tracing::info!("starting m.room.power_levels check");
if let Some(required_pwr_lvl) =
check_power_levels(room_version, incoming_event, &auth_events)
check_power_levels(room_version, &incoming_event, &auth_events)
{
if !required_pwr_lvl {
tracing::warn!("power level was not allowed");
@ -284,9 +284,9 @@ pub fn auth_check(
/// the current State.
pub fn valid_membership_change(
user: Requester<'_>,
prev_event: Option<&StateEvent>,
current_third_party_invite: Option<&StateEvent>,
auth_events: &StateMap<StateEvent>,
prev_event: Option<Arc<StateEvent>>,
current_third_party_invite: Option<Arc<StateEvent>>,
auth_events: &StateMap<Arc<StateEvent>>,
) -> Result<bool> {
let state_key = if let Some(s) = user.state_key.as_ref() {
s
@ -377,7 +377,7 @@ pub fn valid_membership_change(
.join_rule;
}
if let Some(prev) = prev_event {
if let Some(prev) = dbg!(prev_event) {
if prev.kind() == EventType::RoomCreate && prev.prev_event_ids().is_empty() {
return Ok(true);
}
@ -440,7 +440,7 @@ pub fn valid_membership_change(
/// Is the event's sender in the room that they sent the event to.
pub fn check_event_sender_in_room(
sender: &UserId,
auth_events: &StateMap<StateEvent>,
auth_events: &StateMap<Arc<StateEvent>>,
) -> Option<bool> {
let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?;
Some(
@ -453,7 +453,7 @@ pub fn check_event_sender_in_room(
/// Is the user allowed to send a specific event based on the rooms power levels. Does the event
/// have the correct userId as it's state_key if it's not the "" state_key.
pub fn can_send_event(event: &StateEvent, auth_events: &StateMap<StateEvent>) -> Result<bool> {
pub fn can_send_event(event: &Arc<StateEvent>, auth_events: &StateMap<Arc<StateEvent>>) -> Result<bool> {
let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
let event_type_power_level = get_send_level(event.kind(), event.state_key(), ple);
@ -481,8 +481,8 @@ pub fn can_send_event(event: &StateEvent, auth_events: &StateMap<StateEvent>) ->
/// Confirm that the event sender has the required power levels.
pub fn check_power_levels(
_: &RoomVersionId,
power_event: &StateEvent,
auth_events: &StateMap<StateEvent>,
power_event: &Arc<StateEvent>,
auth_events: &StateMap<Arc<StateEvent>>,
) -> Option<bool> {
let key = (power_event.kind(), power_event.state_key());
let current_state = if let Some(current_state) = auth_events.get(&key) {
@ -627,8 +627,8 @@ fn get_deserialize_levels(
/// Does the event redacting come from a user with enough power to redact the given event.
pub fn check_redaction(
room_version: &RoomVersionId,
redaction_event: &StateEvent,
auth_events: &StateMap<StateEvent>,
redaction_event: &Arc<StateEvent>,
auth_events: &StateMap<Arc<StateEvent>>,
) -> Result<bool> {
let user_level = get_user_power_level(redaction_event.sender(), auth_events);
let redact_level = get_named_level(auth_events, "redact", 50);
@ -662,7 +662,7 @@ pub fn check_redaction(
/// Check that the member event matches `state`.
///
/// This function returns false instead of failing when deserialization fails.
pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipState) -> bool {
pub fn check_membership(member_event: Option<Arc<StateEvent>>, state: MembershipState) -> bool {
if let Some(event) = member_event {
if let Ok(content) =
serde_json::from_value::<room::member::MemberEventContent>(event.content().clone())
@ -677,7 +677,7 @@ pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipStat
}
/// Can this room federate based on its m.room.create event.
pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
pub fn can_federate(auth_events: &StateMap<Arc<StateEvent>>) -> bool {
let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into())));
if let Some(ev) = creation_event {
if let Some(fed) = ev.content().get("m.federate") {
@ -692,7 +692,7 @@ pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content.
/// or return `default` if no power level event is found or zero if no field matches `name`.
pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default: i64) -> i64 {
pub fn get_named_level(auth_events: &StateMap<Arc<StateEvent>>, name: &str, default: i64) -> i64 {
let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
if let Some(pl) = power_level_event {
// TODO do this the right way and deserialize
@ -708,7 +708,7 @@ pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default:
/// Helper function to fetch a users default power level from a "m.room.power_level" event's `users`
/// object.
pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<StateEvent>) -> i64 {
pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<Arc<StateEvent>>) -> i64 {
if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))) {
if let Ok(content) = pl.deserialize_content::<room::power_levels::PowerLevelsEventContent>()
{
@ -744,7 +744,7 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<StateEvent>
pub fn get_send_level(
e_type: EventType,
state_key: Option<String>,
power_lvl: Option<&StateEvent>,
power_lvl: Option<&Arc<StateEvent>>,
) -> i64 {
tracing::debug!("{:?} {:?}", e_type, state_key);
if let Some(ple) = power_lvl {
@ -772,7 +772,7 @@ pub fn get_send_level(
}
/// Check user can send invite.
pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap<StateEvent>) -> Result<bool> {
pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap<Arc<StateEvent>>) -> Result<bool> {
let user_level = get_user_power_level(event.sender, auth_events);
let key = (EventType::RoomPowerLevels, Some("".into()));
let invite_level = auth_events
@ -794,7 +794,7 @@ pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap<StateEvent>
pub fn verify_third_party_invite(
event: &Requester<'_>,
tp_id: &member::ThirdPartyInvite,
current_third_party_invite: Option<&StateEvent>,
current_third_party_invite: Option<Arc<StateEvent>>,
) -> bool {
// 1. check for user being banned happens before this is called
// checking for mxid and token keys is done by ruma when deserializing

View File

@ -2,7 +2,7 @@ use std::{
cmp::Reverse,
collections::{BTreeMap, BTreeSet, BinaryHeap},
time::SystemTime,
};
sync::Arc};
use maplit::btreeset;
use ruma::{
@ -58,7 +58,7 @@ impl StateResolution {
room_id: &RoomId,
room_version: &RoomVersionId,
state_sets: &[StateMap<EventId>],
event_map: Option<EventMap<StateEvent>>,
event_map: Option<EventMap<Arc<StateEvent>>>,
store: &dyn StateStore,
) -> Result<StateMap<EventId>> {
tracing::info!("State resolution starting");
@ -292,7 +292,7 @@ impl StateResolution {
pub fn reverse_topological_power_sort(
room_id: &RoomId,
events_to_sort: &[EventId],
event_map: &mut EventMap<StateEvent>,
event_map: &mut EventMap<Arc<StateEvent>>,
store: &dyn StateStore,
auth_diff: &[EventId],
) -> Vec<EventId> {
@ -418,7 +418,7 @@ impl StateResolution {
fn get_power_level_for_sender(
room_id: &RoomId,
event_id: &EventId,
event_map: &mut EventMap<StateEvent>,
event_map: &mut EventMap<Arc<StateEvent>>,
store: &dyn StateStore,
) -> i64 {
tracing::info!("fetch event ({}) senders power level", event_id.to_string());
@ -476,7 +476,7 @@ impl StateResolution {
room_version: &RoomVersionId,
events_to_check: &[EventId],
unconflicted_state: &StateMap<EventId>,
event_map: &mut EventMap<StateEvent>,
event_map: &mut EventMap<Arc<StateEvent>>,
store: &dyn StateStore,
) -> Result<StateMap<EventId>> {
tracing::info!("starting iterative auth check");
@ -526,8 +526,8 @@ impl StateResolution {
tracing::debug!("event to check {:?}", event.event_id().as_str());
let most_recent_prev_event = event
.prev_event_ids()
let most_recent_prev_event = dbg!(event
.prev_event_ids())
.iter()
.filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store))
.next_back();
@ -545,9 +545,9 @@ impl StateResolution {
if event_auth::auth_check(
room_version,
&event,
most_recent_prev_event.as_ref(),
most_recent_prev_event,
auth_events,
current_third_party.as_ref(),
current_third_party,
)? {
// add event to resolved state map
resolved_state.insert((event.kind(), event.state_key()), event_id.clone());
@ -579,7 +579,7 @@ impl StateResolution {
room_id: &RoomId,
to_sort: &[EventId],
resolved_power_level: Option<&EventId>,
event_map: &mut EventMap<StateEvent>,
event_map: &mut EventMap<Arc<StateEvent>>,
store: &dyn StateStore,
) -> Vec<EventId> {
tracing::debug!("mainline sort of events");
@ -658,14 +658,13 @@ impl StateResolution {
sort_event_ids
}
// TODO make `event` not clone every loop
/// Get the mainline depth from the `mainline_map` or finds a power_level event
/// that has an associated mainline depth.
fn get_mainline_depth(
room_id: &RoomId,
mut event: Option<StateEvent>,
mut event: Option<Arc<StateEvent>>,
mainline_map: &EventMap<usize>,
event_map: &mut EventMap<StateEvent>,
event_map: &mut EventMap<Arc<StateEvent>>,
store: &dyn StateStore,
) -> usize {
while let Some(sort_ev) = event {
@ -681,7 +680,7 @@ impl StateResolution {
let aev =
StateResolution::get_or_load_event(room_id, &aid, event_map, store).unwrap();
if aev.is_type_and_key(EventType::RoomPowerLevels, "") {
event = Some(aev.clone());
event = Some(aev);
break;
}
}
@ -694,7 +693,7 @@ impl StateResolution {
room_id: &RoomId,
graph: &mut BTreeMap<EventId, Vec<EventId>>,
event_id: &EventId,
event_map: &mut EventMap<StateEvent>,
event_map: &mut EventMap<Arc<StateEvent>>,
store: &dyn StateStore,
auth_diff: &[EventId],
) {
@ -730,21 +729,22 @@ impl StateResolution {
fn get_or_load_event(
room_id: &RoomId,
ev_id: &EventId,
event_map: &mut EventMap<StateEvent>,
event_map: &mut EventMap<Arc<StateEvent>>,
store: &dyn StateStore,
) -> Option<StateEvent> {
// TODO can we cut down on the clones?
if !event_map.contains_key(ev_id) {
let event = store.get_event(room_id, ev_id).ok()?;
event_map.insert(ev_id.clone(), event.clone());
Some(event)
} else {
event_map.get(ev_id).cloned()
) -> Option<Arc<StateEvent>> {
if let Some(e) = event_map.get(ev_id) {
return Some(Arc::clone(e));
}
if let Ok(e) = store.get_event(room_id, ev_id) {
return Some(Arc::clone(event_map.entry(ev_id.clone()).or_insert(e)))
}
None
}
}
pub fn is_power_event(event_id: &EventId, event_map: &EventMap<StateEvent>) -> bool {
pub fn is_power_event(event_id: &EventId, event_map: &EventMap<Arc<StateEvent>>) -> bool {
match event_map.get(event_id) {
Some(state) => state.is_power_event(),
_ => false,

View File

@ -202,6 +202,8 @@ impl StateEvent {
},
}
}
#[cfg(not(feature = "unstable-pre-spec"))]
pub fn origin(&self) -> String {
match self {
Self::Full(ev) => match ev {

View File

@ -1,4 +1,4 @@
use std::collections::BTreeSet;
use std::{collections::BTreeSet, sync::Arc};
use ruma::identifiers::{EventId, RoomId};
@ -6,10 +6,10 @@ use crate::{Result, StateEvent};
pub trait StateStore {
/// Return a single event based on the EventId.
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent>;
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>>;
/// Returns the events that correspond to the `event_ids` sorted in the same order.
fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<StateEvent>> {
fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<StateEvent>>> {
let mut events = vec![];
for id in event_ids {
events.push(self.get_event(room_id, id)?);

View File

@ -1,4 +1,4 @@
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom};
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
use ruma::{
events::{
@ -71,15 +71,14 @@ fn member_content_join() -> JsonValue {
.unwrap()
}
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
#[allow(unused)]
impl StateStore for TestStore {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
self.0
.borrow()
.get(event_id)
.cloned()
.map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
}
}
@ -92,7 +91,7 @@ fn to_pdu_event<S>(
content: JsonValue,
auth_events: &[S],
prev_events: &[S],
) -> StateEvent
) -> Arc<StateEvent>
where
S: AsRef<str>,
{
@ -166,12 +165,12 @@ where
"signatures": {},
})
};
serde_json::from_value(json).unwrap()
Arc::new(serde_json::from_value(json).unwrap())
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
// this is always called so we can init the logger here
let _ = LOGGER.call_once(|| {
tracer::fmt()
@ -246,11 +245,12 @@ fn test_ban_pass() {
let prev = events
.values()
.find(|ev| ev.event_id().as_str().contains("IMC"));
.find(|ev| ev.event_id().as_str().contains("IMC"))
.map(Arc::clone);
let auth_events = events
.values()
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone()))
.map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev)))
.collect::<StateMap<_>>();
let requester = Requester {
@ -270,11 +270,12 @@ fn test_ban_fail() {
let prev = events
.values()
.find(|ev| ev.event_id().as_str().contains("IMC"));
.find(|ev| ev.event_id().as_str().contains("IMC"))
.map(Arc::clone);
let auth_events = events
.values()
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone()))
.map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev)))
.collect::<StateMap<_>>();
let requester = Requester {

View File

@ -1,4 +1,4 @@
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom};
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
use ruma::{
events::{
@ -53,15 +53,14 @@ fn member_content_join() -> JsonValue {
.unwrap()
}
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
#[allow(unused)]
impl StateStore for TestStore {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
self.0
.borrow()
.get(event_id)
.cloned()
.map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
}
}
@ -74,7 +73,7 @@ fn to_pdu_event<S>(
content: JsonValue,
auth_events: &[S],
prev_events: &[S],
) -> StateEvent
) -> Arc<StateEvent>
where
S: AsRef<str>,
{
@ -148,12 +147,12 @@ where
"signatures": {},
})
};
serde_json::from_value(json).unwrap()
Arc::new(serde_json::from_value(json).unwrap())
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
// this is always called so we can init the logger here
let _ = LOGGER.call_once(|| {
tracer::fmt()
@ -243,8 +242,7 @@ fn shuffle(list: &mut [EventId]) {
fn test_event_sort() {
let mut events = INITIAL_EVENTS();
let store = TestStore(RefCell::new(events.clone()));
let store = TestStore(events.clone());
let event_map = events
.values()

View File

@ -1,6 +1,6 @@
#![allow(clippy::or_fun_call, clippy::expect_fun_call)]
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, sync::Once, time::UNIX_EPOCH};
use std::{collections::BTreeMap, convert::TryFrom, sync::Once, time::UNIX_EPOCH, sync::Arc};
use ruma::{
events::{
@ -21,7 +21,7 @@ static LOGGER: Once = Once::new();
static mut SERVER_TIMESTAMP: i32 = 0;
fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
fn do_check(events: &[Arc<StateEvent>], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
// to activate logging use `RUST_LOG=debug cargo t`
let _ = LOGGER.call_once(|| {
tracer::fmt()
@ -29,13 +29,13 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
.init()
});
let store = TestStore(RefCell::new(
let mut store = TestStore(
INITIAL_EVENTS()
.values()
.chain(events)
.map(|ev| (ev.event_id(), ev.clone()))
.collect(),
));
);
// This will be lexi_topo_sorted for resolution
let mut graph = BTreeMap::new();
@ -64,7 +64,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
}
// event_id -> StateEvent
let mut event_map: BTreeMap<EventId, StateEvent> = BTreeMap::new();
let mut event_map: BTreeMap<EventId, Arc<StateEvent>> = BTreeMap::new();
// event_id -> StateMap<EventId>
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new();
@ -152,10 +152,10 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
// we have to update our store, an actual user of this lib would
// be giving us state from a DB.
*store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone();
store.0.insert(ev_id.clone(), event.clone());
state_at_event.insert(node, state_after);
event_map.insert(event_id.clone(), event);
event_map.insert(event_id.clone(), Arc::clone(store.0.get(&ev_id).unwrap()));
}
let mut expected_state = StateMap::new();
@ -186,15 +186,14 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
assert_eq!(expected_state, end_state);
}
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
#[allow(unused)]
impl StateStore for TestStore {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
self.0
.borrow()
.get(event_id)
.cloned()
.map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
}
}
@ -256,7 +255,7 @@ fn to_pdu_event<S>(
content: JsonValue,
auth_events: &[S],
prev_events: &[S],
) -> StateEvent
) -> Arc<StateEvent>
where
S: AsRef<str>,
{
@ -330,12 +329,12 @@ where
"signatures": {},
})
};
serde_json::from_value(json).unwrap()
Arc::new(serde_json::from_value(json).unwrap())
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
// this is always called so we can init the logger here
let _ = LOGGER.call_once(|| {
tracer::fmt()
@ -432,7 +431,7 @@ fn INITIAL_EDGES() -> Vec<EventId> {
// all graphs start with these input events
#[allow(non_snake_case)]
fn BAN_STATE_SET() -> BTreeMap<EventId, StateEvent> {
fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> {
vec![
to_pdu_event(
"PA",
@ -499,7 +498,7 @@ fn ban_with_auth_chains() {
#[test]
fn base_with_auth_chains() {
let store = TestStore(RefCell::new(INITIAL_EVENTS()));
let store = TestStore(INITIAL_EVENTS());
let resolved: BTreeMap<_, EventId> =
match StateResolution::resolve(&room_id(), &RoomVersionId::Version2, &[], None, &store) {
@ -537,7 +536,7 @@ fn ban_with_auth_chains2() {
let mut inner = init.clone();
inner.extend(ban);
let store = TestStore(RefCell::new(inner.clone()));
let store = TestStore(inner.clone());
let state_set_a = [
inner.get(&event_id("CREATE")).unwrap(),
@ -607,7 +606,7 @@ fn ban_with_auth_chains2() {
// all graphs start with these input events
#[allow(non_snake_case)]
fn JOIN_RULE() -> BTreeMap<EventId, StateEvent> {
fn JOIN_RULE() -> BTreeMap<EventId, Arc<StateEvent>> {
vec![
to_pdu_event(
"JR",

View File

@ -1,4 +1,4 @@
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
use std::{collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH, sync::Arc};
use maplit::btreemap;
use ruma::{
@ -78,7 +78,7 @@ fn to_pdu_event<S>(
content: JsonValue,
auth_events: &[S],
prev_events: &[S],
) -> StateEvent
) -> Arc<StateEvent>
where
S: AsRef<str>,
{
@ -152,7 +152,7 @@ where
"signatures": {},
})
};
serde_json::from_value(json).unwrap()
Arc::new(serde_json::from_value(json).unwrap())
}
fn to_init_pdu_event(
@ -161,7 +161,7 @@ fn to_init_pdu_event(
ev_type: EventType,
state_key: Option<&str>,
content: JsonValue,
) -> StateEvent {
) -> Arc<StateEvent> {
let ts = unsafe {
let ts = SERVER_TIMESTAMP;
// increment the "origin_server_ts" value
@ -206,12 +206,12 @@ fn to_init_pdu_event(
"signatures": {},
})
};
serde_json::from_value(json).unwrap()
Arc::new(serde_json::from_value(json).unwrap())
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
vec![
to_init_pdu_event(
"CREATE",
@ -280,7 +280,7 @@ fn INITIAL_EDGES() -> Vec<EventId> {
.collect::<Vec<_>>()
}
fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
fn do_check(events: &[Arc<StateEvent>], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
// to activate logging use `RUST_LOG=debug cargo t one_test_only`
let _ = LOGGER.call_once(|| {
tracer::fmt()
@ -288,13 +288,13 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
.init()
});
let store = TestStore(RefCell::new(
let mut store = TestStore(
INITIAL_EVENTS()
.values()
.chain(events)
.map(|ev| (ev.event_id(), ev.clone()))
.collect(),
));
);
// This will be lexi_topo_sorted for resolution
let mut graph = BTreeMap::new();
@ -329,7 +329,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
}
// event_id -> StateEvent
let mut event_map: BTreeMap<EventId, StateEvent> = BTreeMap::new();
let mut event_map: BTreeMap<EventId, Arc<StateEvent>> = BTreeMap::new();
// event_id -> StateMap<EventId>
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new();
@ -420,7 +420,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
// TODO
// TODO we need to convert the `StateResolution::resolve` to use the event_map
// because the user of this crate cannot update their DB's state.
*store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone();
store.0.insert(ev_id.clone(), Arc::clone(&event));
state_at_event.insert(node, state_after);
event_map.insert(event_id.clone(), event);
@ -702,7 +702,7 @@ fn topic_setting() {
#[test]
fn test_event_map_none() {
let store = TestStore(RefCell::new(btreemap! {}));
let mut store = TestStore(btreemap! {});
// build up the DAG
let (state_at_bob, state_at_charlie, expected) = store.set_up();
@ -748,21 +748,20 @@ fn test_lexicographical_sort() {
//
/// The test state store.
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
#[allow(unused)]
impl StateStore for TestStore {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
self.0
.borrow()
.get(event_id)
.cloned()
.map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
}
}
impl TestStore {
pub fn set_up(&self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
pub fn set_up(&mut self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
// to activate logging use `RUST_LOG=debug cargo t one_test_only`
let _ = LOGGER.call_once(|| {
tracer::fmt()
@ -780,8 +779,7 @@ impl TestStore {
);
let cre = create_event.event_id();
self.0
.borrow_mut()
.insert(cre.clone(), create_event.clone());
.insert(cre.clone(), Arc::clone(&create_event));
let alice_mem = to_pdu_event(
"IMA",
@ -793,8 +791,7 @@ impl TestStore {
&[cre.clone()],
);
self.0
.borrow_mut()
.insert(alice_mem.event_id(), alice_mem.clone());
.insert(alice_mem.event_id(), Arc::clone(&alice_mem));
let join_rules = to_pdu_event(
"IJR",
@ -806,7 +803,6 @@ impl TestStore {
&[alice_mem.event_id()],
);
self.0
.borrow_mut()
.insert(join_rules.event_id(), join_rules.clone());
// Bob and Charlie join at the same time, so there is a fork
@ -821,7 +817,6 @@ impl TestStore {
&[join_rules.event_id()],
);
self.0
.borrow_mut()
.insert(bob_mem.event_id(), bob_mem.clone());
let charlie_mem = to_pdu_event(
@ -834,7 +829,6 @@ impl TestStore {
&[join_rules.event_id()],
);
self.0
.borrow_mut()
.insert(charlie_mem.event_id(), charlie_mem.clone());
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]