diff --git a/Cargo.toml b/Cargo.toml index 90085ecd..6aef52eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,19 +14,25 @@ repository = "https://github.com/ruma/state-res" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] itertools = "0.9.0" -js_int = "0.1.8" -serde = { version = "1.0.114", features = ["derive"] } -serde_json = "1.0.56" -tracing = "0.1.16" +js_int = "0.1.9" +serde = { version = "1.0.115", features = ["derive"] } +serde_json = "1.0.57" +tracing = "0.1.19" maplit = "1.0.2" thiserror = "1.0.20" -tracing-subscriber = "0.2.8" +tracing-subscriber = "0.2.11" + +# [dependencies.ruma] +# git = "https://github.com/ruma/ruma" +# rev = "d5d2d1d893fa12d27960e4c58d6c09b215d06e95" +# features = ["client-api", "federation-api", "appservice-api"] [dependencies.ruma] -git = "https://github.com/ruma/ruma" -rev = "d5d2d1d893fa12d27960e4c58d6c09b215d06e95" +git = "https://github.com/timokoesters/ruma" +branch = "timo-fixes" features = ["client-api", "federation-api", "appservice-api"] + [dev-dependencies] criterion = "0.3.3" diff --git a/README.md b/README.md index f751df51..decc9685 100644 --- a/README.md +++ b/README.md @@ -47,4 +47,4 @@ trait StateStore { The `StateStore` trait is an abstraction around what ever database your server (or maybe even client) uses to store __P__[]()ersistant __D__[]()ata __U__[]()nits. -We use `ruma`s types when deserializing any PDU or it's contents which helps avoid a lot of type checking logic [synapse](https://github.com/matrix-org/synapse) must do while authenticating event chains. \ No newline at end of file +We use `ruma`s types when deserializing any PDU or it's contents which helps avoid a lot of type checking logic [synapse](https://github.com/matrix-org/synapse) must do while authenticating event chains. diff --git a/benches/state_res_bench.rs b/benches/state_res_bench.rs index 67e48450..7f501d47 100644 --- a/benches/state_res_bench.rs +++ b/benches/state_res_bench.rs @@ -40,8 +40,8 @@ fn lexico_topo_sort(c: &mut Criterion) { b.iter(|| { let resolver = StateResolution::default(); - let _ = resolver - .lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))); + let _ = + resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone())); }) }); } @@ -92,7 +92,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .collect::>(); let state_set_b = [ @@ -105,7 +105,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .collect::>(); b.iter(|| { @@ -142,7 +142,7 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_events(&self, events: &[EventId]) -> Result, String> { + fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result, String> { Ok(self .0 .borrow() @@ -153,7 +153,7 @@ impl StateStore for TestStore { .collect()) } - fn get_event(&self, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) @@ -178,7 +178,7 @@ impl StateStore for TestStore { result.push(ev_id.clone()); - let event = self.get_event(&ev_id).unwrap(); + let event = self.get_event(room_id, &ev_id).unwrap(); stack.extend(event.auth_events()); } @@ -232,7 +232,7 @@ impl TestStore { &[], &[], ); - let cre = create_event.event_id().unwrap().clone(); + let cre = create_event.event_id(); self.0 .borrow_mut() .insert(cre.clone(), create_event.clone()); @@ -248,7 +248,7 @@ impl TestStore { ); self.0 .borrow_mut() - .insert(alice_mem.event_id().unwrap().clone(), alice_mem.clone()); + .insert(alice_mem.event_id(), alice_mem.clone()); let join_rules = to_pdu_event( "IJR", @@ -256,12 +256,12 @@ impl TestStore { EventType::RoomJoinRules, Some(""), json!({ "join_rule": JoinRule::Public }), - &[cre.clone(), alice_mem.event_id().unwrap().clone()], - &[alice_mem.event_id().unwrap().clone()], + &[cre.clone(), alice_mem.event_id()], + &[alice_mem.event_id()], ); self.0 .borrow_mut() - .insert(join_rules.event_id().unwrap().clone(), join_rules.clone()); + .insert(join_rules.event_id(), join_rules.clone()); // Bob and Charlie join at the same time, so there is a fork // this will be represented in the state_sets when we resolve @@ -271,12 +271,12 @@ impl TestStore { EventType::RoomMember, Some(bob().to_string().as_str()), member_content_join(), - &[cre.clone(), join_rules.event_id().unwrap().clone()], - &[join_rules.event_id().unwrap().clone()], + &[cre.clone(), join_rules.event_id()], + &[join_rules.event_id()], ); self.0 .borrow_mut() - .insert(bob_mem.event_id().unwrap().clone(), bob_mem.clone()); + .insert(bob_mem.event_id(), bob_mem.clone()); let charlie_mem = to_pdu_event( "IMC", @@ -284,21 +284,21 @@ impl TestStore { EventType::RoomMember, Some(charlie().to_string().as_str()), member_content_join(), - &[cre, join_rules.event_id().unwrap().clone()], - &[join_rules.event_id().unwrap().clone()], + &[cre, join_rules.event_id()], + &[join_rules.event_id()], ); self.0 .borrow_mut() - .insert(charlie_mem.event_id().unwrap().clone(), charlie_mem.clone()); + .insert(charlie_mem.event_id(), charlie_mem.clone()); let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .map(|e| ((e.kind(), e.state_key()), e.event_id())) .collect::>(); let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .map(|e| ((e.kind(), e.state_key()), e.event_id())) .collect::>(); let expected = [ @@ -309,7 +309,7 @@ impl TestStore { &charlie_mem, ] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .map(|e| ((e.kind(), e.state_key()), e.event_id())) .collect::>(); (state_at_bob, state_at_charlie, expected) @@ -525,7 +525,7 @@ fn INITIAL_EVENTS() -> BTreeMap { ), ] .into_iter() - .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .map(|ev| (ev.event_id(), ev)) .collect() } @@ -571,6 +571,6 @@ fn BAN_STATE_SET() -> BTreeMap { ), ] .into_iter() - .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .map(|ev| (ev.event_id(), ev)) .collect() } diff --git a/src/event_auth.rs b/src/event_auth.rs index af1e68e4..747e4289 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -68,10 +68,7 @@ pub fn auth_check( auth_events: StateMap, do_sig_check: bool, ) -> Option { - tracing::info!( - "auth_check beginning for {}", - event.event_id().unwrap().as_str() - ); + tracing::info!("auth_check beginning for {}", event.event_id().as_str()); // don't let power from other rooms be used for auth_event in auth_events.values() { @@ -455,7 +452,7 @@ fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Opt tracing::debug!( "{} snd {} usr {}", - event.event_id().unwrap().to_string(), + event.event_id().to_string(), send_level, user_level ); @@ -630,7 +627,10 @@ fn check_redaction( } if let RoomVersionId::Version1 = room_version { - if redaction_event.event_id() == redaction_event.redacts() { + // are the redacter and redactee in the same domain + if Some(redaction_event.event_id().server_name()) + == redaction_event.redacts().map(|id| id.server_name()) + { return Some(RedactAllowed::OwnEvent); } } else { diff --git a/src/lib.rs b/src/lib.rs index e3085893..2a1e1c6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,6 +91,7 @@ impl StateResolution { // gather missing events for the event_map let events = store .get_events( + room_id, &all_conflicted .iter() // we only want the events we don't know about yet @@ -101,11 +102,7 @@ impl StateResolution { .unwrap(); // update event_map to include the fetched events - event_map.extend( - events - .into_iter() - .flat_map(|ev| Some((ev.event_id()?.clone(), ev))), - ); + event_map.extend(events.into_iter().map(|ev| (ev.event_id(), ev))); // at this point our event_map == store there should be no missing events tracing::debug!("event map size: {}", event_map.len()); @@ -114,10 +111,7 @@ impl StateResolution { if event.room_id() != Some(room_id) { return Err(Error::TempString(format!( "resolving event {} in room {}, when correct room is {}", - event - .event_id() - .map(|id| id.as_str()) - .unwrap_or("`unknown`"), + event.event_id(), event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"), room_id.as_str() ))); @@ -307,7 +301,7 @@ impl StateResolution { pub fn reverse_topological_power_sort( &self, room_id: &RoomId, - power_events: &[EventId], + events_to_sort: &[EventId], event_map: &mut EventMap, store: &dyn StateStore, auth_diff: &[EventId], @@ -315,7 +309,7 @@ impl StateResolution { tracing::debug!("reverse topological sort of power events"); let mut graph = BTreeMap::new(); - for (idx, event_id) in power_events.iter().enumerate() { + for (idx, event_id) in events_to_sort.iter().enumerate() { self.add_event_and_auth_chain_to_graph( room_id, &mut graph, event_id, event_map, store, auth_diff, ); @@ -347,10 +341,7 @@ impl StateResolution { let ev = event_map.get(event_id).unwrap(); let pl = event_to_pl.get(event_id).unwrap(); - tracing::debug!( - "{:?}", - (-*pl, *ev.origin_server_ts(), ev.event_id().cloned()) - ); + tracing::debug!("{:?}", (-*pl, *ev.origin_server_ts(), ev.event_id())); // count_0.sort_by(|(x, _), (y, _)| { // x.power_level @@ -361,7 +352,7 @@ impl StateResolution { // This return value is the key used for sorting events, // events are then sorted by power level, time, // and lexically by event_id. - (-*pl, *ev.origin_server_ts(), ev.event_id().cloned()) + (-*pl, *ev.origin_server_ts(), ev.event_id()) }) } @@ -374,7 +365,7 @@ impl StateResolution { key_fn: F, ) -> Vec where - F: Fn(&EventId) -> (i64, SystemTime, Option), + F: Fn(&EventId) -> (i64, SystemTime, EventId), { tracing::info!("starting lexicographical topological sort"); // NOTE: an event that has no incoming edges happened most recently, @@ -458,8 +449,8 @@ impl StateResolution { } if pl.is_none() { - for aid in store.get_event(event_id).unwrap().auth_events() { - if let Ok(aev) = store.get_event(&aid) { + for aid in store.get_event(room_id, event_id).unwrap().auth_events() { + if let Ok(aev) = store.get_event(room_id, &aid) { if aev.is_type_and_key(EventType::RoomCreate, "") { if let Ok(content) = aev .deserialize_content::() @@ -541,7 +532,8 @@ impl StateResolution { } } - tracing::debug!("event to check {:?}", event.event_id().unwrap().to_string()); + tracing::debug!("event to check {:?}", event.event_id().to_string()); + if event_auth::auth_check(room_version, &event, auth_events, false) .ok_or("Auth check failed due to deserialization most likely".to_string()) .map_err(Error::TempString)? @@ -657,14 +649,10 @@ impl StateResolution { store: &dyn StateStore, ) -> usize { while let Some(sort_ev) = event { - tracing::debug!( - "mainline event_id {}", - sort_ev.event_id().unwrap().to_string() - ); - if let Some(id) = sort_ev.event_id() { - if let Some(depth) = mainline_map.get(id) { - return *depth; - } + tracing::debug!("mainline event_id {}", sort_ev.event_id().to_string()); + let id = sort_ev.event_id(); + if let Some(depth) = mainline_map.get(&id) { + return *depth; } let auth_events = sort_ev.auth_events(); @@ -717,14 +705,14 @@ impl StateResolution { /// TODO update self if we go that route just as event_map will be updated fn _get_event( &self, - _room_id: &RoomId, + room_id: &RoomId, ev_id: &EventId, event_map: &mut EventMap, store: &dyn StateStore, ) -> Option { // TODO can we cut down on the clones? if !event_map.contains_key(ev_id) { - let event = store.get_event(ev_id).ok()?; + let event = store.get_event(room_id, ev_id).ok()?; event_map.insert(ev_id.clone(), event.clone()); Some(event) } else { diff --git a/src/state_event.rs b/src/state_event.rs index abcd2fbd..34465d14 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, convert::TryFrom}; use ruma::{ events::{ @@ -100,14 +100,27 @@ impl StateEvent { }, } } - pub fn event_id(&self) -> Option<&EventId> { - println!("{:?}", self); + pub fn event_id(&self) -> EventId { match self { Self::Full(ev) => match ev { - Pdu::RoomV1Pdu(ev) => Some(&ev.event_id), - Pdu::RoomV3Pdu(_) => None, + Pdu::RoomV1Pdu(ev) => ev.event_id.clone(), + Pdu::RoomV3Pdu(_) => EventId::try_from(&*format!( + "${}", + ruma::signatures::reference_hash( + &serde_json::to_value(&ev).expect("event is valid, we just created it") + ) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"), }, - Self::Sync(_) => None, + Self::Sync(ev) => EventId::try_from(&*format!( + "${}", + ruma::signatures::reference_hash( + &serde_json::to_value(&ev).expect("event is valid, we just created it") + ) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"), } } @@ -214,6 +227,21 @@ impl StateEvent { } } + pub fn unsigned(&self) -> &BTreeMap { + // CONFIRM: The only way this would fail is if we got bad json, it should fail in ruma + // before it fails here. + match self { + Self::Full(ev) => match ev { + Pdu::RoomV1Pdu(ev) => &ev.unsigned, + Pdu::RoomV3Pdu(ev) => &ev.unsigned, + }, + Self::Sync(ev) => match ev { + PduStub::RoomV1PduStub(ev) => &ev.unsigned, + PduStub::RoomV3PduStub(ev) => &ev.unsigned, + }, + } + } + pub fn signatures(&self) -> BTreeMap, BTreeMap> { match self { Self::Full(ev) => match ev { diff --git a/src/state_store.rs b/src/state_store.rs index 7881883a..80621bf0 100644 --- a/src/state_store.rs +++ b/src/state_store.rs @@ -1,25 +1,83 @@ +use std::collections::BTreeSet; + use ruma::identifiers::{EventId, RoomId}; use crate::StateEvent; pub trait StateStore { /// Return a single event based on the EventId. - fn get_event(&self, event_id: &EventId) -> Result; + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result; /// Returns the events that correspond to the `event_ids` sorted in the same order. - fn get_events(&self, event_ids: &[EventId]) -> Result, String>; + fn get_events( + &self, + room_id: &RoomId, + event_ids: &[EventId], + ) -> Result, String> { + let mut events = vec![]; + for id in event_ids { + events.push(self.get_event(room_id, id)?); + } + Ok(events) + } /// Returns a Vec of the related auth events to the given `event`. fn auth_event_ids( &self, room_id: &RoomId, event_ids: &[EventId], - ) -> Result, String>; + ) -> Result, String> { + let mut result = vec![]; + let mut stack = event_ids.to_vec(); + + // DFS for auth event chain + while !stack.is_empty() { + let ev_id = stack.pop().unwrap(); + if result.contains(&ev_id) { + continue; + } + + result.push(ev_id.clone()); + + let event = self.get_event(room_id, &ev_id).unwrap(); + + stack.extend(event.auth_events()); + } + + Ok(result) + } /// Returns a Vec representing the difference in auth chains of the given `events`. fn auth_chain_diff( &self, room_id: &RoomId, - event_id: Vec>, - ) -> Result, String>; + event_ids: Vec>, + ) -> Result, String> { + let mut chains = vec![]; + for ids in event_ids { + // TODO state store `auth_event_ids` returns self in the event ids list + // when an event returns `auth_event_ids` self is not contained + let chain = self + .auth_event_ids(room_id, &ids)? + .into_iter() + .collect::>(); + chains.push(chain); + } + + if let Some(chain) = chains.first() { + let rest = chains.iter().skip(1).flatten().cloned().collect(); + let common = chain.intersection(&rest).collect::>(); + + Ok(chains + .iter() + .flatten() + .filter(|id| !common.contains(&id)) + .cloned() + .collect::>() + .into_iter() + .collect()) + } else { + Ok(vec![]) + } + } } diff --git a/tests/res_with_auth_ids.rs b/tests/res_with_auth_ids.rs index 71bdfdd6..94e35ba3 100644 --- a/tests/res_with_auth_ids.rs +++ b/tests/res_with_auth_ids.rs @@ -41,7 +41,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: INITIAL_EVENTS() .values() .chain(events) - .map(|ev| (ev.event_id().unwrap().clone(), ev.clone())) + .map(|ev| (ev.event_id(), ev.clone())) .collect(), )); @@ -53,8 +53,8 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // create the DB of events that led up to this point // TODO maybe clean up some of these clones it is just tests but... for ev in INITIAL_EVENTS().values().chain(events) { - graph.insert(ev.event_id().unwrap().clone(), vec![]); - fake_event_map.insert(ev.event_id().unwrap().clone(), ev.clone()); + graph.insert(ev.event_id().clone(), vec![]); + fake_event_map.insert(ev.event_id().clone(), ev.clone()); } for pair in INITIAL_EDGES().windows(2) { @@ -78,11 +78,10 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // resolve the current state and add it to the state_at_event map then continue // on in "time" - for node in - resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))) + for node in resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone())) { let fake_event = fake_event_map.get(&node).unwrap(); - let event_id = fake_event.event_id().unwrap(); + let event_id = fake_event.event_id(); let prev_events = graph.get(&node).unwrap(); @@ -152,9 +151,9 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // TODO The event is just remade, adding the auth_events and prev_events here // UPDATE: the `to_pdu_event` was split into `init` and the fn below, could be better let e = fake_event; - let ev_id = e.event_id().unwrap(); + let ev_id = e.event_id(); let event = to_pdu_event( - &e.event_id().unwrap().to_string(), + &e.event_id().to_string(), e.sender().clone(), e.kind(), e.state_key().as_deref(), @@ -168,7 +167,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // TODO // TODO we need to convert the `StateResolution::resolve` to use the event_map // because the user of this crate cannot update their DB's state. - *store.0.borrow_mut().get_mut(ev_id).unwrap() = event.clone(); + *store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone(); state_at_event.insert(node, state_after); event_map.insert(event_id.clone(), event); @@ -206,7 +205,7 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_events(&self, events: &[EventId]) -> Result, String> { + fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result, String> { Ok(self .0 .borrow() @@ -217,7 +216,7 @@ impl StateStore for TestStore { .collect()) } - fn get_event(&self, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) @@ -242,7 +241,7 @@ impl StateStore for TestStore { result.push(ev_id.clone()); - let event = self.get_event(&ev_id).unwrap(); + let event = self.get_event(room_id, &ev_id).unwrap(); stack.extend(event.auth_events()); } @@ -504,7 +503,7 @@ fn INITIAL_EVENTS() -> BTreeMap { ), ] .into_iter() - .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .map(|ev| (ev.event_id(), ev)) .collect() } @@ -558,7 +557,7 @@ fn BAN_STATE_SET() -> BTreeMap { ), ] .into_iter() - .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .map(|ev| (ev.event_id(), ev)) .collect() } @@ -599,11 +598,7 @@ fn base_with_auth_chains() { let resolved = resolved .values() .cloned() - .chain( - INITIAL_EVENTS() - .values() - .map(|e| e.event_id().unwrap().clone()), - ) + .chain(INITIAL_EVENTS().values().map(|e| e.event_id())) .collect::>(); let expected = vec![ @@ -644,7 +639,7 @@ fn ban_with_auth_chains2() { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .collect::>(); let state_set_b = [ @@ -657,7 +652,7 @@ fn ban_with_auth_chains2() { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id())) .collect::>(); let resolved: StateMap = match resolver.resolve( @@ -725,7 +720,7 @@ fn JOIN_RULE() -> BTreeMap { ), ] .into_iter() - .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .map(|ev| (ev.event_id(), ev)) .collect() } diff --git a/tests/state_res.rs b/tests/state_res.rs index 80bf3500..399133d6 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -273,7 +273,7 @@ fn INITIAL_EVENTS() -> BTreeMap { to_init_pdu_event("END", zera(), EventType::RoomMessage, None, json!({})), ] .into_iter() - .map(|ev| (ev.event_id().unwrap().clone(), ev)) + .map(|ev| (ev.event_id(), ev)) .collect() } @@ -301,7 +301,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: INITIAL_EVENTS() .values() .chain(events) - .map(|ev| (ev.event_id().unwrap().clone(), ev.clone())) + .map(|ev| (ev.event_id(), ev.clone())) .collect(), )); @@ -313,8 +313,8 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // create the DB of events that led up to this point // TODO maybe clean up some of these clones it is just tests but... for ev in INITIAL_EVENTS().values().chain(events) { - graph.insert(ev.event_id().unwrap().clone(), vec![]); - fake_event_map.insert(ev.event_id().unwrap().clone(), ev.clone()); + graph.insert(ev.event_id(), vec![]); + fake_event_map.insert(ev.event_id(), ev.clone()); } for pair in INITIAL_EDGES().windows(2) { @@ -338,11 +338,10 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // resolve the current state and add it to the state_at_event map then continue // on in "time" - for node in - resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))) + for node in resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone())) { let fake_event = fake_event_map.get(&node).unwrap(); - let event_id = fake_event.event_id().unwrap(); + let event_id = fake_event.event_id(); let prev_events = graph.get(&node).unwrap(); @@ -412,9 +411,9 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // TODO The event is just remade, adding the auth_events and prev_events here // UPDATE: the `to_pdu_event` was split into `init` and the fn below, could be better let e = fake_event; - let ev_id = e.event_id().unwrap(); + let ev_id = e.event_id(); let event = to_pdu_event( - &e.event_id().unwrap().to_string(), + &e.event_id().to_string(), e.sender().clone(), e.kind(), e.state_key().as_deref(), @@ -428,7 +427,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: // TODO // TODO we need to convert the `StateResolution::resolve` to use the event_map // because the user of this crate cannot update their DB's state. - *store.0.borrow_mut().get_mut(ev_id).unwrap() = event.clone(); + *store.0.borrow_mut().get_mut(&ev_id).unwrap() = event.clone(); state_at_event.insert(node, state_after); event_map.insert(event_id.clone(), event); @@ -742,8 +741,7 @@ fn test_lexicographical_sort() { event_id("p") => vec![event_id("o")], }; - let res = - resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, Some(id.clone()))); + let res = resolver.lexicographical_topological_sort(&graph, |id| (0, UNIX_EPOCH, id.clone())); assert_eq!( vec!["o", "l", "n", "m", "p"], @@ -763,7 +761,7 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_events(&self, events: &[EventId]) -> Result, String> { + fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result, String> { Ok(self .0 .borrow() @@ -774,7 +772,7 @@ impl StateStore for TestStore { .collect()) } - fn get_event(&self, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) @@ -799,7 +797,7 @@ impl StateStore for TestStore { result.push(ev_id.clone()); - let event = self.get_event(&ev_id).unwrap(); + let event = self.get_event(room_id, &ev_id).unwrap(); stack.extend(event.auth_events()); } @@ -860,7 +858,7 @@ impl TestStore { &[], &[], ); - let cre = create_event.event_id().unwrap().clone(); + let cre = create_event.event_id(); self.0 .borrow_mut() .insert(cre.clone(), create_event.clone()); @@ -876,7 +874,7 @@ impl TestStore { ); self.0 .borrow_mut() - .insert(alice_mem.event_id().unwrap().clone(), alice_mem.clone()); + .insert(alice_mem.event_id(), alice_mem.clone()); let join_rules = to_pdu_event( "IJR", @@ -884,12 +882,12 @@ impl TestStore { EventType::RoomJoinRules, Some(""), json!({ "join_rule": JoinRule::Public }), - &[cre.clone(), alice_mem.event_id().unwrap().clone()], - &[alice_mem.event_id().unwrap().clone()], + &[cre.clone(), alice_mem.event_id()], + &[alice_mem.event_id()], ); self.0 .borrow_mut() - .insert(join_rules.event_id().unwrap().clone(), join_rules.clone()); + .insert(join_rules.event_id(), join_rules.clone()); // Bob and Charlie join at the same time, so there is a fork // this will be represented in the state_sets when we resolve @@ -899,12 +897,12 @@ impl TestStore { EventType::RoomMember, Some(bob().to_string().as_str()), member_content_join(), - &[cre.clone(), join_rules.event_id().unwrap().clone()], - &[join_rules.event_id().unwrap().clone()], + &[cre.clone(), join_rules.event_id()], + &[join_rules.event_id()], ); self.0 .borrow_mut() - .insert(bob_mem.event_id().unwrap().clone(), bob_mem.clone()); + .insert(bob_mem.event_id(), bob_mem.clone()); let charlie_mem = to_pdu_event( "IMC", @@ -912,21 +910,21 @@ impl TestStore { EventType::RoomMember, Some(charlie().to_string().as_str()), member_content_join(), - &[cre, join_rules.event_id().unwrap().clone()], - &[join_rules.event_id().unwrap().clone()], + &[cre, join_rules.event_id()], + &[join_rules.event_id()], ); self.0 .borrow_mut() - .insert(charlie_mem.event_id().unwrap().clone(), charlie_mem.clone()); + .insert(charlie_mem.event_id(), charlie_mem.clone()); let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .map(|e| ((e.kind(), e.state_key()), e.event_id())) .collect::>(); let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .map(|e| ((e.kind(), e.state_key()), e.event_id())) .collect::>(); let expected = [ @@ -937,7 +935,7 @@ impl TestStore { &charlie_mem, ] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .map(|e| ((e.kind(), e.state_key()), e.event_id())) .collect::>(); (state_at_bob, state_at_charlie, expected)