Remove StateStore trait, caller must collect all events needed

This commit is contained in:
Devin Ragotzy 2021-01-05 17:11:30 -05:00
parent f4772e0fcb
commit 47b19fdc15
9 changed files with 155 additions and 189 deletions

View File

@ -0,0 +1 @@

View File

@ -48,13 +48,21 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) {
let (state_at_bob, state_at_charlie, _) = store.set_up(); let (state_at_bob, state_at_charlie, _) = store.set_up();
b.iter(|| { b.iter(|| {
let mut ev_map = state_res::EventMap::default(); let mut ev_map: state_res::EventMap<Arc<event::StateEvent>> = store.0.clone();
let _resolved = match StateResolution::resolve( let state_sets = vec![state_at_bob.clone(), state_at_charlie.clone()];
let _ = match StateResolution::resolve::<event::StateEvent>(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version2,
&[state_at_bob.clone(), state_at_charlie.clone()], &state_sets,
state_sets
.iter()
.map(|map| {
store
.auth_event_ids(&room_id(), &map.values().cloned().collect::<Vec<_>>())
.unwrap()
})
.collect(),
&mut ev_map, &mut ev_map,
&store,
) { ) {
Ok(state) => state, Ok(state) => state,
Err(e) => panic!("{}", e), Err(e) => panic!("{}", e),
@ -99,12 +107,20 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
b.iter(|| { b.iter(|| {
let _resolved = match StateResolution::resolve( let state_sets = vec![state_set_a.clone(), state_set_b.clone()];
let _ = match StateResolution::resolve::<event::StateEvent>(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version2,
&[state_set_a.clone(), state_set_b.clone()], &state_sets,
state_sets
.iter()
.map(|map| {
store
.auth_event_ids(&room_id(), &map.values().cloned().collect::<Vec<_>>())
.unwrap()
})
.collect(),
&mut inner, &mut inner,
&store,
) { ) {
Ok(state) => state, Ok(state) => state,
Err(_) => panic!("resolution failed during benchmarking"), Err(_) => panic!("resolution failed during benchmarking"),
@ -530,7 +546,9 @@ pub mod event {
fn hashes(&self) -> &EventHash { fn hashes(&self) -> &EventHash {
self.hashes() self.hashes()
} }
fn signatures(&self) -> BTreeMap<Box<ServerName>, BTreeMap<ruma::ServerSigningKeyId, String>> { fn signatures(
&self,
) -> BTreeMap<Box<ServerName>, BTreeMap<ruma::ServerSigningKeyId, String>> {
self.signatures() self.signatures()
} }
fn unsigned(&self) -> &BTreeMap<String, JsonValue> { fn unsigned(&self) -> &BTreeMap<String, JsonValue> {
@ -643,7 +661,10 @@ pub mod event {
} }
impl StateEvent { impl StateEvent {
pub fn from_id_value(id: EventId, json: serde_json::Value) -> Result<Self, serde_json::Error> { pub fn from_id_value(
id: EventId,
json: serde_json::Value,
) -> Result<Self, serde_json::Error> {
Ok(Self::Full( Ok(Self::Full(
id, id,
Pdu::RoomV3Pdu(serde_json::from_value(json)?), Pdu::RoomV3Pdu(serde_json::from_value(json)?),
@ -671,6 +692,7 @@ pub mod event {
EventType::RoomMember => { EventType::RoomMember => {
if let Ok(content) = if let Ok(content) =
// TODO fix clone // TODO fix clone
serde_json::from_value::<MemberEventContent>(event.content.clone()) serde_json::from_value::<MemberEventContent>(event.content.clone())
{ {
if [MembershipState::Leave, MembershipState::Ban] if [MembershipState::Leave, MembershipState::Ban]
@ -771,7 +793,9 @@ pub mod event {
pub fn prev_event_ids(&self) -> Vec<EventId> { pub fn prev_event_ids(&self) -> Vec<EventId> {
match self { match self {
Self::Full(_, ev) => match ev { Self::Full(_, ev) => match ev {
Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().map(|(id, _)| id).cloned().collect(), Pdu::RoomV1Pdu(ev) => {
ev.prev_events.iter().map(|(id, _)| id).cloned().collect()
}
Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(), Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(),
}, },
} }
@ -780,7 +804,9 @@ pub mod event {
pub fn auth_events(&self) -> Vec<EventId> { pub fn auth_events(&self) -> Vec<EventId> {
match self { match self {
Self::Full(_, ev) => match ev { Self::Full(_, ev) => match ev {
Pdu::RoomV1Pdu(ev) => ev.auth_events.iter().map(|(id, _)| id).cloned().collect(), Pdu::RoomV1Pdu(ev) => {
ev.auth_events.iter().map(|(id, _)| id).cloned().collect()
}
Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(), Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(),
}, },
} }
@ -860,4 +886,4 @@ pub mod event {
} }
} }
} }
} }

View File

@ -1,4 +1,4 @@
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc}; use std::{convert::TryFrom, sync::Arc};
use maplit::btreeset; use maplit::btreeset;
use ruma::{ use ruma::{
@ -67,6 +67,10 @@ pub fn auth_types_for_event(
/// * check that the event is being authenticated for the correct room /// * check that the event is being authenticated for the correct room
/// * check that the events signatures are valid /// * check that the events signatures are valid
/// * then there are checks for specific event types /// * then there are checks for specific event types
///
/// The `auth_events` that are passed to this function should be a state snapshot.
/// We need to know if the event passes auth against some state not a recursive collection
/// of auth_events fields.
pub fn auth_check<E: Event>( pub fn auth_check<E: Event>(
room_version: &RoomVersionId, room_version: &RoomVersionId,
incoming_event: &Arc<E>, incoming_event: &Arc<E>,

View File

@ -46,14 +46,13 @@ impl StateResolution {
room_version: &RoomVersionId, room_version: &RoomVersionId,
incoming_event: Arc<E>, incoming_event: Arc<E>,
current_state: &StateMap<EventId>, current_state: &StateMap<EventId>,
event_map: &mut EventMap<Arc<E>>, event_map: &EventMap<Arc<E>>,
store: &dyn StateStore<E>,
) -> Result<bool> { ) -> Result<bool> {
tracing::info!("Applying a single event, state resolution starting"); tracing::info!("Applying a single event, state resolution starting");
let ev = incoming_event; let ev = incoming_event;
let prev_event = if let Some(id) = ev.prev_events().first() { let prev_event = if let Some(id) = ev.prev_events().first() {
store.get_event(room_id, id).ok() event_map.get(id).map(Arc::clone)
} else { } else {
None None
}; };
@ -63,9 +62,7 @@ impl StateResolution {
event_auth::auth_types_for_event(&ev.kind(), &ev.sender(), ev.state_key(), ev.content()) event_auth::auth_types_for_event(&ev.kind(), &ev.sender(), ev.state_key(), ev.content())
{ {
if let Some(ev_id) = current_state.get(&key) { if let Some(ev_id) = current_state.get(&key) {
if let Some(event) = if let Ok(event) = StateResolution::get_or_load_event(room_id, ev_id, event_map) {
StateResolution::get_or_load_event(room_id, ev_id, event_map, store)
{
// TODO synapse checks `rejected_reason` is None here // TODO synapse checks `rejected_reason` is None here
auth_events.insert(key.clone(), event); auth_events.insert(key.clone(), event);
} }
@ -83,22 +80,21 @@ impl StateResolution {
/// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork /// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork
/// in the state of a room. /// in the state of a room.
/// ///
/// * `auth_events` - The full recursive set of `auth_events` for each event in the `state_sets`.
///
/// * `event_map` - The `EventMap` acts as a local cache of state, any event that is not found /// * `event_map` - The `EventMap` acts as a local cache of state, any event that is not found
/// in the `event_map` will be fetched from the `StateStore` and cached in the `event_map`. There /// in the `event_map` will be fetched from the `StateStore` and cached in the `event_map`. There
/// is no state kept from separate `resolve` calls, although this could be a potential optimization /// is no state kept from separate `resolve` calls, although this could be a potential optimization
/// in the future. /// in the future.
/// ///
/// * `store` - Any type that implements `StateStore` acts as the database. When an event is not
/// found in the `event_map` it will be retrieved from the `store`.
///
/// It is up the the caller to check that the events returned from `StateStore::get_event` are /// It is up the the caller to check that the events returned from `StateStore::get_event` are
/// events for the correct room (synapse checks that all events are in the right room). /// events for the correct room (synapse checks that all events are in the right room).
pub fn resolve<E: Event>( pub fn resolve<E: Event>(
room_id: &RoomId, room_id: &RoomId,
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: &[StateMap<EventId>], state_sets: &[StateMap<EventId>],
auth_events: Vec<Vec<EventId>>,
event_map: &mut EventMap<Arc<E>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>,
) -> Result<StateMap<EventId>> { ) -> Result<StateMap<EventId>> {
tracing::info!("State resolution starting"); tracing::info!("State resolution starting");
@ -115,9 +111,9 @@ impl StateResolution {
tracing::info!("{} conflicting events", conflicting.len()); tracing::info!("{} conflicting events", conflicting.len());
// the set of auth events that are not common across server forks // the set of auth events that are not common across server forks
let mut auth_diff = StateResolution::get_auth_chain_diff(room_id, &state_sets, store)?; let mut auth_diff = StateResolution::get_auth_chain_diff(room_id, &auth_events)?;
tracing::debug!("auth diff size {}", auth_diff.len()); tracing::debug!("auth diff size {:?}", auth_diff);
// add the auth_diff to conflicting now we have a full set of conflicting events // add the auth_diff to conflicting now we have a full set of conflicting events
auth_diff.extend(conflicting.values().cloned().flatten()); auth_diff.extend(conflicting.values().cloned().flatten());
@ -129,25 +125,6 @@ impl StateResolution {
tracing::info!("full conflicted set is {} events", all_conflicted.len()); tracing::info!("full conflicted set is {} events", all_conflicted.len());
// gather missing events for the event_map
let events = store
.get_events(
room_id,
&all_conflicted
.iter()
// we only want the events we don't know about yet
.filter(|id| !event_map.contains_key(id))
.cloned()
.collect::<Vec<_>>(),
)
.unwrap();
// update event_map to include the fetched events
event_map.extend(events.into_iter().map(|ev| (ev.event_id().clone(), ev)));
// at this point our event_map == store there should be no missing events
tracing::debug!("event map size: {}", event_map.len());
// we used to check that all events are events from the correct room // we used to check that all events are events from the correct room
// this is now a check the caller of `resolve` must make. // this is now a check the caller of `resolve` must make.
@ -168,17 +145,10 @@ impl StateResolution {
room_id, room_id,
&control_events, &control_events,
event_map, event_map,
store,
&all_conflicted, &all_conflicted,
); );
tracing::debug!( tracing::debug!("SRTD {:?}", sorted_control_levels);
"SRTD {:?}",
sorted_control_levels
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
);
// sequentially auth check each control event. // sequentially auth check each control event.
let resolved_control = StateResolution::iterative_auth_check( let resolved_control = StateResolution::iterative_auth_check(
@ -187,7 +157,6 @@ impl StateResolution {
&sorted_control_levels, &sorted_control_levels,
&clean, &clean,
event_map, event_map,
store,
)?; )?;
tracing::debug!( tracing::debug!(
@ -223,13 +192,8 @@ impl StateResolution {
tracing::debug!("PL {:?}", power_event); tracing::debug!("PL {:?}", power_event);
let sorted_left_events = StateResolution::mainline_sort( let sorted_left_events =
room_id, StateResolution::mainline_sort(room_id, &events_to_resolve, power_event, event_map);
&events_to_resolve,
power_event,
event_map,
store,
);
tracing::debug!( tracing::debug!(
"SORTED LEFT {:?}", "SORTED LEFT {:?}",
@ -245,7 +209,6 @@ impl StateResolution {
&sorted_left_events, &sorted_left_events,
&resolved_control, // The control events are added to the final resolved state &resolved_control, // The control events are added to the final resolved state
event_map, event_map,
store,
)?; )?;
// add unconflicted state to the resolved state // add unconflicted state to the resolved state
@ -298,23 +261,34 @@ impl StateResolution {
} }
/// Returns a Vec of deduped EventIds that appear in some chains but not others. /// Returns a Vec of deduped EventIds that appear in some chains but not others.
pub fn get_auth_chain_diff<E: Event>( pub fn get_auth_chain_diff(
room_id: &RoomId, _room_id: &RoomId,
state_sets: &[StateMap<EventId>], auth_event_ids: &[Vec<EventId>],
store: &dyn StateStore<E>,
) -> Result<Vec<EventId>> { ) -> Result<Vec<EventId>> {
use itertools::Itertools; let mut chains = vec![];
tracing::debug!("calculating auth chain difference"); for ids in auth_event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = ids.iter().cloned().collect::<BTreeSet<_>>();
chains.push(chain);
}
store.auth_chain_diff( if let Some(chain) = chains.first() {
room_id, let rest = chains.iter().skip(1).flatten().cloned().collect();
state_sets let common = chain.intersection(&rest).collect::<Vec<_>>();
Ok(chains
.iter() .iter()
.map(|map| map.values().cloned().collect()) .flatten()
.dedup() .filter(|id| !common.contains(&id))
.collect::<Vec<_>>(), .cloned()
) .collect::<BTreeSet<_>>()
.into_iter()
.collect())
} else {
Ok(vec![])
}
} }
/// Events are sorted from "earliest" to "latest". They are compared using /// Events are sorted from "earliest" to "latest". They are compared using
@ -328,7 +302,6 @@ impl StateResolution {
room_id: &RoomId, room_id: &RoomId,
events_to_sort: &[EventId], events_to_sort: &[EventId],
event_map: &mut EventMap<Arc<E>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>,
auth_diff: &[EventId], auth_diff: &[EventId],
) -> Vec<EventId> { ) -> Vec<EventId> {
tracing::debug!("reverse topological sort of power events"); tracing::debug!("reverse topological sort of power events");
@ -336,7 +309,7 @@ impl StateResolution {
let mut graph = BTreeMap::new(); let mut graph = BTreeMap::new();
for (idx, event_id) in events_to_sort.iter().enumerate() { for (idx, event_id) in events_to_sort.iter().enumerate() {
StateResolution::add_event_and_auth_chain_to_graph( StateResolution::add_event_and_auth_chain_to_graph(
room_id, &mut graph, event_id, event_map, store, auth_diff, room_id, &mut graph, event_id, event_map, auth_diff,
); );
// We yield occasionally when we're working with large data sets to // We yield occasionally when we're working with large data sets to
@ -349,8 +322,7 @@ impl StateResolution {
// this is used in the `key_fn` passed to the lexico_topo_sort fn // this is used in the `key_fn` passed to the lexico_topo_sort fn
let mut event_to_pl = BTreeMap::new(); let mut event_to_pl = BTreeMap::new();
for (idx, event_id) in graph.keys().enumerate() { for (idx, event_id) in graph.keys().enumerate() {
let pl = let pl = StateResolution::get_power_level_for_sender(room_id, &event_id, event_map);
StateResolution::get_power_level_for_sender(room_id, &event_id, event_map, store);
tracing::info!("{} power level {}", event_id.to_string(), pl); tracing::info!("{} power level {}", event_id.to_string(), pl);
event_to_pl.insert(event_id.clone(), pl); event_to_pl.insert(event_id.clone(), pl);
@ -454,11 +426,10 @@ impl StateResolution {
room_id: &RoomId, room_id: &RoomId,
event_id: &EventId, event_id: &EventId,
event_map: &mut EventMap<Arc<E>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>,
) -> i64 { ) -> i64 {
tracing::info!("fetch event ({}) senders power level", event_id.to_string()); tracing::info!("fetch event ({}) senders power level", event_id.to_string());
let event = StateResolution::get_or_load_event(room_id, event_id, event_map, store); let event = StateResolution::get_or_load_event(room_id, event_id, event_map);
let mut pl = None; let mut pl = None;
// TODO store.auth_event_ids returns "self" with the event ids is this ok // TODO store.auth_event_ids returns "self" with the event ids is this ok
@ -468,7 +439,7 @@ impl StateResolution {
.map(|pdu| pdu.auth_events()) .map(|pdu| pdu.auth_events())
.unwrap_or_default() .unwrap_or_default()
{ {
if let Some(aev) = StateResolution::get_or_load_event(room_id, &aid, event_map, store) { if let Ok(aev) = StateResolution::get_or_load_event(room_id, &aid, event_map) {
if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, EventType::RoomPowerLevels, "") {
pl = Some(aev); pl = Some(aev);
break; break;
@ -489,9 +460,9 @@ impl StateResolution {
}) })
.flatten() .flatten()
{ {
if let Some(ev) = event { if let Ok(ev) = event {
if let Some(user) = content.users.get(&ev.sender()) { if let Some(user) = content.users.get(&ev.sender()) {
tracing::debug!("found {} at power_level {}", ev.sender().to_string(), user); tracing::debug!("found {} at power_level {}", ev.sender().as_str(), user);
return (*user).into(); return (*user).into();
} }
} }
@ -517,7 +488,6 @@ impl StateResolution {
events_to_check: &[EventId], events_to_check: &[EventId],
unconflicted_state: &StateMap<EventId>, unconflicted_state: &StateMap<EventId>,
event_map: &mut EventMap<Arc<E>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>,
) -> Result<StateMap<EventId>> { ) -> Result<StateMap<EventId>> {
tracing::info!("starting iterative auth check"); tracing::info!("starting iterative auth check");
@ -532,14 +502,11 @@ impl StateResolution {
let mut resolved_state = unconflicted_state.clone(); let mut resolved_state = unconflicted_state.clone();
for (idx, event_id) in events_to_check.iter().enumerate() { for (idx, event_id) in events_to_check.iter().enumerate() {
let event = let event = StateResolution::get_or_load_event(room_id, event_id, event_map)?;
StateResolution::get_or_load_event(room_id, event_id, event_map, store).unwrap();
let mut auth_events = BTreeMap::new(); let mut auth_events = BTreeMap::new();
for aid in &event.auth_events() { for aid in &event.auth_events() {
if let Some(ev) = if let Ok(ev) = StateResolution::get_or_load_event(room_id, &aid, event_map) {
StateResolution::get_or_load_event(room_id, &aid, event_map, store)
{
// TODO what to do when no state_key is found ?? // TODO what to do when no state_key is found ??
// TODO synapse check "rejected_reason", I'm guessing this is redacted_because in ruma ?? // TODO synapse check "rejected_reason", I'm guessing this is redacted_because in ruma ??
auth_events.insert((ev.kind(), ev.state_key()), ev); auth_events.insert((ev.kind(), ev.state_key()), ev);
@ -555,8 +522,7 @@ impl StateResolution {
event.content(), event.content(),
) { ) {
if let Some(ev_id) = resolved_state.get(&key) { if let Some(ev_id) = resolved_state.get(&key) {
if let Some(event) = if let Ok(event) = StateResolution::get_or_load_event(room_id, ev_id, event_map)
StateResolution::get_or_load_event(room_id, ev_id, event_map, store)
{ {
// TODO synapse checks `rejected_reason` is None here // TODO synapse checks `rejected_reason` is None here
auth_events.insert(key.clone(), event); auth_events.insert(key.clone(), event);
@ -569,7 +535,7 @@ impl StateResolution {
let most_recent_prev_event = event let most_recent_prev_event = event
.prev_events() .prev_events()
.iter() .iter()
.filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store)) .filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map).ok())
.next_back(); .next_back();
// The key for this is (eventType + a state_key of the signed token not sender) so search // The key for this is (eventType + a state_key of the signed token not sender) so search
@ -620,7 +586,6 @@ impl StateResolution {
to_sort: &[EventId], to_sort: &[EventId],
resolved_power_level: Option<&EventId>, resolved_power_level: Option<&EventId>,
event_map: &mut EventMap<Arc<E>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>,
) -> Vec<EventId> { ) -> Vec<EventId> {
tracing::debug!("mainline sort of events"); tracing::debug!("mainline sort of events");
@ -635,12 +600,11 @@ impl StateResolution {
while let Some(p) = pl { while let Some(p) = pl {
mainline.push(p.clone()); mainline.push(p.clone());
let event = StateResolution::get_or_load_event(room_id, &p, event_map, store).unwrap(); let event = StateResolution::get_or_load_event(room_id, &p, event_map).unwrap();
let auth_events = &event.auth_events(); let auth_events = &event.auth_events();
pl = None; pl = None;
for aid in auth_events { for aid in auth_events {
let ev = let ev = StateResolution::get_or_load_event(room_id, &aid, event_map).unwrap();
StateResolution::get_or_load_event(room_id, &aid, event_map, store).unwrap();
if is_type_and_key(&ev, EventType::RoomPowerLevels, "") { if is_type_and_key(&ev, EventType::RoomPowerLevels, "") {
pl = Some(aid.clone()); pl = Some(aid.clone());
break; break;
@ -663,15 +627,12 @@ impl StateResolution {
let mut order_map = BTreeMap::new(); let mut order_map = BTreeMap::new();
for (idx, ev_id) in to_sort.iter().enumerate() { for (idx, ev_id) in to_sort.iter().enumerate() {
if let Some(event) = if let Ok(event) = StateResolution::get_or_load_event(room_id, ev_id, event_map) {
StateResolution::get_or_load_event(room_id, ev_id, event_map, store)
{
if let Ok(depth) = StateResolution::get_mainline_depth( if let Ok(depth) = StateResolution::get_mainline_depth(
room_id, room_id,
Some(event), Some(event),
&mainline_map, &mainline_map,
event_map, event_map,
store,
) { ) {
order_map.insert( order_map.insert(
ev_id, ev_id,
@ -706,7 +667,6 @@ impl StateResolution {
mut event: Option<Arc<E>>, mut event: Option<Arc<E>>,
mainline_map: &EventMap<usize>, mainline_map: &EventMap<usize>,
event_map: &mut EventMap<Arc<E>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>,
) -> Result<usize> { ) -> Result<usize> {
while let Some(sort_ev) = event { while let Some(sort_ev) = event {
tracing::debug!("mainline event_id {}", sort_ev.event_id().to_string()); tracing::debug!("mainline event_id {}", sort_ev.event_id().to_string());
@ -720,8 +680,7 @@ impl StateResolution {
event = None; event = None;
for aid in auth_events { for aid in auth_events {
// dbg!(&aid); // dbg!(&aid);
let aev = StateResolution::get_or_load_event(room_id, &aid, event_map, store) let aev = StateResolution::get_or_load_event(room_id, &aid, event_map)?;
.ok_or_else(|| Error::NotFound("Auth event not found".to_owned()))?;
if is_type_and_key(&aev, EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, EventType::RoomPowerLevels, "") {
event = Some(aev); event = Some(aev);
break; break;
@ -737,7 +696,6 @@ impl StateResolution {
graph: &mut BTreeMap<EventId, Vec<EventId>>, graph: &mut BTreeMap<EventId, Vec<EventId>>,
event_id: &EventId, event_id: &EventId,
event_map: &mut EventMap<Arc<E>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>,
auth_diff: &[EventId], auth_diff: &[EventId],
) { ) {
let mut state = vec![event_id.clone()]; let mut state = vec![event_id.clone()];
@ -747,7 +705,7 @@ impl StateResolution {
graph.entry(eid.clone()).or_insert_with(Vec::new); graph.entry(eid.clone()).or_insert_with(Vec::new);
// prefer the store to event as the store filters dedups the events // prefer the store to event as the store filters dedups the events
// otherwise it seems we can loop forever // otherwise it seems we can loop forever
for aid in &StateResolution::get_or_load_event(room_id, &eid, event_map, store) for aid in &StateResolution::get_or_load_event(room_id, &eid, event_map)
.unwrap() .unwrap()
.auth_events() .auth_events()
{ {
@ -763,27 +721,16 @@ impl StateResolution {
} }
} }
// TODO having the event_map as a field of self would allow us to keep /// Uses the `event_map` to return the full PDU or fails.
// cached state from `resolve` to `resolve` calls, good idea or not?
/// Uses the `event_map` to return the full PDU or fetches from the `StateStore` implementation
/// if the event_map does not have the PDU.
///
/// If the PDU is missing from the `event_map` it is added.
fn get_or_load_event<E: Event>( fn get_or_load_event<E: Event>(
room_id: &RoomId, _room_id: &RoomId,
ev_id: &EventId, ev_id: &EventId,
event_map: &mut EventMap<Arc<E>>, event_map: &EventMap<Arc<E>>,
store: &dyn StateStore<E>, ) -> Result<Arc<E>> {
) -> Option<Arc<E>> { event_map.get(ev_id).map_or_else(
if let Some(e) = event_map.get(ev_id) { || Err(Error::NotFound(format!("EventId: {:?} not found", ev_id))),
return Some(Arc::clone(e)); |e| Ok(Arc::clone(e)),
} )
if let Ok(e) = store.get_event(room_id, ev_id) {
return Some(Arc::clone(event_map.entry(ev_id.clone()).or_insert(e)));
}
None
} }
} }

View File

@ -4,6 +4,7 @@ use ruma::identifiers::{EventId, RoomId};
use crate::{Event, Result}; use crate::{Event, Result};
/// TODO: this is only used in testing on this branch now REMOVE
pub trait StateStore<E: Event> { pub trait StateStore<E: Event> {
/// Return a single event based on the EventId. /// Return a single event based on the EventId.
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>>; fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>>;

View File

@ -7,7 +7,7 @@ use ruma::{
use state_res::{is_power_event, StateMap}; use state_res::{is_power_event, StateMap};
mod utils; mod utils;
use utils::{room_id, TestStore, INITIAL_EVENTS}; use utils::{room_id, INITIAL_EVENTS};
fn shuffle(list: &mut [EventId]) { fn shuffle(list: &mut [EventId]) {
use rand::Rng; use rand::Rng;
@ -21,7 +21,6 @@ fn shuffle(list: &mut [EventId]) {
fn test_event_sort() { fn test_event_sort() {
let mut events = INITIAL_EVENTS(); let mut events = INITIAL_EVENTS();
let store = TestStore(events.clone());
let event_map = events let event_map = events
.values() .values()
@ -43,7 +42,6 @@ fn test_event_sort() {
&room_id(), &room_id(),
&power_events, &power_events,
&mut events, &mut events,
&store,
&auth_chain, &auth_chain,
); );
@ -55,7 +53,6 @@ fn test_event_sort() {
&sorted_power_events, &sorted_power_events,
&BTreeMap::new(), // unconflicted events &BTreeMap::new(), // unconflicted events
&mut events, &mut events,
&store,
) )
.expect("iterative auth check failed on resolved events"); .expect("iterative auth check failed on resolved events");
@ -71,7 +68,6 @@ fn test_event_sort() {
&events_to_sort, &events_to_sort,
power_level, power_level,
&mut events, &mut events,
&store,
); );
assert_eq!( assert_eq!(

View File

@ -7,7 +7,7 @@ use ruma::{
identifiers::{EventId, RoomVersionId}, identifiers::{EventId, RoomVersionId},
}; };
use serde_json::json; use serde_json::json;
use state_res::{StateMap, StateResolution}; use state_res::{EventMap, StateMap, StateResolution, StateStore};
mod utils; mod utils;
use utils::{ use utils::{
@ -36,46 +36,6 @@ fn ban_with_auth_chains() {
); );
} }
// Sanity check that the store is able to fetch auth chain and such
#[test]
fn base_with_auth_chains() {
let store = TestStore(INITIAL_EVENTS());
let mut ev_map = state_res::EventMap::default();
let resolved: BTreeMap<_, EventId> = match StateResolution::resolve(
&room_id(),
&RoomVersionId::Version6,
&[],
&mut ev_map,
&store,
) {
Ok(state) => state,
Err(e) => panic!("{}", e),
};
let resolved = resolved
.values()
.cloned()
.chain(INITIAL_EVENTS().values().map(|e| e.event_id().clone()))
.collect::<Vec<_>>();
let expected = vec![
"$CREATE:foo",
"$IJR:foo",
"$IPOWER:foo",
"$IMA:foo",
"$IMB:foo",
"$IMC:foo",
"START",
"END",
];
for id in expected.iter().map(|i| event_id(i)) {
// make sure our resolved events are equal to the expected list
assert!(resolved.iter().any(|eid| eid == &id), "{}", id)
}
assert_eq!(expected.len(), resolved.len())
}
#[test] #[test]
fn ban_with_auth_chains2() { fn ban_with_auth_chains2() {
let init = INITIAL_EVENTS(); let init = INITIAL_EVENTS();
@ -111,13 +71,21 @@ fn ban_with_auth_chains2() {
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let mut ev_map = state_res::EventMap::default(); let mut ev_map: EventMap<Arc<StateEvent>> = store.0.clone();
let resolved: StateMap<EventId> = match StateResolution::resolve( let state_sets = vec![state_set_a, state_set_b];
let resolved = match StateResolution::resolve::<StateEvent>(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version2,
&[state_set_a, state_set_b], &state_sets,
state_sets
.iter()
.map(|map| {
store
.auth_event_ids(&room_id(), &map.values().cloned().collect::<Vec<_>>())
.unwrap()
})
.collect(),
&mut ev_map, &mut ev_map,
&store,
) { ) {
Ok(state) => state, Ok(state) => state,
Err(e) => panic!("{}", e), Err(e) => panic!("{}", e),

View File

@ -6,7 +6,7 @@ use ruma::{
identifiers::{EventId, RoomVersionId}, identifiers::{EventId, RoomVersionId},
}; };
use serde_json::json; use serde_json::json;
use state_res::{StateMap, StateResolution}; use state_res::{StateMap, StateResolution, StateStore};
use tracing_subscriber as tracer; use tracing_subscriber as tracer;
mod utils; mod utils;
@ -265,13 +265,21 @@ fn test_event_map_none() {
// build up the DAG // build up the DAG
let (state_at_bob, state_at_charlie, expected) = store.set_up(); let (state_at_bob, state_at_charlie, expected) = store.set_up();
let mut ev_map = state_res::EventMap::default(); let mut ev_map: state_res::EventMap<Arc<StateEvent>> = store.0.clone();
let resolved = match StateResolution::resolve( let state_sets = vec![state_at_bob, state_at_charlie];
let resolved = match StateResolution::resolve::<StateEvent>(
&room_id(), &room_id(),
&RoomVersionId::Version2, &RoomVersionId::Version2,
&[state_at_bob, state_at_charlie], &state_sets,
state_sets
.iter()
.map(|map| {
store
.auth_event_ids(&room_id(), &map.values().cloned().collect::<Vec<_>>())
.unwrap()
})
.collect(),
&mut ev_map, &mut ev_map,
&store,
) { ) {
Ok(state) => state, Ok(state) => state,
Err(e) => panic!("{}", e), Err(e) => panic!("{}", e),

View File

@ -114,8 +114,15 @@ pub fn do_check(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version6,
&state_sets, &state_sets,
state_sets
.iter()
.map(|map| {
store
.auth_event_ids(&room_id(), &map.values().cloned().collect::<Vec<_>>())
.unwrap()
})
.collect(),
&mut event_map, &mut event_map,
&store,
); );
match resolved { match resolved {
Ok(state) => state, Ok(state) => state,
@ -565,7 +572,9 @@ pub mod event {
fn hashes(&self) -> &EventHash { fn hashes(&self) -> &EventHash {
self.hashes() self.hashes()
} }
fn signatures(&self) -> BTreeMap<Box<ServerName>, BTreeMap<ruma::ServerSigningKeyId, String>> { fn signatures(
&self,
) -> BTreeMap<Box<ServerName>, BTreeMap<ruma::ServerSigningKeyId, String>> {
self.signatures() self.signatures()
} }
fn unsigned(&self) -> &BTreeMap<String, JsonValue> { fn unsigned(&self) -> &BTreeMap<String, JsonValue> {
@ -678,7 +687,10 @@ pub mod event {
} }
impl StateEvent { impl StateEvent {
pub fn from_id_value(id: EventId, json: serde_json::Value) -> Result<Self, serde_json::Error> { pub fn from_id_value(
id: EventId,
json: serde_json::Value,
) -> Result<Self, serde_json::Error> {
Ok(Self::Full( Ok(Self::Full(
id, id,
Pdu::RoomV3Pdu(serde_json::from_value(json)?), Pdu::RoomV3Pdu(serde_json::from_value(json)?),
@ -806,7 +818,9 @@ pub mod event {
pub fn prev_event_ids(&self) -> Vec<EventId> { pub fn prev_event_ids(&self) -> Vec<EventId> {
match self { match self {
Self::Full(_, ev) => match ev { Self::Full(_, ev) => match ev {
Pdu::RoomV1Pdu(ev) => ev.prev_events.iter().map(|(id, _)| id).cloned().collect(), Pdu::RoomV1Pdu(ev) => {
ev.prev_events.iter().map(|(id, _)| id).cloned().collect()
}
Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(), Pdu::RoomV3Pdu(ev) => ev.prev_events.clone(),
}, },
} }
@ -815,7 +829,9 @@ pub mod event {
pub fn auth_events(&self) -> Vec<EventId> { pub fn auth_events(&self) -> Vec<EventId> {
match self { match self {
Self::Full(_, ev) => match ev { Self::Full(_, ev) => match ev {
Pdu::RoomV1Pdu(ev) => ev.auth_events.iter().map(|(id, _)| id).cloned().collect(), Pdu::RoomV1Pdu(ev) => {
ev.auth_events.iter().map(|(id, _)| id).cloned().collect()
}
Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(), Pdu::RoomV3Pdu(ev) => ev.auth_events.to_vec(),
}, },
} }
@ -936,5 +952,4 @@ pub mod event {
) )
} }
} }
}
}