state-res: Make functions more general

Don't require different parameters to use the same impl Event.
This commit is contained in:
Jonas Platte 2021-09-13 18:34:08 +02:00
parent 0999e420ae
commit a6a1224652
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 65 additions and 95 deletions

View File

@ -84,17 +84,13 @@ pub fn auth_types_for_event(
/// ## Returns /// ## Returns
/// ///
/// This returns an `Error` only when serialization fails or some other fatal outcome. /// This returns an `Error` only when serialization fails or some other fatal outcome.
pub fn auth_check<E, F>( pub fn auth_check<E: Event>(
room_version: &RoomVersion, room_version: &RoomVersion,
incoming_event: &E, incoming_event: impl Event,
prev_event: Option<&E>, prev_event: Option<impl Event>,
current_third_party_invite: Option<&E>, current_third_party_invite: Option<impl Event>,
fetch_state: F, fetch_state: impl Fn(&EventType, &str) -> Option<E>,
) -> Result<bool> ) -> Result<bool> {
where
E: Event,
F: Fn(&EventType, &str) -> Option<E>,
{
info!( info!(
"auth_check beginning for {} ({})", "auth_check beginning for {} ({})",
incoming_event.event_id(), incoming_event.event_id(),
@ -311,7 +307,7 @@ where
// If the event type's required power level is greater than the sender's power level, reject // If the event type's required power level is greater than the sender's power level, reject
// If the event has a state_key that starts with an @ and does not match the sender, reject. // If the event has a state_key that starts with an @ and does not match the sender, reject.
if !can_send_event(incoming_event, power_levels_event.as_ref(), sender_power_level) { if !can_send_event(&incoming_event, power_levels_event.as_ref(), sender_power_level) {
warn!("user cannot send event"); warn!("user cannot send event");
return Ok(false); return Ok(false);
} }
@ -321,7 +317,7 @@ where
if let Some(required_pwr_lvl) = check_power_levels( if let Some(required_pwr_lvl) = check_power_levels(
room_version, room_version,
incoming_event, &incoming_event,
power_levels_event.as_ref(), power_levels_event.as_ref(),
sender_power_level, sender_power_level,
) { ) {
@ -378,16 +374,16 @@ where
/// This is generated by calling `auth_types_for_event` with the membership event and the current /// This is generated by calling `auth_types_for_event` with the membership event and the current
/// State. /// State.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn valid_membership_change<E: Event>( fn valid_membership_change(
target_user: &UserId, target_user: &UserId,
target_user_membership_event: Option<&E>, target_user_membership_event: Option<impl Event>,
sender: &UserId, sender: &UserId,
sender_membership_event: Option<&E>, sender_membership_event: Option<impl Event>,
content: &serde_json::Value, content: &serde_json::Value,
prev_event: Option<&E>, prev_event: Option<impl Event>,
current_third_party_invite: Option<&E>, current_third_party_invite: Option<impl Event>,
power_levels_event: Option<&E>, power_levels_event: Option<impl Event>,
join_rules_event: Option<&E>, join_rules_event: Option<impl Event>,
) -> Result<bool> { ) -> Result<bool> {
let target_membership = serde_json::from_value::<MembershipState>( let target_membership = serde_json::from_value::<MembershipState>(
content.get("membership").expect("we test before that this field exists").clone(), content.get("membership").expect("we test before that this field exists").clone(),
@ -572,7 +568,7 @@ fn valid_membership_change<E: Event>(
/// Is the user allowed to send a specific event based on the rooms power levels. /// Is the user allowed to send a specific event based on the rooms power levels.
/// ///
/// Does the event have the correct userId as its state_key if it's not the "" state_key. /// Does the event have the correct userId as its state_key if it's not the "" state_key.
fn can_send_event<E: Event>(event: &E, ple: Option<&E>, user_level: Int) -> bool { fn can_send_event(event: impl Event, ple: Option<impl Event>, user_level: Int) -> bool {
let event_type_power_level = get_send_level(event.event_type(), event.state_key(), ple); let event_type_power_level = get_send_level(event.event_type(), event.state_key(), ple);
debug!("{} ev_type {} usr {}", event.event_id(), event_type_power_level, user_level); debug!("{} ev_type {} usr {}", event.event_id(), event_type_power_level, user_level);
@ -591,15 +587,12 @@ fn can_send_event<E: Event>(event: &E, ple: Option<&E>, user_level: Int) -> bool
} }
/// Confirm that the event sender has the required power levels. /// Confirm that the event sender has the required power levels.
fn check_power_levels<E>( fn check_power_levels(
room_version: &RoomVersion, room_version: &RoomVersion,
power_event: &E, power_event: impl Event,
previous_power_event: Option<&E>, previous_power_event: Option<impl Event>,
user_level: Int, user_level: Int,
) -> Option<bool> ) -> Option<bool> {
where
E: Event,
{
match power_event.state_key() { match power_event.state_key() {
Some("") => {} Some("") => {}
Some(key) => { Some(key) => {
@ -746,9 +739,9 @@ fn get_deserialize_levels(
} }
/// Does the event redacting come from a user with enough power to redact the given event. /// Does the event redacting come from a user with enough power to redact the given event.
fn check_redaction<E: Event>( fn check_redaction(
_room_version: &RoomVersion, _room_version: &RoomVersion,
redaction_event: &E, redaction_event: impl Event,
user_level: Int, user_level: Int,
redact_level: Int, redact_level: Int,
) -> Result<bool> { ) -> Result<bool> {
@ -771,10 +764,10 @@ fn check_redaction<E: Event>(
/// Helper function to fetch the power level needed to send an event of type /// Helper function to fetch the power level needed to send an event of type
/// `e_type` based on the rooms "m.room.power_level" event. /// `e_type` based on the rooms "m.room.power_level" event.
fn get_send_level<E: Event>( fn get_send_level(
e_type: &EventType, e_type: &EventType,
state_key: Option<&str>, state_key: Option<&str>,
power_lvl: Option<&E>, power_lvl: Option<impl Event>,
) -> Int { ) -> Int {
power_lvl power_lvl
.and_then(|ple| { .and_then(|ple| {
@ -793,11 +786,11 @@ fn get_send_level<E: Event>(
.unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) }) .unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) })
} }
fn verify_third_party_invite<E: Event>( fn verify_third_party_invite(
target_user: Option<&UserId>, target_user: Option<&UserId>,
sender: &UserId, sender: &UserId,
tp_id: &ThirdPartyInvite, tp_id: &ThirdPartyInvite,
current_third_party_invite: Option<&E>, current_third_party_invite: Option<impl Event>,
) -> bool { ) -> bool {
// 1. Check for user being banned happens before this is called // 1. Check for user being banned happens before this is called
// checking for mxid and token keys is done by ruma when deserializing // checking for mxid and token keys is done by ruma when deserializing
@ -845,7 +838,9 @@ mod tests {
use crate::{ use crate::{
event_auth::valid_membership_change, event_auth::valid_membership_change,
test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS}, test_utils::{
alice, charlie, event_id, member_content_ban, to_pdu_event, StateEvent, INITIAL_EVENTS,
},
Event, StateMap, Event, StateMap,
}; };
use ruma_events::EventType; use ruma_events::EventType;
@ -882,14 +877,14 @@ mod tests {
assert!(valid_membership_change( assert!(valid_membership_change(
&target_user, &target_user,
fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(), fetch_state(EventType::RoomMember, target_user.to_string()),
&sender, &sender,
fetch_state(EventType::RoomMember, sender.to_string()).as_deref(), fetch_state(EventType::RoomMember, sender.to_string()),
requester.content(), requester.content(),
prev_event.as_deref(), prev_event,
None, None::<StateEvent>,
fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(), fetch_state(EventType::RoomPowerLevels, "".to_owned()),
fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(), fetch_state(EventType::RoomJoinRules, "".to_owned()),
) )
.unwrap()); .unwrap());
} }
@ -926,14 +921,14 @@ mod tests {
assert!(!valid_membership_change( assert!(!valid_membership_change(
&target_user, &target_user,
fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(), fetch_state(EventType::RoomMember, target_user.to_string()),
&sender, &sender,
fetch_state(EventType::RoomMember, sender.to_string()).as_deref(), fetch_state(EventType::RoomMember, sender.to_string()),
requester.content(), requester.content(),
prev_event.as_deref(), prev_event,
None, None::<StateEvent>,
fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(), fetch_state(EventType::RoomPowerLevels, "".to_owned()),
fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(), fetch_state(EventType::RoomJoinRules, "".to_owned()),
) )
.unwrap()); .unwrap());
} }

View File

@ -51,15 +51,14 @@ type EventMap<T> = HashMap<EventId, T>;
/// ///
/// The caller of `resolve` must ensure that all the events are from the same room. Although this /// The caller of `resolve` must ensure that all the events are from the same room. Although this
/// function takes a `RoomId` it does not check that each event is part of the same room. /// function takes a `RoomId` it does not check that each event is part of the same room.
pub fn resolve<'a, E, F, SSI>( pub fn resolve<'a, E, SSI>(
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SSI>, state_sets: impl IntoIterator<IntoIter = SSI>,
auth_chain_sets: Vec<HashSet<EventId>>, auth_chain_sets: Vec<HashSet<EventId>>,
fetch_event: F, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<EventId>> ) -> Result<StateMap<EventId>>
where where
E: Event + Clone, E: Event + Clone,
F: Fn(&EventId) -> Option<E>,
SSI: Iterator<Item = &'a StateMap<EventId>> + Clone, SSI: Iterator<Item = &'a StateMap<EventId>> + Clone,
{ {
info!("State resolution starting"); info!("State resolution starting");
@ -203,15 +202,11 @@ fn get_auth_chain_diff(auth_chain_sets: Vec<HashSet<EventId>>) -> impl Iterator<
/// ///
/// The power level is negative because a higher power level is equated to an earlier (further back /// The power level is negative because a higher power level is equated to an earlier (further back
/// in time) origin server timestamp. /// in time) origin server timestamp.
fn reverse_topological_power_sort<E, F>( fn reverse_topological_power_sort<E: Event>(
events_to_sort: Vec<EventId>, events_to_sort: Vec<EventId>,
auth_diff: &HashSet<EventId>, auth_diff: &HashSet<EventId>,
fetch_event: F, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<EventId>> ) -> Result<Vec<EventId>> {
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
debug!("reverse topological sort of power events"); debug!("reverse topological sort of power events");
let mut graph = HashMap::new(); let mut graph = HashMap::new();
@ -320,11 +315,10 @@ where
} }
/// Find the power level for the sender of `event_id` or return a default value of zero. /// Find the power level for the sender of `event_id` or return a default value of zero.
fn get_power_level_for_sender<E, F>(event_id: &EventId, fetch_event: F) -> i64 fn get_power_level_for_sender<E: Event>(
where event_id: &EventId,
E: Event, fetch_event: impl Fn(&EventId) -> Option<E>,
F: Fn(&EventId) -> Option<E>, ) -> i64 {
{
info!("fetch event ({}) senders power level", event_id); info!("fetch event ({}) senders power level", event_id);
let event = fetch_event(event_id); let event = fetch_event(event_id);
@ -367,16 +361,12 @@ where
/// ///
/// For each `events_to_check` event we gather the events needed to auth it from the the /// For each `events_to_check` event we gather the events needed to auth it from the the
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
fn iterative_auth_check<E, F>( fn iterative_auth_check<E: Event + Clone>(
room_version: &RoomVersion, room_version: &RoomVersion,
events_to_check: &[EventId], events_to_check: &[EventId],
unconflicted_state: StateMap<EventId>, unconflicted_state: StateMap<EventId>,
fetch_event: F, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<EventId>> ) -> Result<StateMap<EventId>> {
where
E: Event + Clone,
F: Fn(&EventId) -> Option<E>,
{
info!("starting iterative auth check"); info!("starting iterative auth check");
debug!("performing auth checks on {:?}", events_to_check); debug!("performing auth checks on {:?}", events_to_check);
@ -468,15 +458,11 @@ where
/// power_level event. If there have been two power events the after the most recent are depth 0, /// power_level event. If there have been two power events the after the most recent are depth 0,
/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is /// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
/// "older" than depth 0. /// "older" than depth 0.
fn mainline_sort<E, F>( fn mainline_sort<E: Event>(
to_sort: &[EventId], to_sort: &[EventId],
resolved_power_level: Option<&EventId>, resolved_power_level: Option<&EventId>,
fetch_event: F, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<EventId>> ) -> Result<Vec<EventId>> {
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
debug!("mainline sort of events"); debug!("mainline sort of events");
// There are no EventId's to sort, bail. // There are no EventId's to sort, bail.
@ -538,15 +524,11 @@ where
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an /// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
/// associated mainline depth. /// associated mainline depth.
fn get_mainline_depth<E, F>( fn get_mainline_depth<E: Event>(
mut event: Option<E>, mut event: Option<E>,
mainline_map: &EventMap<usize>, mainline_map: &EventMap<usize>,
fetch_event: F, fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<usize> ) -> Result<usize> {
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
while let Some(sort_ev) = event { while let Some(sort_ev) = event {
debug!("mainline event_id {}", sort_ev.event_id()); debug!("mainline event_id {}", sort_ev.event_id());
let id = &sort_ev.event_id(); let id = &sort_ev.event_id();
@ -568,15 +550,12 @@ where
Ok(0) Ok(0)
} }
fn add_event_and_auth_chain_to_graph<E, F>( fn add_event_and_auth_chain_to_graph<E: Event>(
graph: &mut HashMap<EventId, HashSet<EventId>>, graph: &mut HashMap<EventId, HashSet<EventId>>,
event_id: EventId, event_id: EventId,
auth_diff: &HashSet<EventId>, auth_diff: &HashSet<EventId>,
fetch_event: F, fetch_event: impl Fn(&EventId) -> Option<E>,
) where ) {
E: Event,
F: Fn(&EventId) -> Option<E>,
{
let mut state = vec![event_id]; let mut state = vec![event_id];
while let Some(eid) = state.pop() { while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default(); graph.entry(eid.clone()).or_default();
@ -594,22 +573,18 @@ fn add_event_and_auth_chain_to_graph<E, F>(
} }
} }
fn is_power_event_id<E, F>(event_id: &EventId, fetch: F) -> bool fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
match fetch(event_id).as_ref() { match fetch(event_id).as_ref() {
Some(state) => is_power_event(state), Some(state) => is_power_event(state),
_ => false, _ => false,
} }
} }
fn is_type_and_key<E: Event>(ev: &E, ev_type: &EventType, state_key: &str) -> bool { fn is_type_and_key(ev: impl Event, ev_type: &EventType, state_key: &str) -> bool {
ev.event_type() == ev_type && ev.state_key() == Some(state_key) ev.event_type() == ev_type && ev.state_key() == Some(state_key)
} }
fn is_power_event<E: Event>(event: &E) -> bool { fn is_power_event(event: impl Event) -> bool {
match event.event_type() { match event.event_type() {
EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => { EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => {
event.state_key() == Some("") event.state_key() == Some("")