state-res: Make functions more general
Don't require different parameters to use the same impl Event.
This commit is contained in:
parent
0999e420ae
commit
a6a1224652
@ -84,17 +84,13 @@ pub fn auth_types_for_event(
|
||||
/// ## Returns
|
||||
///
|
||||
/// This returns an `Error` only when serialization fails or some other fatal outcome.
|
||||
pub fn auth_check<E, F>(
|
||||
pub fn auth_check<E: Event>(
|
||||
room_version: &RoomVersion,
|
||||
incoming_event: &E,
|
||||
prev_event: Option<&E>,
|
||||
current_third_party_invite: Option<&E>,
|
||||
fetch_state: F,
|
||||
) -> Result<bool>
|
||||
where
|
||||
E: Event,
|
||||
F: Fn(&EventType, &str) -> Option<E>,
|
||||
{
|
||||
incoming_event: impl Event,
|
||||
prev_event: Option<impl Event>,
|
||||
current_third_party_invite: Option<impl Event>,
|
||||
fetch_state: impl Fn(&EventType, &str) -> Option<E>,
|
||||
) -> Result<bool> {
|
||||
info!(
|
||||
"auth_check beginning for {} ({})",
|
||||
incoming_event.event_id(),
|
||||
@ -311,7 +307,7 @@ where
|
||||
|
||||
// If the event type's required power level is greater than the sender's power level, reject
|
||||
// If the event has a state_key that starts with an @ and does not match the sender, reject.
|
||||
if !can_send_event(incoming_event, power_levels_event.as_ref(), sender_power_level) {
|
||||
if !can_send_event(&incoming_event, power_levels_event.as_ref(), sender_power_level) {
|
||||
warn!("user cannot send event");
|
||||
return Ok(false);
|
||||
}
|
||||
@ -321,7 +317,7 @@ where
|
||||
|
||||
if let Some(required_pwr_lvl) = check_power_levels(
|
||||
room_version,
|
||||
incoming_event,
|
||||
&incoming_event,
|
||||
power_levels_event.as_ref(),
|
||||
sender_power_level,
|
||||
) {
|
||||
@ -378,16 +374,16 @@ where
|
||||
/// This is generated by calling `auth_types_for_event` with the membership event and the current
|
||||
/// State.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn valid_membership_change<E: Event>(
|
||||
fn valid_membership_change(
|
||||
target_user: &UserId,
|
||||
target_user_membership_event: Option<&E>,
|
||||
target_user_membership_event: Option<impl Event>,
|
||||
sender: &UserId,
|
||||
sender_membership_event: Option<&E>,
|
||||
sender_membership_event: Option<impl Event>,
|
||||
content: &serde_json::Value,
|
||||
prev_event: Option<&E>,
|
||||
current_third_party_invite: Option<&E>,
|
||||
power_levels_event: Option<&E>,
|
||||
join_rules_event: Option<&E>,
|
||||
prev_event: Option<impl Event>,
|
||||
current_third_party_invite: Option<impl Event>,
|
||||
power_levels_event: Option<impl Event>,
|
||||
join_rules_event: Option<impl Event>,
|
||||
) -> Result<bool> {
|
||||
let target_membership = serde_json::from_value::<MembershipState>(
|
||||
content.get("membership").expect("we test before that this field exists").clone(),
|
||||
@ -572,7 +568,7 @@ fn valid_membership_change<E: Event>(
|
||||
/// Is the user allowed to send a specific event based on the rooms power levels.
|
||||
///
|
||||
/// Does the event have the correct userId as its state_key if it's not the "" state_key.
|
||||
fn can_send_event<E: Event>(event: &E, ple: Option<&E>, user_level: Int) -> bool {
|
||||
fn can_send_event(event: impl Event, ple: Option<impl Event>, user_level: Int) -> bool {
|
||||
let event_type_power_level = get_send_level(event.event_type(), event.state_key(), ple);
|
||||
|
||||
debug!("{} ev_type {} usr {}", event.event_id(), event_type_power_level, user_level);
|
||||
@ -591,15 +587,12 @@ fn can_send_event<E: Event>(event: &E, ple: Option<&E>, user_level: Int) -> bool
|
||||
}
|
||||
|
||||
/// Confirm that the event sender has the required power levels.
|
||||
fn check_power_levels<E>(
|
||||
fn check_power_levels(
|
||||
room_version: &RoomVersion,
|
||||
power_event: &E,
|
||||
previous_power_event: Option<&E>,
|
||||
power_event: impl Event,
|
||||
previous_power_event: Option<impl Event>,
|
||||
user_level: Int,
|
||||
) -> Option<bool>
|
||||
where
|
||||
E: Event,
|
||||
{
|
||||
) -> Option<bool> {
|
||||
match power_event.state_key() {
|
||||
Some("") => {}
|
||||
Some(key) => {
|
||||
@ -746,9 +739,9 @@ fn get_deserialize_levels(
|
||||
}
|
||||
|
||||
/// Does the event redacting come from a user with enough power to redact the given event.
|
||||
fn check_redaction<E: Event>(
|
||||
fn check_redaction(
|
||||
_room_version: &RoomVersion,
|
||||
redaction_event: &E,
|
||||
redaction_event: impl Event,
|
||||
user_level: Int,
|
||||
redact_level: Int,
|
||||
) -> Result<bool> {
|
||||
@ -771,10 +764,10 @@ fn check_redaction<E: Event>(
|
||||
|
||||
/// Helper function to fetch the power level needed to send an event of type
|
||||
/// `e_type` based on the rooms "m.room.power_level" event.
|
||||
fn get_send_level<E: Event>(
|
||||
fn get_send_level(
|
||||
e_type: &EventType,
|
||||
state_key: Option<&str>,
|
||||
power_lvl: Option<&E>,
|
||||
power_lvl: Option<impl Event>,
|
||||
) -> Int {
|
||||
power_lvl
|
||||
.and_then(|ple| {
|
||||
@ -793,11 +786,11 @@ fn get_send_level<E: Event>(
|
||||
.unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) })
|
||||
}
|
||||
|
||||
fn verify_third_party_invite<E: Event>(
|
||||
fn verify_third_party_invite(
|
||||
target_user: Option<&UserId>,
|
||||
sender: &UserId,
|
||||
tp_id: &ThirdPartyInvite,
|
||||
current_third_party_invite: Option<&E>,
|
||||
current_third_party_invite: Option<impl Event>,
|
||||
) -> bool {
|
||||
// 1. Check for user being banned happens before this is called
|
||||
// checking for mxid and token keys is done by ruma when deserializing
|
||||
@ -845,7 +838,9 @@ mod tests {
|
||||
|
||||
use crate::{
|
||||
event_auth::valid_membership_change,
|
||||
test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS},
|
||||
test_utils::{
|
||||
alice, charlie, event_id, member_content_ban, to_pdu_event, StateEvent, INITIAL_EVENTS,
|
||||
},
|
||||
Event, StateMap,
|
||||
};
|
||||
use ruma_events::EventType;
|
||||
@ -882,14 +877,14 @@ mod tests {
|
||||
|
||||
assert!(valid_membership_change(
|
||||
&target_user,
|
||||
fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(),
|
||||
fetch_state(EventType::RoomMember, target_user.to_string()),
|
||||
&sender,
|
||||
fetch_state(EventType::RoomMember, sender.to_string()).as_deref(),
|
||||
fetch_state(EventType::RoomMember, sender.to_string()),
|
||||
requester.content(),
|
||||
prev_event.as_deref(),
|
||||
None,
|
||||
fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(),
|
||||
fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(),
|
||||
prev_event,
|
||||
None::<StateEvent>,
|
||||
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
|
||||
fetch_state(EventType::RoomJoinRules, "".to_owned()),
|
||||
)
|
||||
.unwrap());
|
||||
}
|
||||
@ -926,14 +921,14 @@ mod tests {
|
||||
|
||||
assert!(!valid_membership_change(
|
||||
&target_user,
|
||||
fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(),
|
||||
fetch_state(EventType::RoomMember, target_user.to_string()),
|
||||
&sender,
|
||||
fetch_state(EventType::RoomMember, sender.to_string()).as_deref(),
|
||||
fetch_state(EventType::RoomMember, sender.to_string()),
|
||||
requester.content(),
|
||||
prev_event.as_deref(),
|
||||
None,
|
||||
fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(),
|
||||
fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(),
|
||||
prev_event,
|
||||
None::<StateEvent>,
|
||||
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
|
||||
fetch_state(EventType::RoomJoinRules, "".to_owned()),
|
||||
)
|
||||
.unwrap());
|
||||
}
|
||||
|
@ -51,15 +51,14 @@ type EventMap<T> = HashMap<EventId, T>;
|
||||
///
|
||||
/// The caller of `resolve` must ensure that all the events are from the same room. Although this
|
||||
/// function takes a `RoomId` it does not check that each event is part of the same room.
|
||||
pub fn resolve<'a, E, F, SSI>(
|
||||
pub fn resolve<'a, E, SSI>(
|
||||
room_version: &RoomVersionId,
|
||||
state_sets: impl IntoIterator<IntoIter = SSI>,
|
||||
auth_chain_sets: Vec<HashSet<EventId>>,
|
||||
fetch_event: F,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<StateMap<EventId>>
|
||||
where
|
||||
E: Event + Clone,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
SSI: Iterator<Item = &'a StateMap<EventId>> + Clone,
|
||||
{
|
||||
info!("State resolution starting");
|
||||
@ -203,15 +202,11 @@ fn get_auth_chain_diff(auth_chain_sets: Vec<HashSet<EventId>>) -> impl Iterator<
|
||||
///
|
||||
/// The power level is negative because a higher power level is equated to an earlier (further back
|
||||
/// in time) origin server timestamp.
|
||||
fn reverse_topological_power_sort<E, F>(
|
||||
fn reverse_topological_power_sort<E: Event>(
|
||||
events_to_sort: Vec<EventId>,
|
||||
auth_diff: &HashSet<EventId>,
|
||||
fetch_event: F,
|
||||
) -> Result<Vec<EventId>>
|
||||
where
|
||||
E: Event,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
{
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<Vec<EventId>> {
|
||||
debug!("reverse topological sort of power events");
|
||||
|
||||
let mut graph = HashMap::new();
|
||||
@ -320,11 +315,10 @@ where
|
||||
}
|
||||
|
||||
/// Find the power level for the sender of `event_id` or return a default value of zero.
|
||||
fn get_power_level_for_sender<E, F>(event_id: &EventId, fetch_event: F) -> i64
|
||||
where
|
||||
E: Event,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
{
|
||||
fn get_power_level_for_sender<E: Event>(
|
||||
event_id: &EventId,
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> i64 {
|
||||
info!("fetch event ({}) senders power level", event_id);
|
||||
|
||||
let event = fetch_event(event_id);
|
||||
@ -367,16 +361,12 @@ where
|
||||
///
|
||||
/// For each `events_to_check` event we gather the events needed to auth it from the the
|
||||
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
|
||||
fn iterative_auth_check<E, F>(
|
||||
fn iterative_auth_check<E: Event + Clone>(
|
||||
room_version: &RoomVersion,
|
||||
events_to_check: &[EventId],
|
||||
unconflicted_state: StateMap<EventId>,
|
||||
fetch_event: F,
|
||||
) -> Result<StateMap<EventId>>
|
||||
where
|
||||
E: Event + Clone,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
{
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<StateMap<EventId>> {
|
||||
info!("starting iterative auth check");
|
||||
|
||||
debug!("performing auth checks on {:?}", events_to_check);
|
||||
@ -468,15 +458,11 @@ where
|
||||
/// power_level event. If there have been two power events the after the most recent are depth 0,
|
||||
/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
|
||||
/// "older" than depth 0.
|
||||
fn mainline_sort<E, F>(
|
||||
fn mainline_sort<E: Event>(
|
||||
to_sort: &[EventId],
|
||||
resolved_power_level: Option<&EventId>,
|
||||
fetch_event: F,
|
||||
) -> Result<Vec<EventId>>
|
||||
where
|
||||
E: Event,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
{
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<Vec<EventId>> {
|
||||
debug!("mainline sort of events");
|
||||
|
||||
// There are no EventId's to sort, bail.
|
||||
@ -538,15 +524,11 @@ where
|
||||
|
||||
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
|
||||
/// associated mainline depth.
|
||||
fn get_mainline_depth<E, F>(
|
||||
fn get_mainline_depth<E: Event>(
|
||||
mut event: Option<E>,
|
||||
mainline_map: &EventMap<usize>,
|
||||
fetch_event: F,
|
||||
) -> Result<usize>
|
||||
where
|
||||
E: Event,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
{
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) -> Result<usize> {
|
||||
while let Some(sort_ev) = event {
|
||||
debug!("mainline event_id {}", sort_ev.event_id());
|
||||
let id = &sort_ev.event_id();
|
||||
@ -568,15 +550,12 @@ where
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn add_event_and_auth_chain_to_graph<E, F>(
|
||||
fn add_event_and_auth_chain_to_graph<E: Event>(
|
||||
graph: &mut HashMap<EventId, HashSet<EventId>>,
|
||||
event_id: EventId,
|
||||
auth_diff: &HashSet<EventId>,
|
||||
fetch_event: F,
|
||||
) where
|
||||
E: Event,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
{
|
||||
fetch_event: impl Fn(&EventId) -> Option<E>,
|
||||
) {
|
||||
let mut state = vec![event_id];
|
||||
while let Some(eid) = state.pop() {
|
||||
graph.entry(eid.clone()).or_default();
|
||||
@ -594,22 +573,18 @@ fn add_event_and_auth_chain_to_graph<E, F>(
|
||||
}
|
||||
}
|
||||
|
||||
fn is_power_event_id<E, F>(event_id: &EventId, fetch: F) -> bool
|
||||
where
|
||||
E: Event,
|
||||
F: Fn(&EventId) -> Option<E>,
|
||||
{
|
||||
fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
|
||||
match fetch(event_id).as_ref() {
|
||||
Some(state) => is_power_event(state),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_type_and_key<E: Event>(ev: &E, ev_type: &EventType, state_key: &str) -> bool {
|
||||
fn is_type_and_key(ev: impl Event, ev_type: &EventType, state_key: &str) -> bool {
|
||||
ev.event_type() == ev_type && ev.state_key() == Some(state_key)
|
||||
}
|
||||
|
||||
fn is_power_event<E: Event>(event: &E) -> bool {
|
||||
fn is_power_event(event: impl Event) -> bool {
|
||||
match event.event_type() {
|
||||
EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => {
|
||||
event.state_key() == Some("")
|
||||
|
Loading…
x
Reference in New Issue
Block a user