From b3cea6b998757f98bb27523ba8033339a4d7f8c6 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 9 Apr 2022 00:55:14 +0200 Subject: [PATCH] state-res: Use StateEventType over RoomEventType where applicable --- .../ruma-state-res/benches/state_res_bench.rs | 38 +++--- crates/ruma-state-res/src/event_auth.rs | 116 ++++++++---------- crates/ruma-state-res/src/lib.rs | 68 ++++++---- crates/ruma-state-res/src/test_utils.rs | 27 ++-- 4 files changed, 119 insertions(+), 130 deletions(-) diff --git a/crates/ruma-state-res/benches/state_res_bench.rs b/crates/ruma-state-res/benches/state_res_bench.rs index 86cf0240..47110b0a 100644 --- a/crates/ruma-state-res/benches/state_res_bench.rs +++ b/crates/ruma-state-res/benches/state_res_bench.rs @@ -28,7 +28,7 @@ use ruma_common::{ join_rules::{JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, }, - RoomEventType, + RoomEventType, StateEventType, }, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, UserId, }; @@ -104,10 +104,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { ] .iter() .map(|ev| { - ( - (ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), - ev.event_id().to_owned(), - ) + (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id().to_owned()) }) .collect::>(); @@ -122,10 +119,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { ] .iter() .map(|ev| { - ( - (ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), - ev.event_id().to_owned(), - ) + (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id().to_owned()) }) .collect::>(); @@ -298,30 +292,21 @@ impl TestStore { let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() .map(|e| { - ( - (e.event_type().to_owned(), e.state_key().unwrap().to_owned()), - e.event_id().to_owned(), - ) + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) }) .collect::>(); let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] .iter() .map(|e| { - ( - (e.event_type().to_owned(), e.state_key().unwrap().to_owned()), - e.event_id().to_owned(), - ) + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) }) .collect::>(); let expected = [&create_event, &alice_mem, &join_rules, &bob_mem, &charlie_mem] .iter() .map(|e| { - ( - (e.event_type().to_owned(), e.state_key().unwrap().to_owned()), - e.event_id().to_owned(), - ) + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) }) .collect::>(); @@ -532,6 +517,17 @@ fn BAN_STATE_SET() -> HashMap, Arc> { .collect() } +/// Convenience trait for adding event type plus state key to state maps. +trait EventTypeExt { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); +} + +impl EventTypeExt for &RoomEventType { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + (self.to_string().into(), state_key.into()) + } +} + mod event { use ruma_common::{ events::{pdu::Pdu, RoomEventType}, diff --git a/crates/ruma-state-res/src/event_auth.rs b/crates/ruma-state-res/src/event_auth.rs index 68339975..475e717a 100644 --- a/crates/ruma-state-res/src/event_auth.rs +++ b/crates/ruma-state-res/src/event_auth.rs @@ -10,7 +10,7 @@ use ruma_common::{ power_levels::RoomPowerLevelsEventContent, third_party_invite::RoomThirdPartyInviteEventContent, }, - RoomEventType, + RoomEventType, StateEventType, }, serde::{Base64, Raw}, RoomVersionId, UserId, @@ -49,15 +49,15 @@ pub fn auth_types_for_event( sender: &UserId, state_key: Option<&str>, content: &RawJsonValue, -) -> serde_json::Result> { +) -> serde_json::Result> { if kind == &RoomEventType::RoomCreate { return Ok(vec![]); } let mut auth_types = vec![ - (RoomEventType::RoomPowerLevels, "".to_owned()), - (RoomEventType::RoomMember, sender.to_string()), - (RoomEventType::RoomCreate, "".to_owned()), + (StateEventType::RoomPowerLevels, "".to_owned()), + (StateEventType::RoomMember, sender.to_string()), + (StateEventType::RoomCreate, "".to_owned()), ]; if kind == &RoomEventType::RoomMember { @@ -75,7 +75,7 @@ pub fn auth_types_for_event( if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock] .contains(&membership) { - let key = (RoomEventType::RoomJoinRules, "".to_owned()); + let key = (StateEventType::RoomJoinRules, "".to_owned()); if !auth_types.contains(&key) { auth_types.push(key); } @@ -83,21 +83,21 @@ pub fn auth_types_for_event( if let Some(Ok(u)) = content.join_authorised_via_users_server.map(|m| m.deserialize()) { - let key = (RoomEventType::RoomMember, u.to_string()); + let key = (StateEventType::RoomMember, u.to_string()); if !auth_types.contains(&key) { auth_types.push(key); } } } - let key = (RoomEventType::RoomMember, state_key.to_owned()); + let key = (StateEventType::RoomMember, state_key.to_owned()); if !auth_types.contains(&key) { auth_types.push(key); } if membership == MembershipState::Invite { if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) { - let key = (RoomEventType::RoomThirdPartyInvite, t_id.signed.token); + let key = (StateEventType::RoomThirdPartyInvite, t_id.signed.token); if !auth_types.contains(&key) { auth_types.push(key); } @@ -974,7 +974,7 @@ mod tests { }, member::{MembershipState, RoomMemberEventContent}, }, - RoomEventType, + RoomEventType, StateEventType, }; use serde_json::value::to_raw_value as to_raw_json_value; @@ -984,7 +984,7 @@ mod tests { alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id, to_pdu_event, PduEvent, INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, }, - Event, RoomVersion, StateMap, + Event, EventTypeExt, RoomVersion, StateMap, }; #[test] @@ -995,9 +995,7 @@ mod tests { let auth_events = events .values() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), Arc::clone(ev)) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) .collect::>(); let requester = to_pdu_event( @@ -1017,16 +1015,16 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V6, &target_user, - fetch_state(RoomEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.to_string()), &sender, - fetch_state(RoomEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.to_string()), &requester, None::, - fetch_state(RoomEventType::RoomPowerLevels, "".to_owned()), - fetch_state(RoomEventType::RoomJoinRules, "".to_owned()), + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), None, &MembershipState::Leave, - fetch_state(RoomEventType::RoomCreate, "".to_owned()).unwrap(), + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), ) .unwrap()); } @@ -1039,9 +1037,7 @@ mod tests { let auth_events = events .values() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), Arc::clone(ev)) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) .collect::>(); let requester = to_pdu_event( @@ -1061,16 +1057,16 @@ mod tests { assert!(!valid_membership_change( &RoomVersion::V6, &target_user, - fetch_state(RoomEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.to_string()), &sender, - fetch_state(RoomEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.to_string()), &requester, None::, - fetch_state(RoomEventType::RoomPowerLevels, "".to_owned()), - fetch_state(RoomEventType::RoomJoinRules, "".to_owned()), + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), None, &MembershipState::Leave, - fetch_state(RoomEventType::RoomCreate, "".to_owned()).unwrap(), + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), ) .unwrap()); } @@ -1083,9 +1079,7 @@ mod tests { let auth_events = events .values() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), Arc::clone(ev)) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) .collect::>(); let requester = to_pdu_event( @@ -1105,16 +1099,16 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V6, &target_user, - fetch_state(RoomEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.to_string()), &sender, - fetch_state(RoomEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.to_string()), &requester, None::, - fetch_state(RoomEventType::RoomPowerLevels, "".to_owned()), - fetch_state(RoomEventType::RoomJoinRules, "".to_owned()), + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), None, &MembershipState::Leave, - fetch_state(RoomEventType::RoomCreate, "".to_owned()).unwrap(), + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), ) .unwrap()); } @@ -1127,9 +1121,7 @@ mod tests { let auth_events = events .values() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), Arc::clone(ev)) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) .collect::>(); let requester = to_pdu_event( @@ -1149,16 +1141,16 @@ mod tests { assert!(!valid_membership_change( &RoomVersion::V6, &target_user, - fetch_state(RoomEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.to_string()), &sender, - fetch_state(RoomEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.to_string()), &requester, None::, - fetch_state(RoomEventType::RoomPowerLevels, "".to_owned()), - fetch_state(RoomEventType::RoomJoinRules, "".to_owned()), + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), None, &MembershipState::Leave, - fetch_state(RoomEventType::RoomCreate, "".to_owned()).unwrap(), + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), ) .unwrap()); } @@ -1188,9 +1180,7 @@ mod tests { let auth_events = events .values() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), Arc::clone(ev)) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) .collect::>(); let requester = to_pdu_event( @@ -1210,32 +1200,32 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V9, &target_user, - fetch_state(RoomEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.to_string()), &sender, - fetch_state(RoomEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.to_string()), &requester, None::, - fetch_state(RoomEventType::RoomPowerLevels, "".to_owned()), - fetch_state(RoomEventType::RoomJoinRules, "".to_owned()), + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), Some(&alice()), &MembershipState::Join, - fetch_state(RoomEventType::RoomCreate, "".to_owned()).unwrap(), + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), ) .unwrap()); assert!(!valid_membership_change( &RoomVersion::V9, &target_user, - fetch_state(RoomEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.to_string()), &sender, - fetch_state(RoomEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.to_string()), &requester, None::, - fetch_state(RoomEventType::RoomPowerLevels, "".to_owned()), - fetch_state(RoomEventType::RoomJoinRules, "".to_owned()), + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), Some(&ella()), &MembershipState::Leave, - fetch_state(RoomEventType::RoomCreate, "".to_owned()).unwrap(), + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), ) .unwrap()); } @@ -1257,9 +1247,7 @@ mod tests { let auth_events = events .values() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), Arc::clone(ev)) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) .collect::>(); let requester = to_pdu_event( @@ -1279,16 +1267,16 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V7, &target_user, - fetch_state(RoomEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.to_string()), &sender, - fetch_state(RoomEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.to_string()), &requester, None::, - fetch_state(RoomEventType::RoomPowerLevels, "".to_owned()), - fetch_state(RoomEventType::RoomJoinRules, "".to_owned()), + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), None, &MembershipState::Leave, - fetch_state(RoomEventType::RoomCreate, "".to_owned()).unwrap(), + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), ) .unwrap()); } diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 5f0a5d1f..72162897 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -10,7 +10,7 @@ use js_int::{int, Int}; use ruma_common::{ events::{ room::member::{MembershipState, RoomMemberEventContent}, - RoomEventType, + RoomEventType, StateEventType, }, EventId, MilliSecondsSinceUnixEpoch, RoomVersionId, UserId, }; @@ -31,7 +31,7 @@ pub use room_version::RoomVersion; pub use state_event::Event; /// A mapping of event type and state_key to some value `T`, usually an `EventId`. -pub type StateMap = HashMap<(RoomEventType, String), T>; +pub type StateMap = HashMap<(StateEventType, String), T>; /// Resolve sets of state events as they come in. /// @@ -132,7 +132,7 @@ where trace!("{:?}", events_to_resolve); // This "epochs" power level event - let power_event = resolved_control.get(&(RoomEventType::RoomPowerLevels, "".into())); + let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, "".into())); debug!("power event: {:?}", power_event); @@ -418,20 +418,15 @@ fn iterative_auth_check( .state_key() .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; - let mut auth_events = HashMap::new(); + let mut auth_events = StateMap::new(); for aid in event.auth_events() { if let Some(ev) = fetch_event(aid.borrow()) { // TODO synapse check "rejected_reason" which is most likely // related to soft-failing auth_events.insert( - ( - ev.event_type().to_owned(), - ev.state_key() - .ok_or_else(|| { - Error::InvalidPdu("State event had no state key".to_owned()) - })? - .to_owned(), - ), + ev.event_type().with_state_key(ev.state_key().ok_or_else(|| { + Error::InvalidPdu("State event had no state key".to_owned()) + })?), ev, ); } else { @@ -462,11 +457,10 @@ fn iterative_auth_check( }); if auth_check(room_version, &event, current_third_party, |ty, key| { - auth_events.get(&(ty.clone(), key.to_owned())) + auth_events.get(&ty.with_state_key(key)) })? { // add event to resolved state map - resolved_state - .insert((event.event_type().to_owned(), state_key.to_owned()), event_id.clone()); + resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); } else { // synapse passes here on AuthError. We do not add this event to resolved_state. warn!("event {} failed the authentication check", event_id); @@ -632,6 +626,32 @@ fn is_power_event(event: impl Event) -> bool { } } +/// Convenience trait for adding event type plus state key to state maps. +trait EventTypeExt { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); +} + +impl EventTypeExt for StateEventType { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + (self, state_key.into()) + } +} + +impl EventTypeExt for RoomEventType { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + (self.to_string().into(), state_key.into()) + } +} + +impl EventTypeExt for &T +where + T: EventTypeExt + Clone, +{ + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + self.to_owned().with_state_key(state_key) + } +} + #[cfg(test)] mod tests { use std::{ @@ -645,7 +665,7 @@ mod tests { use ruma_common::{ events::{ room::join_rules::{JoinRule, RoomJoinRulesEventContent}, - RoomEventType, + RoomEventType, StateEventType, }, EventId, MilliSecondsSinceUnixEpoch, RoomVersionId, }; @@ -659,7 +679,7 @@ mod tests { alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join, room_id, to_init_pdu_event, to_pdu_event, zara, PduEvent, TestStore, INITIAL_EVENTS, }, - Event, StateMap, + Event, EventTypeExt, StateMap, }; fn test_event_sort() { @@ -669,9 +689,7 @@ mod tests { let event_map = events .values() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), ev.clone()) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone())) .collect::>(); let auth_chain: HashSet> = HashSet::new(); @@ -702,7 +720,7 @@ mod tests { events_to_sort.shuffle(&mut rand::thread_rng()); let power_level = - resolved_power.get(&(RoomEventType::RoomPowerLevels, "".to_owned())).cloned(); + resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned(); let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).map(Arc::clone)) @@ -1131,9 +1149,7 @@ mod tests { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), ev.event_id.clone()) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone())) .collect::>(); let state_set_b = [ @@ -1146,9 +1162,7 @@ mod tests { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| { - ((ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()), ev.event_id.clone()) - }) + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone())) .collect::>(); let ev_map = store.0.clone(); diff --git a/crates/ruma-state-res/src/test_utils.rs b/crates/ruma-state-res/src/test_utils.rs index 2cb5b592..cbcc772c 100644 --- a/crates/ruma-state-res/src/test_utils.rs +++ b/crates/ruma-state-res/src/test_utils.rs @@ -27,7 +27,7 @@ use serde_json::{ }; use tracing::info; -use crate::{auth_types_for_event, Error, Event, Result, StateMap}; +use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap}; pub use event::PduEvent; @@ -130,9 +130,9 @@ pub fn do_check( let mut state_after = state_before.clone(); - let ty = fake_event.event_type().to_owned(); - let key = fake_event.state_key().unwrap().to_owned(); - state_after.insert((ty, key), event_id.to_owned()); + let ty = fake_event.event_type(); + let key = fake_event.state_key().unwrap(); + state_after.insert(ty.with_state_key(key), event_id.to_owned()); let auth_types = auth_types_for_event( fake_event.event_type(), @@ -181,7 +181,7 @@ pub fn do_check( ) }); - let key = (ev.event_type().to_owned(), ev.state_key().unwrap().to_owned()); + let key = ev.event_type().with_state_key(ev.state_key().unwrap()); expected_state.insert(key, node); } @@ -198,7 +198,7 @@ pub fn do_check( // Filter out the dummy messages events. // These act as points in time where there should be a known state to // test against. - && **k != (RoomEventType::RoomMessage, "dummy".to_owned()) + && **k != ("m.room.message".into(), "dummy".to_owned()) }) .map(|(k, v)| (k.clone(), v.clone())) .collect::>>(); @@ -310,30 +310,21 @@ impl TestStore { let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() .map(|e| { - ( - (e.event_type().to_owned(), e.state_key().unwrap().to_owned()), - e.event_id().to_owned(), - ) + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) }) .collect::>(); let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] .iter() .map(|e| { - ( - (e.event_type().to_owned(), e.state_key().unwrap().to_owned()), - e.event_id().to_owned(), - ) + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) }) .collect::>(); let expected = [&create_event, &alice_mem, &join_rules, &bob_mem, &charlie_mem] .iter() .map(|e| { - ( - (e.event_type().to_owned(), e.state_key().unwrap().to_owned()), - e.event_id().to_owned(), - ) + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) }) .collect::>();