Remove StateStore trait and clean up imports in event_auth

This commit is contained in:
Devin Ragotzy 2021-04-29 12:07:35 -04:00 committed by Devin Ragotzy
parent 138ecd4f35
commit f62df4d9ae
6 changed files with 184 additions and 65 deletions

View File

@ -5,7 +5,7 @@
// To pass args to criterion, use this form // To pass args to criterion, use this form
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`. // `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
use std::{ use std::{
collections::BTreeMap, collections::{BTreeMap, BTreeSet},
convert::TryFrom, convert::TryFrom,
sync::Arc, sync::Arc,
time::{Duration, UNIX_EPOCH}, time::{Duration, UNIX_EPOCH},
@ -26,7 +26,7 @@ use ruma::{
EventId, RoomId, RoomVersionId, UserId, EventId, RoomId, RoomVersionId, UserId,
}; };
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
use state_res::{Error, Event, Result, StateMap, StateResolution, StateStore}; use state_res::{Error, Event, Result, StateMap, StateResolution};
static mut SERVER_TIMESTAMP: u64 = 0; static mut SERVER_TIMESTAMP: u64 = 0;
@ -153,13 +153,78 @@ criterion_main!(benches);
pub struct TestStore<E: Event>(pub BTreeMap<EventId, Arc<E>>); pub struct TestStore<E: Event>(pub BTreeMap<EventId, Arc<E>>);
#[allow(unused)] #[allow(unused)]
impl<E: Event> StateStore<E> for TestStore<E> { impl<E: Event> TestStore<E> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> { pub fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> {
self.0 self.0
.get(event_id) .get(event_id)
.map(Arc::clone) .map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
} }
/// Returns the events that correspond to the `event_ids` sorted in the same order.
pub fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<E>>> {
let mut events = vec![];
for id in event_ids {
events.push(self.get_event(room_id, id)?);
}
Ok(events)
}
/// Returns a Vec of the related auth events to the given `event`.
pub fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<EventId>> {
let mut result = vec![];
let mut stack = event_ids.to_vec();
// DFS for auth event chain
while !stack.is_empty() {
let ev_id = stack.pop().unwrap();
if result.contains(&ev_id) {
continue;
}
result.push(ev_id.clone());
let event = self.get_event(room_id, &ev_id)?;
stack.extend(event.auth_events().clone());
}
Ok(result)
}
/// Returns a Vec<EventId> representing the difference in auth chains of the given `events`.
pub fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<EventId>>,
) -> Result<Vec<EventId>> {
let mut chains = vec![];
for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = self
.auth_event_ids(room_id, &ids)?
.into_iter()
.collect::<BTreeSet<_>>();
chains.push(chain);
}
if let Some(chain) = chains.first() {
let rest = chains.iter().skip(1).flatten().cloned().collect();
let common = chain.intersection(&rest).collect::<Vec<_>>();
Ok(chains
.iter()
.flatten()
.filter(|id| !common.contains(&id))
.cloned()
.collect::<BTreeSet<_>>()
.into_iter()
.collect())
} else {
Ok(vec![])
}
}
} }
impl TestStore<event::StateEvent> { impl TestStore<event::StateEvent> {

View File

@ -5,10 +5,10 @@ use maplit::btreeset;
use ruma::{ use ruma::{
events::{ events::{
room::{ room::{
self, create::CreateEventContent,
join_rules::JoinRule, join_rules::{JoinRule, JoinRulesEventContent},
member::{self, MembershipState}, member::{MembershipState, ThirdPartyInvite},
power_levels::{self, PowerLevelsEventContent}, power_levels::PowerLevelsEventContent,
}, },
EventType, EventType,
}, },
@ -39,7 +39,7 @@ pub fn auth_types_for_event(
if let Some(state_key) = state_key { if let Some(state_key) = state_key {
if let Some(Ok(membership)) = content if let Some(Ok(membership)) = content
.get("membership") .get("membership")
.map(|m| serde_json::from_value::<room::member::MembershipState>(m.clone())) .map(|m| serde_json::from_value::<MembershipState>(m.clone()))
{ {
if [MembershipState::Join, MembershipState::Invite].contains(&membership) { if [MembershipState::Join, MembershipState::Invite].contains(&membership) {
let key = (EventType::RoomJoinRules, "".to_string()); let key = (EventType::RoomJoinRules, "".to_string());
@ -54,9 +54,10 @@ pub fn auth_types_for_event(
} }
if membership == MembershipState::Invite { if membership == MembershipState::Invite {
if let Some(Ok(t_id)) = content.get("third_party_invite").map(|t| { if let Some(Ok(t_id)) = content
serde_json::from_value::<room::member::ThirdPartyInvite>(t.clone()) .get("third_party_invite")
}) { .map(|t| serde_json::from_value::<ThirdPartyInvite>(t.clone()))
{
let key = (EventType::RoomThirdPartyInvite, t_id.signed.token); let key = (EventType::RoomThirdPartyInvite, t_id.signed.token);
if !auth_types.contains(&key) { if !auth_types.contains(&key) {
auth_types.push(key) auth_types.push(key)
@ -206,7 +207,7 @@ pub fn auth_check<E: Event>(
let membership = incoming_event let membership = incoming_event
.content() .content()
.get("membership") .get("membership")
.map(|m| serde_json::from_value::<room::member::MembershipState>(m.clone())); .map(|m| serde_json::from_value::<MembershipState>(m.clone()));
if !matches!(membership, Some(Ok(_))) { if !matches!(membership, Some(Ok(_))) {
log::warn!("no valid membership field found for m.room.member event content"); log::warn!("no valid membership field found for m.room.member event content");
@ -308,7 +309,7 @@ pub fn valid_membership_change<E: Event>(
current_third_party_invite: Option<Arc<E>>, current_third_party_invite: Option<Arc<E>>,
auth_events: &StateMap<Arc<E>>, auth_events: &StateMap<Arc<E>>,
) -> Result<bool> { ) -> Result<bool> {
let target_membership = serde_json::from_value::<room::member::MembershipState>( let target_membership = serde_json::from_value::<MembershipState>(
content content
.get("membership") .get("membership")
.expect("we should test before that this field exists") .expect("we should test before that this field exists")
@ -317,39 +318,37 @@ pub fn valid_membership_change<E: Event>(
let third_party_invite = content let third_party_invite = content
.get("third_party_invite") .get("third_party_invite")
.map(|t| serde_json::from_value::<room::member::ThirdPartyInvite>(t.clone())); .map(|t| serde_json::from_value::<ThirdPartyInvite>(t.clone()));
let target_user_id = let target_user_id =
UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?; UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?;
let key = (EventType::RoomMember, user_sender.to_string()); let key = (EventType::RoomMember, user_sender.to_string());
let sender = auth_events.get(&key); let sender = auth_events.get(&key);
let sender_membership = let sender_membership = sender.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { Ok(serde_json::from_value::<MembershipState>(
Ok(serde_json::from_value::<room::member::MembershipState>( pdu.content()
pdu.content() .get("membership")
.get("membership") .expect("we assume existing events are valid")
.expect("we assume existing events are valid") .clone(),
.clone(), )?)
)?) })?;
})?;
let key = (EventType::RoomMember, target_user_id.to_string()); let key = (EventType::RoomMember, target_user_id.to_string());
let current = auth_events.get(&key); let current = auth_events.get(&key);
let current_membership = let current_membership = current.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { Ok(serde_json::from_value::<MembershipState>(
Ok(serde_json::from_value::<room::member::MembershipState>( pdu.content()
pdu.content() .get("membership")
.get("membership") .expect("we assume existing events are valid")
.expect("we assume existing events are valid") .clone(),
.clone(), )?)
)?) })?;
})?;
let key = (EventType::RoomPowerLevels, "".into()); let key = (EventType::RoomPowerLevels, "".into());
let power_levels = auth_events.get(&key).map_or_else( let power_levels = auth_events.get(&key).map_or_else(
|| Ok::<_, Error>(power_levels::PowerLevelsEventContent::default()), || Ok::<_, Error>(PowerLevelsEventContent::default()),
|power_levels| { |power_levels| {
serde_json::from_value::<PowerLevelsEventContent>(power_levels.content()) serde_json::from_value::<PowerLevelsEventContent>(power_levels.content())
.map_err(Into::into) .map_err(Into::into)
@ -358,7 +357,7 @@ pub fn valid_membership_change<E: Event>(
let sender_power = power_levels.users.get(user_sender).map_or_else( let sender_power = power_levels.users.get(user_sender).map_or_else(
|| { || {
if sender_membership != member::MembershipState::Join { if sender_membership != MembershipState::Join {
None None
} else { } else {
Some(&power_levels.users_default) Some(&power_levels.users_default)
@ -369,7 +368,7 @@ pub fn valid_membership_change<E: Event>(
); );
let target_power = power_levels.users.get(&target_user_id).map_or_else( let target_power = power_levels.users.get(&target_user_id).map_or_else(
|| { || {
if target_membership != member::MembershipState::Join { if target_membership != MembershipState::Join {
None None
} else { } else {
Some(&power_levels.users_default) Some(&power_levels.users_default)
@ -383,9 +382,7 @@ pub fn valid_membership_change<E: Event>(
let join_rules_event = auth_events.get(&key); let join_rules_event = auth_events.get(&key);
let mut join_rules = JoinRule::Invite; let mut join_rules = JoinRule::Invite;
if let Some(jr) = join_rules_event { if let Some(jr) = join_rules_event {
join_rules = join_rules = serde_json::from_value::<JoinRulesEventContent>(jr.content())?.join_rule;
serde_json::from_value::<room::join_rules::JoinRulesEventContent>(jr.content())?
.join_rule;
} }
if let Some(prev) = prev_event { if let Some(prev) = prev_event {
@ -494,7 +491,7 @@ pub fn check_event_sender_in_room<E: Event>(
) -> Option<bool> { ) -> Option<bool> {
let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?; let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?;
let membership = serde_json::from_value::<room::member::MembershipState>( let membership = serde_json::from_value::<MembershipState>(
mem.content() mem.content()
.get("membership") .get("membership")
.expect("we should test before that this field exists") .expect("we should test before that this field exists")
@ -555,15 +552,11 @@ pub fn check_power_levels<E: Event>(
// If users key in content is not a dictionary with keys that are valid user IDs // If users key in content is not a dictionary with keys that are valid user IDs
// with values that are integers (or a string that is an integer), reject. // with values that are integers (or a string that is an integer), reject.
let user_content = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>( let user_content =
power_event.content(), serde_json::from_value::<PowerLevelsEventContent>(power_event.content()).unwrap();
)
.unwrap();
let current_content = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>( let current_content =
current_state.content(), serde_json::from_value::<PowerLevelsEventContent>(current_state.content()).unwrap();
)
.unwrap();
// validation of users is done in Ruma, synapse for loops validating user_ids and integers here // validation of users is done in Ruma, synapse for loops validating user_ids and integers here
log::info!("validation of power event finished"); log::info!("validation of power event finished");
@ -728,7 +721,7 @@ pub fn check_membership<E: Event>(member_event: Option<Arc<E>>, state: Membershi
if let Some(Ok(membership)) = event if let Some(Ok(membership)) = event
.content() .content()
.get("membership") .get("membership")
.map(|m| serde_json::from_value::<room::member::MembershipState>(m.clone())) .map(|m| serde_json::from_value::<MembershipState>(m.clone()))
{ {
membership == state membership == state
} else { } else {
@ -773,9 +766,7 @@ pub fn get_named_level<E: Event>(auth_events: &StateMap<Arc<E>>, name: &str, def
/// object. /// object.
pub fn get_user_power_level<E: Event>(user_id: &UserId, auth_events: &StateMap<Arc<E>>) -> i64 { pub fn get_user_power_level<E: Event>(user_id: &UserId, auth_events: &StateMap<Arc<E>>) -> i64 {
if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) { if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) {
if let Ok(content) = if let Ok(content) = serde_json::from_value::<PowerLevelsEventContent>(pl.content()) {
serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(pl.content())
{
if let Some(level) = content.users.get(user_id) { if let Some(level) = content.users.get(user_id) {
(*level).into() (*level).into()
} else { } else {
@ -788,9 +779,7 @@ pub fn get_user_power_level<E: Event>(user_id: &UserId, auth_events: &StateMap<A
// if no power level event found the creator gets 100 everyone else gets 0 // if no power level event found the creator gets 100 everyone else gets 0
let key = (EventType::RoomCreate, "".into()); let key = (EventType::RoomCreate, "".into());
if let Some(create) = auth_events.get(&key) { if let Some(create) = auth_events.get(&key) {
if let Ok(c) = if let Ok(c) = serde_json::from_value::<CreateEventContent>(create.content()) {
serde_json::from_value::<room::create::CreateEventContent>(create.content())
{
if &c.creator == user_id { if &c.creator == user_id {
100 100
} else { } else {
@ -815,7 +804,7 @@ pub fn get_send_level<E: Event>(
log::debug!("{:?} {:?}", e_type, state_key); log::debug!("{:?} {:?}", e_type, state_key);
power_lvl power_lvl
.and_then(|ple| { .and_then(|ple| {
serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(ple.content()) serde_json::from_value::<PowerLevelsEventContent>(ple.content())
.map(|content| { .map(|content| {
content.events.get(&e_type).cloned().unwrap_or_else(|| { content.events.get(&e_type).cloned().unwrap_or_else(|| {
if state_key.is_some() { if state_key.is_some() {
@ -853,7 +842,7 @@ pub fn can_send_invite<E: Event>(event: &Arc<E>, auth_events: &StateMap<Arc<E>>)
pub fn verify_third_party_invite<E: Event>( pub fn verify_third_party_invite<E: Event>(
user_state_key: Option<&str>, user_state_key: Option<&str>,
sender: &UserId, sender: &UserId,
tp_id: &member::ThirdPartyInvite, tp_id: &ThirdPartyInvite,
current_third_party_invite: Option<Arc<E>>, current_third_party_invite: Option<Arc<E>>,
) -> bool { ) -> bool {
// 1. check for user being banned happens before this is called // 1. check for user being banned happens before this is called

View File

@ -1,7 +1,7 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use ruma::{events::EventType, EventId, RoomVersionId}; use ruma::{events::EventType, EventId, RoomVersionId};
use state_res::{is_power_event, StateMap}; use state_res::{is_power_event, room_version::RoomVersion, StateMap};
mod utils; mod utils;
use utils::{room_id, INITIAL_EVENTS}; use utils::{room_id, INITIAL_EVENTS};
@ -46,7 +46,7 @@ fn test_event_sort() {
// TODO we may be able to skip this since they are resolved according to spec // TODO we may be able to skip this since they are resolved according to spec
let resolved_power = state_res::StateResolution::iterative_auth_check( let resolved_power = state_res::StateResolution::iterative_auth_check(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersion::new(&RoomVersionId::Version6).unwrap(),
&sorted_power_events, &sorted_power_events,
&BTreeMap::new(), // unconflicted events &BTreeMap::new(), // unconflicted events
&mut events, &mut events,

View File

@ -4,7 +4,7 @@ use std::{collections::BTreeMap, sync::Arc};
use ruma::{events::EventType, EventId, RoomVersionId}; use ruma::{events::EventType, EventId, RoomVersionId};
use serde_json::json; use serde_json::json;
use state_res::{EventMap, StateMap, StateResolution, StateStore}; use state_res::{EventMap, StateMap, StateResolution};
mod utils; mod utils;
use utils::{ use utils::{

View File

@ -6,7 +6,7 @@ use ruma::{
EventId, RoomVersionId, EventId, RoomVersionId,
}; };
use serde_json::json; use serde_json::json;
use state_res::{StateMap, StateResolution, StateStore}; use state_res::{StateMap, StateResolution};
use tracing_subscriber as tracer; use tracing_subscriber as tracer;
mod utils; mod utils;

View File

@ -1,7 +1,7 @@
#![allow(clippy::or_fun_call, clippy::expect_fun_call, dead_code)] #![allow(clippy::or_fun_call, clippy::expect_fun_call, dead_code)]
use std::{ use std::{
collections::BTreeMap, collections::{BTreeMap, BTreeSet},
convert::TryFrom, convert::TryFrom,
sync::{Arc, Once}, sync::{Arc, Once},
time::{Duration, UNIX_EPOCH}, time::{Duration, UNIX_EPOCH},
@ -20,7 +20,7 @@ use ruma::{
EventId, RoomId, RoomVersionId, UserId, EventId, RoomId, RoomVersionId, UserId,
}; };
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
use state_res::{Error, Event, Result, StateMap, StateResolution, StateStore}; use state_res::{Error, Event, Result, StateMap, StateResolution};
use tracing_subscriber as tracer; use tracing_subscriber as tracer;
pub use event::StateEvent; pub use event::StateEvent;
@ -215,13 +215,78 @@ pub fn do_check(
pub struct TestStore<E: Event>(pub BTreeMap<EventId, Arc<E>>); pub struct TestStore<E: Event>(pub BTreeMap<EventId, Arc<E>>);
#[allow(unused)] #[allow(unused)]
impl<E: Event> StateStore<E> for TestStore<E> { impl<E: Event> TestStore<E> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> { pub fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> {
self.0 self.0
.get(event_id) .get(event_id)
.map(Arc::clone) .map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
} }
/// Returns the events that correspond to the `event_ids` sorted in the same order.
pub fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<E>>> {
let mut events = vec![];
for id in event_ids {
events.push(self.get_event(room_id, id)?);
}
Ok(events)
}
/// Returns a Vec of the related auth events to the given `event`.
pub fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<EventId>> {
let mut result = vec![];
let mut stack = event_ids.to_vec();
// DFS for auth event chain
while !stack.is_empty() {
let ev_id = stack.pop().unwrap();
if result.contains(&ev_id) {
continue;
}
result.push(ev_id.clone());
let event = self.get_event(room_id, &ev_id)?;
stack.extend(event.auth_events().clone());
}
Ok(result)
}
/// Returns a Vec<EventId> representing the difference in auth chains of the given `events`.
pub fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<EventId>>,
) -> Result<Vec<EventId>> {
let mut chains = vec![];
for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = self
.auth_event_ids(room_id, &ids)?
.into_iter()
.collect::<BTreeSet<_>>();
chains.push(chain);
}
if let Some(chain) = chains.first() {
let rest = chains.iter().skip(1).flatten().cloned().collect();
let common = chain.intersection(&rest).collect::<Vec<_>>();
Ok(chains
.iter()
.flatten()
.filter(|id| !common.contains(&id))
.cloned()
.collect::<BTreeSet<_>>()
.into_iter()
.collect())
} else {
Ok(vec![])
}
}
} }
pub fn event_id(id: &str) -> EventId { pub fn event_id(id: &str) -> EventId {