Fixes
This commit is contained in:
parent
acd829336e
commit
f587b88a60
17
Cargo.toml
17
Cargo.toml
@ -22,15 +22,17 @@ maplit = "1.0.2"
|
||||
thiserror = "1.0.20"
|
||||
tracing-subscriber = "0.2.11"
|
||||
|
||||
# [dependencies.ruma]
|
||||
# path = "../__forks__/ruma/ruma"
|
||||
# features = ["client-api", "federation-api", "appservice-api"]
|
||||
|
||||
[dependencies.ruma]
|
||||
git = "https://github.com/ruma/ruma"
|
||||
rev = "aff914050eb297bd82b8aafb12158c88a9e480e1"
|
||||
path = "../ruma/ruma"
|
||||
features = ["client-api", "federation-api", "appservice-api"]
|
||||
|
||||
#[dependencies.ruma]
|
||||
#git = "https://github.com/ruma/ruma"
|
||||
#rev = "aff914050eb297bd82b8aafb12158c88a9e480e1"
|
||||
#features = ["client-api", "federation-api", "appservice-api"]
|
||||
|
||||
[features]
|
||||
unstable-pre-spec = ["ruma/unstable-pre-spec"]
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.3"
|
||||
@ -38,4 +40,5 @@ rand = "0.7.3"
|
||||
|
||||
[[bench]]
|
||||
name = "state_res_bench"
|
||||
harness = false
|
||||
harness = false
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
// `cargo bench unknown option --save-baseline`.
|
||||
// To pass args to criterion, use this form
|
||||
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
|
||||
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH, sync::Arc};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use maplit::btreemap;
|
||||
@ -42,7 +42,7 @@ fn lexico_topo_sort(c: &mut Criterion) {
|
||||
|
||||
fn resolution_shallow_auth_chain(c: &mut Criterion) {
|
||||
c.bench_function("resolve state of 5 events one fork", |b| {
|
||||
let store = TestStore(RefCell::new(btreemap! {}));
|
||||
let mut store = TestStore(btreemap! {});
|
||||
|
||||
// build up the DAG
|
||||
let (state_at_bob, state_at_charlie, _) = store.set_up();
|
||||
@ -69,7 +69,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
|
||||
|
||||
let mut inner = init;
|
||||
inner.extend(ban);
|
||||
let store = TestStore(RefCell::new(inner.clone()));
|
||||
let store = TestStore(inner.clone());
|
||||
|
||||
let state_set_a = [
|
||||
inner.get(&event_id("CREATE")).unwrap(),
|
||||
@ -126,22 +126,21 @@ criterion_main!(benches);
|
||||
// IMPLEMENTATION DETAILS AHEAD
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////*/
|
||||
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
||||
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
|
||||
|
||||
#[allow(unused)]
|
||||
impl StateStore for TestStore {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
|
||||
self.0
|
||||
.borrow()
|
||||
.get(event_id)
|
||||
.cloned()
|
||||
.map(Arc::clone)
|
||||
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||
}
|
||||
}
|
||||
|
||||
impl TestStore {
|
||||
pub fn set_up(&self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
|
||||
let create_event = to_pdu_event::<EventId>(
|
||||
pub fn set_up(&mut self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
|
||||
let create_event = Arc::new(to_pdu_event::<EventId>(
|
||||
"CREATE",
|
||||
alice(),
|
||||
EventType::RoomCreate,
|
||||
@ -149,11 +148,10 @@ impl TestStore {
|
||||
json!({ "creator": alice() }),
|
||||
&[],
|
||||
&[],
|
||||
);
|
||||
));
|
||||
let cre = create_event.event_id();
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(cre.clone(), create_event.clone());
|
||||
.insert(cre.clone(), Arc::clone(&create_event));
|
||||
|
||||
let alice_mem = to_pdu_event(
|
||||
"IMA",
|
||||
@ -165,8 +163,7 @@ impl TestStore {
|
||||
&[cre.clone()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(alice_mem.event_id(), alice_mem.clone());
|
||||
.insert(alice_mem.event_id(), Arc::clone(&alice_mem));
|
||||
|
||||
let join_rules = to_pdu_event(
|
||||
"IJR",
|
||||
@ -178,8 +175,7 @@ impl TestStore {
|
||||
&[alice_mem.event_id()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(join_rules.event_id(), join_rules.clone());
|
||||
.insert(join_rules.event_id(), Arc::clone(&join_rules));
|
||||
|
||||
// Bob and Charlie join at the same time, so there is a fork
|
||||
// this will be represented in the state_sets when we resolve
|
||||
@ -193,8 +189,7 @@ impl TestStore {
|
||||
&[join_rules.event_id()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(bob_mem.event_id(), bob_mem.clone());
|
||||
.insert(bob_mem.event_id(), Arc::clone(&bob_mem));
|
||||
|
||||
let charlie_mem = to_pdu_event(
|
||||
"IMC",
|
||||
@ -206,8 +201,7 @@ impl TestStore {
|
||||
&[join_rules.event_id()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(charlie_mem.event_id(), charlie_mem.clone());
|
||||
.insert(charlie_mem.event_id(), Arc::clone(&charlie_mem));
|
||||
|
||||
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
|
||||
.iter()
|
||||
@ -288,7 +282,7 @@ fn to_pdu_event<S>(
|
||||
content: JsonValue,
|
||||
auth_events: &[S],
|
||||
prev_events: &[S],
|
||||
) -> StateEvent
|
||||
) -> Arc<StateEvent>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
@ -362,12 +356,12 @@ where
|
||||
"signatures": {},
|
||||
})
|
||||
};
|
||||
serde_json::from_value(json).unwrap()
|
||||
Arc::new(serde_json::from_value(json).unwrap())
|
||||
}
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
vec![
|
||||
to_pdu_event::<EventId>(
|
||||
"CREATE",
|
||||
@ -449,7 +443,7 @@ fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn BAN_STATE_SET() -> BTreeMap<EventId, StateEvent> {
|
||||
fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
vec![
|
||||
to_pdu_event(
|
||||
"PA",
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{collections::BTreeMap, convert::TryFrom};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
|
||||
|
||||
use maplit::btreeset;
|
||||
use ruma::{
|
||||
@ -72,10 +72,10 @@ pub fn auth_types_for_event(
|
||||
/// * then there are checks for specific event types
|
||||
pub fn auth_check(
|
||||
room_version: &RoomVersionId,
|
||||
incoming_event: &StateEvent,
|
||||
prev_event: Option<&StateEvent>,
|
||||
auth_events: StateMap<StateEvent>,
|
||||
current_third_party_invite: Option<&StateEvent>,
|
||||
incoming_event: &Arc<StateEvent>,
|
||||
prev_event: Option<Arc<StateEvent>>,
|
||||
auth_events: StateMap<Arc<StateEvent>>,
|
||||
current_third_party_invite: Option<Arc<StateEvent>>,
|
||||
) -> Result<bool> {
|
||||
tracing::info!("auth_check beginning for {}", incoming_event.kind());
|
||||
|
||||
@ -206,8 +206,8 @@ pub fn auth_check(
|
||||
|
||||
if !valid_membership_change(
|
||||
incoming_event.to_requester(),
|
||||
current_third_party_invite,
|
||||
prev_event,
|
||||
current_third_party_invite,
|
||||
&auth_events,
|
||||
)? {
|
||||
return Ok(false);
|
||||
@ -241,7 +241,7 @@ pub fn auth_check(
|
||||
|
||||
// If the event type's required power level is greater than the sender's power level, reject
|
||||
// If the event has a state_key that starts with an @ and does not match the sender, reject.
|
||||
if !can_send_event(incoming_event, &auth_events)? {
|
||||
if !can_send_event(&incoming_event, &auth_events)? {
|
||||
tracing::warn!("user cannot send event");
|
||||
return Ok(false);
|
||||
}
|
||||
@ -250,7 +250,7 @@ pub fn auth_check(
|
||||
tracing::info!("starting m.room.power_levels check");
|
||||
|
||||
if let Some(required_pwr_lvl) =
|
||||
check_power_levels(room_version, incoming_event, &auth_events)
|
||||
check_power_levels(room_version, &incoming_event, &auth_events)
|
||||
{
|
||||
if !required_pwr_lvl {
|
||||
tracing::warn!("power level was not allowed");
|
||||
@ -284,9 +284,9 @@ pub fn auth_check(
|
||||
/// the current State.
|
||||
pub fn valid_membership_change(
|
||||
user: Requester<'_>,
|
||||
prev_event: Option<&StateEvent>,
|
||||
current_third_party_invite: Option<&StateEvent>,
|
||||
auth_events: &StateMap<StateEvent>,
|
||||
prev_event: Option<Arc<StateEvent>>,
|
||||
current_third_party_invite: Option<Arc<StateEvent>>,
|
||||
auth_events: &StateMap<Arc<StateEvent>>,
|
||||
) -> Result<bool> {
|
||||
let state_key = if let Some(s) = user.state_key.as_ref() {
|
||||
s
|
||||
@ -377,7 +377,7 @@ pub fn valid_membership_change(
|
||||
.join_rule;
|
||||
}
|
||||
|
||||
if let Some(prev) = prev_event {
|
||||
if let Some(prev) = dbg!(prev_event) {
|
||||
if prev.kind() == EventType::RoomCreate && prev.prev_event_ids().is_empty() {
|
||||
return Ok(true);
|
||||
}
|
||||
@ -440,7 +440,7 @@ pub fn valid_membership_change(
|
||||
/// Is the event's sender in the room that they sent the event to.
|
||||
pub fn check_event_sender_in_room(
|
||||
sender: &UserId,
|
||||
auth_events: &StateMap<StateEvent>,
|
||||
auth_events: &StateMap<Arc<StateEvent>>,
|
||||
) -> Option<bool> {
|
||||
let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?;
|
||||
Some(
|
||||
@ -453,7 +453,7 @@ pub fn check_event_sender_in_room(
|
||||
|
||||
/// Is the user allowed to send a specific event based on the rooms power levels. Does the event
|
||||
/// have the correct userId as it's state_key if it's not the "" state_key.
|
||||
pub fn can_send_event(event: &StateEvent, auth_events: &StateMap<StateEvent>) -> Result<bool> {
|
||||
pub fn can_send_event(event: &Arc<StateEvent>, auth_events: &StateMap<Arc<StateEvent>>) -> Result<bool> {
|
||||
let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
|
||||
|
||||
let event_type_power_level = get_send_level(event.kind(), event.state_key(), ple);
|
||||
@ -481,8 +481,8 @@ pub fn can_send_event(event: &StateEvent, auth_events: &StateMap<StateEvent>) ->
|
||||
/// Confirm that the event sender has the required power levels.
|
||||
pub fn check_power_levels(
|
||||
_: &RoomVersionId,
|
||||
power_event: &StateEvent,
|
||||
auth_events: &StateMap<StateEvent>,
|
||||
power_event: &Arc<StateEvent>,
|
||||
auth_events: &StateMap<Arc<StateEvent>>,
|
||||
) -> Option<bool> {
|
||||
let key = (power_event.kind(), power_event.state_key());
|
||||
let current_state = if let Some(current_state) = auth_events.get(&key) {
|
||||
@ -627,8 +627,8 @@ fn get_deserialize_levels(
|
||||
/// Does the event redacting come from a user with enough power to redact the given event.
|
||||
pub fn check_redaction(
|
||||
room_version: &RoomVersionId,
|
||||
redaction_event: &StateEvent,
|
||||
auth_events: &StateMap<StateEvent>,
|
||||
redaction_event: &Arc<StateEvent>,
|
||||
auth_events: &StateMap<Arc<StateEvent>>,
|
||||
) -> Result<bool> {
|
||||
let user_level = get_user_power_level(redaction_event.sender(), auth_events);
|
||||
let redact_level = get_named_level(auth_events, "redact", 50);
|
||||
@ -662,7 +662,7 @@ pub fn check_redaction(
|
||||
/// Check that the member event matches `state`.
|
||||
///
|
||||
/// This function returns false instead of failing when deserialization fails.
|
||||
pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipState) -> bool {
|
||||
pub fn check_membership(member_event: Option<Arc<StateEvent>>, state: MembershipState) -> bool {
|
||||
if let Some(event) = member_event {
|
||||
if let Ok(content) =
|
||||
serde_json::from_value::<room::member::MemberEventContent>(event.content().clone())
|
||||
@ -677,7 +677,7 @@ pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipStat
|
||||
}
|
||||
|
||||
/// Can this room federate based on its m.room.create event.
|
||||
pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
|
||||
pub fn can_federate(auth_events: &StateMap<Arc<StateEvent>>) -> bool {
|
||||
let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into())));
|
||||
if let Some(ev) = creation_event {
|
||||
if let Some(fed) = ev.content().get("m.federate") {
|
||||
@ -692,7 +692,7 @@ pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
|
||||
|
||||
/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content.
|
||||
/// or return `default` if no power level event is found or zero if no field matches `name`.
|
||||
pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default: i64) -> i64 {
|
||||
pub fn get_named_level(auth_events: &StateMap<Arc<StateEvent>>, name: &str, default: i64) -> i64 {
|
||||
let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
|
||||
if let Some(pl) = power_level_event {
|
||||
// TODO do this the right way and deserialize
|
||||
@ -708,7 +708,7 @@ pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default:
|
||||
|
||||
/// Helper function to fetch a users default power level from a "m.room.power_level" event's `users`
|
||||
/// object.
|
||||
pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<StateEvent>) -> i64 {
|
||||
pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<Arc<StateEvent>>) -> i64 {
|
||||
if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))) {
|
||||
if let Ok(content) = pl.deserialize_content::<room::power_levels::PowerLevelsEventContent>()
|
||||
{
|
||||
@ -744,7 +744,7 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<StateEvent>
|
||||
pub fn get_send_level(
|
||||
e_type: EventType,
|
||||
state_key: Option<String>,
|
||||
power_lvl: Option<&StateEvent>,
|
||||
power_lvl: Option<&Arc<StateEvent>>,
|
||||
) -> i64 {
|
||||
tracing::debug!("{:?} {:?}", e_type, state_key);
|
||||
if let Some(ple) = power_lvl {
|
||||
@ -772,7 +772,7 @@ pub fn get_send_level(
|
||||
}
|
||||
|
||||
/// Check user can send invite.
|
||||
pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap<StateEvent>) -> Result<bool> {
|
||||
pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap<Arc<StateEvent>>) -> Result<bool> {
|
||||
let user_level = get_user_power_level(event.sender, auth_events);
|
||||
let key = (EventType::RoomPowerLevels, Some("".into()));
|
||||
let invite_level = auth_events
|
||||
@ -794,7 +794,7 @@ pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap<StateEvent>
|
||||
pub fn verify_third_party_invite(
|
||||
event: &Requester<'_>,
|
||||
tp_id: &member::ThirdPartyInvite,
|
||||
current_third_party_invite: Option<&StateEvent>,
|
||||
current_third_party_invite: Option<Arc<StateEvent>>,
|
||||
) -> bool {
|
||||
// 1. check for user being banned happens before this is called
|
||||
// checking for mxid and token keys is done by ruma when deserializing
|
||||
|
50
src/lib.rs
50
src/lib.rs
@ -2,7 +2,7 @@ use std::{
|
||||
cmp::Reverse,
|
||||
collections::{BTreeMap, BTreeSet, BinaryHeap},
|
||||
time::SystemTime,
|
||||
};
|
||||
sync::Arc};
|
||||
|
||||
use maplit::btreeset;
|
||||
use ruma::{
|
||||
@ -58,7 +58,7 @@ impl StateResolution {
|
||||
room_id: &RoomId,
|
||||
room_version: &RoomVersionId,
|
||||
state_sets: &[StateMap<EventId>],
|
||||
event_map: Option<EventMap<StateEvent>>,
|
||||
event_map: Option<EventMap<Arc<StateEvent>>>,
|
||||
store: &dyn StateStore,
|
||||
) -> Result<StateMap<EventId>> {
|
||||
tracing::info!("State resolution starting");
|
||||
@ -292,7 +292,7 @@ impl StateResolution {
|
||||
pub fn reverse_topological_power_sort(
|
||||
room_id: &RoomId,
|
||||
events_to_sort: &[EventId],
|
||||
event_map: &mut EventMap<StateEvent>,
|
||||
event_map: &mut EventMap<Arc<StateEvent>>,
|
||||
store: &dyn StateStore,
|
||||
auth_diff: &[EventId],
|
||||
) -> Vec<EventId> {
|
||||
@ -418,7 +418,7 @@ impl StateResolution {
|
||||
fn get_power_level_for_sender(
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
event_map: &mut EventMap<StateEvent>,
|
||||
event_map: &mut EventMap<Arc<StateEvent>>,
|
||||
store: &dyn StateStore,
|
||||
) -> i64 {
|
||||
tracing::info!("fetch event ({}) senders power level", event_id.to_string());
|
||||
@ -476,7 +476,7 @@ impl StateResolution {
|
||||
room_version: &RoomVersionId,
|
||||
events_to_check: &[EventId],
|
||||
unconflicted_state: &StateMap<EventId>,
|
||||
event_map: &mut EventMap<StateEvent>,
|
||||
event_map: &mut EventMap<Arc<StateEvent>>,
|
||||
store: &dyn StateStore,
|
||||
) -> Result<StateMap<EventId>> {
|
||||
tracing::info!("starting iterative auth check");
|
||||
@ -526,8 +526,8 @@ impl StateResolution {
|
||||
|
||||
tracing::debug!("event to check {:?}", event.event_id().as_str());
|
||||
|
||||
let most_recent_prev_event = event
|
||||
.prev_event_ids()
|
||||
let most_recent_prev_event = dbg!(event
|
||||
.prev_event_ids())
|
||||
.iter()
|
||||
.filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store))
|
||||
.next_back();
|
||||
@ -545,9 +545,9 @@ impl StateResolution {
|
||||
if event_auth::auth_check(
|
||||
room_version,
|
||||
&event,
|
||||
most_recent_prev_event.as_ref(),
|
||||
most_recent_prev_event,
|
||||
auth_events,
|
||||
current_third_party.as_ref(),
|
||||
current_third_party,
|
||||
)? {
|
||||
// add event to resolved state map
|
||||
resolved_state.insert((event.kind(), event.state_key()), event_id.clone());
|
||||
@ -579,7 +579,7 @@ impl StateResolution {
|
||||
room_id: &RoomId,
|
||||
to_sort: &[EventId],
|
||||
resolved_power_level: Option<&EventId>,
|
||||
event_map: &mut EventMap<StateEvent>,
|
||||
event_map: &mut EventMap<Arc<StateEvent>>,
|
||||
store: &dyn StateStore,
|
||||
) -> Vec<EventId> {
|
||||
tracing::debug!("mainline sort of events");
|
||||
@ -658,14 +658,13 @@ impl StateResolution {
|
||||
sort_event_ids
|
||||
}
|
||||
|
||||
// TODO make `event` not clone every loop
|
||||
/// Get the mainline depth from the `mainline_map` or finds a power_level event
|
||||
/// that has an associated mainline depth.
|
||||
fn get_mainline_depth(
|
||||
room_id: &RoomId,
|
||||
mut event: Option<StateEvent>,
|
||||
mut event: Option<Arc<StateEvent>>,
|
||||
mainline_map: &EventMap<usize>,
|
||||
event_map: &mut EventMap<StateEvent>,
|
||||
event_map: &mut EventMap<Arc<StateEvent>>,
|
||||
store: &dyn StateStore,
|
||||
) -> usize {
|
||||
while let Some(sort_ev) = event {
|
||||
@ -681,7 +680,7 @@ impl StateResolution {
|
||||
let aev =
|
||||
StateResolution::get_or_load_event(room_id, &aid, event_map, store).unwrap();
|
||||
if aev.is_type_and_key(EventType::RoomPowerLevels, "") {
|
||||
event = Some(aev.clone());
|
||||
event = Some(aev);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -694,7 +693,7 @@ impl StateResolution {
|
||||
room_id: &RoomId,
|
||||
graph: &mut BTreeMap<EventId, Vec<EventId>>,
|
||||
event_id: &EventId,
|
||||
event_map: &mut EventMap<StateEvent>,
|
||||
event_map: &mut EventMap<Arc<StateEvent>>,
|
||||
store: &dyn StateStore,
|
||||
auth_diff: &[EventId],
|
||||
) {
|
||||
@ -730,21 +729,22 @@ impl StateResolution {
|
||||
fn get_or_load_event(
|
||||
room_id: &RoomId,
|
||||
ev_id: &EventId,
|
||||
event_map: &mut EventMap<StateEvent>,
|
||||
event_map: &mut EventMap<Arc<StateEvent>>,
|
||||
store: &dyn StateStore,
|
||||
) -> Option<StateEvent> {
|
||||
// TODO can we cut down on the clones?
|
||||
if !event_map.contains_key(ev_id) {
|
||||
let event = store.get_event(room_id, ev_id).ok()?;
|
||||
event_map.insert(ev_id.clone(), event.clone());
|
||||
Some(event)
|
||||
} else {
|
||||
event_map.get(ev_id).cloned()
|
||||
) -> Option<Arc<StateEvent>> {
|
||||
if let Some(e) = event_map.get(ev_id) {
|
||||
return Some(Arc::clone(e));
|
||||
}
|
||||
|
||||
if let Ok(e) = store.get_event(room_id, ev_id) {
|
||||
return Some(Arc::clone(event_map.entry(ev_id.clone()).or_insert(e)))
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_power_event(event_id: &EventId, event_map: &EventMap<StateEvent>) -> bool {
|
||||
pub fn is_power_event(event_id: &EventId, event_map: &EventMap<Arc<StateEvent>>) -> bool {
|
||||
match event_map.get(event_id) {
|
||||
Some(state) => state.is_power_event(),
|
||||
_ => false,
|
||||
|
@ -202,6 +202,8 @@ impl StateEvent {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "unstable-pre-spec"))]
|
||||
pub fn origin(&self) -> String {
|
||||
match self {
|
||||
Self::Full(ev) => match ev {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::{collections::BTreeSet, sync::Arc};
|
||||
|
||||
use ruma::identifiers::{EventId, RoomId};
|
||||
|
||||
@ -6,10 +6,10 @@ use crate::{Result, StateEvent};
|
||||
|
||||
pub trait StateStore {
|
||||
/// Return a single event based on the EventId.
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent>;
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>>;
|
||||
|
||||
/// Returns the events that correspond to the `event_ids` sorted in the same order.
|
||||
fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<StateEvent>> {
|
||||
fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<StateEvent>>> {
|
||||
let mut events = vec![];
|
||||
for id in event_ids {
|
||||
events.push(self.get_event(room_id, id)?);
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
events::{
|
||||
@ -71,15 +71,14 @@ fn member_content_join() -> JsonValue {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
||||
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
|
||||
|
||||
#[allow(unused)]
|
||||
impl StateStore for TestStore {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
|
||||
self.0
|
||||
.borrow()
|
||||
.get(event_id)
|
||||
.cloned()
|
||||
.map(Arc::clone)
|
||||
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||
}
|
||||
}
|
||||
@ -92,7 +91,7 @@ fn to_pdu_event<S>(
|
||||
content: JsonValue,
|
||||
auth_events: &[S],
|
||||
prev_events: &[S],
|
||||
) -> StateEvent
|
||||
) -> Arc<StateEvent>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
@ -166,12 +165,12 @@ where
|
||||
"signatures": {},
|
||||
})
|
||||
};
|
||||
serde_json::from_value(json).unwrap()
|
||||
Arc::new(serde_json::from_value(json).unwrap())
|
||||
}
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
// this is always called so we can init the logger here
|
||||
let _ = LOGGER.call_once(|| {
|
||||
tracer::fmt()
|
||||
@ -246,11 +245,12 @@ fn test_ban_pass() {
|
||||
|
||||
let prev = events
|
||||
.values()
|
||||
.find(|ev| ev.event_id().as_str().contains("IMC"));
|
||||
.find(|ev| ev.event_id().as_str().contains("IMC"))
|
||||
.map(Arc::clone);
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone()))
|
||||
.map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev)))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = Requester {
|
||||
@ -270,11 +270,12 @@ fn test_ban_fail() {
|
||||
|
||||
let prev = events
|
||||
.values()
|
||||
.find(|ev| ev.event_id().as_str().contains("IMC"));
|
||||
.find(|ev| ev.event_id().as_str().contains("IMC"))
|
||||
.map(Arc::clone);
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone()))
|
||||
.map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev)))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = Requester {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
events::{
|
||||
@ -53,15 +53,14 @@ fn member_content_join() -> JsonValue {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
||||
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
|
||||
|
||||
#[allow(unused)]
|
||||
impl StateStore for TestStore {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
|
||||
self.0
|
||||
.borrow()
|
||||
.get(event_id)
|
||||
.cloned()
|
||||
.map(Arc::clone)
|
||||
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||
}
|
||||
}
|
||||
@ -74,7 +73,7 @@ fn to_pdu_event<S>(
|
||||
content: JsonValue,
|
||||
auth_events: &[S],
|
||||
prev_events: &[S],
|
||||
) -> StateEvent
|
||||
) -> Arc<StateEvent>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
@ -148,12 +147,12 @@ where
|
||||
"signatures": {},
|
||||
})
|
||||
};
|
||||
serde_json::from_value(json).unwrap()
|
||||
Arc::new(serde_json::from_value(json).unwrap())
|
||||
}
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
// this is always called so we can init the logger here
|
||||
let _ = LOGGER.call_once(|| {
|
||||
tracer::fmt()
|
||||
@ -243,8 +242,7 @@ fn shuffle(list: &mut [EventId]) {
|
||||
|
||||
fn test_event_sort() {
|
||||
let mut events = INITIAL_EVENTS();
|
||||
|
||||
let store = TestStore(RefCell::new(events.clone()));
|
||||
let store = TestStore(events.clone());
|
||||
|
||||
let event_map = events
|
||||
.values()
|
||||
|
@ -1,6 +1,6 @@
|
||||
#![allow(clippy::or_fun_call, clippy::expect_fun_call)]
|
||||
|
||||
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, sync::Once, time::UNIX_EPOCH};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, sync::Once, time::UNIX_EPOCH, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
events::{
|
||||
@ -21,7 +21,7 @@ static LOGGER: Once = Once::new();
|
||||
|
||||
static mut SERVER_TIMESTAMP: i32 = 0;
|
||||
|
||||
fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
|
||||
fn do_check(events: &[Arc<StateEvent>], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
|
||||
// to activate logging use `RUST_LOG=debug cargo t`
|
||||
let _ = LOGGER.call_once(|| {
|
||||
tracer::fmt()
|
||||
@ -29,13 +29,13 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
||||
.init()
|
||||
});
|
||||
|
||||
let store = TestStore(RefCell::new(
|
||||
let mut store = TestStore(
|
||||
INITIAL_EVENTS()
|
||||
.values()
|
||||
.chain(events)
|
||||
.map(|ev| (ev.event_id(), ev.clone()))
|
||||
.collect(),
|
||||
));
|
||||
);
|
||||
|
||||
// This will be lexi_topo_sorted for resolution
|
||||
let mut graph = BTreeMap::new();
|
||||
@ -64,7 +64,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
||||
}
|
||||
|
||||
// event_id -> StateEvent
|
||||
let mut event_map: BTreeMap<EventId, StateEvent> = BTreeMap::new();
|
||||
let mut event_map: BTreeMap<EventId, Arc<StateEvent>> = BTreeMap::new();
|
||||
// event_id -> StateMap<EventId>
|
||||
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new();
|
||||
|
||||
@ -152,10 +152,10 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
||||
|
||||
// we have to update our store, an actual user of this lib would
|
||||
// be giving us state from a DB.
|
||||
*store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone();
|
||||
store.0.insert(ev_id.clone(), event.clone());
|
||||
|
||||
state_at_event.insert(node, state_after);
|
||||
event_map.insert(event_id.clone(), event);
|
||||
event_map.insert(event_id.clone(), Arc::clone(store.0.get(&ev_id).unwrap()));
|
||||
}
|
||||
|
||||
let mut expected_state = StateMap::new();
|
||||
@ -186,15 +186,14 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
||||
|
||||
assert_eq!(expected_state, end_state);
|
||||
}
|
||||
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
||||
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
|
||||
|
||||
#[allow(unused)]
|
||||
impl StateStore for TestStore {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
|
||||
self.0
|
||||
.borrow()
|
||||
.get(event_id)
|
||||
.cloned()
|
||||
.map(Arc::clone)
|
||||
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||
}
|
||||
}
|
||||
@ -256,7 +255,7 @@ fn to_pdu_event<S>(
|
||||
content: JsonValue,
|
||||
auth_events: &[S],
|
||||
prev_events: &[S],
|
||||
) -> StateEvent
|
||||
) -> Arc<StateEvent>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
@ -330,12 +329,12 @@ where
|
||||
"signatures": {},
|
||||
})
|
||||
};
|
||||
serde_json::from_value(json).unwrap()
|
||||
Arc::new(serde_json::from_value(json).unwrap())
|
||||
}
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
// this is always called so we can init the logger here
|
||||
let _ = LOGGER.call_once(|| {
|
||||
tracer::fmt()
|
||||
@ -432,7 +431,7 @@ fn INITIAL_EDGES() -> Vec<EventId> {
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn BAN_STATE_SET() -> BTreeMap<EventId, StateEvent> {
|
||||
fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
vec![
|
||||
to_pdu_event(
|
||||
"PA",
|
||||
@ -499,7 +498,7 @@ fn ban_with_auth_chains() {
|
||||
|
||||
#[test]
|
||||
fn base_with_auth_chains() {
|
||||
let store = TestStore(RefCell::new(INITIAL_EVENTS()));
|
||||
let store = TestStore(INITIAL_EVENTS());
|
||||
|
||||
let resolved: BTreeMap<_, EventId> =
|
||||
match StateResolution::resolve(&room_id(), &RoomVersionId::Version2, &[], None, &store) {
|
||||
@ -537,7 +536,7 @@ fn ban_with_auth_chains2() {
|
||||
|
||||
let mut inner = init.clone();
|
||||
inner.extend(ban);
|
||||
let store = TestStore(RefCell::new(inner.clone()));
|
||||
let store = TestStore(inner.clone());
|
||||
|
||||
let state_set_a = [
|
||||
inner.get(&event_id("CREATE")).unwrap(),
|
||||
@ -607,7 +606,7 @@ fn ban_with_auth_chains2() {
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn JOIN_RULE() -> BTreeMap<EventId, StateEvent> {
|
||||
fn JOIN_RULE() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
vec![
|
||||
to_pdu_event(
|
||||
"JR",
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH, sync::Arc};
|
||||
|
||||
use maplit::btreemap;
|
||||
use ruma::{
|
||||
@ -78,7 +78,7 @@ fn to_pdu_event<S>(
|
||||
content: JsonValue,
|
||||
auth_events: &[S],
|
||||
prev_events: &[S],
|
||||
) -> StateEvent
|
||||
) -> Arc<StateEvent>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
@ -152,7 +152,7 @@ where
|
||||
"signatures": {},
|
||||
})
|
||||
};
|
||||
serde_json::from_value(json).unwrap()
|
||||
Arc::new(serde_json::from_value(json).unwrap())
|
||||
}
|
||||
|
||||
fn to_init_pdu_event(
|
||||
@ -161,7 +161,7 @@ fn to_init_pdu_event(
|
||||
ev_type: EventType,
|
||||
state_key: Option<&str>,
|
||||
content: JsonValue,
|
||||
) -> StateEvent {
|
||||
) -> Arc<StateEvent> {
|
||||
let ts = unsafe {
|
||||
let ts = SERVER_TIMESTAMP;
|
||||
// increment the "origin_server_ts" value
|
||||
@ -206,12 +206,12 @@ fn to_init_pdu_event(
|
||||
"signatures": {},
|
||||
})
|
||||
};
|
||||
serde_json::from_value(json).unwrap()
|
||||
Arc::new(serde_json::from_value(json).unwrap())
|
||||
}
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
|
||||
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
|
||||
vec![
|
||||
to_init_pdu_event(
|
||||
"CREATE",
|
||||
@ -280,7 +280,7 @@ fn INITIAL_EDGES() -> Vec<EventId> {
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
|
||||
fn do_check(events: &[Arc<StateEvent>], edges: Vec<Vec<EventId>>, expected_state_ids: Vec<EventId>) {
|
||||
// to activate logging use `RUST_LOG=debug cargo t one_test_only`
|
||||
let _ = LOGGER.call_once(|| {
|
||||
tracer::fmt()
|
||||
@ -288,13 +288,13 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
||||
.init()
|
||||
});
|
||||
|
||||
let store = TestStore(RefCell::new(
|
||||
let mut store = TestStore(
|
||||
INITIAL_EVENTS()
|
||||
.values()
|
||||
.chain(events)
|
||||
.map(|ev| (ev.event_id(), ev.clone()))
|
||||
.collect(),
|
||||
));
|
||||
);
|
||||
|
||||
// This will be lexi_topo_sorted for resolution
|
||||
let mut graph = BTreeMap::new();
|
||||
@ -329,7 +329,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
||||
}
|
||||
|
||||
// event_id -> StateEvent
|
||||
let mut event_map: BTreeMap<EventId, StateEvent> = BTreeMap::new();
|
||||
let mut event_map: BTreeMap<EventId, Arc<StateEvent>> = BTreeMap::new();
|
||||
// event_id -> StateMap<EventId>
|
||||
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new();
|
||||
|
||||
@ -420,7 +420,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
||||
// TODO
|
||||
// TODO we need to convert the `StateResolution::resolve` to use the event_map
|
||||
// because the user of this crate cannot update their DB's state.
|
||||
*store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone();
|
||||
store.0.insert(ev_id.clone(), Arc::clone(&event));
|
||||
|
||||
state_at_event.insert(node, state_after);
|
||||
event_map.insert(event_id.clone(), event);
|
||||
@ -702,7 +702,7 @@ fn topic_setting() {
|
||||
|
||||
#[test]
|
||||
fn test_event_map_none() {
|
||||
let store = TestStore(RefCell::new(btreemap! {}));
|
||||
let mut store = TestStore(btreemap! {});
|
||||
|
||||
// build up the DAG
|
||||
let (state_at_bob, state_at_charlie, expected) = store.set_up();
|
||||
@ -748,21 +748,20 @@ fn test_lexicographical_sort() {
|
||||
//
|
||||
|
||||
/// The test state store.
|
||||
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
||||
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>);
|
||||
|
||||
#[allow(unused)]
|
||||
impl StateStore for TestStore {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> {
|
||||
self.0
|
||||
.borrow()
|
||||
.get(event_id)
|
||||
.cloned()
|
||||
.map(Arc::clone)
|
||||
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||
}
|
||||
}
|
||||
|
||||
impl TestStore {
|
||||
pub fn set_up(&self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
|
||||
pub fn set_up(&mut self) -> (StateMap<EventId>, StateMap<EventId>, StateMap<EventId>) {
|
||||
// to activate logging use `RUST_LOG=debug cargo t one_test_only`
|
||||
let _ = LOGGER.call_once(|| {
|
||||
tracer::fmt()
|
||||
@ -780,8 +779,7 @@ impl TestStore {
|
||||
);
|
||||
let cre = create_event.event_id();
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(cre.clone(), create_event.clone());
|
||||
.insert(cre.clone(), Arc::clone(&create_event));
|
||||
|
||||
let alice_mem = to_pdu_event(
|
||||
"IMA",
|
||||
@ -793,8 +791,7 @@ impl TestStore {
|
||||
&[cre.clone()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(alice_mem.event_id(), alice_mem.clone());
|
||||
.insert(alice_mem.event_id(), Arc::clone(&alice_mem));
|
||||
|
||||
let join_rules = to_pdu_event(
|
||||
"IJR",
|
||||
@ -806,7 +803,6 @@ impl TestStore {
|
||||
&[alice_mem.event_id()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(join_rules.event_id(), join_rules.clone());
|
||||
|
||||
// Bob and Charlie join at the same time, so there is a fork
|
||||
@ -821,7 +817,6 @@ impl TestStore {
|
||||
&[join_rules.event_id()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(bob_mem.event_id(), bob_mem.clone());
|
||||
|
||||
let charlie_mem = to_pdu_event(
|
||||
@ -834,7 +829,6 @@ impl TestStore {
|
||||
&[join_rules.event_id()],
|
||||
);
|
||||
self.0
|
||||
.borrow_mut()
|
||||
.insert(charlie_mem.event_id(), charlie_mem.clone());
|
||||
|
||||
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
|
||||
|
Loading…
x
Reference in New Issue
Block a user