Convert state-res to use possible ruma::ServerPdu

This commit is contained in:
Devin Ragotzy 2020-12-28 17:46:41 -05:00
parent 5299679c21
commit 05a4dd1bf0
9 changed files with 141 additions and 143 deletions

View File

@ -26,16 +26,6 @@ branch = "server-pdu"
# rev = "45d01011554f9d07739e9a5edf5498d8ac16f273" # rev = "45d01011554f9d07739e9a5edf5498d8ac16f273"
features = ["client-api", "federation-api", "appservice-api", "unstable-pre-spec", "unstable-synapse-quirks"] features = ["client-api", "federation-api", "appservice-api", "unstable-pre-spec", "unstable-synapse-quirks"]
#[dependencies.ruma]
#path = "../ruma/ruma"
#features = ["client-api", "federation-api", "appservice-api"]
# [dependencies.ruma]
# git = "https://github.com/timokoesters/ruma"
# branch = "timo-fed-fixes"
# #rev = "aff914050eb297bd82b8aafb12158c88a9e480e1"
# features = ["client-api", "federation-api", "appservice-api"]
[features] [features]
default = ["unstable-pre-spec"] default = ["unstable-pre-spec"]
gen-eventid = [] gen-eventid = []

View File

@ -10,6 +10,7 @@ use criterion::{criterion_group, criterion_main, Criterion};
use maplit::btreemap; use maplit::btreemap;
use ruma::{ use ruma::{
events::{ events::{
pdu::ServerPdu,
room::{ room::{
join_rules::JoinRule, join_rules::JoinRule,
member::{MemberEventContent, MembershipState}, member::{MemberEventContent, MembershipState},
@ -19,7 +20,7 @@ use ruma::{
identifiers::{EventId, RoomId, RoomVersionId, UserId}, identifiers::{EventId, RoomId, RoomVersionId, UserId},
}; };
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
use state_res::{Error, Result, StateEvent, StateMap, StateResolution, StateStore}; use state_res::{Error, Result, StateMap, StateResolution, StateStore};
static mut SERVER_TIMESTAMP: i32 = 0; static mut SERVER_TIMESTAMP: i32 = 0;
@ -81,7 +82,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
inner.get(&event_id("PA")).unwrap(), inner.get(&event_id("PA")).unwrap(),
] ]
.iter() .iter()
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .map(|ev| ((ev.kind.clone(), ev.state_key.clone()), ev.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let state_set_b = [ let state_set_b = [
@ -94,7 +95,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
inner.get(&event_id("PA")).unwrap(), inner.get(&event_id("PA")).unwrap(),
] ]
.iter() .iter()
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .map(|ev| ((ev.kind.clone(), ev.state_key.clone()), ev.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
b.iter(|| { b.iter(|| {
@ -114,7 +115,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
criterion_group!( criterion_group!(
benches, benches,
lexico_topo_sort, // lexico_topo_sort,
resolution_shallow_auth_chain, resolution_shallow_auth_chain,
resolve_deeper_event_set resolve_deeper_event_set
); );
@ -126,11 +127,11 @@ criterion_main!(benches);
// IMPLEMENTATION DETAILS AHEAD // IMPLEMENTATION DETAILS AHEAD
// //
/////////////////////////////////////////////////////////////////////*/ /////////////////////////////////////////////////////////////////////*/
pub struct TestStore(BTreeMap<EventId, Arc<StateEvent>>); pub struct TestStore(BTreeMap<EventId, Arc<ServerPdu>>);
#[allow(unused)] #[allow(unused)]
impl StateStore for TestStore { impl StateStore for TestStore {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> { fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<ServerPdu>> {
self.0 self.0
.get(event_id) .get(event_id)
.map(Arc::clone) .map(Arc::clone)
@ -149,7 +150,7 @@ impl TestStore {
&[], &[],
&[], &[],
)); ));
let cre = create_event.event_id(); let cre = create_event.event_id.clone();
self.0.insert(cre.clone(), Arc::clone(&create_event)); self.0.insert(cre.clone(), Arc::clone(&create_event));
let alice_mem = to_pdu_event( let alice_mem = to_pdu_event(
@ -161,7 +162,8 @@ impl TestStore {
&[cre.clone()], &[cre.clone()],
&[cre.clone()], &[cre.clone()],
); );
self.0.insert(alice_mem.event_id(), Arc::clone(&alice_mem)); self.0
.insert(alice_mem.event_id.clone(), Arc::clone(&alice_mem));
let join_rules = to_pdu_event( let join_rules = to_pdu_event(
"IJR", "IJR",
@ -169,11 +171,11 @@ impl TestStore {
EventType::RoomJoinRules, EventType::RoomJoinRules,
Some(""), Some(""),
json!({ "join_rule": JoinRule::Public }), json!({ "join_rule": JoinRule::Public }),
&[cre.clone(), alice_mem.event_id()], &[cre.clone(), alice_mem.event_id.clone()],
&[alice_mem.event_id()], &[alice_mem.event_id.clone()],
); );
self.0 self.0
.insert(join_rules.event_id(), Arc::clone(&join_rules)); .insert(join_rules.event_id.clone(), Arc::clone(&join_rules));
// Bob and Charlie join at the same time, so there is a fork // Bob and Charlie join at the same time, so there is a fork
// this will be represented in the state_sets when we resolve // this will be represented in the state_sets when we resolve
@ -183,10 +185,11 @@ impl TestStore {
EventType::RoomMember, EventType::RoomMember,
Some(bob().to_string().as_str()), Some(bob().to_string().as_str()),
member_content_join(), member_content_join(),
&[cre.clone(), join_rules.event_id()], &[cre.clone(), join_rules.event_id.clone()],
&[join_rules.event_id()], &[join_rules.event_id.clone()],
); );
self.0.insert(bob_mem.event_id(), Arc::clone(&bob_mem)); self.0
.insert(bob_mem.event_id.clone(), Arc::clone(&bob_mem));
let charlie_mem = to_pdu_event( let charlie_mem = to_pdu_event(
"IMC", "IMC",
@ -194,20 +197,20 @@ impl TestStore {
EventType::RoomMember, EventType::RoomMember,
Some(charlie().to_string().as_str()), Some(charlie().to_string().as_str()),
member_content_join(), member_content_join(),
&[cre, join_rules.event_id()], &[cre, join_rules.event_id.clone()],
&[join_rules.event_id()], &[join_rules.event_id.clone()],
); );
self.0 self.0
.insert(charlie_mem.event_id(), Arc::clone(&charlie_mem)); .insert(charlie_mem.event_id.clone(), Arc::clone(&charlie_mem));
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
.iter() .iter()
.map(|e| ((e.kind(), e.state_key()), e.event_id())) .map(|e| ((e.kind.clone(), e.state_key.clone()), e.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem]
.iter() .iter()
.map(|e| ((e.kind(), e.state_key()), e.event_id())) .map(|e| ((e.kind.clone(), e.state_key.clone()), e.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let expected = [ let expected = [
@ -218,7 +221,7 @@ impl TestStore {
&charlie_mem, &charlie_mem,
] ]
.iter() .iter()
.map(|e| ((e.kind(), e.state_key()), e.event_id())) .map(|e| ((e.kind.clone(), e.state_key.clone()), e.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
(state_at_bob, state_at_charlie, expected) (state_at_bob, state_at_charlie, expected)
@ -279,7 +282,7 @@ fn to_pdu_event<S>(
content: JsonValue, content: JsonValue,
auth_events: &[S], auth_events: &[S],
prev_events: &[S], prev_events: &[S],
) -> Arc<StateEvent> ) -> Arc<ServerPdu>
where where
S: AsRef<str>, S: AsRef<str>,
{ {
@ -342,7 +345,7 @@ where
// all graphs start with these input events // all graphs start with these input events
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> { fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<ServerPdu>> {
vec![ vec![
to_pdu_event::<EventId>( to_pdu_event::<EventId>(
"CREATE", "CREATE",
@ -418,13 +421,13 @@ fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
), ),
] ]
.into_iter() .into_iter()
.map(|ev| (ev.event_id(), ev)) .map(|ev| (ev.event_id.clone(), ev))
.collect() .collect()
} }
// all graphs start with these input events // all graphs start with these input events
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> { fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<ServerPdu>> {
vec![ vec![
to_pdu_event( to_pdu_event(
"PA", "PA",
@ -464,6 +467,6 @@ fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> {
), ),
] ]
.into_iter() .into_iter()
.map(|ev| (ev.event_id(), ev)) .map(|ev| (ev.event_id.clone(), ev))
.collect() .collect()
} }

View File

@ -25,7 +25,7 @@ pub fn auth_types_for_event(
state_key: Option<String>, state_key: Option<String>,
content: serde_json::Value, content: serde_json::Value,
) -> Vec<(EventType, Option<String>)> { ) -> Vec<(EventType, Option<String>)> {
if kind == EventType::RoomCreate { if kind == &EventType::RoomCreate {
return vec![]; return vec![];
} }
@ -35,7 +35,7 @@ pub fn auth_types_for_event(
(EventType::RoomCreate, Some("".to_string())), (EventType::RoomCreate, Some("".to_string())),
]; ];
if kind == EventType::RoomMember { if kind == &EventType::RoomMember {
if let Ok(content) = serde_json::from_value::<room::member::MemberEventContent>(content) { if let Ok(content) = serde_json::from_value::<room::member::MemberEventContent>(content) {
if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) { if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) {
let key = (EventType::RoomJoinRules, Some("".into())); let key = (EventType::RoomJoinRules, Some("".into()));
@ -332,7 +332,7 @@ pub fn valid_membership_change(
}) })
}, },
|power_levels| { |power_levels| {
serde_json::from_value::<PowerLevelsEventContent>(power_levels.content) serde_json::from_value::<PowerLevelsEventContent>(power_levels.content.clone())
.map_err(Into::into) .map_err(Into::into)
}, },
)?; )?;
@ -448,7 +448,7 @@ pub fn check_event_sender_in_room(
pub fn can_send_event(event: &Arc<ServerPdu>, auth_events: &StateMap<Arc<ServerPdu>>) -> bool { pub fn can_send_event(event: &Arc<ServerPdu>, auth_events: &StateMap<Arc<ServerPdu>>) -> bool {
let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
let event_type_power_level = get_send_level(event.kind, event.state_key, ple); let event_type_power_level = get_send_level(&event.kind, event.state_key.clone(), ple);
let user_level = get_user_power_level(&event.sender, auth_events); let user_level = get_user_power_level(&event.sender, auth_events);
tracing::debug!( tracing::debug!(
@ -462,7 +462,10 @@ pub fn can_send_event(event: &Arc<ServerPdu>, auth_events: &StateMap<Arc<ServerP
return false; return false;
} }
if event.state_key.map_or(false, |k| k.starts_with('@')) if event
.state_key
.as_ref()
.map_or(false, |k| k.starts_with('@'))
&& event.state_key.as_deref() != Some(event.sender.as_str()) && event.state_key.as_deref() != Some(event.sender.as_str())
{ {
return false; // permission required to post in this room return false; // permission required to post in this room
@ -477,7 +480,7 @@ pub fn check_power_levels(
power_event: &Arc<ServerPdu>, power_event: &Arc<ServerPdu>,
auth_events: &StateMap<Arc<ServerPdu>>, auth_events: &StateMap<Arc<ServerPdu>>,
) -> Option<bool> { ) -> Option<bool> {
let key = (power_event.kind, power_event.state_key); let key = (power_event.kind.clone(), power_event.state_key.clone());
let current_state = if let Some(current_state) = auth_events.get(&key) { let current_state = if let Some(current_state) = auth_events.get(&key) {
current_state current_state
} else { } else {
@ -644,7 +647,10 @@ pub fn check_redaction(
if let RoomVersionId::Version1 = room_version { if let RoomVersionId::Version1 = room_version {
// If the domain of the event_id of the event being redacted is the same as the domain of the event_id of the m.room.redaction, allow // If the domain of the event_id of the event being redacted is the same as the domain of the event_id of the m.room.redaction, allow
if redaction_event.event_id.server_name() if redaction_event.event_id.server_name()
== redaction_event.redacts.and_then(|id| id.server_name()) == redaction_event
.redacts
.as_ref()
.and_then(|id| id.server_name())
{ {
tracing::info!("redaction event allowed via room version 1 rules"); tracing::info!("redaction event allowed via room version 1 rules");
return Ok(true); return Ok(true);
@ -740,7 +746,7 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<Arc<ServerP
/// Helper function to fetch the power level needed to send an event of type /// Helper function to fetch the power level needed to send an event of type
/// `e_type` based on the rooms "m.room.power_level" event. /// `e_type` based on the rooms "m.room.power_level" event.
pub fn get_send_level( pub fn get_send_level(
e_type: EventType, e_type: &EventType,
state_key: Option<String>, state_key: Option<String>,
power_lvl: Option<&Arc<ServerPdu>>, power_lvl: Option<&Arc<ServerPdu>>,
) -> i64 { ) -> i64 {
@ -806,7 +812,7 @@ pub fn verify_third_party_invite(
// If there is no m.room.third_party_invite event in the current room state // If there is no m.room.third_party_invite event in the current room state
// with state_key matching token, reject // with state_key matching token, reject
if let Some(current_tpid) = current_third_party_invite { if let Some(current_tpid) = current_third_party_invite {
if current_tpid.state_key != Some(tp_id.signed.token) { if current_tpid.state_key.as_ref() != Some(&tp_id.signed.token) {
return false; return false;
} }

View File

@ -64,9 +64,12 @@ impl StateResolution {
}; };
let mut auth_events = StateMap::new(); let mut auth_events = StateMap::new();
for key in for key in event_auth::auth_types_for_event(
event_auth::auth_types_for_event(ev.kind, &ev.sender, ev.state_key, ev.content.clone()) &ev.kind,
{ &ev.sender,
ev.state_key.clone(),
ev.content.clone(),
) {
if let Some(ev_id) = current_state.get(&key) { if let Some(ev_id) = current_state.get(&key) {
if let Some(event) = if let Some(event) =
StateResolution::get_or_load_event(room_id, ev_id, &mut event_map, store) StateResolution::get_or_load_event(room_id, ev_id, &mut event_map, store)
@ -170,7 +173,7 @@ impl StateResolution {
// get only the control events with a state_key: "" or ban/kick event (sender != state_key) // get only the control events with a state_key: "" or ban/kick event (sender != state_key)
let control_events = all_conflicted let control_events = all_conflicted
.iter() .iter()
.filter(|id| is_power_event(id, &event_map)) .filter(|id| is_power_event_id(id, &event_map))
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -378,12 +381,12 @@ impl StateResolution {
let ev = event_map.get(event_id).unwrap(); let ev = event_map.get(event_id).unwrap();
let pl = event_to_pl.get(event_id).unwrap(); let pl = event_to_pl.get(event_id).unwrap();
tracing::debug!("{:?}", (-*pl, ev.origin_server_ts, ev.event_id)); tracing::debug!("{:?}", (-*pl, ev.origin_server_ts, &ev.event_id));
// This return value is the key used for sorting events, // This return value is the key used for sorting events,
// events are then sorted by power level, time, // events are then sorted by power level, time,
// and lexically by event_id. // and lexically by event_id.
(-*pl, ev.origin_server_ts, ev.event_id) (-*pl, ev.origin_server_ts, ev.event_id.clone())
}) })
} }
@ -476,7 +479,7 @@ impl StateResolution {
// event.auth_event_ids does not include its own event id ? // event.auth_event_ids does not include its own event id ?
for aid in event for aid in event
.as_ref() .as_ref()
.map(|pdu| pdu.auth_events) .map(|pdu| pdu.auth_events.to_vec())
.unwrap_or_default() .unwrap_or_default()
{ {
if let Some(aev) = StateResolution::get_or_load_event(room_id, &aid, event_map, store) { if let Some(aev) = StateResolution::get_or_load_event(room_id, &aid, event_map, store) {
@ -560,9 +563,9 @@ impl StateResolution {
} }
for key in event_auth::auth_types_for_event( for key in event_auth::auth_types_for_event(
event.kind, &event.kind,
&event.sender, &event.sender,
event.state_key, event.state_key.clone(),
event.content.clone(), event.content.clone(),
) { ) {
if let Some(ev_id) = resolved_state.get(&key) { if let Some(ev_id) = resolved_state.get(&key) {
@ -801,18 +804,18 @@ impl StateResolution {
} }
} }
pub fn is_power_event(event_id: &EventId, event_map: &EventMap<Arc<ServerPdu>>) -> bool { pub fn is_power_event_id(event_id: &EventId, event_map: &EventMap<Arc<ServerPdu>>) -> bool {
match event_map.get(event_id) { match event_map.get(event_id) {
Some(state) => _is_power_event(state), Some(state) => is_power_event(state),
_ => false, _ => false,
} }
} }
pub fn is_type_and_key(&ev: &Arc<ServerPdu>, ev_type: EventType, state_key: &str) -> bool { pub fn is_type_and_key(ev: &Arc<ServerPdu>, ev_type: EventType, state_key: &str) -> bool {
ev.kind == ev_type && ev.state_key.as_deref() == Some(state_key) ev.kind == ev_type && ev.state_key.as_deref() == Some(state_key)
} }
fn _is_power_event(&event: &Arc<ServerPdu>) -> bool { pub fn is_power_event(event: &Arc<ServerPdu>) -> bool {
use ruma::events::room::member::{MemberEventContent, MembershipState}; use ruma::events::room::member::{MemberEventContent, MembershipState};
match event.kind { match event.kind {
EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => { EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => {
@ -838,7 +841,7 @@ fn _is_power_event(&event: &Arc<ServerPdu>) -> bool {
pub fn to_requester(event: &Arc<ServerPdu>) -> Requester<'_> { pub fn to_requester(event: &Arc<ServerPdu>) -> Requester<'_> {
Requester { Requester {
prev_event_ids: event.prev_events, prev_event_ids: event.prev_events.to_vec(),
room_id: &event.room_id, room_id: &event.room_id,
content: &event.content, content: &event.content,
state_key: event.state_key.clone(), state_key: event.state_key.clone(),

View File

@ -18,12 +18,12 @@ fn test_ban_pass() {
let prev = events let prev = events
.values() .values()
.find(|ev| ev.event_id().as_str().contains("IMC")) .find(|ev| ev.event_id.as_str().contains("IMC"))
.map(Arc::clone); .map(Arc::clone);
let auth_events = events let auth_events = events
.values() .values()
.map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev))) .map(|ev| ((ev.kind.clone(), ev.state_key.clone()), Arc::clone(ev)))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let requester = Requester { let requester = Requester {
@ -43,12 +43,12 @@ fn test_ban_fail() {
let prev = events let prev = events
.values() .values()
.find(|ev| ev.event_id().as_str().contains("IMC")) .find(|ev| ev.event_id.as_str().contains("IMC"))
.map(Arc::clone); .map(Arc::clone);
let auth_events = events let auth_events = events
.values() .values()
.map(|ev| ((ev.kind(), ev.state_key()), Arc::clone(ev))) .map(|ev| ((ev.kind.clone(), ev.state_key.clone()), Arc::clone(ev)))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let requester = Requester { let requester = Requester {

View File

@ -4,7 +4,7 @@ use ruma::{
events::EventType, events::EventType,
identifiers::{EventId, RoomVersionId}, identifiers::{EventId, RoomVersionId},
}; };
use state_res::StateMap; use state_res::{is_power_event, StateMap};
mod utils; mod utils;
use utils::{room_id, TestStore, INITIAL_EVENTS}; use utils::{room_id, TestStore, INITIAL_EVENTS};
@ -25,15 +25,15 @@ fn test_event_sort() {
let event_map = events let event_map = events
.values() .values()
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) .map(|ev| ((ev.kind.clone(), ev.state_key.clone()), ev.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let auth_chain = &[] as &[_]; let auth_chain = &[] as &[_];
let power_events = event_map let power_events = event_map
.values() .values()
.filter(|pdu| pdu.is_power_event()) .filter(|pdu| is_power_event(&pdu))
.map(|pdu| pdu.event_id()) .map(|pdu| pdu.event_id.clone())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// This is a TODO in conduit // This is a TODO in conduit
@ -64,7 +64,7 @@ fn test_event_sort() {
shuffle(&mut events_to_sort); shuffle(&mut events_to_sort);
let power_level = resolved_power.get(&(EventType::RoomPowerLevels, "".into())); let power_level = resolved_power.get(&(EventType::RoomPowerLevels, Some("".to_string())));
let sorted_event_ids = state_res::StateResolution::mainline_sort( let sorted_event_ids = state_res::StateResolution::mainline_sort(
&room_id(), &room_id(),

View File

@ -3,11 +3,11 @@
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use ruma::{ use ruma::{
events::EventType, events::{pdu::ServerPdu, EventType},
identifiers::{EventId, RoomVersionId}, identifiers::{EventId, RoomVersionId},
}; };
use serde_json::json; use serde_json::json;
use state_res::{StateEvent, StateMap, StateResolution}; use state_res::{StateMap, StateResolution};
mod utils; mod utils;
use utils::{ use utils::{
@ -24,7 +24,7 @@ fn ban_with_auth_chains() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec!["PA", "MB", "END"] let expected_state_ids = vec!["PA", "MB"]
.into_iter() .into_iter()
.map(event_id) .map(event_id)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -50,7 +50,7 @@ fn base_with_auth_chains() {
let resolved = resolved let resolved = resolved
.values() .values()
.cloned() .cloned()
.chain(INITIAL_EVENTS().values().map(|e| e.event_id())) .chain(INITIAL_EVENTS().values().map(|e| e.event_id.clone()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected = vec![ let expected = vec![
@ -89,7 +89,7 @@ fn ban_with_auth_chains2() {
inner.get(&event_id("PA")).unwrap(), inner.get(&event_id("PA")).unwrap(),
] ]
.iter() .iter()
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .map(|ev| ((ev.kind.clone(), ev.state_key.clone()), ev.event_id.clone()))
.collect::<BTreeMap<_, _>>(); .collect::<BTreeMap<_, _>>();
let state_set_b = [ let state_set_b = [
@ -102,7 +102,7 @@ fn ban_with_auth_chains2() {
inner.get(&event_id("PA")).unwrap(), inner.get(&event_id("PA")).unwrap(),
] ]
.iter() .iter()
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .map(|ev| ((ev.kind.clone(), ev.state_key.clone()), ev.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let resolved: StateMap<EventId> = match StateResolution::resolve( let resolved: StateMap<EventId> = match StateResolution::resolve(
@ -154,10 +154,7 @@ fn join_rule_with_auth_chain() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec!["JR", "END"] let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check( do_check(
&join_rule.values().cloned().collect::<Vec<_>>(), &join_rule.values().cloned().collect::<Vec<_>>(),
@ -167,7 +164,7 @@ fn join_rule_with_auth_chain() {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> { fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<ServerPdu>> {
vec![ vec![
to_pdu_event( to_pdu_event(
"PA", "PA",
@ -207,12 +204,12 @@ fn BAN_STATE_SET() -> BTreeMap<EventId, Arc<StateEvent>> {
), ),
] ]
.into_iter() .into_iter()
.map(|ev| (ev.event_id(), ev)) .map(|ev| (ev.event_id.clone(), ev))
.collect() .collect()
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn JOIN_RULE() -> BTreeMap<EventId, Arc<StateEvent>> { fn JOIN_RULE() -> BTreeMap<EventId, Arc<ServerPdu>> {
vec![ vec![
to_pdu_event( to_pdu_event(
"JR", "JR",
@ -234,6 +231,6 @@ fn JOIN_RULE() -> BTreeMap<EventId, Arc<StateEvent>> {
), ),
] ]
.into_iter() .into_iter()
.map(|ev| (ev.event_id(), ev)) .map(|ev| (ev.event_id.clone(), ev))
.collect() .collect()
} }

View File

@ -56,7 +56,7 @@ fn ban_vs_power_level() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec!["PA", "MA", "MB", "END"] let expected_state_ids = vec!["PA", "MA", "MB"]
.into_iter() .into_iter()
.map(event_id) .map(event_id)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -101,7 +101,7 @@ fn topic_basic() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec!["PA2", "T2", "END"] let expected_state_ids = vec!["PA2", "T2"]
.into_iter() .into_iter()
.map(event_id) .map(event_id)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -138,7 +138,7 @@ fn topic_reset() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec!["T1", "MB", "PA", "END"] let expected_state_ids = vec!["T1", "MB", "PA"]
.into_iter() .into_iter()
.map(event_id) .map(event_id)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -170,7 +170,7 @@ fn join_rule_evasion() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec![event_id("JR"), event_id("END")]; let expected_state_ids = vec![event_id("JR")];
do_check(events, edges, expected_state_ids) do_check(events, edges, expected_state_ids)
} }
@ -206,10 +206,7 @@ fn offtopic_power_level() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec!["PC", "END"] let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>();
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(events, edges, expected_state_ids) do_check(events, edges, expected_state_ids)
} }
@ -253,7 +250,7 @@ fn topic_setting() {
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>()) .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let expected_state_ids = vec!["T4", "PA2", "END"] let expected_state_ids = vec!["T4", "PA2"]
.into_iter() .into_iter()
.map(event_id) .map(event_id)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -324,7 +321,7 @@ impl TestStore {
&[], &[],
&[], &[],
); );
let cre = create_event.event_id(); let cre = create_event.event_id.clone();
self.0.insert(cre.clone(), Arc::clone(&create_event)); self.0.insert(cre.clone(), Arc::clone(&create_event));
let alice_mem = to_pdu_event( let alice_mem = to_pdu_event(
@ -336,7 +333,8 @@ impl TestStore {
&[cre.clone()], &[cre.clone()],
&[cre.clone()], &[cre.clone()],
); );
self.0.insert(alice_mem.event_id(), Arc::clone(&alice_mem)); self.0
.insert(alice_mem.event_id.clone(), Arc::clone(&alice_mem));
let join_rules = to_pdu_event( let join_rules = to_pdu_event(
"IJR", "IJR",
@ -344,10 +342,11 @@ impl TestStore {
EventType::RoomJoinRules, EventType::RoomJoinRules,
Some(""), Some(""),
json!({ "join_rule": JoinRule::Public }), json!({ "join_rule": JoinRule::Public }),
&[cre.clone(), alice_mem.event_id()], &[cre.clone(), alice_mem.event_id.clone()],
&[alice_mem.event_id()], &[alice_mem.event_id.clone()],
); );
self.0.insert(join_rules.event_id(), join_rules.clone()); self.0
.insert(join_rules.event_id.clone(), join_rules.clone());
// Bob and Charlie join at the same time, so there is a fork // Bob and Charlie join at the same time, so there is a fork
// this will be represented in the state_sets when we resolve // this will be represented in the state_sets when we resolve
@ -357,10 +356,10 @@ impl TestStore {
EventType::RoomMember, EventType::RoomMember,
Some(bob().to_string().as_str()), Some(bob().to_string().as_str()),
member_content_join(), member_content_join(),
&[cre.clone(), join_rules.event_id()], &[cre.clone(), join_rules.event_id.clone()],
&[join_rules.event_id()], &[join_rules.event_id.clone()],
); );
self.0.insert(bob_mem.event_id(), bob_mem.clone()); self.0.insert(bob_mem.event_id.clone(), bob_mem.clone());
let charlie_mem = to_pdu_event( let charlie_mem = to_pdu_event(
"IMC", "IMC",
@ -368,19 +367,20 @@ impl TestStore {
EventType::RoomMember, EventType::RoomMember,
Some(charlie().to_string().as_str()), Some(charlie().to_string().as_str()),
member_content_join(), member_content_join(),
&[cre, join_rules.event_id()], &[cre, join_rules.event_id.clone()],
&[join_rules.event_id()], &[join_rules.event_id.clone()],
); );
self.0.insert(charlie_mem.event_id(), charlie_mem.clone()); self.0
.insert(charlie_mem.event_id.clone(), charlie_mem.clone());
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
.iter() .iter()
.map(|e| ((e.kind(), e.state_key()), e.event_id())) .map(|e| ((e.kind.clone(), e.state_key.clone()), e.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem]
.iter() .iter()
.map(|e| ((e.kind(), e.state_key()), e.event_id())) .map(|e| ((e.kind.clone(), e.state_key.clone()), e.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let expected = [ let expected = [
@ -391,7 +391,7 @@ impl TestStore {
&charlie_mem, &charlie_mem,
] ]
.iter() .iter()
.map(|e| ((e.kind(), e.state_key()), e.event_id())) .map(|e| ((e.kind.clone(), e.state_key.clone()), e.event_id.clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
(state_at_bob, state_at_charlie, expected) (state_at_bob, state_at_charlie, expected)

View File

@ -9,6 +9,7 @@ use std::{
use ruma::{ use ruma::{
events::{ events::{
pdu::ServerPdu,
room::{ room::{
join_rules::JoinRule, join_rules::JoinRule,
member::{MemberEventContent, MembershipState}, member::{MemberEventContent, MembershipState},
@ -18,7 +19,7 @@ use ruma::{
identifiers::{EventId, RoomId, RoomVersionId, UserId}, identifiers::{EventId, RoomId, RoomVersionId, UserId},
}; };
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
use state_res::{Error, Result, StateEvent, StateMap, StateResolution, StateStore}; use state_res::{Error, Result, StateMap, StateResolution, StateStore};
use tracing_subscriber as tracer; use tracing_subscriber as tracer;
pub static LOGGER: Once = Once::new(); pub static LOGGER: Once = Once::new();
@ -26,7 +27,7 @@ pub static LOGGER: Once = Once::new();
static mut SERVER_TIMESTAMP: i32 = 0; static mut SERVER_TIMESTAMP: i32 = 0;
pub fn do_check( pub fn do_check(
events: &[Arc<StateEvent>], events: &[Arc<ServerPdu>],
edges: Vec<Vec<EventId>>, edges: Vec<Vec<EventId>>,
expected_state_ids: Vec<EventId>, expected_state_ids: Vec<EventId>,
) { ) {
@ -41,20 +42,20 @@ pub fn do_check(
INITIAL_EVENTS() INITIAL_EVENTS()
.values() .values()
.chain(events) .chain(events)
.map(|ev| (ev.event_id(), ev.clone())) .map(|ev| (ev.event_id.clone(), ev.clone()))
.collect(), .collect(),
); );
// This will be lexi_topo_sorted for resolution // This will be lexi_topo_sorted for resolution
let mut graph = BTreeMap::new(); let mut graph = BTreeMap::new();
// this is the same as in `resolve` event_id -> StateEvent // this is the same as in `resolve` event_id -> ServerPdu
let mut fake_event_map = BTreeMap::new(); let mut fake_event_map = BTreeMap::new();
// create the DB of events that led up to this point // create the DB of events that led up to this point
// TODO maybe clean up some of these clones it is just tests but... // TODO maybe clean up some of these clones it is just tests but...
for ev in INITIAL_EVENTS().values().chain(events) { for ev in INITIAL_EVENTS().values().chain(events) {
graph.insert(ev.event_id().clone(), vec![]); graph.insert(ev.event_id.clone(), vec![]);
fake_event_map.insert(ev.event_id().clone(), ev.clone()); fake_event_map.insert(ev.event_id.clone(), ev.clone());
} }
for pair in INITIAL_EDGES().windows(2) { for pair in INITIAL_EDGES().windows(2) {
@ -71,10 +72,8 @@ pub fn do_check(
} }
} }
panic!("{}", serde_json::to_string_pretty(&graph).unwrap()); // event_id -> ServerPdu
let mut event_map: BTreeMap<EventId, Arc<ServerPdu>> = BTreeMap::new();
// event_id -> StateEvent
let mut event_map: BTreeMap<EventId, Arc<StateEvent>> = BTreeMap::new();
// event_id -> StateMap<EventId> // event_id -> StateMap<EventId>
let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new(); let mut state_at_event: BTreeMap<EventId, StateMap<EventId>> = BTreeMap::new();
@ -84,7 +83,7 @@ pub fn do_check(
StateResolution::lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone())) StateResolution::lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone()))
{ {
let fake_event = fake_event_map.get(&node).unwrap(); let fake_event = fake_event_map.get(&node).unwrap();
let event_id = fake_event.event_id(); let event_id = fake_event.event_id.clone();
let prev_events = graph.get(&node).unwrap(); let prev_events = graph.get(&node).unwrap();
@ -125,17 +124,17 @@ pub fn do_check(
let mut state_after = state_before.clone(); let mut state_after = state_before.clone();
// if fake_event.state_key().is_some() { if fake_event.state_key.is_some() {
let ty = fake_event.kind().clone(); let ty = fake_event.kind.clone();
let key = fake_event.state_key().clone(); let key = fake_event.state_key.clone();
state_after.insert((ty, key), event_id.clone()); state_after.insert((ty, key), event_id.clone());
// } }
let auth_types = state_res::auth_types_for_event( let auth_types = state_res::auth_types_for_event(
fake_event.kind(), &fake_event.kind,
fake_event.sender(), &fake_event.sender,
Some(fake_event.state_key()), fake_event.state_key.clone(),
fake_event.content().clone(), fake_event.content.clone(),
); );
let mut auth_events = vec![]; let mut auth_events = vec![];
@ -148,13 +147,13 @@ pub fn do_check(
// TODO The event is just remade, adding the auth_events and prev_events here // TODO The event is just remade, adding the auth_events and prev_events here
// the `to_pdu_event` was split into `init` and the fn below, could be better // the `to_pdu_event` was split into `init` and the fn below, could be better
let e = fake_event; let e = fake_event;
let ev_id = e.event_id(); let ev_id = e.event_id.clone();
let event = to_pdu_event( let event = to_pdu_event(
&e.event_id().to_string(), &e.event_id.clone().to_string(),
e.sender().clone(), e.sender.clone(),
e.kind(), e.kind.clone(),
Some(e.state_key()).as_deref(), e.state_key.as_deref(),
e.content().clone(), e.content.clone(),
&auth_events, &auth_events,
prev_events, prev_events,
); );
@ -172,11 +171,11 @@ pub fn do_check(
// println!( // println!(
// "res contains: {} passed: {} for {}\n{:?}", // "res contains: {} passed: {} for {}\n{:?}",
// state_after // state_after
// .get(&(event.kind(), event.state_key())) // .get(&(event.kind, event.state_key()))
// .map(|id| id == &ev_id) // .map(|id| id == &ev_id)
// .unwrap_or_default(), // .unwrap_or_default(),
// res, // res,
// event.event_id().as_str(), // event.event_id.clone().as_str(),
// event // event
// .prev_event_ids() // .prev_event_ids()
// .iter() // .iter()
@ -206,7 +205,7 @@ pub fn do_check(
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
)); ));
let key = (ev.kind(), ev.state_key()); let key = (ev.kind.clone(), ev.state_key.clone());
expected_state.insert(key, node); expected_state.insert(key, node);
} }
@ -224,11 +223,11 @@ pub fn do_check(
assert_eq!(expected_state, end_state); assert_eq!(expected_state, end_state);
} }
pub struct TestStore(pub BTreeMap<EventId, Arc<StateEvent>>); pub struct TestStore(pub BTreeMap<EventId, Arc<ServerPdu>>);
#[allow(unused)] #[allow(unused)]
impl StateStore for TestStore { impl StateStore for TestStore {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>> { fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<ServerPdu>> {
self.0 self.0
.get(event_id) .get(event_id)
.map(Arc::clone) .map(Arc::clone)
@ -291,7 +290,7 @@ pub fn to_init_pdu_event(
ev_type: EventType, ev_type: EventType,
state_key: Option<&str>, state_key: Option<&str>,
content: JsonValue, content: JsonValue,
) -> Arc<StateEvent> { ) -> Arc<ServerPdu> {
let ts = unsafe { let ts = unsafe {
let ts = SERVER_TIMESTAMP; let ts = SERVER_TIMESTAMP;
// increment the "origin_server_ts" value // increment the "origin_server_ts" value
@ -347,7 +346,7 @@ pub fn to_pdu_event<S>(
content: JsonValue, content: JsonValue,
auth_events: &[S], auth_events: &[S],
prev_events: &[S], prev_events: &[S],
) -> Arc<StateEvent> ) -> Arc<ServerPdu>
where where
S: AsRef<str>, S: AsRef<str>,
{ {
@ -410,7 +409,7 @@ where
// all graphs start with these input events // all graphs start with these input events
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> { pub fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<ServerPdu>> {
// this is always called so we can init the logger here // this is always called so we can init the logger here
let _ = LOGGER.call_once(|| { let _ = LOGGER.call_once(|| {
tracer::fmt() tracer::fmt()
@ -476,8 +475,8 @@ pub fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
to_pdu_event::<EventId>( to_pdu_event::<EventId>(
"START", "START",
charlie(), charlie(),
EventType::RoomTopic, EventType::RoomMessage,
Some(""), None,
json!({}), json!({}),
&[], &[],
&[], &[],
@ -485,15 +484,15 @@ pub fn INITIAL_EVENTS() -> BTreeMap<EventId, Arc<StateEvent>> {
to_pdu_event::<EventId>( to_pdu_event::<EventId>(
"END", "END",
charlie(), charlie(),
EventType::RoomTopic, EventType::RoomMessage,
Some(""), None,
json!({}), json!({}),
&[], &[],
&[], &[],
), ),
] ]
.into_iter() .into_iter()
.map(|ev| (ev.event_id(), ev)) .map(|ev| (ev.event_id.clone(), ev))
.collect() .collect()
} }