Add error type, more docs, and conduit sorting test
Not resolve sorting just topo/mainline
This commit is contained in:
commit
8dbd9aae0b
12
README.md
12
README.md
@ -4,20 +4,20 @@
|
|||||||
/// StateMap is just a wrapper/deserialize target for a PDU.
|
/// StateMap is just a wrapper/deserialize target for a PDU.
|
||||||
struct StateEvent {
|
struct StateEvent {
|
||||||
content: serde_json::Value,
|
content: serde_json::Value,
|
||||||
room_id: RoomId,
|
origin_server_ts: SystemTime,
|
||||||
event_id: EventId,
|
sender: UserId,
|
||||||
// ... and so on
|
// ... and so on
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A mapping of event type and state_key to some value `T`, usually an `EventId`.
|
/// A mapping of event type and state_key to some value `T`, usually an `EventId`.
|
||||||
pub type StateMap<T> = BTreeMap<(EventType, String), T>;
|
pub type StateMap<T> = BTreeMap<(EventType, Option<String>), T>;
|
||||||
|
|
||||||
/// A mapping of `EventId` to `T`, usually a `StateEvent`.
|
/// A mapping of `EventId` to `T`, usually a `StateEvent`.
|
||||||
pub type EventMap<T> = BTreeMap<EventId, T>;
|
pub type EventMap<T> = BTreeMap<EventId, T>;
|
||||||
|
|
||||||
struct StateResolution {
|
struct StateResolution {
|
||||||
// For now the StateResolution struct is empty. If "caching" `event_map`
|
// For now the StateResolution struct is empty. If "caching" `event_map`
|
||||||
// between `resolve` calls nds up being more efficient (probably not as this would eat memory)
|
// between `resolve` calls ends up being more efficient (probably not, as this would eat memory)
|
||||||
// it may have an `event_map` field. The `event_map` is all the event's
|
// it may have an `event_map` field. The `event_map` is all the event's
|
||||||
// `StateResolution` has to know about in order to resolve state.
|
// `StateResolution` has to know about in order to resolve state.
|
||||||
}
|
}
|
||||||
@ -30,14 +30,14 @@ impl StateResolution {
|
|||||||
state_sets: &[StateMap<EventId>],
|
state_sets: &[StateMap<EventId>],
|
||||||
event_map: Option<EventMap<StateEvent>>,
|
event_map: Option<EventMap<StateEvent>>,
|
||||||
store: &dyn StateStore,
|
store: &dyn StateStore,
|
||||||
) -> Result<ResolutionResult>;
|
) -> Result<StateMap<EventId>, Error>;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The tricky part, making a good abstraction...
|
// The tricky part, making a good abstraction...
|
||||||
trait StateStore {
|
trait StateStore {
|
||||||
/// Return a single event based on the EventId.
|
/// Return a single event based on the EventId.
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String>;
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, Error>;
|
||||||
|
|
||||||
// There are 3 methods that have default implementations `get_events`,
|
// There are 3 methods that have default implementations `get_events`,
|
||||||
// `auth_event_ids` and `auth_chain_diff`. Each could be overridden if
|
// `auth_event_ids` and `auth_chain_diff`. Each could be overridden if
|
||||||
|
@ -3,12 +3,7 @@
|
|||||||
// `cargo bench unknown option --save-baseline`.
|
// `cargo bench unknown option --save-baseline`.
|
||||||
// To pass args to criterion, use this form
|
// To pass args to criterion, use this form
|
||||||
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
|
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
|
||||||
use std::{
|
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
|
||||||
cell::RefCell,
|
|
||||||
collections::{BTreeMap, BTreeSet},
|
|
||||||
convert::TryFrom,
|
|
||||||
time::UNIX_EPOCH,
|
|
||||||
};
|
|
||||||
|
|
||||||
use criterion::{criterion_group, criterion_main, Criterion};
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
use maplit::btreemap;
|
use maplit::btreemap;
|
||||||
@ -24,7 +19,7 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore};
|
use state_res::{Error, Result, StateEvent, StateMap, StateResolution, StateStore};
|
||||||
|
|
||||||
static mut SERVER_TIMESTAMP: i32 = 0;
|
static mut SERVER_TIMESTAMP: i32 = 0;
|
||||||
|
|
||||||
@ -60,9 +55,8 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) {
|
|||||||
None,
|
None,
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(ResolutionResult::Resolved(state)) => state,
|
Ok(state) => state,
|
||||||
Err(e) => panic!("{}", e),
|
Err(e) => panic!("{}", e),
|
||||||
_ => panic!("conflicted state left"),
|
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@ -111,9 +105,8 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
|
|||||||
Some(inner.clone()),
|
Some(inner.clone()),
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(ResolutionResult::Resolved(state)) => state,
|
Ok(state) => state,
|
||||||
Err(_) => panic!("resolution failed during benchmarking"),
|
Err(_) => panic!("resolution failed during benchmarking"),
|
||||||
_ => panic!("resolution failed during benchmarking"),
|
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@ -137,82 +130,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result<Vec<StateEvent>, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
Ok(self
|
|
||||||
.0
|
|
||||||
.borrow()
|
|
||||||
.iter()
|
|
||||||
.filter(|e| events.contains(e.0))
|
|
||||||
.map(|(_, s)| s)
|
|
||||||
.cloned()
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_event_ids(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
let mut result = vec![];
|
|
||||||
let mut stack = event_ids.to_vec();
|
|
||||||
|
|
||||||
// DFS for auth event chain
|
|
||||||
while !stack.is_empty() {
|
|
||||||
let ev_id = stack.pop().unwrap();
|
|
||||||
if result.contains(&ev_id) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.push(ev_id.clone());
|
|
||||||
|
|
||||||
let event = self.get_event(room_id, &ev_id).unwrap();
|
|
||||||
stack.extend(event.auth_events());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_chain_diff(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: Vec<Vec<EventId>>,
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
let mut chains = vec![];
|
|
||||||
for ids in event_ids {
|
|
||||||
// TODO state store `auth_event_ids` returns self in the event ids list
|
|
||||||
// when an event returns `auth_event_ids` self is not contained
|
|
||||||
let chain = self
|
|
||||||
.auth_event_ids(room_id, &ids)?
|
|
||||||
.into_iter()
|
|
||||||
.collect::<BTreeSet<_>>();
|
|
||||||
chains.push(chain);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(chain) = chains.first() {
|
|
||||||
let rest = chains.iter().skip(1).flatten().cloned().collect();
|
|
||||||
let common = chain.intersection(&rest).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
Ok(chains
|
|
||||||
.iter()
|
|
||||||
.flatten()
|
|
||||||
.filter(|id| !common.contains(&id))
|
|
||||||
.cloned()
|
|
||||||
.collect::<BTreeSet<_>>()
|
|
||||||
.into_iter()
|
|
||||||
.collect())
|
|
||||||
} else {
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
20
src/error.rs
20
src/error.rs
@ -17,7 +17,21 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
IntParseError(#[from] ParseIntError),
|
IntParseError(#[from] ParseIntError),
|
||||||
|
|
||||||
// TODO remove once the correct errors are used
|
#[error("Not found error: {0}")]
|
||||||
#[error("an error occured {0}")]
|
NotFound(String),
|
||||||
TempString(String),
|
|
||||||
|
#[error("Invalid PDU: {0}")]
|
||||||
|
InvalidPdu(String),
|
||||||
|
|
||||||
|
#[error("Conversion failed: {0}")]
|
||||||
|
ConversionError(String),
|
||||||
|
|
||||||
|
#[error("{0}")]
|
||||||
|
Custom(Box<dyn std::error::Error>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
pub fn custom<E: std::error::Error + 'static>(e: E) -> Self {
|
||||||
|
Self::Custom(Box::new(e))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,19 +1,23 @@
|
|||||||
use std::convert::TryFrom;
|
use std::{collections::BTreeMap, convert::TryFrom};
|
||||||
|
|
||||||
use maplit::btreeset;
|
use maplit::btreeset;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
events::{
|
events::{
|
||||||
room::{self, join_rules::JoinRule, member::MembershipState},
|
room::{
|
||||||
|
self,
|
||||||
|
join_rules::JoinRule,
|
||||||
|
member::{self, MembershipState},
|
||||||
|
power_levels::{self, PowerLevelsEventContent},
|
||||||
|
},
|
||||||
EventType,
|
EventType,
|
||||||
},
|
},
|
||||||
identifiers::{RoomVersionId, UserId},
|
identifiers::{RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
room_version::RoomVersion,
|
room_version::RoomVersion,
|
||||||
state_event::{Requester, StateEvent},
|
state_event::{Requester, StateEvent},
|
||||||
StateMap,
|
Error, Result, StateMap,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Represents the 3 event redaction outcomes.
|
/// Represents the 3 event redaction outcomes.
|
||||||
@ -79,25 +83,26 @@ pub fn auth_types_for_event(
|
|||||||
/// * then there are checks for specific event types
|
/// * then there are checks for specific event types
|
||||||
pub fn auth_check(
|
pub fn auth_check(
|
||||||
room_version: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
event: &StateEvent,
|
incoming_event: &StateEvent,
|
||||||
|
prev_event: Option<&StateEvent>,
|
||||||
auth_events: StateMap<StateEvent>,
|
auth_events: StateMap<StateEvent>,
|
||||||
do_sig_check: bool,
|
do_sig_check: bool,
|
||||||
) -> Option<bool> {
|
) -> Result<bool> {
|
||||||
tracing::info!("auth_check beginning for {}", event.event_id().as_str());
|
tracing::info!("auth_check beginning for {}", incoming_event.kind());
|
||||||
|
|
||||||
// don't let power from other rooms be used
|
// don't let power from other rooms be used
|
||||||
for auth_event in auth_events.values() {
|
for auth_event in auth_events.values() {
|
||||||
if auth_event.room_id() != event.room_id() {
|
if auth_event.room_id() != incoming_event.room_id() {
|
||||||
tracing::warn!("found auth event that did not match event's room_id");
|
tracing::warn!("found auth event that did not match event's room_id");
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if do_sig_check {
|
if do_sig_check {
|
||||||
let sender_domain = event.sender().server_name();
|
let sender_domain = incoming_event.sender().server_name();
|
||||||
|
|
||||||
let is_invite_via_3pid = if event.kind() == EventType::RoomMember {
|
let is_invite_via_3pid = if incoming_event.kind() == EventType::RoomMember {
|
||||||
event
|
incoming_event
|
||||||
.deserialize_content::<room::member::MemberEventContent>()
|
.deserialize_content::<room::member::MemberEventContent>()
|
||||||
.map(|c| c.membership == MembershipState::Invite && c.third_party_invite.is_some())
|
.map(|c| c.membership == MembershipState::Invite && c.third_party_invite.is_some())
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
@ -106,19 +111,19 @@ pub fn auth_check(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// check the event has been signed by the domain of the sender
|
// check the event has been signed by the domain of the sender
|
||||||
if event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid {
|
if incoming_event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid {
|
||||||
tracing::warn!("event not signed by sender's server");
|
tracing::warn!("event not signed by sender's server");
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.room_version() == RoomVersionId::Version1
|
if incoming_event.room_version() == RoomVersionId::Version1
|
||||||
&& event
|
&& incoming_event
|
||||||
.signatures()
|
.signatures()
|
||||||
.get(event.event_id().server_name().unwrap())
|
.get(incoming_event.event_id().server_name().unwrap())
|
||||||
.is_none()
|
.is_none()
|
||||||
{
|
{
|
||||||
tracing::warn!("event not signed by event_id's server");
|
tracing::warn!("event not signed by event_id's server");
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,330 +134,334 @@ pub fn auth_check(
|
|||||||
// Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
|
// Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
|
||||||
//
|
//
|
||||||
// 1. If type is m.room.create:
|
// 1. If type is m.room.create:
|
||||||
if event.kind() == EventType::RoomCreate {
|
if incoming_event.kind() == EventType::RoomCreate {
|
||||||
tracing::info!("start m.room.create check");
|
tracing::info!("start m.room.create check");
|
||||||
|
|
||||||
// domain of room_id must match domain of sender.
|
// If it has any previous events, reject
|
||||||
if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) {
|
if !incoming_event.prev_event_ids().is_empty() {
|
||||||
tracing::warn!("creation events server does not match sender");
|
tracing::warn!("the room creation event had previous events");
|
||||||
return Some(false); // creation events room id does not match senders
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// if content.room_version is present and is not a valid version
|
// If the domain of the room_id does not match the domain of the sender, reject
|
||||||
|
if incoming_event.room_id().map(|id| id.server_name())
|
||||||
|
!= Some(incoming_event.sender().server_name())
|
||||||
|
{
|
||||||
|
tracing::warn!("creation events server does not match sender");
|
||||||
|
return Ok(false); // creation events room id does not match senders
|
||||||
|
}
|
||||||
|
|
||||||
|
// If content.room_version is present and is not a recognized version, reject
|
||||||
if serde_json::from_value::<RoomVersionId>(
|
if serde_json::from_value::<RoomVersionId>(
|
||||||
event
|
incoming_event
|
||||||
.content()
|
.content()
|
||||||
.get("room_version")
|
.get("room_version")
|
||||||
.cloned()
|
.cloned()
|
||||||
// synapse defaults to version 1
|
// TODO synapse defaults to version 1
|
||||||
.unwrap_or_else(|| serde_json::json!("1")),
|
.unwrap_or_else(|| serde_json::json!("1")),
|
||||||
)
|
)
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
tracing::warn!("invalid room version found in m.room.create event");
|
tracing::warn!("invalid room version found in m.room.create event");
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If content has no creator field, reject
|
||||||
|
if incoming_event.content().get("creator").is_none() {
|
||||||
|
tracing::warn!("no creator field found in room create content");
|
||||||
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("m.room.create event was allowed");
|
tracing::info!("m.room.create event was allowed");
|
||||||
return Some(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. If event does not have m.room.create in auth_events reject.
|
// 3. If event does not have m.room.create in auth_events reject
|
||||||
if auth_events
|
if auth_events
|
||||||
.get(&(EventType::RoomCreate, Some("".into())))
|
.get(&(EventType::RoomCreate, Some("".into())))
|
||||||
.is_none()
|
.is_none()
|
||||||
{
|
{
|
||||||
tracing::warn!("no m.room.create event in auth chain");
|
tracing::warn!("no m.room.create event in auth chain");
|
||||||
|
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for m.federate
|
// [synapse] checks for federation here
|
||||||
if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) {
|
|
||||||
tracing::info!("checking federation");
|
|
||||||
|
|
||||||
if !can_federate(&auth_events) {
|
|
||||||
tracing::warn!("federation not allowed");
|
|
||||||
|
|
||||||
return Some(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. if type is m.room.aliases
|
// 4. if type is m.room.aliases
|
||||||
if event.kind() == EventType::RoomAliases {
|
if incoming_event.kind() == EventType::RoomAliases {
|
||||||
tracing::info!("starting m.room.aliases check");
|
tracing::info!("starting m.room.aliases check");
|
||||||
// TODO && room_version "special case aliases auth" ??
|
// TODO && room_version "special case aliases auth" ??
|
||||||
if event.state_key().is_none() {
|
if incoming_event.state_key().is_none() {
|
||||||
tracing::warn!("no state_key field found for event");
|
tracing::warn!("no state_key field found for event");
|
||||||
return Some(false); // must have state_key
|
return Ok(false); // must have state_key
|
||||||
}
|
|
||||||
if event.state_key().unwrap().is_empty() {
|
|
||||||
tracing::warn!("state_key must be non-empty");
|
|
||||||
return Some(false); // and be non-empty state_key (point to a user_id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.state_key() != Some(event.sender().to_string()) {
|
// TODO this is not part of the spec
|
||||||
tracing::warn!("no state_key field found for event");
|
// if event.state_key().unwrap().is_empty() {
|
||||||
return Some(false);
|
// tracing::warn!("state_key must be non-empty");
|
||||||
|
// return Ok(false); // and be non-empty state_key (point to a user_id)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// If sender's domain doesn't matches state_key, reject
|
||||||
|
if incoming_event.state_key().as_deref()
|
||||||
|
!= Some(incoming_event.sender().server_name().as_str())
|
||||||
|
{
|
||||||
|
tracing::warn!("state_key does not match sender");
|
||||||
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("m.room.aliases event was allowed");
|
tracing::info!("m.room.aliases event was allowed");
|
||||||
return Some(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.kind() == EventType::RoomMember {
|
if incoming_event.kind() == EventType::RoomMember {
|
||||||
tracing::info!("starting m.room.member check");
|
tracing::info!("starting m.room.member check");
|
||||||
|
|
||||||
if !is_membership_change_allowed(event.to_requester(), &auth_events)? {
|
if incoming_event.state_key().is_none() {
|
||||||
return Some(false);
|
tracing::warn!("no state_key found for m.room.member event");
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if incoming_event
|
||||||
|
.deserialize_content::<room::member::MemberEventContent>()
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
tracing::warn!("no membership filed found for m.room.member event content");
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid_membership_change(incoming_event.to_requester(), prev_event, &auth_events)? {
|
||||||
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("m.room.member event was allowed");
|
tracing::info!("m.room.member event was allowed");
|
||||||
return Some(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(in_room) = check_event_sender_in_room(event, &auth_events) {
|
// If the sender's current membership state is not join, reject
|
||||||
if !in_room {
|
match check_event_sender_in_room(incoming_event.sender(), &auth_events) {
|
||||||
tracing::warn!("sender not in room");
|
Some(true) => {} // sender in room
|
||||||
return Some(false);
|
Some(false) => {
|
||||||
|
tracing::warn!("sender's membership is not join");
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
tracing::warn!("sender not found in room");
|
||||||
|
return Ok(false);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
tracing::warn!("sender not in room");
|
|
||||||
return Some(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case to allow m.room.third_party_invite events where ever
|
// Special case to allow m.room.third_party_invite events where ever
|
||||||
// a user is allowed to issue invites
|
// a user is allowed to issue invites
|
||||||
if event.kind() == EventType::RoomThirdPartyInvite {
|
if incoming_event.kind() == EventType::RoomThirdPartyInvite {
|
||||||
// TODO impl this
|
// TODO impl this
|
||||||
unimplemented!("third party invite")
|
unimplemented!("third party invite")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !can_send_event(event, &auth_events)? {
|
// If the event type's required power level is greater than the sender's power level, reject
|
||||||
|
// If the event has a state_key that starts with an @ and does not match the sender, reject.
|
||||||
|
if !can_send_event(incoming_event, &auth_events)? {
|
||||||
tracing::warn!("user cannot send event");
|
tracing::warn!("user cannot send event");
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.kind() == EventType::RoomPowerLevels {
|
if incoming_event.kind() == EventType::RoomPowerLevels {
|
||||||
tracing::info!("starting m.room.power_levels check");
|
tracing::info!("starting m.room.power_levels check");
|
||||||
if let Some(required_pwr_lvl) = check_power_levels(room_version, event, &auth_events) {
|
|
||||||
|
if let Some(required_pwr_lvl) =
|
||||||
|
check_power_levels(room_version, incoming_event, &auth_events)
|
||||||
|
{
|
||||||
if !required_pwr_lvl {
|
if !required_pwr_lvl {
|
||||||
tracing::warn!("power level was not allowed");
|
tracing::warn!("power level was not allowed");
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tracing::warn!("power level was not allowed");
|
tracing::warn!("power level was not allowed");
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
tracing::info!("power levels event allowed");
|
tracing::info!("power levels event allowed");
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.kind() == EventType::RoomRedaction {
|
if incoming_event.kind() == EventType::RoomRedaction {
|
||||||
if let RedactAllowed::No = check_redaction(room_version, event, &auth_events)? {
|
if let RedactAllowed::No = check_redaction(room_version, incoming_event, &auth_events)? {
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("allowing event passed all checks");
|
tracing::info!("allowing event passed all checks");
|
||||||
Some(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// synapse has an `event: &StateEvent` param but it's never used
|
// TODO deserializing the member, power, join_rules event contents is done in conduit
|
||||||
/// Can this room federate based on its m.room.create event.
|
// just before this is called. Could they be passed in?
|
||||||
pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
|
/// Does the user who sent this member event have required power levels to do so.
|
||||||
let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into())));
|
///
|
||||||
if let Some(ev) = creation_event {
|
/// * `user` - Information about the membership event and user making the request.
|
||||||
if let Some(fed) = ev.content().get("m.federate") {
|
/// * `prev_event` - The event that occurred immediately before the `user` event or None.
|
||||||
fed == "true"
|
/// * `auth_events` - The set of auth events that relate to a membership event.
|
||||||
} else {
|
/// this is generated by calling `auth_types_for_event` with the membership event and
|
||||||
false
|
/// the current State.
|
||||||
}
|
pub fn valid_membership_change(
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Dose the user who sent this member event have required power levels to do so.
|
|
||||||
pub fn is_membership_change_allowed(
|
|
||||||
user: Requester<'_>,
|
user: Requester<'_>,
|
||||||
|
prev_event: Option<&StateEvent>,
|
||||||
auth_events: &StateMap<StateEvent>,
|
auth_events: &StateMap<StateEvent>,
|
||||||
) -> Option<bool> {
|
) -> Result<bool> {
|
||||||
|
let state_key = if let Some(s) = user.state_key.as_ref() {
|
||||||
|
s
|
||||||
|
} else {
|
||||||
|
return Err(Error::InvalidPdu("State event requires state_key".into()));
|
||||||
|
};
|
||||||
|
|
||||||
let content =
|
let content =
|
||||||
// TODO return error
|
serde_json::from_str::<room::member::MemberEventContent>(&user.content.to_string())?;
|
||||||
serde_json::from_str::<room::member::MemberEventContent>(&user.content.to_string()).ok()?;
|
|
||||||
|
|
||||||
let membership = content.membership;
|
let target_membership = content.membership;
|
||||||
|
|
||||||
// check if this is the room creator joining
|
let target_user_id = UserId::try_from(state_key.as_str())
|
||||||
if user.prev_event_ids.len() == 1 && membership == MembershipState::Join {
|
.map_err(|e| Error::ConversionError(format!("{}", e)))?;
|
||||||
if let Some(create) = auth_events.get(&(EventType::RoomCreate, Some("".into()))) {
|
|
||||||
if let Ok(create_ev) = create.deserialize_content::<room::create::CreateEventContent>()
|
|
||||||
{
|
|
||||||
if user.state_key == Some(create_ev.creator.to_string()) {
|
|
||||||
tracing::debug!("m.room.member event allowed via m.room.create");
|
|
||||||
return Some(true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let target_user_id = UserId::try_from(user.state_key.as_deref().unwrap())
|
|
||||||
.ok()
|
|
||||||
.unwrap();
|
|
||||||
// if the server_names are different and federation is NOT allowed
|
|
||||||
if user.room_id.server_name() != target_user_id.server_name() && !can_federate(auth_events) {
|
|
||||||
tracing::warn!("server cannot federate");
|
|
||||||
return Some(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
let key = (EventType::RoomMember, Some(user.sender.to_string()));
|
let key = (EventType::RoomMember, Some(user.sender.to_string()));
|
||||||
let caller = auth_events.get(&key);
|
let sender = auth_events.get(&key);
|
||||||
|
let sender_membership =
|
||||||
let caller_in_room = caller.is_some() && check_membership(caller, MembershipState::Join);
|
sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
|
||||||
let caller_invited = caller.is_some() && check_membership(caller, MembershipState::Invite);
|
Ok(pdu
|
||||||
|
.deserialize_content::<room::member::MemberEventContent>()?
|
||||||
|
.membership)
|
||||||
|
})?;
|
||||||
|
|
||||||
let key = (EventType::RoomMember, Some(target_user_id.to_string()));
|
let key = (EventType::RoomMember, Some(target_user_id.to_string()));
|
||||||
let target = auth_events.get(&key);
|
let current = auth_events.get(&key);
|
||||||
|
let current_membership =
|
||||||
|
current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
|
||||||
|
Ok(pdu
|
||||||
|
.deserialize_content::<room::member::MemberEventContent>()?
|
||||||
|
.membership)
|
||||||
|
})?;
|
||||||
|
|
||||||
let target_in_room = target.is_some() && check_membership(target, MembershipState::Join);
|
let key = (EventType::RoomPowerLevels, Some("".into()));
|
||||||
let target_banned = target.is_some() && check_membership(target, MembershipState::Ban);
|
let power_levels = auth_events.get(&key).map_or_else(
|
||||||
|
|| {
|
||||||
|
Ok::<_, Error>(power_levels::PowerLevelsEventContent {
|
||||||
|
ban: 50.into(),
|
||||||
|
events: BTreeMap::new(),
|
||||||
|
events_default: 0.into(),
|
||||||
|
invite: 50.into(),
|
||||||
|
kick: 50.into(),
|
||||||
|
redact: 50.into(),
|
||||||
|
state_default: 0.into(),
|
||||||
|
users: BTreeMap::new(),
|
||||||
|
users_default: 0.into(),
|
||||||
|
notifications: ruma::events::room::power_levels::NotificationPowerLevels {
|
||||||
|
room: 50.into(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|power_levels| {
|
||||||
|
power_levels
|
||||||
|
.deserialize_content::<PowerLevelsEventContent>()
|
||||||
|
.map_err(Into::into)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let sender_power = power_levels.users.get(&user.sender).map_or_else(
|
||||||
|
|| {
|
||||||
|
if sender_membership != member::MembershipState::Join {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(&power_levels.users_default)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// If it's okay, wrap with Some(_)
|
||||||
|
Some,
|
||||||
|
);
|
||||||
|
let target_power = power_levels.users.get(&target_user_id).map_or_else(
|
||||||
|
|| {
|
||||||
|
if target_membership != member::MembershipState::Join {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(&power_levels.users_default)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// If it's okay, wrap with Some(_)
|
||||||
|
Some,
|
||||||
|
);
|
||||||
|
|
||||||
let key = (EventType::RoomJoinRules, Some("".to_string()));
|
let key = (EventType::RoomJoinRules, Some("".to_string()));
|
||||||
let join_rules_event = auth_events.get(&key);
|
let join_rules_event = auth_events.get(&key);
|
||||||
|
let mut join_rules = JoinRule::Invite;
|
||||||
let mut join_rule = JoinRule::Invite;
|
|
||||||
if let Some(jr) = join_rules_event {
|
if let Some(jr) = join_rules_event {
|
||||||
join_rule = jr
|
join_rules = jr
|
||||||
.deserialize_content::<room::join_rules::JoinRulesEventContent>()
|
.deserialize_content::<room::join_rules::JoinRulesEventContent>()?
|
||||||
.ok()? // TODO these are errors? and should be treated as a DB failure?
|
|
||||||
.join_rule;
|
.join_rule;
|
||||||
}
|
}
|
||||||
|
|
||||||
let user_level = get_user_power_level(user.sender, auth_events);
|
if let Some(prev) = prev_event {
|
||||||
let target_level = get_user_power_level(&target_user_id, auth_events);
|
if prev.kind() == EventType::RoomCreate && prev.prev_event_ids().is_empty() {
|
||||||
|
return Ok(true);
|
||||||
// synapse has a not "what to do for default here 50"
|
|
||||||
let ban_level = get_named_level(auth_events, "ban", 50);
|
|
||||||
|
|
||||||
// TODO clean this up
|
|
||||||
tracing::debug!(
|
|
||||||
"_is_membership_change_allowed: {}",
|
|
||||||
serde_json::to_string_pretty(&json!({
|
|
||||||
"caller_in_room": caller_in_room,
|
|
||||||
"caller_invited": caller_invited,
|
|
||||||
"target_banned": target_banned,
|
|
||||||
"target_in_room": target_in_room,
|
|
||||||
"membership": membership,
|
|
||||||
"join_rule": join_rule,
|
|
||||||
"target_user_id": target_user_id,
|
|
||||||
"event.user_id": user.sender,
|
|
||||||
}))
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if membership == MembershipState::Invite && content.third_party_invite.is_some() {
|
|
||||||
// TODO this is unimpled
|
|
||||||
if !verify_third_party_invite(&user, auth_events) {
|
|
||||||
tracing::warn!("not invited to this room",);
|
|
||||||
return Some(false);
|
|
||||||
}
|
|
||||||
if target_banned {
|
|
||||||
tracing::warn!("banned from this room",);
|
|
||||||
return Some(false);
|
|
||||||
}
|
|
||||||
tracing::info!("invite succeded");
|
|
||||||
return Some(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if membership != MembershipState::Join {
|
|
||||||
if caller_invited && membership == MembershipState::Leave && &target_user_id == user.sender
|
|
||||||
{
|
|
||||||
tracing::warn!("join event succeded");
|
|
||||||
return Some(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if !caller_in_room {
|
|
||||||
tracing::warn!("user is not in this room {}", user.room_id.as_str(),);
|
|
||||||
return Some(false); // caller is not joined
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if membership == MembershipState::Invite {
|
Ok(if target_membership == MembershipState::Join {
|
||||||
if target_banned {
|
|
||||||
tracing::warn!("target has been banned");
|
|
||||||
return Some(false);
|
|
||||||
} else if target_in_room {
|
|
||||||
tracing::warn!("already in room");
|
|
||||||
return Some(false); // already in room
|
|
||||||
} else {
|
|
||||||
let invite_level = get_named_level(auth_events, "invite", 0);
|
|
||||||
if user_level < invite_level {
|
|
||||||
return Some(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if membership == MembershipState::Join {
|
|
||||||
if user.sender != &target_user_id {
|
if user.sender != &target_user_id {
|
||||||
tracing::warn!("cannot force another user to join");
|
false
|
||||||
return Some(false); // cannot force another user to join
|
} else if let MembershipState::Ban = current_membership {
|
||||||
} else if target_banned {
|
false
|
||||||
tracing::warn!("cannot join when banned");
|
|
||||||
return Some(false); // cannot joined when banned
|
|
||||||
} else if join_rule == JoinRule::Public {
|
|
||||||
tracing::info!("join rule public")
|
|
||||||
// pass
|
|
||||||
} else if join_rule == JoinRule::Invite {
|
|
||||||
if !caller_in_room && !caller_invited {
|
|
||||||
tracing::warn!("user has not been invited to this room");
|
|
||||||
return Some(false); // you are not invited to this room
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
tracing::warn!("the join rule is Private or yet to be spec'ed by Matrix");
|
join_rules == JoinRule::Invite
|
||||||
// synapse has 2 TODO's may_join list and private rooms
|
&& (current_membership == MembershipState::Join
|
||||||
|
|| current_membership == MembershipState::Invite)
|
||||||
// the join_rule is Private or Knock which means it is not yet spec'ed
|
|| join_rules == JoinRule::Public
|
||||||
return Some(false);
|
|
||||||
}
|
}
|
||||||
} else if membership == MembershipState::Leave {
|
} else if target_membership == MembershipState::Invite {
|
||||||
if target_banned && user_level < ban_level {
|
if let Some(_tp_id) = content.third_party_invite {
|
||||||
tracing::warn!("not enough power to unban");
|
if current_membership == MembershipState::Ban {
|
||||||
return Some(false); // you cannot unban this user
|
false
|
||||||
} else if &target_user_id != user.sender {
|
} else {
|
||||||
let kick_level = get_named_level(auth_events, "kick", 50);
|
// TODO this is not filled out
|
||||||
|
verify_third_party_invite(&user, auth_events)
|
||||||
if user_level < kick_level || user_level <= target_level {
|
|
||||||
tracing::warn!("not enough power to kick user");
|
|
||||||
return Some(false); // you do not have the power to kick user
|
|
||||||
}
|
}
|
||||||
|
} else if sender_membership != MembershipState::Join
|
||||||
|
|| current_membership == MembershipState::Join
|
||||||
|
|| current_membership == MembershipState::Ban
|
||||||
|
{
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
sender_power
|
||||||
|
.filter(|&p| p >= &power_levels.invite)
|
||||||
|
.is_some()
|
||||||
}
|
}
|
||||||
} else if membership == MembershipState::Ban {
|
} else if target_membership == MembershipState::Leave {
|
||||||
tracing::debug!(
|
if user.sender == &target_user_id {
|
||||||
"{} < {} || {} <= {}",
|
current_membership == MembershipState::Join
|
||||||
user_level,
|
|| current_membership == MembershipState::Invite
|
||||||
ban_level,
|
} else if sender_membership != MembershipState::Join
|
||||||
user_level,
|
|| current_membership == MembershipState::Ban
|
||||||
target_level
|
&& sender_power.filter(|&p| p < &power_levels.ban).is_some()
|
||||||
);
|
{
|
||||||
if user_level < ban_level || user_level <= target_level {
|
false
|
||||||
tracing::warn!("not enough power to ban");
|
} else {
|
||||||
return Some(false);
|
sender_power.filter(|&p| p >= &power_levels.kick).is_some()
|
||||||
|
&& target_power < sender_power
|
||||||
|
}
|
||||||
|
} else if target_membership == MembershipState::Ban {
|
||||||
|
if sender_membership != MembershipState::Join {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
sender_power.filter(|&p| p >= &power_levels.ban).is_some()
|
||||||
|
&& target_power < sender_power
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tracing::warn!("unknown membership status");
|
false
|
||||||
// Unknown membership status
|
})
|
||||||
return Some(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Is the event's sender in the room that they sent the event to.
|
/// Is the event's sender in the room that they sent the event to.
|
||||||
///
|
|
||||||
/// A return value of None is not a failure
|
|
||||||
pub fn check_event_sender_in_room(
|
pub fn check_event_sender_in_room(
|
||||||
event: &StateEvent,
|
sender: &UserId,
|
||||||
auth_events: &StateMap<StateEvent>,
|
auth_events: &StateMap<StateEvent>,
|
||||||
) -> Option<bool> {
|
) -> Option<bool> {
|
||||||
let mem = auth_events.get(&(EventType::RoomMember, Some(event.sender().to_string())))?;
|
let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?;
|
||||||
// TODO this is check_membership a helper fn in synapse but it does this
|
// TODO this is check_membership a helper fn in synapse but it does this
|
||||||
Some(
|
Some(
|
||||||
mem.deserialize_content::<room::member::MemberEventContent>()
|
mem.deserialize_content::<room::member::MemberEventContent>()
|
||||||
@ -462,30 +471,31 @@ pub fn check_event_sender_in_room(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Is the user allowed to send a specific event based on the rooms power levels.
|
/// Is the user allowed to send a specific event based on the rooms power levels. Does the event
|
||||||
pub fn can_send_event(event: &StateEvent, auth_events: &StateMap<StateEvent>) -> Option<bool> {
|
/// have the correct userId as it's state_key if it's not the "" state_key.
|
||||||
|
pub fn can_send_event(event: &StateEvent, auth_events: &StateMap<StateEvent>) -> Result<bool> {
|
||||||
let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
|
let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
|
||||||
|
|
||||||
let send_level = get_send_level(event.kind(), event.state_key(), ple);
|
let event_type_power_level = get_send_level(event.kind(), event.state_key(), ple);
|
||||||
let user_level = get_user_power_level(event.sender(), auth_events);
|
let user_level = get_user_power_level(event.sender(), auth_events);
|
||||||
|
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"{} snd {} usr {}",
|
"{} ev_type {} usr {}",
|
||||||
event.event_id().to_string(),
|
event.event_id().to_string(),
|
||||||
send_level,
|
event_type_power_level,
|
||||||
user_level
|
user_level
|
||||||
);
|
);
|
||||||
|
|
||||||
if user_level < send_level {
|
if user_level < event_type_power_level {
|
||||||
return Some(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(sk) = event.state_key() {
|
if let Some(sk) = event.state_key() {
|
||||||
if sk.starts_with('@') && sk != event.sender().as_str() {
|
if sk.starts_with('@') && sk != event.sender().as_str() {
|
||||||
return Some(false); // permission required to post in this room
|
return Ok(false); // permission required to post in this room
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Confirm that the event sender has the required power levels.
|
/// Confirm that the event sender has the required power levels.
|
||||||
@ -494,17 +504,17 @@ pub fn check_power_levels(
|
|||||||
power_event: &StateEvent,
|
power_event: &StateEvent,
|
||||||
auth_events: &StateMap<StateEvent>,
|
auth_events: &StateMap<StateEvent>,
|
||||||
) -> Option<bool> {
|
) -> Option<bool> {
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
let key = (power_event.kind(), power_event.state_key());
|
let key = (power_event.kind(), power_event.state_key());
|
||||||
|
|
||||||
let current_state = if let Some(current_state) = auth_events.get(&key) {
|
let current_state = if let Some(current_state) = auth_events.get(&key) {
|
||||||
current_state
|
current_state
|
||||||
} else {
|
} else {
|
||||||
// TODO synapse returns here, shouldn't this be an error ??
|
// If there is no previous m.room.power_levels event in the room, allow
|
||||||
return Some(true);
|
return Some(true);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// If users key in content is not a dictionary with keys that are valid user IDs
|
||||||
|
// with values that are integers (or a string that is an integer), reject.
|
||||||
let user_content = power_event
|
let user_content = power_event
|
||||||
.deserialize_content::<room::power_levels::PowerLevelsEventContent>()
|
.deserialize_content::<room::power_levels::PowerLevelsEventContent>()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -520,7 +530,7 @@ pub fn check_power_levels(
|
|||||||
let mut user_levels_to_check = btreeset![];
|
let mut user_levels_to_check = btreeset![];
|
||||||
let old_list = ¤t_content.users;
|
let old_list = ¤t_content.users;
|
||||||
let user_list = &user_content.users;
|
let user_list = &user_content.users;
|
||||||
for user in old_list.keys().chain(user_list.keys()).dedup() {
|
for user in old_list.keys().chain(user_list.keys()) {
|
||||||
let user: &UserId = user;
|
let user: &UserId = user;
|
||||||
user_levels_to_check.insert(user);
|
user_levels_to_check.insert(user);
|
||||||
}
|
}
|
||||||
@ -530,7 +540,7 @@ pub fn check_power_levels(
|
|||||||
let mut event_levels_to_check = btreeset![];
|
let mut event_levels_to_check = btreeset![];
|
||||||
let old_list = ¤t_content.events;
|
let old_list = ¤t_content.events;
|
||||||
let new_list = &user_content.events;
|
let new_list = &user_content.events;
|
||||||
for ev_id in old_list.keys().chain(new_list.keys()).dedup() {
|
for ev_id in old_list.keys().chain(new_list.keys()) {
|
||||||
let ev_id: &EventType = ev_id;
|
let ev_id: &EventType = ev_id;
|
||||||
event_levels_to_check.insert(ev_id);
|
event_levels_to_check.insert(ev_id);
|
||||||
}
|
}
|
||||||
@ -637,27 +647,26 @@ pub fn check_redaction(
|
|||||||
room_version: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
redaction_event: &StateEvent,
|
redaction_event: &StateEvent,
|
||||||
auth_events: &StateMap<StateEvent>,
|
auth_events: &StateMap<StateEvent>,
|
||||||
) -> Option<RedactAllowed> {
|
) -> Result<RedactAllowed> {
|
||||||
let user_level = get_user_power_level(redaction_event.sender(), auth_events);
|
let user_level = get_user_power_level(redaction_event.sender(), auth_events);
|
||||||
let redact_level = get_named_level(auth_events, "redact", 50);
|
let redact_level = get_named_level(auth_events, "redact", 50);
|
||||||
|
|
||||||
if user_level >= redact_level {
|
if user_level >= redact_level {
|
||||||
return Some(RedactAllowed::CanRedact);
|
tracing::info!("redaction allowed via power levels");
|
||||||
|
return Ok(RedactAllowed::CanRedact);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let RoomVersionId::Version1 = room_version {
|
if let RoomVersionId::Version1 = room_version {
|
||||||
// are the redacter and redactee in the same domain
|
// are the redacter and redactee in the same domain
|
||||||
if Some(redaction_event.event_id().server_name())
|
if Some(redaction_event.sender().server_name())
|
||||||
== redaction_event.redacts().map(|id| id.server_name())
|
== redaction_event.redacts().and_then(|id| id.server_name())
|
||||||
{
|
{
|
||||||
return Some(RedactAllowed::OwnEvent);
|
tracing::info!("redaction event allowed via room version 1 rules");
|
||||||
|
return Ok(RedactAllowed::OwnEvent);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// TODO synapse has this line also
|
|
||||||
// event.internal_metadata.recheck_redaction = True
|
|
||||||
return Some(RedactAllowed::OwnEvent);
|
|
||||||
}
|
}
|
||||||
Some(RedactAllowed::No)
|
|
||||||
|
Ok(RedactAllowed::No)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check that the member event matches `state`.
|
/// Check that the member event matches `state`.
|
||||||
@ -677,6 +686,20 @@ pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipStat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Can this room federate based on its m.room.create event.
|
||||||
|
pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
|
||||||
|
let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into())));
|
||||||
|
if let Some(ev) = creation_event {
|
||||||
|
if let Some(fed) = ev.content().get("m.federate") {
|
||||||
|
fed == "true"
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content.
|
/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content.
|
||||||
/// or return `default` if no power level event is found or zero if no field matches `name`.
|
/// or return `default` if no power level event is found or zero if no field matches `name`.
|
||||||
pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default: i64) -> i64 {
|
pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default: i64) -> i64 {
|
||||||
|
48
src/lib.rs
48
src/lib.rs
@ -26,11 +26,6 @@ pub use state_store::StateStore;
|
|||||||
// yielding to reactor during loops every N iterations.
|
// yielding to reactor during loops every N iterations.
|
||||||
const _YIELD_AFTER_ITERATIONS: usize = 100;
|
const _YIELD_AFTER_ITERATIONS: usize = 100;
|
||||||
|
|
||||||
pub enum ResolutionResult {
|
|
||||||
Conflicted(Vec<StateMap<EventId>>),
|
|
||||||
Resolved(StateMap<EventId>),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A mapping of event type and state_key to some value `T`, usually an `EventId`.
|
/// A mapping of event type and state_key to some value `T`, usually an `EventId`.
|
||||||
pub type StateMap<T> = BTreeMap<(EventType, Option<String>), T>;
|
pub type StateMap<T> = BTreeMap<(EventType, Option<String>), T>;
|
||||||
|
|
||||||
@ -63,7 +58,7 @@ impl StateResolution {
|
|||||||
event_map: Option<EventMap<StateEvent>>,
|
event_map: Option<EventMap<StateEvent>>,
|
||||||
store: &dyn StateStore,
|
store: &dyn StateStore,
|
||||||
// TODO actual error handling (`thiserror`??)
|
// TODO actual error handling (`thiserror`??)
|
||||||
) -> Result<ResolutionResult> {
|
) -> Result<StateMap<EventId>> {
|
||||||
tracing::info!("State resolution starting");
|
tracing::info!("State resolution starting");
|
||||||
|
|
||||||
let mut event_map = if let Some(ev_map) = event_map {
|
let mut event_map = if let Some(ev_map) = event_map {
|
||||||
@ -78,7 +73,7 @@ impl StateResolution {
|
|||||||
|
|
||||||
if conflicting.is_empty() {
|
if conflicting.is_empty() {
|
||||||
tracing::info!("no conflicting state found");
|
tracing::info!("no conflicting state found");
|
||||||
return Ok(ResolutionResult::Resolved(clean));
|
return Ok(clean);
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("{} conflicting events", conflicting.len());
|
tracing::info!("{} conflicting events", conflicting.len());
|
||||||
@ -119,7 +114,7 @@ impl StateResolution {
|
|||||||
|
|
||||||
for event in event_map.values() {
|
for event in event_map.values() {
|
||||||
if event.room_id() != Some(room_id) {
|
if event.room_id() != Some(room_id) {
|
||||||
return Err(Error::TempString(format!(
|
return Err(Error::InvalidPdu(format!(
|
||||||
"resolving event {} in room {}, when correct room is {}",
|
"resolving event {} in room {}, when correct room is {}",
|
||||||
event.event_id(),
|
event.event_id(),
|
||||||
event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"),
|
event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"),
|
||||||
@ -227,7 +222,7 @@ impl StateResolution {
|
|||||||
// add unconflicted state to the resolved state
|
// add unconflicted state to the resolved state
|
||||||
resolved_state.extend(clean);
|
resolved_state.extend(clean);
|
||||||
|
|
||||||
Ok(ResolutionResult::Resolved(resolved_state))
|
Ok(resolved_state)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Split the events that have no conflicts from those that are conflicting.
|
/// Split the events that have no conflicts from those that are conflicting.
|
||||||
@ -288,16 +283,14 @@ impl StateResolution {
|
|||||||
|
|
||||||
tracing::debug!("calculating auth chain difference");
|
tracing::debug!("calculating auth chain difference");
|
||||||
|
|
||||||
store
|
store.auth_chain_diff(
|
||||||
.auth_chain_diff(
|
room_id,
|
||||||
room_id,
|
state_sets
|
||||||
state_sets
|
.iter()
|
||||||
.iter()
|
.map(|map| map.values().cloned().collect())
|
||||||
.map(|map| map.values().cloned().collect())
|
.dedup()
|
||||||
.dedup()
|
.collect::<Vec<_>>(),
|
||||||
.collect::<Vec<_>>(),
|
)
|
||||||
)
|
|
||||||
.map_err(Error::TempString)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Events are sorted from "earliest" to "latest". They are compared using
|
/// Events are sorted from "earliest" to "latest". They are compared using
|
||||||
@ -553,10 +546,19 @@ impl StateResolution {
|
|||||||
|
|
||||||
tracing::debug!("event to check {:?}", event.event_id().to_string());
|
tracing::debug!("event to check {:?}", event.event_id().to_string());
|
||||||
|
|
||||||
if event_auth::auth_check(room_version, &event, auth_events, false)
|
let most_recent_prev_event = event
|
||||||
.ok_or_else(|| "Auth check failed due to deserialization most likely".to_string())
|
.prev_event_ids()
|
||||||
.map_err(Error::TempString)?
|
.iter()
|
||||||
{
|
.filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store))
|
||||||
|
.next_back();
|
||||||
|
|
||||||
|
if event_auth::auth_check(
|
||||||
|
room_version,
|
||||||
|
&event,
|
||||||
|
most_recent_prev_event.as_ref(),
|
||||||
|
auth_events,
|
||||||
|
false,
|
||||||
|
)? {
|
||||||
// add event to resolved state map
|
// add event to resolved state map
|
||||||
resolved_state.insert((event.kind(), event.state_key()), event_id.clone());
|
resolved_state.insert((event.kind(), event.state_key()), event_id.clone());
|
||||||
} else {
|
} else {
|
||||||
|
@ -2,18 +2,14 @@ use std::collections::BTreeSet;
|
|||||||
|
|
||||||
use ruma::identifiers::{EventId, RoomId};
|
use ruma::identifiers::{EventId, RoomId};
|
||||||
|
|
||||||
use crate::StateEvent;
|
use crate::{Result, StateEvent};
|
||||||
|
|
||||||
pub trait StateStore {
|
pub trait StateStore {
|
||||||
/// Return a single event based on the EventId.
|
/// Return a single event based on the EventId.
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String>;
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent>;
|
||||||
|
|
||||||
/// Returns the events that correspond to the `event_ids` sorted in the same order.
|
/// Returns the events that correspond to the `event_ids` sorted in the same order.
|
||||||
fn get_events(
|
fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<StateEvent>> {
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<StateEvent>, String> {
|
|
||||||
let mut events = vec![];
|
let mut events = vec![];
|
||||||
for id in event_ids {
|
for id in event_ids {
|
||||||
events.push(self.get_event(room_id, id)?);
|
events.push(self.get_event(room_id, id)?);
|
||||||
@ -22,11 +18,7 @@ pub trait StateStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Vec of the related auth events to the given `event`.
|
/// Returns a Vec of the related auth events to the given `event`.
|
||||||
fn auth_event_ids(
|
fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<EventId>> {
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
let mut result = vec![];
|
let mut result = vec![];
|
||||||
let mut stack = event_ids.to_vec();
|
let mut stack = event_ids.to_vec();
|
||||||
|
|
||||||
@ -52,7 +44,7 @@ pub trait StateStore {
|
|||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
event_ids: Vec<Vec<EventId>>,
|
event_ids: Vec<Vec<EventId>>,
|
||||||
) -> Result<Vec<EventId>, String> {
|
) -> Result<Vec<EventId>> {
|
||||||
let mut chains = vec![];
|
let mut chains = vec![];
|
||||||
for ids in event_ids {
|
for ids in event_ids {
|
||||||
// TODO state store `auth_event_ids` returns self in the event ids list
|
// TODO state store `auth_event_ids` returns self in the event ids list
|
||||||
|
@ -0,0 +1,289 @@
|
|||||||
|
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom};
|
||||||
|
|
||||||
|
use ruma::{
|
||||||
|
events::{
|
||||||
|
pdu::EventHash,
|
||||||
|
room::{
|
||||||
|
join_rules::JoinRule,
|
||||||
|
member::{MemberEventContent, MembershipState},
|
||||||
|
},
|
||||||
|
EventType,
|
||||||
|
},
|
||||||
|
identifiers::{EventId, RoomId, UserId},
|
||||||
|
};
|
||||||
|
use serde_json::{json, Value as JsonValue};
|
||||||
|
#[rustfmt::skip] // this deletes the comments for some reason yay!
|
||||||
|
use state_res::{
|
||||||
|
event_auth::{
|
||||||
|
// auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction,
|
||||||
|
valid_membership_change,
|
||||||
|
},
|
||||||
|
Requester, StateEvent, StateMap, StateStore, Result, Error
|
||||||
|
};
|
||||||
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
|
use std::sync::Once;
|
||||||
|
|
||||||
|
static LOGGER: Once = Once::new();
|
||||||
|
|
||||||
|
static mut SERVER_TIMESTAMP: i32 = 0;
|
||||||
|
|
||||||
|
fn event_id(id: &str) -> EventId {
|
||||||
|
if id.contains('$') {
|
||||||
|
return EventId::try_from(id).unwrap();
|
||||||
|
}
|
||||||
|
EventId::try_from(format!("${}:foo", id)).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alice() -> UserId {
|
||||||
|
UserId::try_from("@alice:foo").unwrap()
|
||||||
|
}
|
||||||
|
fn bob() -> UserId {
|
||||||
|
UserId::try_from("@bob:foo").unwrap()
|
||||||
|
}
|
||||||
|
fn charlie() -> UserId {
|
||||||
|
UserId::try_from("@charlie:foo").unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn room_id() -> RoomId {
|
||||||
|
RoomId::try_from("!test:foo").unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn member_content_ban() -> JsonValue {
|
||||||
|
serde_json::to_value(MemberEventContent {
|
||||||
|
membership: MembershipState::Ban,
|
||||||
|
displayname: None,
|
||||||
|
avatar_url: None,
|
||||||
|
is_direct: None,
|
||||||
|
third_party_invite: None,
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn member_content_join() -> JsonValue {
|
||||||
|
serde_json::to_value(MemberEventContent {
|
||||||
|
membership: MembershipState::Join,
|
||||||
|
displayname: None,
|
||||||
|
avatar_url: None,
|
||||||
|
is_direct: None,
|
||||||
|
third_party_invite: None,
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
impl StateStore for TestStore {
|
||||||
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
|
self.0
|
||||||
|
.borrow()
|
||||||
|
.get(event_id)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_pdu_event<S>(
|
||||||
|
id: &str,
|
||||||
|
sender: UserId,
|
||||||
|
ev_type: EventType,
|
||||||
|
state_key: Option<&str>,
|
||||||
|
content: JsonValue,
|
||||||
|
auth_events: &[S],
|
||||||
|
prev_events: &[S],
|
||||||
|
) -> StateEvent
|
||||||
|
where
|
||||||
|
S: AsRef<str>,
|
||||||
|
{
|
||||||
|
let ts = unsafe {
|
||||||
|
let ts = SERVER_TIMESTAMP;
|
||||||
|
// increment the "origin_server_ts" value
|
||||||
|
SERVER_TIMESTAMP += 1;
|
||||||
|
ts
|
||||||
|
};
|
||||||
|
let id = if id.contains('$') {
|
||||||
|
id.to_string()
|
||||||
|
} else {
|
||||||
|
format!("${}:foo", id)
|
||||||
|
};
|
||||||
|
let auth_events = auth_events
|
||||||
|
.iter()
|
||||||
|
.map(AsRef::as_ref)
|
||||||
|
.map(event_id)
|
||||||
|
.map(|id| {
|
||||||
|
(
|
||||||
|
id,
|
||||||
|
EventHash {
|
||||||
|
sha256: "hello".into(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let prev_events = prev_events
|
||||||
|
.iter()
|
||||||
|
.map(AsRef::as_ref)
|
||||||
|
.map(event_id)
|
||||||
|
.map(|id| {
|
||||||
|
(
|
||||||
|
id,
|
||||||
|
EventHash {
|
||||||
|
sha256: "hello".into(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let json = if let Some(state_key) = state_key {
|
||||||
|
json!({
|
||||||
|
"auth_events": auth_events,
|
||||||
|
"prev_events": prev_events,
|
||||||
|
"event_id": id,
|
||||||
|
"sender": sender,
|
||||||
|
"type": ev_type,
|
||||||
|
"state_key": state_key,
|
||||||
|
"content": content,
|
||||||
|
"origin_server_ts": ts,
|
||||||
|
"room_id": room_id(),
|
||||||
|
"origin": "foo",
|
||||||
|
"depth": 0,
|
||||||
|
"hashes": { "sha256": "hello" },
|
||||||
|
"signatures": {},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
json!({
|
||||||
|
"auth_events": auth_events,
|
||||||
|
"prev_events": prev_events,
|
||||||
|
"event_id": id,
|
||||||
|
"sender": sender,
|
||||||
|
"type": ev_type,
|
||||||
|
"content": content,
|
||||||
|
"origin_server_ts": ts,
|
||||||
|
"room_id": room_id(),
|
||||||
|
"origin": "foo",
|
||||||
|
"depth": 0,
|
||||||
|
"hashes": { "sha256": "hello" },
|
||||||
|
"signatures": {},
|
||||||
|
})
|
||||||
|
};
|
||||||
|
serde_json::from_value(json).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
// all graphs start with these input events
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
fn INITIAL_EVENTS() -> BTreeMap<EventId, StateEvent> {
|
||||||
|
// this is always called so we can init the logger here
|
||||||
|
let _ = LOGGER.call_once(|| {
|
||||||
|
tracer::fmt()
|
||||||
|
.with_env_filter(tracer::EnvFilter::from_default_env())
|
||||||
|
.init()
|
||||||
|
});
|
||||||
|
|
||||||
|
vec![
|
||||||
|
to_pdu_event::<EventId>(
|
||||||
|
"CREATE",
|
||||||
|
alice(),
|
||||||
|
EventType::RoomCreate,
|
||||||
|
Some(""),
|
||||||
|
json!({ "creator": alice() }),
|
||||||
|
&[],
|
||||||
|
&[],
|
||||||
|
),
|
||||||
|
to_pdu_event(
|
||||||
|
"IMA",
|
||||||
|
alice(),
|
||||||
|
EventType::RoomMember,
|
||||||
|
Some(alice().to_string().as_str()),
|
||||||
|
member_content_join(),
|
||||||
|
&["CREATE"],
|
||||||
|
&["CREATE"],
|
||||||
|
),
|
||||||
|
to_pdu_event(
|
||||||
|
"IPOWER",
|
||||||
|
alice(),
|
||||||
|
EventType::RoomPowerLevels,
|
||||||
|
Some(""),
|
||||||
|
json!({"users": {alice().to_string(): 100}}),
|
||||||
|
&["CREATE", "IMA"],
|
||||||
|
&["IMA"],
|
||||||
|
),
|
||||||
|
to_pdu_event(
|
||||||
|
"IJR",
|
||||||
|
alice(),
|
||||||
|
EventType::RoomJoinRules,
|
||||||
|
Some(""),
|
||||||
|
json!({ "join_rule": JoinRule::Public }),
|
||||||
|
&["CREATE", "IMA", "IPOWER"],
|
||||||
|
&["IPOWER"],
|
||||||
|
),
|
||||||
|
to_pdu_event(
|
||||||
|
"IMB",
|
||||||
|
bob(),
|
||||||
|
EventType::RoomMember,
|
||||||
|
Some(bob().to_string().as_str()),
|
||||||
|
member_content_join(),
|
||||||
|
&["CREATE", "IJR", "IPOWER"],
|
||||||
|
&["IJR"],
|
||||||
|
),
|
||||||
|
to_pdu_event(
|
||||||
|
"IMC",
|
||||||
|
charlie(),
|
||||||
|
EventType::RoomMember,
|
||||||
|
Some(charlie().to_string().as_str()),
|
||||||
|
member_content_join(),
|
||||||
|
&["CREATE", "IJR", "IPOWER"],
|
||||||
|
&["IMB"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.map(|ev| (ev.event_id(), ev))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ban_pass() {
|
||||||
|
let events = INITIAL_EVENTS();
|
||||||
|
|
||||||
|
let prev = events
|
||||||
|
.values()
|
||||||
|
.find(|ev| ev.event_id().as_str().contains("IMC"));
|
||||||
|
|
||||||
|
let auth_events = events
|
||||||
|
.values()
|
||||||
|
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone()))
|
||||||
|
.collect::<StateMap<_>>();
|
||||||
|
|
||||||
|
let requester = Requester {
|
||||||
|
prev_event_ids: vec![event_id("IMC")],
|
||||||
|
room_id: &room_id(),
|
||||||
|
content: &member_content_ban(),
|
||||||
|
state_key: Some(charlie().to_string()),
|
||||||
|
sender: &alice(),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(valid_membership_change(requester, prev, &auth_events).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ban_fail() {
|
||||||
|
let events = INITIAL_EVENTS();
|
||||||
|
|
||||||
|
let prev = events
|
||||||
|
.values()
|
||||||
|
.find(|ev| ev.event_id().as_str().contains("IMC"));
|
||||||
|
|
||||||
|
let auth_events = events
|
||||||
|
.values()
|
||||||
|
.map(|ev| ((ev.kind(), ev.state_key()), ev.clone()))
|
||||||
|
.collect::<StateMap<_>>();
|
||||||
|
|
||||||
|
let requester = Requester {
|
||||||
|
prev_event_ids: vec![event_id("IMC")],
|
||||||
|
room_id: &room_id(),
|
||||||
|
content: &member_content_ban(),
|
||||||
|
state_key: Some(alice().to_string()),
|
||||||
|
sender: &charlie(),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(!valid_membership_change(requester, prev, &auth_events).unwrap())
|
||||||
|
}
|
@ -12,7 +12,7 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{StateEvent, StateMap, StateStore};
|
use state_res::{Error, Result, StateEvent, StateMap, StateStore};
|
||||||
use tracing_subscriber as tracer;
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
@ -57,12 +57,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore};
|
use state_res::{Error, Result, StateEvent, StateMap, StateResolution, StateStore};
|
||||||
use tracing_subscriber as tracer;
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
static LOGGER: Once = Once::new();
|
static LOGGER: Once = Once::new();
|
||||||
@ -108,17 +108,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
|||||||
&store,
|
&store,
|
||||||
);
|
);
|
||||||
match resolved {
|
match resolved {
|
||||||
Ok(ResolutionResult::Resolved(state)) => state,
|
Ok(state) => state,
|
||||||
Ok(ResolutionResult::Conflicted(state)) => panic!(
|
|
||||||
"conflicted: {:?}",
|
|
||||||
state
|
|
||||||
.iter()
|
|
||||||
.map(|map| map
|
|
||||||
.iter()
|
|
||||||
.map(|(key, id)| (key, id.to_string()))
|
|
||||||
.collect::<Vec<_>>())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
),
|
|
||||||
Err(e) => panic!("resolution for {} failed: {}", node, e),
|
Err(e) => panic!("resolution for {} failed: {}", node, e),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -200,12 +190,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -513,9 +503,8 @@ fn base_with_auth_chains() {
|
|||||||
|
|
||||||
let resolved: BTreeMap<_, EventId> =
|
let resolved: BTreeMap<_, EventId> =
|
||||||
match StateResolution::resolve(&room_id(), &RoomVersionId::Version2, &[], None, &store) {
|
match StateResolution::resolve(&room_id(), &RoomVersionId::Version2, &[], None, &store) {
|
||||||
Ok(ResolutionResult::Resolved(state)) => state,
|
Ok(state) => state,
|
||||||
Err(e) => panic!("{}", e),
|
Err(e) => panic!("{}", e),
|
||||||
_ => panic!("conflicted state left"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let resolved = resolved
|
let resolved = resolved
|
||||||
@ -583,9 +572,8 @@ fn ban_with_auth_chains2() {
|
|||||||
None,
|
None,
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(ResolutionResult::Resolved(state)) => state,
|
Ok(state) => state,
|
||||||
Err(e) => panic!("{}", e),
|
Err(e) => panic!("{}", e),
|
||||||
_ => panic!("conflicted state left"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
|
@ -1,9 +1,4 @@
|
|||||||
use std::{
|
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
|
||||||
cell::RefCell,
|
|
||||||
collections::{BTreeMap, BTreeSet},
|
|
||||||
convert::TryFrom,
|
|
||||||
time::UNIX_EPOCH,
|
|
||||||
};
|
|
||||||
|
|
||||||
use maplit::btreemap;
|
use maplit::btreemap;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
@ -18,7 +13,7 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore};
|
use state_res::{Error, Result, StateEvent, StateMap, StateResolution, StateStore};
|
||||||
use tracing_subscriber as tracer;
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
@ -378,17 +373,7 @@ fn do_check(events: &[StateEvent], edges: Vec<Vec<EventId>>, expected_state_ids:
|
|||||||
&store,
|
&store,
|
||||||
);
|
);
|
||||||
match resolved {
|
match resolved {
|
||||||
Ok(ResolutionResult::Resolved(state)) => state,
|
Ok(state) => state,
|
||||||
Ok(ResolutionResult::Conflicted(state)) => panic!(
|
|
||||||
"conflicted: {:?}",
|
|
||||||
state
|
|
||||||
.iter()
|
|
||||||
.map(|map| map
|
|
||||||
.iter()
|
|
||||||
.map(|(key, id)| (key, id.to_string()))
|
|
||||||
.collect::<Vec<_>>())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
),
|
|
||||||
Err(e) => panic!("resolution for {} failed: {}", node, e),
|
Err(e) => panic!("resolution for {} failed: {}", node, e),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -729,9 +714,8 @@ fn test_event_map_none() {
|
|||||||
None,
|
None,
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(ResolutionResult::Resolved(state)) => state,
|
Ok(state) => state,
|
||||||
Err(e) => panic!("{}", e),
|
Err(e) => panic!("{}", e),
|
||||||
_ => panic!("conflicted state left"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(expected, resolved)
|
assert_eq!(expected, resolved)
|
||||||
@ -768,83 +752,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result<Vec<StateEvent>, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
Ok(self
|
|
||||||
.0
|
|
||||||
.borrow()
|
|
||||||
.iter()
|
|
||||||
.filter(|e| events.contains(e.0))
|
|
||||||
.map(|(_, s)| s)
|
|
||||||
.cloned()
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_event_ids(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
let mut result = vec![];
|
|
||||||
let mut stack = event_ids.to_vec();
|
|
||||||
|
|
||||||
// DFS for auth event chain
|
|
||||||
while !stack.is_empty() {
|
|
||||||
let ev_id = stack.pop().unwrap();
|
|
||||||
if result.contains(&ev_id) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.push(ev_id.clone());
|
|
||||||
|
|
||||||
let event = self.get_event(room_id, &ev_id).unwrap();
|
|
||||||
|
|
||||||
stack.extend(event.auth_events());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_chain_diff(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: Vec<Vec<EventId>>,
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
let mut chains = vec![];
|
|
||||||
for ids in event_ids {
|
|
||||||
// TODO state store `auth_event_ids` returns self in the event ids list
|
|
||||||
// when an event returns `auth_event_ids` self is not contained
|
|
||||||
let chain = self
|
|
||||||
.auth_event_ids(room_id, &ids)?
|
|
||||||
.into_iter()
|
|
||||||
.collect::<BTreeSet<_>>();
|
|
||||||
chains.push(chain);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(chain) = chains.first() {
|
|
||||||
let rest = chains.iter().skip(1).flatten().cloned().collect();
|
|
||||||
let common = chain.intersection(&rest).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
Ok(chains
|
|
||||||
.iter()
|
|
||||||
.flatten()
|
|
||||||
.filter(|id| !common.contains(&id))
|
|
||||||
.cloned()
|
|
||||||
.collect::<BTreeSet<_>>()
|
|
||||||
.into_iter()
|
|
||||||
.collect())
|
|
||||||
} else {
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user