diff --git a/crates/ruma-state-res/src/event_auth.rs b/crates/ruma-state-res/src/event_auth.rs index 5aa914ca..acb24f34 100644 --- a/crates/ruma-state-res/src/event_auth.rs +++ b/crates/ruma-state-res/src/event_auth.rs @@ -84,17 +84,13 @@ pub fn auth_types_for_event( /// ## Returns /// /// This returns an `Error` only when serialization fails or some other fatal outcome. -pub fn auth_check( +pub fn auth_check( room_version: &RoomVersion, - incoming_event: &E, - prev_event: Option<&E>, - current_third_party_invite: Option<&E>, - fetch_state: F, -) -> Result -where - E: Event, - F: Fn(&EventType, &str) -> Option, -{ + incoming_event: impl Event, + prev_event: Option, + current_third_party_invite: Option, + fetch_state: impl Fn(&EventType, &str) -> Option, +) -> Result { info!( "auth_check beginning for {} ({})", incoming_event.event_id(), @@ -311,7 +307,7 @@ where // If the event type's required power level is greater than the sender's power level, reject // If the event has a state_key that starts with an @ and does not match the sender, reject. - if !can_send_event(incoming_event, power_levels_event.as_ref(), sender_power_level) { + if !can_send_event(&incoming_event, power_levels_event.as_ref(), sender_power_level) { warn!("user cannot send event"); return Ok(false); } @@ -321,7 +317,7 @@ where if let Some(required_pwr_lvl) = check_power_levels( room_version, - incoming_event, + &incoming_event, power_levels_event.as_ref(), sender_power_level, ) { @@ -378,16 +374,16 @@ where /// This is generated by calling `auth_types_for_event` with the membership event and the current /// State. #[allow(clippy::too_many_arguments)] -fn valid_membership_change( +fn valid_membership_change( target_user: &UserId, - target_user_membership_event: Option<&E>, + target_user_membership_event: Option, sender: &UserId, - sender_membership_event: Option<&E>, + sender_membership_event: Option, content: &serde_json::Value, - prev_event: Option<&E>, - current_third_party_invite: Option<&E>, - power_levels_event: Option<&E>, - join_rules_event: Option<&E>, + prev_event: Option, + current_third_party_invite: Option, + power_levels_event: Option, + join_rules_event: Option, ) -> Result { let target_membership = serde_json::from_value::( content.get("membership").expect("we test before that this field exists").clone(), @@ -572,7 +568,7 @@ fn valid_membership_change( /// Is the user allowed to send a specific event based on the rooms power levels. /// /// Does the event have the correct userId as its state_key if it's not the "" state_key. -fn can_send_event(event: &E, ple: Option<&E>, user_level: Int) -> bool { +fn can_send_event(event: impl Event, ple: Option, user_level: Int) -> bool { let event_type_power_level = get_send_level(event.event_type(), event.state_key(), ple); debug!("{} ev_type {} usr {}", event.event_id(), event_type_power_level, user_level); @@ -591,15 +587,12 @@ fn can_send_event(event: &E, ple: Option<&E>, user_level: Int) -> bool } /// Confirm that the event sender has the required power levels. -fn check_power_levels( +fn check_power_levels( room_version: &RoomVersion, - power_event: &E, - previous_power_event: Option<&E>, + power_event: impl Event, + previous_power_event: Option, user_level: Int, -) -> Option -where - E: Event, -{ +) -> Option { match power_event.state_key() { Some("") => {} Some(key) => { @@ -746,9 +739,9 @@ fn get_deserialize_levels( } /// Does the event redacting come from a user with enough power to redact the given event. -fn check_redaction( +fn check_redaction( _room_version: &RoomVersion, - redaction_event: &E, + redaction_event: impl Event, user_level: Int, redact_level: Int, ) -> Result { @@ -771,10 +764,10 @@ fn check_redaction( /// Helper function to fetch the power level needed to send an event of type /// `e_type` based on the rooms "m.room.power_level" event. -fn get_send_level( +fn get_send_level( e_type: &EventType, state_key: Option<&str>, - power_lvl: Option<&E>, + power_lvl: Option, ) -> Int { power_lvl .and_then(|ple| { @@ -793,11 +786,11 @@ fn get_send_level( .unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) }) } -fn verify_third_party_invite( +fn verify_third_party_invite( target_user: Option<&UserId>, sender: &UserId, tp_id: &ThirdPartyInvite, - current_third_party_invite: Option<&E>, + current_third_party_invite: Option, ) -> bool { // 1. Check for user being banned happens before this is called // checking for mxid and token keys is done by ruma when deserializing @@ -845,7 +838,9 @@ mod tests { use crate::{ event_auth::valid_membership_change, - test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS}, + test_utils::{ + alice, charlie, event_id, member_content_ban, to_pdu_event, StateEvent, INITIAL_EVENTS, + }, Event, StateMap, }; use ruma_events::EventType; @@ -882,14 +877,14 @@ mod tests { assert!(valid_membership_change( &target_user, - fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(), + fetch_state(EventType::RoomMember, target_user.to_string()), &sender, - fetch_state(EventType::RoomMember, sender.to_string()).as_deref(), + fetch_state(EventType::RoomMember, sender.to_string()), requester.content(), - prev_event.as_deref(), - None, - fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(), - fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(), + prev_event, + None::, + fetch_state(EventType::RoomPowerLevels, "".to_owned()), + fetch_state(EventType::RoomJoinRules, "".to_owned()), ) .unwrap()); } @@ -926,14 +921,14 @@ mod tests { assert!(!valid_membership_change( &target_user, - fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(), + fetch_state(EventType::RoomMember, target_user.to_string()), &sender, - fetch_state(EventType::RoomMember, sender.to_string()).as_deref(), + fetch_state(EventType::RoomMember, sender.to_string()), requester.content(), - prev_event.as_deref(), - None, - fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(), - fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(), + prev_event, + None::, + fetch_state(EventType::RoomPowerLevels, "".to_owned()), + fetch_state(EventType::RoomJoinRules, "".to_owned()), ) .unwrap()); } diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 081777e3..93b0ddc9 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -51,15 +51,14 @@ type EventMap = HashMap; /// /// The caller of `resolve` must ensure that all the events are from the same room. Although this /// function takes a `RoomId` it does not check that each event is part of the same room. -pub fn resolve<'a, E, F, SSI>( +pub fn resolve<'a, E, SSI>( room_version: &RoomVersionId, state_sets: impl IntoIterator, auth_chain_sets: Vec>, - fetch_event: F, + fetch_event: impl Fn(&EventId) -> Option, ) -> Result> where E: Event + Clone, - F: Fn(&EventId) -> Option, SSI: Iterator> + Clone, { info!("State resolution starting"); @@ -203,15 +202,11 @@ fn get_auth_chain_diff(auth_chain_sets: Vec>) -> impl Iterator< /// /// The power level is negative because a higher power level is equated to an earlier (further back /// in time) origin server timestamp. -fn reverse_topological_power_sort( +fn reverse_topological_power_sort( events_to_sort: Vec, auth_diff: &HashSet, - fetch_event: F, -) -> Result> -where - E: Event, - F: Fn(&EventId) -> Option, -{ + fetch_event: impl Fn(&EventId) -> Option, +) -> Result> { debug!("reverse topological sort of power events"); let mut graph = HashMap::new(); @@ -320,11 +315,10 @@ where } /// Find the power level for the sender of `event_id` or return a default value of zero. -fn get_power_level_for_sender(event_id: &EventId, fetch_event: F) -> i64 -where - E: Event, - F: Fn(&EventId) -> Option, -{ +fn get_power_level_for_sender( + event_id: &EventId, + fetch_event: impl Fn(&EventId) -> Option, +) -> i64 { info!("fetch event ({}) senders power level", event_id); let event = fetch_event(event_id); @@ -367,16 +361,12 @@ where /// /// For each `events_to_check` event we gather the events needed to auth it from the the /// `fetch_event` closure and verify each event using the `event_auth::auth_check` function. -fn iterative_auth_check( +fn iterative_auth_check( room_version: &RoomVersion, events_to_check: &[EventId], unconflicted_state: StateMap, - fetch_event: F, -) -> Result> -where - E: Event + Clone, - F: Fn(&EventId) -> Option, -{ + fetch_event: impl Fn(&EventId) -> Option, +) -> Result> { info!("starting iterative auth check"); debug!("performing auth checks on {:?}", events_to_check); @@ -468,15 +458,11 @@ where /// power_level event. If there have been two power events the after the most recent are depth 0, /// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is /// "older" than depth 0. -fn mainline_sort( +fn mainline_sort( to_sort: &[EventId], resolved_power_level: Option<&EventId>, - fetch_event: F, -) -> Result> -where - E: Event, - F: Fn(&EventId) -> Option, -{ + fetch_event: impl Fn(&EventId) -> Option, +) -> Result> { debug!("mainline sort of events"); // There are no EventId's to sort, bail. @@ -538,15 +524,11 @@ where /// Get the mainline depth from the `mainline_map` or finds a power_level event that has an /// associated mainline depth. -fn get_mainline_depth( +fn get_mainline_depth( mut event: Option, mainline_map: &EventMap, - fetch_event: F, -) -> Result -where - E: Event, - F: Fn(&EventId) -> Option, -{ + fetch_event: impl Fn(&EventId) -> Option, +) -> Result { while let Some(sort_ev) = event { debug!("mainline event_id {}", sort_ev.event_id()); let id = &sort_ev.event_id(); @@ -568,15 +550,12 @@ where Ok(0) } -fn add_event_and_auth_chain_to_graph( +fn add_event_and_auth_chain_to_graph( graph: &mut HashMap>, event_id: EventId, auth_diff: &HashSet, - fetch_event: F, -) where - E: Event, - F: Fn(&EventId) -> Option, -{ + fetch_event: impl Fn(&EventId) -> Option, +) { let mut state = vec![event_id]; while let Some(eid) = state.pop() { graph.entry(eid.clone()).or_default(); @@ -594,22 +573,18 @@ fn add_event_and_auth_chain_to_graph( } } -fn is_power_event_id(event_id: &EventId, fetch: F) -> bool -where - E: Event, - F: Fn(&EventId) -> Option, -{ +fn is_power_event_id(event_id: &EventId, fetch: impl Fn(&EventId) -> Option) -> bool { match fetch(event_id).as_ref() { Some(state) => is_power_event(state), _ => false, } } -fn is_type_and_key(ev: &E, ev_type: &EventType, state_key: &str) -> bool { +fn is_type_and_key(ev: impl Event, ev_type: &EventType, state_key: &str) -> bool { ev.event_type() == ev_type && ev.state_key() == Some(state_key) } -fn is_power_event(event: &E) -> bool { +fn is_power_event(event: impl Event) -> bool { match event.event_type() { EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => { event.state_key() == Some("")