state-res: Make functions more general

Don't require different parameters to use the same impl Event.
This commit is contained in:
Jonas Platte 2021-09-13 18:34:08 +02:00
parent 0999e420ae
commit a6a1224652
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 65 additions and 95 deletions

View File

@ -84,17 +84,13 @@ pub fn auth_types_for_event(
/// ## Returns
///
/// This returns an `Error` only when serialization fails or some other fatal outcome.
pub fn auth_check<E, F>(
pub fn auth_check<E: Event>(
room_version: &RoomVersion,
incoming_event: &E,
prev_event: Option<&E>,
current_third_party_invite: Option<&E>,
fetch_state: F,
) -> Result<bool>
where
E: Event,
F: Fn(&EventType, &str) -> Option<E>,
{
incoming_event: impl Event,
prev_event: Option<impl Event>,
current_third_party_invite: Option<impl Event>,
fetch_state: impl Fn(&EventType, &str) -> Option<E>,
) -> Result<bool> {
info!(
"auth_check beginning for {} ({})",
incoming_event.event_id(),
@ -311,7 +307,7 @@ where
// If the event type's required power level is greater than the sender's power level, reject
// If the event has a state_key that starts with an @ and does not match the sender, reject.
if !can_send_event(incoming_event, power_levels_event.as_ref(), sender_power_level) {
if !can_send_event(&incoming_event, power_levels_event.as_ref(), sender_power_level) {
warn!("user cannot send event");
return Ok(false);
}
@ -321,7 +317,7 @@ where
if let Some(required_pwr_lvl) = check_power_levels(
room_version,
incoming_event,
&incoming_event,
power_levels_event.as_ref(),
sender_power_level,
) {
@ -378,16 +374,16 @@ where
/// This is generated by calling `auth_types_for_event` with the membership event and the current
/// State.
#[allow(clippy::too_many_arguments)]
fn valid_membership_change<E: Event>(
fn valid_membership_change(
target_user: &UserId,
target_user_membership_event: Option<&E>,
target_user_membership_event: Option<impl Event>,
sender: &UserId,
sender_membership_event: Option<&E>,
sender_membership_event: Option<impl Event>,
content: &serde_json::Value,
prev_event: Option<&E>,
current_third_party_invite: Option<&E>,
power_levels_event: Option<&E>,
join_rules_event: Option<&E>,
prev_event: Option<impl Event>,
current_third_party_invite: Option<impl Event>,
power_levels_event: Option<impl Event>,
join_rules_event: Option<impl Event>,
) -> Result<bool> {
let target_membership = serde_json::from_value::<MembershipState>(
content.get("membership").expect("we test before that this field exists").clone(),
@ -572,7 +568,7 @@ fn valid_membership_change<E: Event>(
/// Is the user allowed to send a specific event based on the rooms power levels.
///
/// Does the event have the correct userId as its state_key if it's not the "" state_key.
fn can_send_event<E: Event>(event: &E, ple: Option<&E>, user_level: Int) -> bool {
fn can_send_event(event: impl Event, ple: Option<impl Event>, user_level: Int) -> bool {
let event_type_power_level = get_send_level(event.event_type(), event.state_key(), ple);
debug!("{} ev_type {} usr {}", event.event_id(), event_type_power_level, user_level);
@ -591,15 +587,12 @@ fn can_send_event<E: Event>(event: &E, ple: Option<&E>, user_level: Int) -> bool
}
/// Confirm that the event sender has the required power levels.
fn check_power_levels<E>(
fn check_power_levels(
room_version: &RoomVersion,
power_event: &E,
previous_power_event: Option<&E>,
power_event: impl Event,
previous_power_event: Option<impl Event>,
user_level: Int,
) -> Option<bool>
where
E: Event,
{
) -> Option<bool> {
match power_event.state_key() {
Some("") => {}
Some(key) => {
@ -746,9 +739,9 @@ fn get_deserialize_levels(
}
/// Does the event redacting come from a user with enough power to redact the given event.
fn check_redaction<E: Event>(
fn check_redaction(
_room_version: &RoomVersion,
redaction_event: &E,
redaction_event: impl Event,
user_level: Int,
redact_level: Int,
) -> Result<bool> {
@ -771,10 +764,10 @@ fn check_redaction<E: Event>(
/// Helper function to fetch the power level needed to send an event of type
/// `e_type` based on the rooms "m.room.power_level" event.
fn get_send_level<E: Event>(
fn get_send_level(
e_type: &EventType,
state_key: Option<&str>,
power_lvl: Option<&E>,
power_lvl: Option<impl Event>,
) -> Int {
power_lvl
.and_then(|ple| {
@ -793,11 +786,11 @@ fn get_send_level<E: Event>(
.unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) })
}
fn verify_third_party_invite<E: Event>(
fn verify_third_party_invite(
target_user: Option<&UserId>,
sender: &UserId,
tp_id: &ThirdPartyInvite,
current_third_party_invite: Option<&E>,
current_third_party_invite: Option<impl Event>,
) -> bool {
// 1. Check for user being banned happens before this is called
// checking for mxid and token keys is done by ruma when deserializing
@ -845,7 +838,9 @@ mod tests {
use crate::{
event_auth::valid_membership_change,
test_utils::{alice, charlie, event_id, member_content_ban, to_pdu_event, INITIAL_EVENTS},
test_utils::{
alice, charlie, event_id, member_content_ban, to_pdu_event, StateEvent, INITIAL_EVENTS,
},
Event, StateMap,
};
use ruma_events::EventType;
@ -882,14 +877,14 @@ mod tests {
assert!(valid_membership_change(
&target_user,
fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(),
fetch_state(EventType::RoomMember, target_user.to_string()),
&sender,
fetch_state(EventType::RoomMember, sender.to_string()).as_deref(),
fetch_state(EventType::RoomMember, sender.to_string()),
requester.content(),
prev_event.as_deref(),
None,
fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(),
fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(),
prev_event,
None::<StateEvent>,
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
fetch_state(EventType::RoomJoinRules, "".to_owned()),
)
.unwrap());
}
@ -926,14 +921,14 @@ mod tests {
assert!(!valid_membership_change(
&target_user,
fetch_state(EventType::RoomMember, target_user.to_string()).as_deref(),
fetch_state(EventType::RoomMember, target_user.to_string()),
&sender,
fetch_state(EventType::RoomMember, sender.to_string()).as_deref(),
fetch_state(EventType::RoomMember, sender.to_string()),
requester.content(),
prev_event.as_deref(),
None,
fetch_state(EventType::RoomPowerLevels, "".to_owned()).as_deref(),
fetch_state(EventType::RoomJoinRules, "".to_owned()).as_deref(),
prev_event,
None::<StateEvent>,
fetch_state(EventType::RoomPowerLevels, "".to_owned()),
fetch_state(EventType::RoomJoinRules, "".to_owned()),
)
.unwrap());
}

View File

@ -51,15 +51,14 @@ type EventMap<T> = HashMap<EventId, T>;
///
/// The caller of `resolve` must ensure that all the events are from the same room. Although this
/// function takes a `RoomId` it does not check that each event is part of the same room.
pub fn resolve<'a, E, F, SSI>(
pub fn resolve<'a, E, SSI>(
room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SSI>,
auth_chain_sets: Vec<HashSet<EventId>>,
fetch_event: F,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<EventId>>
where
E: Event + Clone,
F: Fn(&EventId) -> Option<E>,
SSI: Iterator<Item = &'a StateMap<EventId>> + Clone,
{
info!("State resolution starting");
@ -203,15 +202,11 @@ fn get_auth_chain_diff(auth_chain_sets: Vec<HashSet<EventId>>) -> impl Iterator<
///
/// The power level is negative because a higher power level is equated to an earlier (further back
/// in time) origin server timestamp.
fn reverse_topological_power_sort<E, F>(
fn reverse_topological_power_sort<E: Event>(
events_to_sort: Vec<EventId>,
auth_diff: &HashSet<EventId>,
fetch_event: F,
) -> Result<Vec<EventId>>
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<EventId>> {
debug!("reverse topological sort of power events");
let mut graph = HashMap::new();
@ -320,11 +315,10 @@ where
}
/// Find the power level for the sender of `event_id` or return a default value of zero.
fn get_power_level_for_sender<E, F>(event_id: &EventId, fetch_event: F) -> i64
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
fn get_power_level_for_sender<E: Event>(
event_id: &EventId,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> i64 {
info!("fetch event ({}) senders power level", event_id);
let event = fetch_event(event_id);
@ -367,16 +361,12 @@ where
///
/// For each `events_to_check` event we gather the events needed to auth it from the the
/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
fn iterative_auth_check<E, F>(
fn iterative_auth_check<E: Event + Clone>(
room_version: &RoomVersion,
events_to_check: &[EventId],
unconflicted_state: StateMap<EventId>,
fetch_event: F,
) -> Result<StateMap<EventId>>
where
E: Event + Clone,
F: Fn(&EventId) -> Option<E>,
{
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<EventId>> {
info!("starting iterative auth check");
debug!("performing auth checks on {:?}", events_to_check);
@ -468,15 +458,11 @@ where
/// power_level event. If there have been two power events the after the most recent are depth 0,
/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
/// "older" than depth 0.
fn mainline_sort<E, F>(
fn mainline_sort<E: Event>(
to_sort: &[EventId],
resolved_power_level: Option<&EventId>,
fetch_event: F,
) -> Result<Vec<EventId>>
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<EventId>> {
debug!("mainline sort of events");
// There are no EventId's to sort, bail.
@ -538,15 +524,11 @@ where
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
/// associated mainline depth.
fn get_mainline_depth<E, F>(
fn get_mainline_depth<E: Event>(
mut event: Option<E>,
mainline_map: &EventMap<usize>,
fetch_event: F,
) -> Result<usize>
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<usize> {
while let Some(sort_ev) = event {
debug!("mainline event_id {}", sort_ev.event_id());
let id = &sort_ev.event_id();
@ -568,15 +550,12 @@ where
Ok(0)
}
fn add_event_and_auth_chain_to_graph<E, F>(
fn add_event_and_auth_chain_to_graph<E: Event>(
graph: &mut HashMap<EventId, HashSet<EventId>>,
event_id: EventId,
auth_diff: &HashSet<EventId>,
fetch_event: F,
) where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
fetch_event: impl Fn(&EventId) -> Option<E>,
) {
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default();
@ -594,22 +573,18 @@ fn add_event_and_auth_chain_to_graph<E, F>(
}
}
fn is_power_event_id<E, F>(event_id: &EventId, fetch: F) -> bool
where
E: Event,
F: Fn(&EventId) -> Option<E>,
{
fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
match fetch(event_id).as_ref() {
Some(state) => is_power_event(state),
_ => false,
}
}
fn is_type_and_key<E: Event>(ev: &E, ev_type: &EventType, state_key: &str) -> bool {
fn is_type_and_key(ev: impl Event, ev_type: &EventType, state_key: &str) -> bool {
ev.event_type() == ev_type && ev.state_key() == Some(state_key)
}
fn is_power_event<E: Event>(event: &E) -> bool {
fn is_power_event(event: impl Event) -> bool {
match event.event_type() {
EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => {
event.state_key() == Some("")