From 886c33eac33dc4c13d14f227b77256f9d59cc380 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Wed, 18 Aug 2021 19:17:45 -0400 Subject: [PATCH] state-res: Use fetch state closure instead of auth_chain --- crates/ruma-state-res/src/event_auth.rs | 153 +++++++++++++--------- crates/ruma-state-res/src/lib.rs | 8 +- crates/ruma-state-res/tests/event_auth.rs | 4 +- 3 files changed, 98 insertions(+), 67 deletions(-) diff --git a/crates/ruma-state-res/src/event_auth.rs b/crates/ruma-state-res/src/event_auth.rs index 269f7c42..6314403a 100644 --- a/crates/ruma-state-res/src/event_auth.rs +++ b/crates/ruma-state-res/src/event_auth.rs @@ -76,19 +76,23 @@ pub fn auth_types_for_event( /// * check that the events signatures are valid /// * then there are checks for specific event types /// -/// The `auth_events` that are passed to this function should be a state snapshot. -/// We need to know if the event passes auth against some state not a recursive collection -/// of auth_events fields. +/// The `fetch_state` closure should gather state from a state snapshot. +/// We need to know if the event passes auth against some state not a recursive +/// collection of auth_events fields. /// /// ## Returns /// This returns an `Error` only when serialization fails or some other fatal outcome. -pub fn auth_check( +pub fn auth_check( room_version: &RoomVersion, incoming_event: &Arc, prev_event: Option>, - auth_events: &StateMap>, current_third_party_invite: Option>, -) -> Result { + fetch_state: F, +) -> Result +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ info!("auth_check beginning for {} ({})", incoming_event.event_id(), incoming_event.kind()); // [synapse] check that all the events are in the same room as `incoming_event` @@ -165,7 +169,7 @@ pub fn auth_check( */ // 3. If event does not have m.room.create in auth_events reject - if auth_events.get(&(EventType::RoomCreate, "".to_owned())).is_none() { + if fetch_state(&EventType::RoomCreate, "").is_none() { warn!("no m.room.create event in auth chain"); return Ok(false); @@ -213,7 +217,7 @@ pub fn auth_check( incoming_event.content(), prev_event, current_third_party_invite, - auth_events, + fetch_state, )? { return Ok(false); } @@ -223,7 +227,7 @@ pub fn auth_check( } // If the sender's current membership state is not join, reject - match check_event_sender_in_room(incoming_event.sender(), auth_events) { + match check_event_sender_in_room(incoming_event.sender(), &fetch_state) { Some(true) => {} // sender in room Some(false) => { warn!("sender's membership is not join"); @@ -238,7 +242,7 @@ pub fn auth_check( // Allow if and only if sender's current power level is greater than // or equal to the invite level if incoming_event.kind() == EventType::RoomThirdPartyInvite - && !can_send_invite(incoming_event, auth_events)? + && !can_send_invite(incoming_event, &fetch_state)? { warn!("sender's cannot send invites in this room"); return Ok(false); @@ -246,7 +250,7 @@ pub fn auth_check( // If the event type's required power level is greater than the sender's power level, reject // If the event has a state_key that starts with an @ and does not match the sender, reject. - if !can_send_event(incoming_event, auth_events) { + if !can_send_event(incoming_event, &fetch_state) { warn!("user cannot send event"); return Ok(false); } @@ -255,7 +259,7 @@ pub fn auth_check( info!("starting m.room.power_levels check"); if let Some(required_pwr_lvl) = - check_power_levels(room_version, incoming_event, auth_events) + check_power_levels(room_version, incoming_event, &fetch_state) { if !required_pwr_lvl { warn!("power level was not allowed"); @@ -277,7 +281,7 @@ pub fn auth_check( if room_version.extra_redaction_checks && incoming_event.kind() == EventType::RoomRedaction - && !check_redaction(room_version, incoming_event, auth_events)? + && !check_redaction(room_version, incoming_event, &fetch_state)? { return Ok(false); } @@ -295,14 +299,18 @@ pub fn auth_check( /// * `auth_events` - The set of auth events that relate to a membership event. /// this is generated by calling `auth_types_for_event` with the membership event and /// the current State. -pub fn valid_membership_change( +pub fn valid_membership_change( state_key: &str, user_sender: &UserId, content: serde_json::Value, prev_event: Option>, current_third_party_invite: Option>, - auth_events: &StateMap>, -) -> Result { + fetch_state: F, +) -> Result +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ let target_membership = serde_json::from_value::( content.get("membership").expect("we test before that this field exists").clone(), )?; @@ -314,16 +322,17 @@ pub fn valid_membership_change( let target_user_id = UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?; - let key = (EventType::RoomMember, user_sender.to_string()); - let sender = auth_events.get(&key); + let sender = fetch_state(&EventType::RoomMember, user_sender.as_str()); + let sender = sender.as_ref(); + let sender_membership = sender.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { Ok(serde_json::from_value::( pdu.content().get("membership").expect("we assume existing events are valid").clone(), )?) })?; - let key = (EventType::RoomMember, target_user_id.to_string()); - let current = auth_events.get(&key); + let current = fetch_state(&EventType::RoomMember, target_user_id.as_str()); + let current = current.as_ref(); let current_membership = current.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| { Ok(serde_json::from_value::( @@ -331,8 +340,9 @@ pub fn valid_membership_change( )?) })?; - let key = (EventType::RoomPowerLevels, "".into()); - let power_levels_event = auth_events.get(&key); + let power_levels_event = fetch_state(&EventType::RoomPowerLevels, ""); + let power_levels_event = power_levels_event.as_ref(); + let power_levels = power_levels_event.map_or_else( || Ok::<_, Error>(PowerLevelsEventContent::default()), |power_levels| { @@ -352,8 +362,8 @@ pub fn valid_membership_change( Some, ); - let key = (EventType::RoomJoinRules, "".into()); - let join_rules_event = auth_events.get(&key); + let join_rules_event = fetch_state(&EventType::RoomJoinRules, ""); + let join_rules_event = join_rules_event.as_ref(); let mut join_rules = JoinRule::Invite; if let Some(jr) = join_rules_event { join_rules = serde_json::from_value::(jr.content())?.join_rule; @@ -485,11 +495,12 @@ pub fn valid_membership_change( } /// Is the event's sender in the room that they sent the event to. -pub fn check_event_sender_in_room( - sender: &UserId, - auth_events: &StateMap>, -) -> Option { - let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?; +pub fn check_event_sender_in_room(sender: &UserId, fetch_state: F) -> Option +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ + let mem = fetch_state(&EventType::RoomMember, sender.as_str())?; let membership = serde_json::from_value::( mem.content() @@ -504,11 +515,15 @@ pub fn check_event_sender_in_room( /// Is the user allowed to send a specific event based on the rooms power levels. Does the event /// have the correct userId as it's state_key if it's not the "" state_key. -pub fn can_send_event(event: &Arc, auth_events: &StateMap>) -> bool { - let ple = auth_events.get(&(EventType::RoomPowerLevels, "".into())); +pub fn can_send_event(event: &Arc, fetch_state: F) -> bool +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ + let ple = fetch_state(&EventType::RoomPowerLevels, ""); - let event_type_power_level = get_send_level(&event.kind(), event.state_key(), ple); - let user_level = get_user_power_level(event.sender(), auth_events); + let event_type_power_level = get_send_level(&event.kind(), event.state_key(), ple.as_ref()); + let user_level = get_user_power_level(event.sender(), fetch_state); debug!("{} ev_type {} usr {}", event.event_id(), event_type_power_level, user_level); @@ -526,19 +541,23 @@ pub fn can_send_event(event: &Arc, auth_events: &StateMap>) } /// Confirm that the event sender has the required power levels. -pub fn check_power_levels( +pub fn check_power_levels( room_version: &RoomVersion, power_event: &Arc, - auth_events: &StateMap>, -) -> Option { + fetch_state: F, +) -> Option +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ let power_event_state_key = power_event.state_key().expect("power events have state keys"); - let key = (power_event.kind(), power_event_state_key); - let current_state = if let Some(current_state) = auth_events.get(&key) { - current_state - } else { - // If there is no previous m.room.power_levels event in the room, allow - return Some(true); - }; + let current_state = + if let Some(current_state) = fetch_state(&power_event.kind(), &power_event_state_key) { + current_state + } else { + // If there is no previous m.room.power_levels event in the room, allow + return Some(true); + }; // If users key in content is not a dictionary with keys that are valid user IDs // with values that are integers (or a string that is an integer), reject. @@ -551,7 +570,7 @@ pub fn check_power_levels( // Validation of users is done in Ruma, synapse for loops validating user_ids and integers here info!("validation of power event finished"); - let user_level = get_user_power_level(power_event.sender(), auth_events); + let user_level = get_user_power_level(power_event.sender(), fetch_state); let mut user_levels_to_check = btreeset![]; let old_list = ¤t_content.users; @@ -668,13 +687,17 @@ fn get_deserialize_levels( } /// Does the event redacting come from a user with enough power to redact the given event. -pub fn check_redaction( +pub fn check_redaction( _room_version: &RoomVersion, redaction_event: &Arc, - auth_events: &StateMap>, -) -> Result { - let user_level = get_user_power_level(redaction_event.sender(), auth_events); - let redact_level = get_named_level(auth_events, "redact", 50); + fetch_state: F, +) -> Result +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ + let user_level = get_user_power_level(redaction_event.sender(), &fetch_state); + let redact_level = get_named_level(fetch_state, "redact", 50); if user_level >= redact_level { info!("redaction allowed via power levels"); @@ -728,8 +751,12 @@ pub fn can_federate(auth_events: &StateMap>) -> bool { /// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. /// or return `default` if no power level event is found or zero if no field matches `name`. -pub fn get_named_level(auth_events: &StateMap>, name: &str, default: i64) -> i64 { - let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, "".into())); +pub fn get_named_level(fetch_state: F, name: &str, default: i64) -> i64 +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ + let power_level_event = fetch_state(&EventType::RoomPowerLevels, ""); if let Some(pl) = power_level_event { // TODO do this the right way and deserialize if let Some(level) = pl.content().get(name) { @@ -744,8 +771,12 @@ pub fn get_named_level(auth_events: &StateMap>, name: &str, def /// Helper function to fetch a users default power level from a "m.room.power_level" event's `users` /// object. -pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap>) -> i64 { - if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) { +pub fn get_user_power_level(user_id: &UserId, fetch_state: F) -> i64 +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ + if let Some(pl) = fetch_state(&EventType::RoomPowerLevels, "") { if let Ok(content) = serde_json::from_value::(pl.content()) { if let Some(level) = content.users.get(user_id) { (*level).into() @@ -757,9 +788,7 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap(create.content()).ok()) .and_then(|create| (create.creator == *user_id).then(|| 100)) .unwrap_or_default() @@ -792,11 +821,13 @@ pub fn get_send_level( } /// Check user can send invite. -pub fn can_send_invite(event: &Arc, auth_events: &StateMap>) -> Result { - let user_level = get_user_power_level(event.sender(), auth_events); - let key = (EventType::RoomPowerLevels, "".into()); - let invite_level = auth_events - .get(&key) +pub fn can_send_invite(event: &Arc, fetch_state: F) -> Result +where + E: Event, + F: Fn(&EventType, &str) -> Option>, +{ + let user_level = get_user_power_level(event.sender(), &fetch_state); + let invite_level = fetch_state(&EventType::RoomPowerLevels, "") .map_or_else( || Ok::<_, Error>(int!(50)), |power_levels| { diff --git a/crates/ruma-state-res/src/lib.rs b/crates/ruma-state-res/src/lib.rs index 193575f5..7cc03ead 100644 --- a/crates/ruma-state-res/src/lib.rs +++ b/crates/ruma-state-res/src/lib.rs @@ -396,7 +396,7 @@ impl StateResolution { /// fails the `event_auth::auth_check` will be excluded from the returned `StateMap`. /// /// For each `events_to_check` event we gather the events needed to auth it from the - /// `event_map` or `store` and verify each event using the `event_auth::auth_check` + /// the `fetch_event` closure and verify each event using the `event_auth::auth_check` /// function. pub fn iterative_auth_check( room_version: &RoomVersion, @@ -424,8 +424,8 @@ impl StateResolution { let mut auth_events = HashMap::new(); for aid in &event.auth_events() { if let Some(ev) = fetch_event(aid) { - // TODO synapse check "rejected_reason", I'm guessing this is redacted_because - // in ruma ?? + // TODO synapse check "rejected_reason" which is most likely + // related to soft-failing auth_events.insert( ( ev.kind(), @@ -472,8 +472,8 @@ impl StateResolution { room_version, &event, most_recent_prev_event, - &auth_events, current_third_party, + |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), )? { // add event to resolved state map resolved_state.insert((event.kind(), state_key), event_id.clone()); diff --git a/crates/ruma-state-res/tests/event_auth.rs b/crates/ruma-state-res/tests/event_auth.rs index 84186e5a..e613f230 100644 --- a/crates/ruma-state-res/tests/event_auth.rs +++ b/crates/ruma-state-res/tests/event_auth.rs @@ -33,7 +33,7 @@ fn test_ban_pass() { requester.content(), prev, None, - &auth_events + |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), ) .unwrap()) } @@ -65,7 +65,7 @@ fn test_ban_fail() { requester.content(), prev, None, - &auth_events + |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), ) .unwrap()) }