diff --git a/benches/state_res_bench.rs b/benches/state_res_bench.rs index ff5c37cc..86d55ff1 100644 --- a/benches/state_res_bench.rs +++ b/benches/state_res_bench.rs @@ -48,11 +48,12 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) { let (state_at_bob, state_at_charlie, _) = store.set_up(); b.iter(|| { + let mut ev_map = state_res::EventMap::default(); let _resolved = match StateResolution::resolve( &room_id(), &RoomVersionId::Version6, &[state_at_bob.clone(), state_at_charlie.clone()], - None, + &mut ev_map, &store, ) { Ok(state) => state, @@ -102,7 +103,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { &room_id(), &RoomVersionId::Version6, &[state_set_a.clone(), state_set_b.clone()], - Some(inner.clone()), + &mut inner, &store, ) { Ok(state) => state, diff --git a/src/event_auth.rs b/src/event_auth.rs index 74b363ce..eaf6ef57 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -193,7 +193,9 @@ pub fn auth_check( } if !valid_membership_change( - incoming_event, + incoming_event.state_key().as_deref(), + incoming_event.sender(), + incoming_event.content(), prev_event, current_third_party_invite, &auth_events, @@ -271,25 +273,27 @@ pub fn auth_check( /// this is generated by calling `auth_types_for_event` with the membership event and /// the current State. pub fn valid_membership_change( - user: &Arc, + user_state_key: Option<&str>, + user_sender: &UserId, + content: serde_json::Value, prev_event: Option>, current_third_party_invite: Option>, auth_events: &StateMap>, ) -> Result { - let state_key = if let Some(s) = user.state_key() { + let state_key = if let Some(s) = user_state_key { s } else { return Err(Error::InvalidPdu("State event requires state_key".into())); }; - let content = serde_json::from_value::(user.content())?; + let content = serde_json::from_value::(content)?; let target_membership = content.membership; - let target_user_id = UserId::try_from(state_key.as_str()) - .map_err(|e| Error::ConversionError(format!("{}", e)))?; + let target_user_id = + UserId::try_from(state_key).map_err(|e| Error::ConversionError(format!("{}", e)))?; - let key = (EventType::RoomMember, Some(user.sender().to_string())); + let key = (EventType::RoomMember, Some(user_sender.to_string())); let sender = auth_events.get(&key); let sender_membership = sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { @@ -333,7 +337,7 @@ pub fn valid_membership_change( }, )?; - let sender_power = power_levels.users.get(&user.sender()).map_or_else( + let sender_power = power_levels.users.get(user_sender).map_or_else( || { if sender_membership != member::MembershipState::Join { None @@ -372,7 +376,7 @@ pub fn valid_membership_change( } Ok(if target_membership == MembershipState::Join { - if user.sender() != &target_user_id { + if user_sender != &target_user_id { false } else if let MembershipState::Ban = current_membership { false @@ -388,7 +392,12 @@ pub fn valid_membership_change( if current_membership == MembershipState::Ban { false } else { - verify_third_party_invite(user, &tp_id, current_third_party_invite) + verify_third_party_invite( + Some(state_key), + user_sender, + &tp_id, + current_third_party_invite, + ) } } else if sender_membership != MembershipState::Join || current_membership == MembershipState::Join @@ -401,7 +410,7 @@ pub fn valid_membership_change( .is_some() } } else if target_membership == MembershipState::Leave { - if user.sender() == &target_user_id { + if user_sender == &target_user_id { current_membership == MembershipState::Join || current_membership == MembershipState::Invite } else if sender_membership != MembershipState::Join @@ -791,14 +800,15 @@ pub fn can_send_invite(event: &Arc, auth_events: &StateMap>) } pub fn verify_third_party_invite( - event: &Arc, + user_state_key: Option<&str>, + sender: &UserId, tp_id: &member::ThirdPartyInvite, current_third_party_invite: Option>, ) -> bool { // 1. check for user being banned happens before this is called // checking for mxid and token keys is done by ruma when deserializing - if event.state_key() != Some(tp_id.signed.mxid.to_string()) { + if user_state_key != Some(tp_id.signed.mxid.as_str()) { return false; } @@ -809,7 +819,7 @@ pub fn verify_third_party_invite( return false; } - if event.sender() != current_tpid.sender() { + if sender != current_tpid.sender() { return false; } diff --git a/src/lib.rs b/src/lib.rs index f2c92753..de68bac3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,17 +46,12 @@ impl StateResolution { room_version: &RoomVersionId, incoming_event: Arc, current_state: &StateMap, - event_map: Option>>, + event_map: &mut EventMap>, store: &dyn StateStore, ) -> Result { tracing::info!("Applying a single event, state resolution starting"); let ev = incoming_event; - let mut event_map = if let Some(ev_map) = event_map { - ev_map - } else { - EventMap::new() - }; let prev_event = if let Some(id) = ev.prev_events().first() { store.get_event(room_id, id).ok() } else { @@ -69,7 +64,7 @@ impl StateResolution { { if let Some(ev_id) = current_state.get(&key) { if let Some(event) = - StateResolution::get_or_load_event(room_id, ev_id, &mut event_map, store) + StateResolution::get_or_load_event(room_id, ev_id, event_map, store) { // TODO synapse checks `rejected_reason` is None here auth_events.insert(key.clone(), event); @@ -102,17 +97,11 @@ impl StateResolution { room_id: &RoomId, room_version: &RoomVersionId, state_sets: &[StateMap], - // TODO: make the `Option<&mut EventMap>>` - event_map: Option>>, + event_map: &mut EventMap>, store: &dyn StateStore, ) -> Result> { tracing::info!("State resolution starting"); - let mut event_map = if let Some(ev_map) = event_map { - ev_map - } else { - EventMap::new() - }; // split non-conflicting and conflicting state let (clean, conflicting) = StateResolution::separate(&state_sets); @@ -178,7 +167,7 @@ impl StateResolution { let mut sorted_control_levels = StateResolution::reverse_topological_power_sort( room_id, &control_events, - &mut event_map, + event_map, store, &all_conflicted, ); @@ -197,7 +186,7 @@ impl StateResolution { room_version, &sorted_control_levels, &clean, - &mut event_map, + event_map, store, )?; @@ -238,7 +227,7 @@ impl StateResolution { room_id, &events_to_resolve, power_event, - &mut event_map, + event_map, store, ); @@ -255,7 +244,7 @@ impl StateResolution { room_version, &sorted_left_events, &resolved_control, // The control events are added to the final resolved state - &mut event_map, + event_map, store, )?; diff --git a/tests/event_auth.rs b/tests/event_auth.rs index fb6989b7..46270ca7 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -36,7 +36,15 @@ fn test_ban_pass() { &vec![event_id("IMC")], ); - assert!(valid_membership_change(&requester, prev, None, &auth_events).unwrap()) + assert!(valid_membership_change( + requester.state_key().as_deref(), + requester.sender(), + requester.content(), + prev, + None, + &auth_events + ) + .unwrap()) } #[test] @@ -63,5 +71,13 @@ fn test_ban_fail() { &vec![event_id("IMC")], ); - assert!(!valid_membership_change(&requester, prev, None, &auth_events).unwrap()) + assert!(!valid_membership_change( + requester.state_key().as_deref(), + requester.sender(), + requester.content(), + prev, + None, + &auth_events + ) + .unwrap()) } diff --git a/tests/res_with_auth_ids.rs b/tests/res_with_auth_ids.rs index cc123208..627750e2 100644 --- a/tests/res_with_auth_ids.rs +++ b/tests/res_with_auth_ids.rs @@ -41,11 +41,17 @@ fn ban_with_auth_chains() { fn base_with_auth_chains() { let store = TestStore(INITIAL_EVENTS()); - let resolved: BTreeMap<_, EventId> = - match StateResolution::resolve(&room_id(), &RoomVersionId::Version6, &[], None, &store) { - Ok(state) => state, - Err(e) => panic!("{}", e), - }; + let mut ev_map = state_res::EventMap::default(); + let resolved: BTreeMap<_, EventId> = match StateResolution::resolve( + &room_id(), + &RoomVersionId::Version6, + &[], + &mut ev_map, + &store, + ) { + Ok(state) => state, + Err(e) => panic!("{}", e), + }; let resolved = resolved .values() @@ -105,11 +111,12 @@ fn ban_with_auth_chains2() { .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) .collect::>(); + let mut ev_map = state_res::EventMap::default(); let resolved: StateMap = match StateResolution::resolve( &room_id(), &RoomVersionId::Version6, &[state_set_a, state_set_b], - None, + &mut ev_map, &store, ) { Ok(state) => state, diff --git a/tests/state_res.rs b/tests/state_res.rs index 044d0811..f85f7eec 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -265,11 +265,12 @@ fn test_event_map_none() { // build up the DAG let (state_at_bob, state_at_charlie, expected) = store.set_up(); + let mut ev_map = state_res::EventMap::default(); let resolved = match StateResolution::resolve( &room_id(), &RoomVersionId::Version2, &[state_at_bob, state_at_charlie], - None, + &mut ev_map, &store, ) { Ok(state) => state, diff --git a/tests/utils.rs b/tests/utils.rs index bed80160..bec72356 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -114,7 +114,7 @@ pub fn do_check( &room_id(), &RoomVersionId::Version6, &state_sets, - Some(event_map.clone()), + &mut event_map, &store, ); match resolved {