From 359a0cb125f8272a4d45f0a19d79252755814aff Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 3 Sep 2021 22:20:13 +0200 Subject: [PATCH] state-res: Don't pass fetch_state to valid_membership_change --- crates/ruma-state-res/src/event_auth.rs | 228 ++++++++++++++++------ crates/ruma-state-res/tests/event_auth.rs | 74 ------- 2 files changed, 165 insertions(+), 137 deletions(-) delete mode 100644 crates/ruma-state-res/tests/event_auth.rs diff --git a/crates/ruma-state-res/src/event_auth.rs b/crates/ruma-state-res/src/event_auth.rs index db294ddc..f8f1bc3b 100644 --- a/crates/ruma-state-res/src/event_auth.rs +++ b/crates/ruma-state-res/src/event_auth.rs @@ -219,13 +219,20 @@ where return Ok(false); } + let sender = incoming_event.sender(); + let target_user = + UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?; + if !valid_membership_change( - &state_key, - incoming_event.sender(), + &target_user, + fetch_state(&EventType::RoomMember, target_user.as_str()), + sender, + fetch_state(&EventType::RoomMember, sender.as_str()), incoming_event.content(), prev_event, current_third_party_invite, - fetch_state, + fetch_state(&EventType::RoomPowerLevels, ""), + fetch_state(&EventType::RoomJoinRules, ""), )? { return Ok(false); } @@ -308,18 +315,18 @@ where /// /// This is generated by calling `auth_types_for_event` with the membership event and the current /// State. -pub fn valid_membership_change( - state_key: &str, - user_sender: &UserId, +#[allow(clippy::too_many_arguments)] +fn valid_membership_change( + target_user: &UserId, + target_user_membership_event: Option>, + sender: &UserId, + sender_membership_event: Option>, content: serde_json::Value, prev_event: Option>, current_third_party_invite: Option>, - fetch_state: F, -) -> Result -where - E: Event, - F: Fn(&EventType, &str) -> Option>, -{ + power_levels_event: Option>, + join_rules_event: Option>, +) -> Result { let target_membership = serde_json::from_value::( content.get("membership").expect("we test before that this field exists").clone(), )?; @@ -328,29 +335,30 @@ where .get("third_party_invite") .map(|t| serde_json::from_value::(t.clone())); - let target_user_id = - UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?; - - let sender = fetch_state(&EventType::RoomMember, user_sender.as_str()); - let sender = sender.as_ref(); - - let sender_membership = sender.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { - Ok(serde_json::from_value::( - pdu.content().get("membership").expect("we assume existing events are valid").clone(), - )?) - })?; - - let current = fetch_state(&EventType::RoomMember, target_user_id.as_str()); - let current = current.as_ref(); - - let current_membership = current.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { - Ok(serde_json::from_value::( - pdu.content().get("membership").expect("we assume existing events are valid").clone(), - )?) - })?; - - let power_levels_event = fetch_state(&EventType::RoomPowerLevels, ""); + let target_user_membership_event = target_user_membership_event.as_ref(); + let sender_membership_event = sender_membership_event.as_ref(); let power_levels_event = power_levels_event.as_ref(); + let join_rules_event = join_rules_event.as_ref(); + + let sender_membership = + sender_membership_event.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { + Ok(serde_json::from_value::( + pdu.content() + .get("membership") + .expect("we assume existing events are valid") + .clone(), + )?) + })?; + + let target_user_current_membership = + target_user_membership_event.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { + Ok(serde_json::from_value::( + pdu.content() + .get("membership") + .expect("we assume existing events are valid") + .clone(), + )?) + })?; let power_levels = power_levels_event.map_or_else( || Ok::<_, Error>(PowerLevelsEventContent::default()), @@ -360,19 +368,17 @@ where }, )?; - let sender_power = power_levels.users.get(user_sender).map_or_else( + let sender_power = power_levels.users.get(sender).map_or_else( || (sender_membership == MembershipState::Join).then(|| &power_levels.users_default), // If it's okay, wrap with Some(_) Some, ); - let target_power = power_levels.users.get(&target_user_id).map_or_else( + let target_power = power_levels.users.get(target_user).map_or_else( || (target_membership == MembershipState::Join).then(|| &power_levels.users_default), // If it's okay, wrap with Some(_) Some, ); - let join_rules_event = fetch_state(&EventType::RoomJoinRules, ""); - let join_rules_event = join_rules_event.as_ref(); let mut join_rules = JoinRule::Invite; if let Some(jr) = join_rules_event { join_rules = serde_json::from_value::(jr.content())?.join_rule; @@ -385,41 +391,41 @@ where } Ok(if target_membership == MembershipState::Join { - if user_sender != &target_user_id { + if sender != target_user { warn!("Can't make other user join"); false - } else if let MembershipState::Ban = current_membership { + } else if let MembershipState::Ban = target_user_current_membership { warn!( "Banned user can't join\nCurrent user state: {:?}", - current.map(|e| e.event_id()) + target_user_membership_event.map(|e| e.event_id()) ); false } else { let allow = join_rules == JoinRule::Invite - && (current_membership == MembershipState::Join - || current_membership == MembershipState::Invite) + && (target_user_current_membership == MembershipState::Join + || target_user_current_membership == MembershipState::Invite) || join_rules == JoinRule::Public; if !allow { warn!("Can't join if join rules is not public and user is not invited/joined\nJoin Rules: {:?}\nCurrent user state: {:?}", join_rules_event.map(|e| e.event_id()), - current.map(|e| e.event_id())); + target_user_membership_event.map(|e| e.event_id())); } allow } } else if target_membership == MembershipState::Invite { // If content has third_party_invite key if let Some(Ok(tp_id)) = third_party_invite { - if current_membership == MembershipState::Ban { + if target_user_current_membership == MembershipState::Ban { warn!( "Can't invite banned user\nCurrent user state: {:?}", - current.map(|e| e.event_id()) + target_user_membership_event.map(|e| e.event_id()) ); false } else { let allow = verify_third_party_invite( - Some(&target_user_id), - user_sender, + Some(target_user), + sender, &tp_id, current_third_party_invite, ); @@ -429,43 +435,43 @@ where allow } } else if sender_membership != MembershipState::Join - || current_membership == MembershipState::Join - || current_membership == MembershipState::Ban + || target_user_current_membership == MembershipState::Join + || target_user_current_membership == MembershipState::Ban { warn!( "Can't invite user if sender not joined or the user is currently joined or banned\nCurrent user state: {:?}\nSender user state: {:?}", - current.map(|e| e.event_id()), - sender.map(|e| e.event_id()) + target_user_membership_event.map(|e| e.event_id()), + sender_membership_event.map(|e| e.event_id()) ); false } else { let allow = sender_power.filter(|&p| p >= &power_levels.invite).is_some(); if !allow { warn!("User does not have enough power to invite\nCurrent user state: {:?}\nPower levels: {:?}", - current.map(|e| e.event_id()), + target_user_membership_event.map(|e| e.event_id()), power_levels_event.map(|e| e.event_id()) ); } allow } } else if target_membership == MembershipState::Leave { - if user_sender == &target_user_id { - let allow = current_membership == MembershipState::Join - || current_membership == MembershipState::Invite; + if sender == target_user { + let allow = target_user_current_membership == MembershipState::Join + || target_user_current_membership == MembershipState::Invite; if !allow { warn!( "Can't leave if not invited or joined\nCurrent user state: {:?}", - current.map(|e| e.event_id()), + target_user_membership_event.map(|e| e.event_id()), ); } allow } else if sender_membership != MembershipState::Join - || current_membership == MembershipState::Ban + || target_user_current_membership == MembershipState::Ban && sender_power.filter(|&p| p < &power_levels.ban).is_some() { warn!("Can't kick if sender not joined or user is already banned\nCurrent user state: {:?}\nSender user state: {:?}", - current.map(|e| e.event_id()), - sender.map(|e| e.event_id()) + target_user_membership_event.map(|e| e.event_id()), + sender_membership_event.map(|e| e.event_id()) ); false } else { @@ -473,7 +479,7 @@ where && target_power < sender_power; if !allow { warn!("User does not have enough power to kick\nCurrent user state: {:?}\nPower levels: {:?}", - current.map(|e| e.event_id()), + target_user_membership_event.map(|e| e.event_id()), power_levels_event.map(|e| e.event_id()) ); } @@ -483,7 +489,7 @@ where if sender_membership != MembershipState::Join { warn!( "Can't ban user if sender is not joined\nSender user state: {:?}", - sender.map(|e| e.event_id()) + sender_membership_event.map(|e| e.event_id()) ); false } else { @@ -491,7 +497,7 @@ where && target_power < sender_power; if !allow { warn!("User does not have enough power to ban\nCurrent user state: {:?}\nPower levels: {:?}", - current.map(|e| e.event_id()), + target_user_membership_event.map(|e| e.event_id()), power_levels_event.map(|e| e.event_id()) ); } @@ -907,3 +913,99 @@ pub fn verify_third_party_invite( false } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + event_auth::valid_membership_change, + test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS}, + StateMap, + }; + use ruma_events::EventType; + + #[test] + fn test_ban_pass() { + let _ = + tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); + let events = INITIAL_EVENTS(); + + let prev_event = + events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone); + + let auth_events = events + .values() + .map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + alice(), + EventType::RoomMember, + Some(charlie().as_str()), + member_content_ban(), + &[], + &[event_id("IMC")], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = charlie(); + let sender = alice(); + + assert!(valid_membership_change( + &target_user, + fetch_state(EventType::RoomMember, target_user.to_string()), + &sender, + fetch_state(EventType::RoomMember, sender.to_string()), + requester.content(), + prev_event, + None, + fetch_state(EventType::RoomPowerLevels, "".to_owned()), + fetch_state(EventType::RoomJoinRules, "".to_owned()), + ) + .unwrap()); + } + + #[test] + fn test_ban_fail() { + let _ = + tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); + let events = INITIAL_EVENTS(); + + let prev_event = + events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone); + + let auth_events = events + .values() + .map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + charlie(), + EventType::RoomMember, + Some(alice().as_str()), + member_content_ban(), + &[], + &[event_id("IMC")], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = alice(); + let sender = charlie(); + + assert!(!valid_membership_change( + &target_user, + fetch_state(EventType::RoomMember, target_user.to_string()), + &sender, + fetch_state(EventType::RoomMember, sender.to_string()), + requester.content(), + prev_event, + None, + fetch_state(EventType::RoomPowerLevels, "".to_owned()), + fetch_state(EventType::RoomJoinRules, "".to_owned()), + ) + .unwrap()); + } +} diff --git a/crates/ruma-state-res/tests/event_auth.rs b/crates/ruma-state-res/tests/event_auth.rs deleted file mode 100644 index 4d1b44c6..00000000 --- a/crates/ruma-state-res/tests/event_auth.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::sync::Arc; - -use ruma_events::EventType; -use ruma_state_res::{ - event_auth::valid_membership_change, - test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS}, - StateMap, -}; - -#[test] -fn test_ban_pass() { - let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); - let events = INITIAL_EVENTS(); - - let prev = events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone); - - let auth_events = events - .values() - .map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev))) - .collect::>(); - - let requester = to_pdu_event( - "HELLO", - alice(), - EventType::RoomMember, - Some(charlie().as_str()), - member_content_ban(), - &[], - &[event_id("IMC")], - ); - - assert!(valid_membership_change( - &requester.state_key(), - requester.sender(), - requester.content(), - prev, - None, - |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), - ) - .unwrap()) -} - -#[test] -fn test_ban_fail() { - let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish()); - let events = INITIAL_EVENTS(); - - let prev = events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone); - - let auth_events = events - .values() - .map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev))) - .collect::>(); - - let requester = to_pdu_event( - "HELLO", - charlie(), - EventType::RoomMember, - Some(alice().as_str()), - member_content_ban(), - &[], - &[event_id("IMC")], - ); - - assert!(!valid_membership_change( - &requester.state_key(), - requester.sender(), - requester.content(), - prev, - None, - |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), - ) - .unwrap()) -}