Remove StateStore trait and clean up imports in event_auth

This commit is contained in:
Devin Ragotzy 2021-04-29 12:07:35 -04:00 committed by Devin Ragotzy
parent 138ecd4f35
commit f62df4d9ae
6 changed files with 184 additions and 65 deletions

View File

@ -5,7 +5,7 @@
// To pass args to criterion, use this form
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
use std::{
collections::BTreeMap,
collections::{BTreeMap, BTreeSet},
convert::TryFrom,
sync::Arc,
time::{Duration, UNIX_EPOCH},
@ -26,7 +26,7 @@ use ruma::{
EventId, RoomId, RoomVersionId, UserId,
};
use serde_json::{json, Value as JsonValue};
use state_res::{Error, Event, Result, StateMap, StateResolution, StateStore};
use state_res::{Error, Event, Result, StateMap, StateResolution};
static mut SERVER_TIMESTAMP: u64 = 0;
@ -153,13 +153,78 @@ criterion_main!(benches);
pub struct TestStore<E: Event>(pub BTreeMap<EventId, Arc<E>>);
#[allow(unused)]
impl<E: Event> StateStore<E> for TestStore<E> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> {
impl<E: Event> TestStore<E> {
pub fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> {
self.0
.get(event_id)
.map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
}
/// Returns the events that correspond to the `event_ids` sorted in the same order.
pub fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<E>>> {
let mut events = vec![];
for id in event_ids {
events.push(self.get_event(room_id, id)?);
}
Ok(events)
}
/// Returns a Vec of the related auth events to the given `event`.
pub fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<EventId>> {
let mut result = vec![];
let mut stack = event_ids.to_vec();
// DFS for auth event chain
while !stack.is_empty() {
let ev_id = stack.pop().unwrap();
if result.contains(&ev_id) {
continue;
}
result.push(ev_id.clone());
let event = self.get_event(room_id, &ev_id)?;
stack.extend(event.auth_events().clone());
}
Ok(result)
}
/// Returns a Vec<EventId> representing the difference in auth chains of the given `events`.
pub fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<EventId>>,
) -> Result<Vec<EventId>> {
let mut chains = vec![];
for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = self
.auth_event_ids(room_id, &ids)?
.into_iter()
.collect::<BTreeSet<_>>();
chains.push(chain);
}
if let Some(chain) = chains.first() {
let rest = chains.iter().skip(1).flatten().cloned().collect();
let common = chain.intersection(&rest).collect::<Vec<_>>();
Ok(chains
.iter()
.flatten()
.filter(|id| !common.contains(&id))
.cloned()
.collect::<BTreeSet<_>>()
.into_iter()
.collect())
} else {
Ok(vec![])
}
}
}
impl TestStore<event::StateEvent> {

View File

@ -5,10 +5,10 @@ use maplit::btreeset;
use ruma::{
events::{
room::{
self,
join_rules::JoinRule,
member::{self, MembershipState},
power_levels::{self, PowerLevelsEventContent},
create::CreateEventContent,
join_rules::{JoinRule, JoinRulesEventContent},
member::{MembershipState, ThirdPartyInvite},
power_levels::PowerLevelsEventContent,
},
EventType,
},
@ -39,7 +39,7 @@ pub fn auth_types_for_event(
if let Some(state_key) = state_key {
if let Some(Ok(membership)) = content
.get("membership")
.map(|m| serde_json::from_value::<room::member::MembershipState>(m.clone()))
.map(|m| serde_json::from_value::<MembershipState>(m.clone()))
{
if [MembershipState::Join, MembershipState::Invite].contains(&membership) {
let key = (EventType::RoomJoinRules, "".to_string());
@ -54,9 +54,10 @@ pub fn auth_types_for_event(
}
if membership == MembershipState::Invite {
if let Some(Ok(t_id)) = content.get("third_party_invite").map(|t| {
serde_json::from_value::<room::member::ThirdPartyInvite>(t.clone())
}) {
if let Some(Ok(t_id)) = content
.get("third_party_invite")
.map(|t| serde_json::from_value::<ThirdPartyInvite>(t.clone()))
{
let key = (EventType::RoomThirdPartyInvite, t_id.signed.token);
if !auth_types.contains(&key) {
auth_types.push(key)
@ -206,7 +207,7 @@ pub fn auth_check<E: Event>(
let membership = incoming_event
.content()
.get("membership")
.map(|m| serde_json::from_value::<room::member::MembershipState>(m.clone()));
.map(|m| serde_json::from_value::<MembershipState>(m.clone()));
if !matches!(membership, Some(Ok(_))) {
log::warn!("no valid membership field found for m.room.member event content");
@ -308,7 +309,7 @@ pub fn valid_membership_change<E: Event>(
current_third_party_invite: Option<Arc<E>>,
auth_events: &StateMap<Arc<E>>,
) -> Result<bool> {
let target_membership = serde_json::from_value::<room::member::MembershipState>(
let target_membership = serde_json::from_value::<MembershipState>(
content
.get("membership")
.expect("we should test before that this field exists")
@ -317,16 +318,15 @@ pub fn valid_membership_change<E: Event>(
let third_party_invite = content
.get("third_party_invite")
.map(|t| serde_json::from_value::<room::member::ThirdPartyInvite>(t.clone()));
.map(|t| serde_json::from_value::<ThirdPartyInvite>(t.clone()));
let target_user_id =
UserId::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{}", e)))?;
let key = (EventType::RoomMember, user_sender.to_string());
let sender = auth_events.get(&key);
let sender_membership =
sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<room::member::MembershipState>(
let sender_membership = sender.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<MembershipState>(
pdu.content()
.get("membership")
.expect("we assume existing events are valid")
@ -337,9 +337,8 @@ pub fn valid_membership_change<E: Event>(
let key = (EventType::RoomMember, target_user_id.to_string());
let current = auth_events.get(&key);
let current_membership =
current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<room::member::MembershipState>(
let current_membership = current.map_or(Ok::<_, Error>(MembershipState::Leave), |pdu| {
Ok(serde_json::from_value::<MembershipState>(
pdu.content()
.get("membership")
.expect("we assume existing events are valid")
@ -349,7 +348,7 @@ pub fn valid_membership_change<E: Event>(
let key = (EventType::RoomPowerLevels, "".into());
let power_levels = auth_events.get(&key).map_or_else(
|| Ok::<_, Error>(power_levels::PowerLevelsEventContent::default()),
|| Ok::<_, Error>(PowerLevelsEventContent::default()),
|power_levels| {
serde_json::from_value::<PowerLevelsEventContent>(power_levels.content())
.map_err(Into::into)
@ -358,7 +357,7 @@ pub fn valid_membership_change<E: Event>(
let sender_power = power_levels.users.get(user_sender).map_or_else(
|| {
if sender_membership != member::MembershipState::Join {
if sender_membership != MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
@ -369,7 +368,7 @@ pub fn valid_membership_change<E: Event>(
);
let target_power = power_levels.users.get(&target_user_id).map_or_else(
|| {
if target_membership != member::MembershipState::Join {
if target_membership != MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
@ -383,9 +382,7 @@ pub fn valid_membership_change<E: Event>(
let join_rules_event = auth_events.get(&key);
let mut join_rules = JoinRule::Invite;
if let Some(jr) = join_rules_event {
join_rules =
serde_json::from_value::<room::join_rules::JoinRulesEventContent>(jr.content())?
.join_rule;
join_rules = serde_json::from_value::<JoinRulesEventContent>(jr.content())?.join_rule;
}
if let Some(prev) = prev_event {
@ -494,7 +491,7 @@ pub fn check_event_sender_in_room<E: Event>(
) -> Option<bool> {
let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?;
let membership = serde_json::from_value::<room::member::MembershipState>(
let membership = serde_json::from_value::<MembershipState>(
mem.content()
.get("membership")
.expect("we should test before that this field exists")
@ -555,15 +552,11 @@ pub fn check_power_levels<E: Event>(
// If users key in content is not a dictionary with keys that are valid user IDs
// with values that are integers (or a string that is an integer), reject.
let user_content = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(
power_event.content(),
)
.unwrap();
let user_content =
serde_json::from_value::<PowerLevelsEventContent>(power_event.content()).unwrap();
let current_content = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(
current_state.content(),
)
.unwrap();
let current_content =
serde_json::from_value::<PowerLevelsEventContent>(current_state.content()).unwrap();
// validation of users is done in Ruma, synapse for loops validating user_ids and integers here
log::info!("validation of power event finished");
@ -728,7 +721,7 @@ pub fn check_membership<E: Event>(member_event: Option<Arc<E>>, state: Membershi
if let Some(Ok(membership)) = event
.content()
.get("membership")
.map(|m| serde_json::from_value::<room::member::MembershipState>(m.clone()))
.map(|m| serde_json::from_value::<MembershipState>(m.clone()))
{
membership == state
} else {
@ -773,9 +766,7 @@ pub fn get_named_level<E: Event>(auth_events: &StateMap<Arc<E>>, name: &str, def
/// object.
pub fn get_user_power_level<E: Event>(user_id: &UserId, auth_events: &StateMap<Arc<E>>) -> i64 {
if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) {
if let Ok(content) =
serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(pl.content())
{
if let Ok(content) = serde_json::from_value::<PowerLevelsEventContent>(pl.content()) {
if let Some(level) = content.users.get(user_id) {
(*level).into()
} else {
@ -788,9 +779,7 @@ pub fn get_user_power_level<E: Event>(user_id: &UserId, auth_events: &StateMap<A
// if no power level event found the creator gets 100 everyone else gets 0
let key = (EventType::RoomCreate, "".into());
if let Some(create) = auth_events.get(&key) {
if let Ok(c) =
serde_json::from_value::<room::create::CreateEventContent>(create.content())
{
if let Ok(c) = serde_json::from_value::<CreateEventContent>(create.content()) {
if &c.creator == user_id {
100
} else {
@ -815,7 +804,7 @@ pub fn get_send_level<E: Event>(
log::debug!("{:?} {:?}", e_type, state_key);
power_lvl
.and_then(|ple| {
serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(ple.content())
serde_json::from_value::<PowerLevelsEventContent>(ple.content())
.map(|content| {
content.events.get(&e_type).cloned().unwrap_or_else(|| {
if state_key.is_some() {
@ -853,7 +842,7 @@ pub fn can_send_invite<E: Event>(event: &Arc<E>, auth_events: &StateMap<Arc<E>>)
pub fn verify_third_party_invite<E: Event>(
user_state_key: Option<&str>,
sender: &UserId,
tp_id: &member::ThirdPartyInvite,
tp_id: &ThirdPartyInvite,
current_third_party_invite: Option<Arc<E>>,
) -> bool {
// 1. check for user being banned happens before this is called

View File

@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use ruma::{events::EventType, EventId, RoomVersionId};
use state_res::{is_power_event, StateMap};
use state_res::{is_power_event, room_version::RoomVersion, StateMap};
mod utils;
use utils::{room_id, INITIAL_EVENTS};
@ -46,7 +46,7 @@ fn test_event_sort() {
// TODO we may be able to skip this since they are resolved according to spec
let resolved_power = state_res::StateResolution::iterative_auth_check(
&room_id(),
&RoomVersionId::Version6,
&RoomVersion::new(&RoomVersionId::Version6).unwrap(),
&sorted_power_events,
&BTreeMap::new(), // unconflicted events
&mut events,

View File

@ -4,7 +4,7 @@ use std::{collections::BTreeMap, sync::Arc};
use ruma::{events::EventType, EventId, RoomVersionId};
use serde_json::json;
use state_res::{EventMap, StateMap, StateResolution, StateStore};
use state_res::{EventMap, StateMap, StateResolution};
mod utils;
use utils::{

View File

@ -6,7 +6,7 @@ use ruma::{
EventId, RoomVersionId,
};
use serde_json::json;
use state_res::{StateMap, StateResolution, StateStore};
use state_res::{StateMap, StateResolution};
use tracing_subscriber as tracer;
mod utils;

View File

@ -1,7 +1,7 @@
#![allow(clippy::or_fun_call, clippy::expect_fun_call, dead_code)]
use std::{
collections::BTreeMap,
collections::{BTreeMap, BTreeSet},
convert::TryFrom,
sync::{Arc, Once},
time::{Duration, UNIX_EPOCH},
@ -20,7 +20,7 @@ use ruma::{
EventId, RoomId, RoomVersionId, UserId,
};
use serde_json::{json, Value as JsonValue};
use state_res::{Error, Event, Result, StateMap, StateResolution, StateStore};
use state_res::{Error, Event, Result, StateMap, StateResolution};
use tracing_subscriber as tracer;
pub use event::StateEvent;
@ -215,13 +215,78 @@ pub fn do_check(
pub struct TestStore<E: Event>(pub BTreeMap<EventId, Arc<E>>);
#[allow(unused)]
impl<E: Event> StateStore<E> for TestStore<E> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> {
impl<E: Event> TestStore<E> {
pub fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<E>> {
self.0
.get(event_id)
.map(Arc::clone)
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
}
/// Returns the events that correspond to the `event_ids` sorted in the same order.
pub fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<E>>> {
let mut events = vec![];
for id in event_ids {
events.push(self.get_event(room_id, id)?);
}
Ok(events)
}
/// Returns a Vec of the related auth events to the given `event`.
pub fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<EventId>> {
let mut result = vec![];
let mut stack = event_ids.to_vec();
// DFS for auth event chain
while !stack.is_empty() {
let ev_id = stack.pop().unwrap();
if result.contains(&ev_id) {
continue;
}
result.push(ev_id.clone());
let event = self.get_event(room_id, &ev_id)?;
stack.extend(event.auth_events().clone());
}
Ok(result)
}
/// Returns a Vec<EventId> representing the difference in auth chains of the given `events`.
pub fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<EventId>>,
) -> Result<Vec<EventId>> {
let mut chains = vec![];
for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = self
.auth_event_ids(room_id, &ids)?
.into_iter()
.collect::<BTreeSet<_>>();
chains.push(chain);
}
if let Some(chain) = chains.first() {
let rest = chains.iter().skip(1).flatten().cloned().collect();
let common = chain.intersection(&rest).collect::<Vec<_>>();
Ok(chains
.iter()
.flatten()
.filter(|id| !common.contains(&id))
.cloned()
.collect::<BTreeSet<_>>()
.into_iter()
.collect())
} else {
Ok(vec![])
}
}
}
pub fn event_id(id: &str) -> EventId {