state-res: Make usage of Arc optional

This commit is contained in:
Jonas Platte 2021-09-12 01:20:55 +02:00
parent a4b8f3bc90
commit 118aa8fc4a
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 28 additions and 29 deletions

View File

@ -1,4 +1,4 @@
use std::{collections::BTreeSet, convert::TryFrom, sync::Arc};
use std::{collections::BTreeSet, convert::TryFrom};
use js_int::{int, Int};
use ruma_events::{
@ -93,7 +93,7 @@ pub fn auth_check<E, F>(
) -> Result<bool>
where
E: Event,
F: Fn(&EventType, &str) -> Option<Arc<E>>,
F: Fn(&EventType, &str) -> Option<E>,
{
info!(
"auth_check beginning for {} ({})",
@ -231,14 +231,14 @@ where
if !valid_membership_change(
&target_user,
fetch_state(&EventType::RoomMember, target_user.as_str()).as_deref(),
fetch_state(&EventType::RoomMember, target_user.as_str()).as_ref(),
sender,
sender_member_event.as_deref(),
sender_member_event.as_ref(),
incoming_event.content(),
prev_event,
current_third_party_invite,
power_levels_event.as_deref(),
fetch_state(&EventType::RoomJoinRules, "").as_deref(),
power_levels_event.as_ref(),
fetch_state(&EventType::RoomJoinRules, "").as_ref(),
)? {
return Ok(false);
}
@ -304,7 +304,7 @@ where
// If the event type's required power level is greater than the sender's power level, reject
// If the event has a state_key that starts with an @ and does not match the sender, reject.
if !can_send_event(incoming_event, power_levels_event.as_deref(), sender_power_level) {
if !can_send_event(incoming_event, power_levels_event.as_ref(), sender_power_level) {
warn!("user cannot send event");
return Ok(false);
}
@ -315,7 +315,7 @@ where
if let Some(required_pwr_lvl) = check_power_levels(
room_version,
incoming_event,
power_levels_event.as_deref(),
power_levels_event.as_ref(),
sender_power_level,
) {
if !required_pwr_lvl {

View File

@ -1,7 +1,6 @@
use std::{
cmp::Reverse,
collections::{BinaryHeap, HashMap, HashSet},
sync::Arc,
};
use itertools::Itertools;
@ -59,8 +58,8 @@ pub fn resolve<E, F>(
fetch_event: F,
) -> Result<StateMap<EventId>>
where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
E: Event + Clone,
F: Fn(&EventId) -> Option<E>,
{
info!("State resolution starting");
@ -210,7 +209,7 @@ fn reverse_topological_power_sort<E, F>(
) -> Result<Vec<EventId>>
where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
F: Fn(&EventId) -> Option<E>,
{
debug!("reverse topological sort of power events");
@ -323,7 +322,7 @@ where
fn get_power_level_for_sender<E, F>(event_id: &EventId, fetch_event: F) -> i64
where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
F: Fn(&EventId) -> Option<E>,
{
info!("fetch event ({}) senders power level", event_id);
@ -332,7 +331,7 @@ where
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
if let Some(aev) = fetch_event(aid) {
if is_type_and_key(&*aev, &EventType::RoomPowerLevels, "") {
if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") {
pl = Some(aev);
break;
}
@ -374,8 +373,8 @@ fn iterative_auth_check<E, F>(
fetch_event: F,
) -> Result<StateMap<EventId>>
where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
E: Event + Clone,
F: Fn(&EventId) -> Option<E>,
{
info!("starting iterative auth check");
@ -441,9 +440,9 @@ where
if auth_check(
room_version,
&*event,
most_recent_prev_event.as_deref(),
current_third_party.as_deref(),
&event,
most_recent_prev_event.as_ref(),
current_third_party.as_ref(),
|ty, key| auth_events.get(&(ty.clone(), key.to_owned())).cloned(),
)? {
// add event to resolved state map
@ -475,7 +474,7 @@ fn mainline_sort<E, F>(
) -> Result<Vec<EventId>>
where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
F: Fn(&EventId) -> Option<E>,
{
debug!("mainline sort of events");
@ -495,7 +494,7 @@ where
for aid in event.auth_events() {
let ev = fetch_event(aid)
.ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?;
if is_type_and_key(&*ev, &EventType::RoomPowerLevels, "") {
if is_type_and_key(&ev, &EventType::RoomPowerLevels, "") {
pl = Some(aid.clone());
break;
}
@ -539,13 +538,13 @@ where
/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
/// associated mainline depth.
fn get_mainline_depth<E, F>(
mut event: Option<Arc<E>>,
mut event: Option<E>,
mainline_map: &EventMap<usize>,
fetch_event: F,
) -> Result<usize>
where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
F: Fn(&EventId) -> Option<E>,
{
while let Some(sort_ev) = event {
debug!("mainline event_id {}", sort_ev.event_id());
@ -558,7 +557,7 @@ where
for aid in sort_ev.auth_events() {
let aev = fetch_event(aid)
.ok_or_else(|| Error::NotFound(format!("Failed to find {}", aid)))?;
if is_type_and_key(&*aev, &EventType::RoomPowerLevels, "") {
if is_type_and_key(&aev, &EventType::RoomPowerLevels, "") {
event = Some(aev);
break;
}
@ -575,7 +574,7 @@ fn add_event_and_auth_chain_to_graph<E, F>(
fetch_event: F,
) where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
F: Fn(&EventId) -> Option<E>,
{
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
@ -597,9 +596,9 @@ fn add_event_and_auth_chain_to_graph<E, F>(
fn is_power_event_id<E, F>(event_id: &EventId, fetch: F) -> bool
where
E: Event,
F: Fn(&EventId) -> Option<Arc<E>>,
F: Fn(&EventId) -> Option<E>,
{
match fetch(event_id).as_deref() {
match fetch(event_id).as_ref() {
Some(state) => is_power_event(state),
_ => false,
}
@ -977,7 +976,7 @@ mod tests {
let ev_map: EventMap<Arc<StateEvent>> = store.0.clone();
let state_sets = vec![state_at_bob, state_at_charlie];
let resolved = match crate::resolve::<StateEvent, _>(
let resolved = match crate::resolve(
&RoomVersionId::Version2,
&state_sets,
state_sets
@ -1083,7 +1082,7 @@ mod tests {
let ev_map: EventMap<Arc<StateEvent>> = store.0.clone();
let state_sets = vec![state_set_a, state_set_b];
let resolved = match crate::resolve::<StateEvent, _>(
let resolved = match crate::resolve(
&RoomVersionId::Version6,
&state_sets,
state_sets