state-res: Don't pass fetch_state to valid_membership_change

This commit is contained in:
Jonas Platte 2021-09-03 22:20:13 +02:00
parent 0a93780e83
commit 359a0cb125
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 165 additions and 137 deletions

View File

@ -219,13 +219,20 @@ where
return Ok(false);
}
let sender = incoming_event.sender();
let target_user =
UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?;
if !valid_membership_change(
&state_key,
incoming_event.sender(),
&target_user,
fetch_state(&EventType::RoomMember, target_user.as_str()),
sender,
fetch_state(&EventType::RoomMember, sender.as_str()),
incoming_event.content(),
prev_event,
current_third_party_invite,
fetch_state,
fetch_state(&EventType::RoomPowerLevels, ""),
fetch_state(&EventType::RoomJoinRules, ""),
)? {
return Ok(false);
}
@ -308,18 +315,18 @@ where
///
/// This is generated by calling `auth_types_for_event` with the membership event and the current
/// State.
pub fn valid_membership_change<E, F>(
state_key: &str,
user_sender: &UserId,
#[allow(clippy::too_many_arguments)]
fn valid_membership_change<E: Event>(
target_user: &UserId,
target_user_membership_event: Option<Arc<E>>,
sender: &UserId,
sender_membership_event: Option<Arc<E>>,
content: serde_json::Value,
prev_event: Option<Arc<E>>,
current_third_party_invite: Option<Arc<E>>,
fetch_state: F,
) -> Result<bool>
where
E: Event,
F: Fn(&EventType, &str) -> Option<Arc<E>>,
{
power_levels_event: Option<Arc<E>>,
join_rules_event: Option<Arc<E>>,
) -> Result<bool> {
let target_membership = serde_json::from_value::<MembershipState>(
content.get("membership").expect("we test before that this field exists").clone(),
)?;
@ -328,29 +335,30 @@ where
.get("third_party_invite")
.map(|t| serde_json::from_value::<ThirdPartyInvite>(t.clone()));
let target_user_id =
UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?;
let sender = fetch_state(&EventType::RoomMember, user_sender.as_str());
let sender = sender.as_ref();
let sender_membership = sender.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<MembershipState>(
pdu.content().get("membership").expect("we assume existing events are valid").clone(),
)?)
})?;
let current = fetch_state(&EventType::RoomMember, target_user_id.as_str());
let current = current.as_ref();
let current_membership = current.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<MembershipState>(
pdu.content().get("membership").expect("we assume existing events are valid").clone(),
)?)
})?;
let power_levels_event = fetch_state(&EventType::RoomPowerLevels, "");
let target_user_membership_event = target_user_membership_event.as_ref();
let sender_membership_event = sender_membership_event.as_ref();
let power_levels_event = power_levels_event.as_ref();
let join_rules_event = join_rules_event.as_ref();
let sender_membership =
sender_membership_event.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<MembershipState>(
pdu.content()
.get("membership")
.expect("we assume existing events are valid")
.clone(),
)?)
})?;
let target_user_current_membership =
target_user_membership_event.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<MembershipState>(
pdu.content()
.get("membership")
.expect("we assume existing events are valid")
.clone(),
)?)
})?;
let power_levels = power_levels_event.map_or_else(
|| Ok::<_, Error>(PowerLevelsEventContent::default()),
@ -360,19 +368,17 @@ where
},
)?;
let sender_power = power_levels.users.get(user_sender).map_or_else(
let sender_power = power_levels.users.get(sender).map_or_else(
|| (sender_membership == MembershipState::Join).then(|| &power_levels.users_default),
// If it's okay, wrap with Some(_)
Some,
);
let target_power = power_levels.users.get(&target_user_id).map_or_else(
let target_power = power_levels.users.get(target_user).map_or_else(
|| (target_membership == MembershipState::Join).then(|| &power_levels.users_default),
// If it's okay, wrap with Some(_)
Some,
);
let join_rules_event = fetch_state(&EventType::RoomJoinRules, "");
let join_rules_event = join_rules_event.as_ref();
let mut join_rules = JoinRule::Invite;
if let Some(jr) = join_rules_event {
join_rules = serde_json::from_value::<JoinRulesEventContent>(jr.content())?.join_rule;
@ -385,41 +391,41 @@ where
}
Ok(if target_membership == MembershipState::Join {
if user_sender != &target_user_id {
if sender != target_user {
warn!("Can't make other user join");
false
} else if let MembershipState::Ban = current_membership {
} else if let MembershipState::Ban = target_user_current_membership {
warn!(
"Banned user can't join\nCurrent user state: {:?}",
current.map(|e| e.event_id())
target_user_membership_event.map(|e| e.event_id())
);
false
} else {
let allow = join_rules == JoinRule::Invite
&& (current_membership == MembershipState::Join
|| current_membership == MembershipState::Invite)
&& (target_user_current_membership == MembershipState::Join
|| target_user_current_membership == MembershipState::Invite)
|| join_rules == JoinRule::Public;
if !allow {
warn!("Can't join if join rules is not public and user is not invited/joined\nJoin Rules: {:?}\nCurrent user state: {:?}",
join_rules_event.map(|e| e.event_id()),
current.map(|e| e.event_id()));
target_user_membership_event.map(|e| e.event_id()));
}
allow
}
} else if target_membership == MembershipState::Invite {
// If content has third_party_invite key
if let Some(Ok(tp_id)) = third_party_invite {
if current_membership == MembershipState::Ban {
if target_user_current_membership == MembershipState::Ban {
warn!(
"Can't invite banned user\nCurrent user state: {:?}",
current.map(|e| e.event_id())
target_user_membership_event.map(|e| e.event_id())
);
false
} else {
let allow = verify_third_party_invite(
Some(&target_user_id),
user_sender,
Some(target_user),
sender,
&tp_id,
current_third_party_invite,
);
@ -429,43 +435,43 @@ where
allow
}
} else if sender_membership != MembershipState::Join
|| current_membership == MembershipState::Join
|| current_membership == MembershipState::Ban
|| target_user_current_membership == MembershipState::Join
|| target_user_current_membership == MembershipState::Ban
{
warn!(
"Can't invite user if sender not joined or the user is currently joined or banned\nCurrent user state: {:?}\nSender user state: {:?}",
current.map(|e| e.event_id()),
sender.map(|e| e.event_id())
target_user_membership_event.map(|e| e.event_id()),
sender_membership_event.map(|e| e.event_id())
);
false
} else {
let allow = sender_power.filter(|&p| p >= &power_levels.invite).is_some();
if !allow {
warn!("User does not have enough power to invite\nCurrent user state: {:?}\nPower levels: {:?}",
current.map(|e| e.event_id()),
target_user_membership_event.map(|e| e.event_id()),
power_levels_event.map(|e| e.event_id())
);
}
allow
}
} else if target_membership == MembershipState::Leave {
if user_sender == &target_user_id {
let allow = current_membership == MembershipState::Join
|| current_membership == MembershipState::Invite;
if sender == target_user {
let allow = target_user_current_membership == MembershipState::Join
|| target_user_current_membership == MembershipState::Invite;
if !allow {
warn!(
"Can't leave if not invited or joined\nCurrent user state: {:?}",
current.map(|e| e.event_id()),
target_user_membership_event.map(|e| e.event_id()),
);
}
allow
} else if sender_membership != MembershipState::Join
|| current_membership == MembershipState::Ban
|| target_user_current_membership == MembershipState::Ban
&& sender_power.filter(|&p| p < &power_levels.ban).is_some()
{
warn!("Can't kick if sender not joined or user is already banned\nCurrent user state: {:?}\nSender user state: {:?}",
current.map(|e| e.event_id()),
sender.map(|e| e.event_id())
target_user_membership_event.map(|e| e.event_id()),
sender_membership_event.map(|e| e.event_id())
);
false
} else {
@ -473,7 +479,7 @@ where
&& target_power < sender_power;
if !allow {
warn!("User does not have enough power to kick\nCurrent user state: {:?}\nPower levels: {:?}",
current.map(|e| e.event_id()),
target_user_membership_event.map(|e| e.event_id()),
power_levels_event.map(|e| e.event_id())
);
}
@ -483,7 +489,7 @@ where
if sender_membership != MembershipState::Join {
warn!(
"Can't ban user if sender is not joined\nSender user state: {:?}",
sender.map(|e| e.event_id())
sender_membership_event.map(|e| e.event_id())
);
false
} else {
@ -491,7 +497,7 @@ where
&& target_power < sender_power;
if !allow {
warn!("User does not have enough power to ban\nCurrent user state: {:?}\nPower levels: {:?}",
current.map(|e| e.event_id()),
target_user_membership_event.map(|e| e.event_id()),
power_levels_event.map(|e| e.event_id())
);
}
@ -907,3 +913,99 @@ pub fn verify_third_party_invite<E: Event>(
false
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
event_auth::valid_membership_change,
test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS},
StateMap,
};
use ruma_events::EventType;
#[test]
fn test_ban_pass() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = INITIAL_EVENTS();
let prev_event =
events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
let auth_events = events
.values()
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
EventType::RoomMember,
Some(charlie().as_str()),
member_content_ban(),
&[],
&[event_id("IMC")],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = charlie();
let sender = alice();
assert!(valid_membership_change(
&target_user,
fetch_state(EventType::RoomMember, target_user.to_string()),
&sender,
fetch_state(EventType::RoomMember, sender.to_string()),
requester.content(),
prev_event,
None,
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
fetch_state(EventType::RoomJoinRules, "".to_owned()),
)
.unwrap());
}
#[test]
fn test_ban_fail() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = INITIAL_EVENTS();
let prev_event =
events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
let auth_events = events
.values()
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
EventType::RoomMember,
Some(alice().as_str()),
member_content_ban(),
&[],
&[event_id("IMC")],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = alice();
let sender = charlie();
assert!(!valid_membership_change(
&target_user,
fetch_state(EventType::RoomMember, target_user.to_string()),
&sender,
fetch_state(EventType::RoomMember, sender.to_string()),
requester.content(),
prev_event,
None,
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
fetch_state(EventType::RoomJoinRules, "".to_owned()),
)
.unwrap());
}
}

View File

@ -1,74 +0,0 @@
use std::sync::Arc;
use ruma_events::EventType;
use ruma_state_res::{
event_auth::valid_membership_change,
test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS},
StateMap,
};
#[test]
fn test_ban_pass() {
let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = INITIAL_EVENTS();
let prev = events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
let auth_events = events
.values()
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
EventType::RoomMember,
Some(charlie().as_str()),
member_content_ban(),
&[],
&[event_id("IMC")],
);
assert!(valid_membership_change(
&requester.state_key(),
requester.sender(),
requester.content(),
prev,
None,
|ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(),
)
.unwrap())
}
#[test]
fn test_ban_fail() {
let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = INITIAL_EVENTS();
let prev = events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
let auth_events = events
.values()
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
EventType::RoomMember,
Some(alice().as_str()),
member_content_ban(),
&[],
&[event_id("IMC")],
);
assert!(!valid_membership_change(
&requester.state_key(),
requester.sender(),
requester.content(),
prev,
None,
|ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(),
)
.unwrap())
}