state-res: Make usage of Arc optional

This commit is contained in:
Jonas Platte 2021-09-12 01:20:55 +02:00
parent a4b8f3bc90
commit 118aa8fc4a
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 28 additions and 29 deletions

View File

@ -1,4 +1,4 @@
use std::{collections::BTreeSet, convert::TryFrom, sync::Arc}; use std::{collections::BTreeSet, convert::TryFrom};
use js_int::{int, Int}; use js_int::{int, Int};
use ruma_events::{ use ruma_events::{
@ -93,7 +93,7 @@ pub fn auth_check<E, F>(
) -> Result<bool> ) -> Result<bool>
where where
E: Event, E: Event,
F: Fn(&EventType, &str) -> Option<Arc<E>>, F: Fn(&EventType, &str) -> Option<E>,
{ {
info!( info!(
"auth_check beginning for {} ({})", "auth_check beginning for {} ({})",
@ -231,14 +231,14 @@ where
if !valid_membership_change( if !valid_membership_change(
&target_user, &target_user,
fetch_state(&EventType::RoomMember, target_user.as_str()).as_deref(), fetch_state(&EventType::RoomMember, target_user.as_str()).as_ref(),
sender, sender,
sender_member_event.as_deref(), sender_member_event.as_ref(),
incoming_event.content(), incoming_event.content(),
prev_event, prev_event,
current_third_party_invite, current_third_party_invite,
power_levels_event.as_deref(), power_levels_event.as_ref(),
fetch_state(&EventType::RoomJoinRules, "").as_deref(), fetch_state(&EventType::RoomJoinRules, "").as_ref(),
)? { )? {
return Ok(false); return Ok(false);
} }
@ -304,7 +304,7 @@ where
// If the event type's required power level is greater than the sender's power level, reject // If the event type's required power level is greater than the sender's power level, reject
// If the event has a state_key that starts with an @ and does not match the sender, reject. // If the event has a state_key that starts with an @ and does not match the sender, reject.
if !can_send_event(incoming_event, power_levels_event.as_deref(), sender_power_level) { if !can_send_event(incoming_event, power_levels_event.as_ref(), sender_power_level) {
warn!("user cannot send event"); warn!("user cannot send event");
return Ok(false); return Ok(false);
} }
@ -315,7 +315,7 @@ where
if let Some(required_pwr_lvl) = check_power_levels( if let Some(required_pwr_lvl) = check_power_levels(
room_version, room_version,
incoming_event, incoming_event,
power_levels_event.as_deref(), power_levels_event.as_ref(),
sender_power_level, sender_power_level,
) { ) {
if !required_pwr_lvl { if !required_pwr_lvl {

View File

@ -1,7 +1,6 @@
use std::{ use std::{
cmp::Reverse, cmp::Reverse,
collections::{BinaryHeap, HashMap, HashSet}, collections::{BinaryHeap, HashMap, HashSet},
sync::Arc,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -59,8 +58,8 @@ pub fn resolve<E, F>(
fetch_event: F, fetch_event: F,
) -> Result<StateMap<EventId>> ) -> Result<StateMap<EventId>>
where where
E: Event, E: Event + Clone,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
info!("State resolution starting"); info!("State resolution starting");
@ -210,7 +209,7 @@ fn reverse_topological_power_sort<E, F>(
) -> Result<Vec<EventId>> ) -> Result<Vec<EventId>>
where where
E: Event, E: Event,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
debug!("reverse topological sort of power events"); debug!("reverse topological sort of power events");
@ -323,7 +322,7 @@ where
fn get_power_level_for_sender<E, F>(event_id: &EventId, fetch_event: F) -> i64 fn get_power_level_for_sender<E, F>(event_id: &EventId, fetch_event: F) -> i64
where where
E: Event, E: Event,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
info!("fetch event ({}) senders power level", event_id); info!("fetch event ({}) senders power level", event_id);
@ -332,7 +331,7 @@ where
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() { for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
if let Some(aev) = fetch_event(aid) { if let Some(aev) = fetch_event(aid) {
if is_type_and_key(&*aev, &EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") {
pl = Some(aev); pl = Some(aev);
break; break;
} }
@ -374,8 +373,8 @@ fn iterative_auth_check<E, F>(
fetch_event: F, fetch_event: F,
) -> Result<StateMap<EventId>> ) -> Result<StateMap<EventId>>
where where
E: Event, E: Event + Clone,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
info!("starting iterative auth check"); info!("starting iterative auth check");
@ -441,9 +440,9 @@ where
if auth_check( if auth_check(
room_version, room_version,
&*event, &event,
most_recent_prev_event.as_deref(), most_recent_prev_event.as_ref(),
current_third_party.as_deref(), current_third_party.as_ref(),
|ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(), |ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(),
)? { )? {
// add event to resolved state map // add event to resolved state map
@ -475,7 +474,7 @@ fn mainline_sort<E, F>(
) -> Result<Vec<EventId>> ) -> Result<Vec<EventId>>
where where
E: Event, E: Event,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
debug!("mainline sort of events"); debug!("mainline sort of events");
@ -495,7 +494,7 @@ where
for aid in event.auth_events() { for aid in event.auth_events() {
let ev = fetch_event(aid) let ev = fetch_event(aid)
.ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?;
if is_type_and_key(&*ev, &EventType::RoomPowerLevels, "") { if is_type_and_key(&ev, &EventType::RoomPowerLevels, "") {
pl = Some(aid.clone()); pl = Some(aid.clone());
break; break;
} }
@ -539,13 +538,13 @@ where
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an /// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
/// associated mainline depth. /// associated mainline depth.
fn get_mainline_depth<E, F>( fn get_mainline_depth<E, F>(
mut event: Option<Arc<E>>, mut event: Option<E>,
mainline_map: &EventMap<usize>, mainline_map: &EventMap<usize>,
fetch_event: F, fetch_event: F,
) -> Result<usize> ) -> Result<usize>
where where
E: Event, E: Event,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
while let Some(sort_ev) = event { while let Some(sort_ev) = event {
debug!("mainline event_id {}", sort_ev.event_id()); debug!("mainline event_id {}", sort_ev.event_id());
@ -558,7 +557,7 @@ where
for aid in sort_ev.auth_events() { for aid in sort_ev.auth_events() {
let aev = fetch_event(aid) let aev = fetch_event(aid)
.ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?; .ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?;
if is_type_and_key(&*aev, &EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") {
event = Some(aev); event = Some(aev);
break; break;
} }
@ -575,7 +574,7 @@ fn add_event_and_auth_chain_to_graph<E, F>(
fetch_event: F, fetch_event: F,
) where ) where
E: Event, E: Event,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
let mut state = vec![event_id]; let mut state = vec![event_id];
while let Some(eid) = state.pop() { while let Some(eid) = state.pop() {
@ -597,9 +596,9 @@ fn add_event_and_auth_chain_to_graph<E, F>(
fn is_power_event_id<E, F>(event_id: &EventId, fetch: F) -> bool fn is_power_event_id<E, F>(event_id: &EventId, fetch: F) -> bool
where where
E: Event, E: Event,
F: Fn(&EventId) -> Option<Arc<E>>, F: Fn(&EventId) -> Option<E>,
{ {
match fetch(event_id).as_deref() { match fetch(event_id).as_ref() {
Some(state) => is_power_event(state), Some(state) => is_power_event(state),
_ => false, _ => false,
} }
@ -977,7 +976,7 @@ mod tests {
let ev_map: EventMap<Arc<StateEvent>> = store.0.clone(); let ev_map: EventMap<Arc<StateEvent>> = store.0.clone();
let state_sets = vec![state_at_bob, state_at_charlie]; let state_sets = vec![state_at_bob, state_at_charlie];
let resolved = match crate::resolve::<StateEvent, _>( let resolved = match crate::resolve(
&RoomVersionId::Version2, &RoomVersionId::Version2,
&state_sets, &state_sets,
state_sets state_sets
@ -1083,7 +1082,7 @@ mod tests {
let ev_map: EventMap<Arc<StateEvent>> = store.0.clone(); let ev_map: EventMap<Arc<StateEvent>> = store.0.clone();
let state_sets = vec![state_set_a, state_set_b]; let state_sets = vec![state_set_a, state_set_b];
let resolved = match crate::resolve::<StateEvent, _>( let resolved = match crate::resolve(
&RoomVersionId::Version6, &RoomVersionId::Version6,
&state_sets, &state_sets,
state_sets state_sets