state-res: Don't pass fetch_state to valid_membership_change
This commit is contained in:
parent
0a93780e83
commit
359a0cb125
@ -219,13 +219,20 @@ where
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let sender = incoming_event.sender();
|
||||
let target_user =
|
||||
UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?;
|
||||
|
||||
if !valid_membership_change(
|
||||
&state_key,
|
||||
incoming_event.sender(),
|
||||
&target_user,
|
||||
fetch_state(&EventType::RoomMember, target_user.as_str()),
|
||||
sender,
|
||||
fetch_state(&EventType::RoomMember, sender.as_str()),
|
||||
incoming_event.content(),
|
||||
prev_event,
|
||||
current_third_party_invite,
|
||||
fetch_state,
|
||||
fetch_state(&EventType::RoomPowerLevels, ""),
|
||||
fetch_state(&EventType::RoomJoinRules, ""),
|
||||
)? {
|
||||
return Ok(false);
|
||||
}
|
||||
@ -308,18 +315,18 @@ where
|
||||
///
|
||||
/// This is generated by calling `auth_types_for_event` with the membership event and the current
|
||||
/// State.
|
||||
pub fn valid_membership_change<E, F>(
|
||||
state_key: &str,
|
||||
user_sender: &UserId,
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn valid_membership_change<E: Event>(
|
||||
target_user: &UserId,
|
||||
target_user_membership_event: Option<Arc<E>>,
|
||||
sender: &UserId,
|
||||
sender_membership_event: Option<Arc<E>>,
|
||||
content: serde_json::Value,
|
||||
prev_event: Option<Arc<E>>,
|
||||
current_third_party_invite: Option<Arc<E>>,
|
||||
fetch_state: F,
|
||||
) -> Result<bool>
|
||||
where
|
||||
E: Event,
|
||||
F: Fn(&EventType, &str) -> Option<Arc<E>>,
|
||||
{
|
||||
power_levels_event: Option<Arc<E>>,
|
||||
join_rules_event: Option<Arc<E>>,
|
||||
) -> Result<bool> {
|
||||
let target_membership = serde_json::from_value::<MembershipState>(
|
||||
content.get("membership").expect("we test before that this field exists").clone(),
|
||||
)?;
|
||||
@ -328,29 +335,30 @@ where
|
||||
.get("third_party_invite")
|
||||
.map(|t| serde_json::from_value::<ThirdPartyInvite>(t.clone()));
|
||||
|
||||
let target_user_id =
|
||||
UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?;
|
||||
|
||||
let sender = fetch_state(&EventType::RoomMember, user_sender.as_str());
|
||||
let sender = sender.as_ref();
|
||||
|
||||
let sender_membership = sender.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
|
||||
Ok(serde_json::from_value::<MembershipState>(
|
||||
pdu.content().get("membership").expect("we assume existing events are valid").clone(),
|
||||
)?)
|
||||
})?;
|
||||
|
||||
let current = fetch_state(&EventType::RoomMember, target_user_id.as_str());
|
||||
let current = current.as_ref();
|
||||
|
||||
let current_membership = current.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
|
||||
Ok(serde_json::from_value::<MembershipState>(
|
||||
pdu.content().get("membership").expect("we assume existing events are valid").clone(),
|
||||
)?)
|
||||
})?;
|
||||
|
||||
let power_levels_event = fetch_state(&EventType::RoomPowerLevels, "");
|
||||
let target_user_membership_event = target_user_membership_event.as_ref();
|
||||
let sender_membership_event = sender_membership_event.as_ref();
|
||||
let power_levels_event = power_levels_event.as_ref();
|
||||
let join_rules_event = join_rules_event.as_ref();
|
||||
|
||||
let sender_membership =
|
||||
sender_membership_event.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
|
||||
Ok(serde_json::from_value::<MembershipState>(
|
||||
pdu.content()
|
||||
.get("membership")
|
||||
.expect("we assume existing events are valid")
|
||||
.clone(),
|
||||
)?)
|
||||
})?;
|
||||
|
||||
let target_user_current_membership =
|
||||
target_user_membership_event.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
|
||||
Ok(serde_json::from_value::<MembershipState>(
|
||||
pdu.content()
|
||||
.get("membership")
|
||||
.expect("we assume existing events are valid")
|
||||
.clone(),
|
||||
)?)
|
||||
})?;
|
||||
|
||||
let power_levels = power_levels_event.map_or_else(
|
||||
|| Ok::<_, Error>(PowerLevelsEventContent::default()),
|
||||
@ -360,19 +368,17 @@ where
|
||||
},
|
||||
)?;
|
||||
|
||||
let sender_power = power_levels.users.get(user_sender).map_or_else(
|
||||
let sender_power = power_levels.users.get(sender).map_or_else(
|
||||
|| (sender_membership == MembershipState::Join).then(|| &power_levels.users_default),
|
||||
// If it's okay, wrap with Some(_)
|
||||
Some,
|
||||
);
|
||||
let target_power = power_levels.users.get(&target_user_id).map_or_else(
|
||||
let target_power = power_levels.users.get(target_user).map_or_else(
|
||||
|| (target_membership == MembershipState::Join).then(|| &power_levels.users_default),
|
||||
// If it's okay, wrap with Some(_)
|
||||
Some,
|
||||
);
|
||||
|
||||
let join_rules_event = fetch_state(&EventType::RoomJoinRules, "");
|
||||
let join_rules_event = join_rules_event.as_ref();
|
||||
let mut join_rules = JoinRule::Invite;
|
||||
if let Some(jr) = join_rules_event {
|
||||
join_rules = serde_json::from_value::<JoinRulesEventContent>(jr.content())?.join_rule;
|
||||
@ -385,41 +391,41 @@ where
|
||||
}
|
||||
|
||||
Ok(if target_membership == MembershipState::Join {
|
||||
if user_sender != &target_user_id {
|
||||
if sender != target_user {
|
||||
warn!("Can't make other user join");
|
||||
false
|
||||
} else if let MembershipState::Ban = current_membership {
|
||||
} else if let MembershipState::Ban = target_user_current_membership {
|
||||
warn!(
|
||||
"Banned user can't join\nCurrent user state: {:?}",
|
||||
current.map(|e| e.event_id())
|
||||
target_user_membership_event.map(|e| e.event_id())
|
||||
);
|
||||
false
|
||||
} else {
|
||||
let allow = join_rules == JoinRule::Invite
|
||||
&& (current_membership == MembershipState::Join
|
||||
|| current_membership == MembershipState::Invite)
|
||||
&& (target_user_current_membership == MembershipState::Join
|
||||
|| target_user_current_membership == MembershipState::Invite)
|
||||
|| join_rules == JoinRule::Public;
|
||||
|
||||
if !allow {
|
||||
warn!("Can't join if join rules is not public and user is not invited/joined\nJoin Rules: {:?}\nCurrent user state: {:?}",
|
||||
join_rules_event.map(|e| e.event_id()),
|
||||
current.map(|e| e.event_id()));
|
||||
target_user_membership_event.map(|e| e.event_id()));
|
||||
}
|
||||
allow
|
||||
}
|
||||
} else if target_membership == MembershipState::Invite {
|
||||
// If content has third_party_invite key
|
||||
if let Some(Ok(tp_id)) = third_party_invite {
|
||||
if current_membership == MembershipState::Ban {
|
||||
if target_user_current_membership == MembershipState::Ban {
|
||||
warn!(
|
||||
"Can't invite banned user\nCurrent user state: {:?}",
|
||||
current.map(|e| e.event_id())
|
||||
target_user_membership_event.map(|e| e.event_id())
|
||||
);
|
||||
false
|
||||
} else {
|
||||
let allow = verify_third_party_invite(
|
||||
Some(&target_user_id),
|
||||
user_sender,
|
||||
Some(target_user),
|
||||
sender,
|
||||
&tp_id,
|
||||
current_third_party_invite,
|
||||
);
|
||||
@ -429,43 +435,43 @@ where
|
||||
allow
|
||||
}
|
||||
} else if sender_membership != MembershipState::Join
|
||||
|| current_membership == MembershipState::Join
|
||||
|| current_membership == MembershipState::Ban
|
||||
|| target_user_current_membership == MembershipState::Join
|
||||
|| target_user_current_membership == MembershipState::Ban
|
||||
{
|
||||
warn!(
|
||||
"Can't invite user if sender not joined or the user is currently joined or banned\nCurrent user state: {:?}\nSender user state: {:?}",
|
||||
current.map(|e| e.event_id()),
|
||||
sender.map(|e| e.event_id())
|
||||
target_user_membership_event.map(|e| e.event_id()),
|
||||
sender_membership_event.map(|e| e.event_id())
|
||||
);
|
||||
false
|
||||
} else {
|
||||
let allow = sender_power.filter(|&p| p >= &power_levels.invite).is_some();
|
||||
if !allow {
|
||||
warn!("User does not have enough power to invite\nCurrent user state: {:?}\nPower levels: {:?}",
|
||||
current.map(|e| e.event_id()),
|
||||
target_user_membership_event.map(|e| e.event_id()),
|
||||
power_levels_event.map(|e| e.event_id())
|
||||
);
|
||||
}
|
||||
allow
|
||||
}
|
||||
} else if target_membership == MembershipState::Leave {
|
||||
if user_sender == &target_user_id {
|
||||
let allow = current_membership == MembershipState::Join
|
||||
|| current_membership == MembershipState::Invite;
|
||||
if sender == target_user {
|
||||
let allow = target_user_current_membership == MembershipState::Join
|
||||
|| target_user_current_membership == MembershipState::Invite;
|
||||
if !allow {
|
||||
warn!(
|
||||
"Can't leave if not invited or joined\nCurrent user state: {:?}",
|
||||
current.map(|e| e.event_id()),
|
||||
target_user_membership_event.map(|e| e.event_id()),
|
||||
);
|
||||
}
|
||||
allow
|
||||
} else if sender_membership != MembershipState::Join
|
||||
|| current_membership == MembershipState::Ban
|
||||
|| target_user_current_membership == MembershipState::Ban
|
||||
&& sender_power.filter(|&p| p < &power_levels.ban).is_some()
|
||||
{
|
||||
warn!("Can't kick if sender not joined or user is already banned\nCurrent user state: {:?}\nSender user state: {:?}",
|
||||
current.map(|e| e.event_id()),
|
||||
sender.map(|e| e.event_id())
|
||||
target_user_membership_event.map(|e| e.event_id()),
|
||||
sender_membership_event.map(|e| e.event_id())
|
||||
);
|
||||
false
|
||||
} else {
|
||||
@ -473,7 +479,7 @@ where
|
||||
&& target_power < sender_power;
|
||||
if !allow {
|
||||
warn!("User does not have enough power to kick\nCurrent user state: {:?}\nPower levels: {:?}",
|
||||
current.map(|e| e.event_id()),
|
||||
target_user_membership_event.map(|e| e.event_id()),
|
||||
power_levels_event.map(|e| e.event_id())
|
||||
);
|
||||
}
|
||||
@ -483,7 +489,7 @@ where
|
||||
if sender_membership != MembershipState::Join {
|
||||
warn!(
|
||||
"Can't ban user if sender is not joined\nSender user state: {:?}",
|
||||
sender.map(|e| e.event_id())
|
||||
sender_membership_event.map(|e| e.event_id())
|
||||
);
|
||||
false
|
||||
} else {
|
||||
@ -491,7 +497,7 @@ where
|
||||
&& target_power < sender_power;
|
||||
if !allow {
|
||||
warn!("User does not have enough power to ban\nCurrent user state: {:?}\nPower levels: {:?}",
|
||||
current.map(|e| e.event_id()),
|
||||
target_user_membership_event.map(|e| e.event_id()),
|
||||
power_levels_event.map(|e| e.event_id())
|
||||
);
|
||||
}
|
||||
@ -907,3 +913,99 @@ pub fn verify_third_party_invite<E: Event>(
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
event_auth::valid_membership_change,
|
||||
test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS},
|
||||
StateMap,
|
||||
};
|
||||
use ruma_events::EventType;
|
||||
|
||||
#[test]
|
||||
fn test_ban_pass() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let prev_event =
|
||||
events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
alice(),
|
||||
EventType::RoomMember,
|
||||
Some(charlie().as_str()),
|
||||
member_content_ban(),
|
||||
&[],
|
||||
&[event_id("IMC")],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = charlie();
|
||||
let sender = alice();
|
||||
|
||||
assert!(valid_membership_change(
|
||||
&target_user,
|
||||
fetch_state(EventType::RoomMember, target_user.to_string()),
|
||||
&sender,
|
||||
fetch_state(EventType::RoomMember, sender.to_string()),
|
||||
requester.content(),
|
||||
prev_event,
|
||||
None,
|
||||
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
|
||||
fetch_state(EventType::RoomJoinRules, "".to_owned()),
|
||||
)
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ban_fail() {
|
||||
let _ =
|
||||
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let prev_event =
|
||||
events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
charlie(),
|
||||
EventType::RoomMember,
|
||||
Some(alice().as_str()),
|
||||
member_content_ban(),
|
||||
&[],
|
||||
&[event_id("IMC")],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = alice();
|
||||
let sender = charlie();
|
||||
|
||||
assert!(!valid_membership_change(
|
||||
&target_user,
|
||||
fetch_state(EventType::RoomMember, target_user.to_string()),
|
||||
&sender,
|
||||
fetch_state(EventType::RoomMember, sender.to_string()),
|
||||
requester.content(),
|
||||
prev_event,
|
||||
None,
|
||||
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
|
||||
fetch_state(EventType::RoomJoinRules, "".to_owned()),
|
||||
)
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
@ -1,74 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruma_events::EventType;
|
||||
use ruma_state_res::{
|
||||
event_auth::valid_membership_change,
|
||||
test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS},
|
||||
StateMap,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_ban_pass() {
|
||||
let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let prev = events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
alice(),
|
||||
EventType::RoomMember,
|
||||
Some(charlie().as_str()),
|
||||
member_content_ban(),
|
||||
&[],
|
||||
&[event_id("IMC")],
|
||||
);
|
||||
|
||||
assert!(valid_membership_change(
|
||||
&requester.state_key(),
|
||||
requester.sender(),
|
||||
requester.content(),
|
||||
prev,
|
||||
None,
|
||||
|ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ban_fail() {
|
||||
let _ = tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let prev = events.values().find(|ev| ev.event_id().as_str().contains("IMC")).map(Arc::clone);
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| ((ev.event_type(), ev.state_key()), Arc::clone(ev)))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
charlie(),
|
||||
EventType::RoomMember,
|
||||
Some(alice().as_str()),
|
||||
member_content_ban(),
|
||||
&[],
|
||||
&[event_id("IMC")],
|
||||
);
|
||||
|
||||
assert!(!valid_membership_change(
|
||||
&requester.state_key(),
|
||||
requester.sender(),
|
||||
requester.content(),
|
||||
prev,
|
||||
None,
|
||||
|ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user