Use own Error type for all errors
This commit is contained in:
parent
b846aec94a
commit
394d26744a
@ -3,12 +3,7 @@
|
|||||||
// `cargo bench unknown option --save-baseline`.
|
// `cargo bench unknown option --save-baseline`.
|
||||||
// To pass args to criterion, use this form
|
// To pass args to criterion, use this form
|
||||||
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
|
// `cargo bench --bench <name of the bench> -- --save-baseline <name>`.
|
||||||
use std::{
|
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
|
||||||
cell::RefCell,
|
|
||||||
collections::{BTreeMap, BTreeSet},
|
|
||||||
convert::TryFrom,
|
|
||||||
time::UNIX_EPOCH,
|
|
||||||
};
|
|
||||||
|
|
||||||
use criterion::{criterion_group, criterion_main, Criterion};
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
use maplit::btreemap;
|
use maplit::btreemap;
|
||||||
@ -24,7 +19,9 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore};
|
use state_res::{
|
||||||
|
Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore,
|
||||||
|
};
|
||||||
|
|
||||||
static mut SERVER_TIMESTAMP: i32 = 0;
|
static mut SERVER_TIMESTAMP: i32 = 0;
|
||||||
|
|
||||||
@ -137,82 +134,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result<Vec<StateEvent>, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
Ok(self
|
|
||||||
.0
|
|
||||||
.borrow()
|
|
||||||
.iter()
|
|
||||||
.filter(|e| events.contains(e.0))
|
|
||||||
.map(|(_, s)| s)
|
|
||||||
.cloned()
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_event_ids(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
let mut result = vec![];
|
|
||||||
let mut stack = event_ids.to_vec();
|
|
||||||
|
|
||||||
// DFS for auth event chain
|
|
||||||
while !stack.is_empty() {
|
|
||||||
let ev_id = stack.pop().unwrap();
|
|
||||||
if result.contains(&ev_id) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.push(ev_id.clone());
|
|
||||||
|
|
||||||
let event = self.get_event(room_id, &ev_id).unwrap();
|
|
||||||
stack.extend(event.auth_events());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_chain_diff(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: Vec<Vec<EventId>>,
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
let mut chains = vec![];
|
|
||||||
for ids in event_ids {
|
|
||||||
// TODO state store `auth_event_ids` returns self in the event ids list
|
|
||||||
// when an event returns `auth_event_ids` self is not contained
|
|
||||||
let chain = self
|
|
||||||
.auth_event_ids(room_id, &ids)?
|
|
||||||
.into_iter()
|
|
||||||
.collect::<BTreeSet<_>>();
|
|
||||||
chains.push(chain);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(chain) = chains.first() {
|
|
||||||
let rest = chains.iter().skip(1).flatten().cloned().collect();
|
|
||||||
let common = chain.intersection(&rest).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
Ok(chains
|
|
||||||
.iter()
|
|
||||||
.flatten()
|
|
||||||
.filter(|id| !common.contains(&id))
|
|
||||||
.cloned()
|
|
||||||
.collect::<BTreeSet<_>>()
|
|
||||||
.into_iter()
|
|
||||||
.collect())
|
|
||||||
} else {
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
17
src/error.rs
17
src/error.rs
@ -20,7 +20,18 @@ pub enum Error {
|
|||||||
#[error("Not found error: {0}")]
|
#[error("Not found error: {0}")]
|
||||||
NotFound(String),
|
NotFound(String),
|
||||||
|
|
||||||
// TODO remove once the correct errors are used
|
#[error("Invalid PDU: {0}")]
|
||||||
#[error("an error occured {0}")]
|
InvalidPdu(String),
|
||||||
TempString(String),
|
|
||||||
|
#[error("Conversion failed: {0}")]
|
||||||
|
ConversionError(String),
|
||||||
|
|
||||||
|
#[error("{0}")]
|
||||||
|
Custom(Box<dyn std::error::Error>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
pub fn custom<E: std::error::Error + 'static>(e: E) -> Self {
|
||||||
|
Self::Custom(Box::new(e))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,26 +83,26 @@ pub fn auth_types_for_event(
|
|||||||
/// * then there are checks for specific event types
|
/// * then there are checks for specific event types
|
||||||
pub fn auth_check(
|
pub fn auth_check(
|
||||||
room_version: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
event: &StateEvent,
|
incoming_event: &StateEvent,
|
||||||
prev_event: Option<&StateEvent>,
|
prev_event: Option<&StateEvent>,
|
||||||
auth_events: StateMap<StateEvent>,
|
auth_events: StateMap<StateEvent>,
|
||||||
do_sig_check: bool,
|
do_sig_check: bool,
|
||||||
) -> Result<bool> {
|
) -> Result<bool> {
|
||||||
tracing::info!("auth_check beginning for {}", event.event_id().as_str());
|
tracing::info!("auth_check beginning for {}", incoming_event.kind());
|
||||||
|
|
||||||
// don't let power from other rooms be used
|
// don't let power from other rooms be used
|
||||||
for auth_event in auth_events.values() {
|
for auth_event in auth_events.values() {
|
||||||
if auth_event.room_id() != event.room_id() {
|
if auth_event.room_id() != incoming_event.room_id() {
|
||||||
tracing::warn!("found auth event that did not match event's room_id");
|
tracing::warn!("found auth event that did not match event's room_id");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if do_sig_check {
|
if do_sig_check {
|
||||||
let sender_domain = event.sender().server_name();
|
let sender_domain = incoming_event.sender().server_name();
|
||||||
|
|
||||||
let is_invite_via_3pid = if event.kind() == EventType::RoomMember {
|
let is_invite_via_3pid = if incoming_event.kind() == EventType::RoomMember {
|
||||||
event
|
incoming_event
|
||||||
.deserialize_content::<room::member::MemberEventContent>()
|
.deserialize_content::<room::member::MemberEventContent>()
|
||||||
.map(|c| c.membership == MembershipState::Invite && c.third_party_invite.is_some())
|
.map(|c| c.membership == MembershipState::Invite && c.third_party_invite.is_some())
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
@ -111,15 +111,15 @@ pub fn auth_check(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// check the event has been signed by the domain of the sender
|
// check the event has been signed by the domain of the sender
|
||||||
if event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid {
|
if incoming_event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid {
|
||||||
tracing::warn!("event not signed by sender's server");
|
tracing::warn!("event not signed by sender's server");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.room_version() == RoomVersionId::Version1
|
if incoming_event.room_version() == RoomVersionId::Version1
|
||||||
&& event
|
&& incoming_event
|
||||||
.signatures()
|
.signatures()
|
||||||
.get(event.event_id().server_name().unwrap())
|
.get(incoming_event.event_id().server_name().unwrap())
|
||||||
.is_none()
|
.is_none()
|
||||||
{
|
{
|
||||||
tracing::warn!("event not signed by event_id's server");
|
tracing::warn!("event not signed by event_id's server");
|
||||||
@ -134,24 +134,26 @@ pub fn auth_check(
|
|||||||
// Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
|
// Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
|
||||||
//
|
//
|
||||||
// 1. If type is m.room.create:
|
// 1. If type is m.room.create:
|
||||||
if event.kind() == EventType::RoomCreate {
|
if incoming_event.kind() == EventType::RoomCreate {
|
||||||
tracing::info!("start m.room.create check");
|
tracing::info!("start m.room.create check");
|
||||||
|
|
||||||
// If it has any previous events, reject
|
// If it has any previous events, reject
|
||||||
if !event.prev_event_ids().is_empty() {
|
if !incoming_event.prev_event_ids().is_empty() {
|
||||||
tracing::warn!("the room creation event had previous events");
|
tracing::warn!("the room creation event had previous events");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the domain of the room_id does not match the domain of the sender, reject
|
// If the domain of the room_id does not match the domain of the sender, reject
|
||||||
if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) {
|
if incoming_event.room_id().map(|id| id.server_name())
|
||||||
|
!= Some(incoming_event.sender().server_name())
|
||||||
|
{
|
||||||
tracing::warn!("creation events server does not match sender");
|
tracing::warn!("creation events server does not match sender");
|
||||||
return Ok(false); // creation events room id does not match senders
|
return Ok(false); // creation events room id does not match senders
|
||||||
}
|
}
|
||||||
|
|
||||||
// If content.room_version is present and is not a recognized version, reject
|
// If content.room_version is present and is not a recognized version, reject
|
||||||
if serde_json::from_value::<RoomVersionId>(
|
if serde_json::from_value::<RoomVersionId>(
|
||||||
event
|
incoming_event
|
||||||
.content()
|
.content()
|
||||||
.get("room_version")
|
.get("room_version")
|
||||||
.cloned()
|
.cloned()
|
||||||
@ -165,7 +167,7 @@ pub fn auth_check(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If content has no creator field, reject
|
// If content has no creator field, reject
|
||||||
if event.content().get("creator").is_none() {
|
if incoming_event.content().get("creator").is_none() {
|
||||||
tracing::warn!("no creator field found in room create content");
|
tracing::warn!("no creator field found in room create content");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
@ -187,10 +189,10 @@ pub fn auth_check(
|
|||||||
// [synapse] checks for federation here
|
// [synapse] checks for federation here
|
||||||
|
|
||||||
// 4. if type is m.room.aliases
|
// 4. if type is m.room.aliases
|
||||||
if event.kind() == EventType::RoomAliases {
|
if incoming_event.kind() == EventType::RoomAliases {
|
||||||
tracing::info!("starting m.room.aliases check");
|
tracing::info!("starting m.room.aliases check");
|
||||||
// TODO && room_version "special case aliases auth" ??
|
// TODO && room_version "special case aliases auth" ??
|
||||||
if event.state_key().is_none() {
|
if incoming_event.state_key().is_none() {
|
||||||
tracing::warn!("no state_key field found for event");
|
tracing::warn!("no state_key field found for event");
|
||||||
return Ok(false); // must have state_key
|
return Ok(false); // must have state_key
|
||||||
}
|
}
|
||||||
@ -202,7 +204,9 @@ pub fn auth_check(
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// If sender's domain doesn't matches state_key, reject
|
// If sender's domain doesn't matches state_key, reject
|
||||||
if event.state_key().as_deref() != Some(event.sender().server_name().as_str()) {
|
if incoming_event.state_key().as_deref()
|
||||||
|
!= Some(incoming_event.sender().server_name().as_str())
|
||||||
|
{
|
||||||
tracing::warn!("state_key does not match sender");
|
tracing::warn!("state_key does not match sender");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
@ -211,15 +215,15 @@ pub fn auth_check(
|
|||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.kind() == EventType::RoomMember {
|
if incoming_event.kind() == EventType::RoomMember {
|
||||||
tracing::info!("starting m.room.member check");
|
tracing::info!("starting m.room.member check");
|
||||||
|
|
||||||
if event.state_key().is_none() {
|
if incoming_event.state_key().is_none() {
|
||||||
tracing::warn!("no state_key found for m.room.member event");
|
tracing::warn!("no state_key found for m.room.member event");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if event
|
if incoming_event
|
||||||
.deserialize_content::<room::member::MemberEventContent>()
|
.deserialize_content::<room::member::MemberEventContent>()
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
@ -227,7 +231,7 @@ pub fn auth_check(
|
|||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !valid_membership_change(event.to_requester(), prev_event, &auth_events)? {
|
if !valid_membership_change(incoming_event.to_requester(), prev_event, &auth_events)? {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +240,7 @@ pub fn auth_check(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the sender's current membership state is not join, reject
|
// If the sender's current membership state is not join, reject
|
||||||
match check_event_sender_in_room(event, &auth_events) {
|
match check_event_sender_in_room(incoming_event.sender(), &auth_events) {
|
||||||
Some(true) => {} // sender in room
|
Some(true) => {} // sender in room
|
||||||
Some(false) => {
|
Some(false) => {
|
||||||
tracing::warn!("sender's membership is not join");
|
tracing::warn!("sender's membership is not join");
|
||||||
@ -250,22 +254,24 @@ pub fn auth_check(
|
|||||||
|
|
||||||
// Special case to allow m.room.third_party_invite events where ever
|
// Special case to allow m.room.third_party_invite events where ever
|
||||||
// a user is allowed to issue invites
|
// a user is allowed to issue invites
|
||||||
if event.kind() == EventType::RoomThirdPartyInvite {
|
if incoming_event.kind() == EventType::RoomThirdPartyInvite {
|
||||||
// TODO impl this
|
// TODO impl this
|
||||||
unimplemented!("third party invite")
|
unimplemented!("third party invite")
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the event type's required power level is greater than the sender's power level, reject
|
// If the event type's required power level is greater than the sender's power level, reject
|
||||||
// If the event has a state_key that starts with an @ and does not match the sender, reject.
|
// If the event has a state_key that starts with an @ and does not match the sender, reject.
|
||||||
if !can_send_event(event, &auth_events)? {
|
if !can_send_event(incoming_event, &auth_events)? {
|
||||||
tracing::warn!("user cannot send event");
|
tracing::warn!("user cannot send event");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.kind() == EventType::RoomPowerLevels {
|
if incoming_event.kind() == EventType::RoomPowerLevels {
|
||||||
tracing::info!("starting m.room.power_levels check");
|
tracing::info!("starting m.room.power_levels check");
|
||||||
|
|
||||||
if let Some(required_pwr_lvl) = check_power_levels(room_version, event, &auth_events) {
|
if let Some(required_pwr_lvl) =
|
||||||
|
check_power_levels(room_version, incoming_event, &auth_events)
|
||||||
|
{
|
||||||
if !required_pwr_lvl {
|
if !required_pwr_lvl {
|
||||||
tracing::warn!("power level was not allowed");
|
tracing::warn!("power level was not allowed");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
@ -277,8 +283,8 @@ pub fn auth_check(
|
|||||||
tracing::info!("power levels event allowed");
|
tracing::info!("power levels event allowed");
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.kind() == EventType::RoomRedaction {
|
if incoming_event.kind() == EventType::RoomRedaction {
|
||||||
if let RedactAllowed::No = check_redaction(room_version, event, &auth_events)? {
|
if let RedactAllowed::No = check_redaction(room_version, incoming_event, &auth_events)? {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -287,20 +293,6 @@ pub fn auth_check(
|
|||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Can this room federate based on its m.room.create event.
|
|
||||||
pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
|
|
||||||
let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into())));
|
|
||||||
if let Some(ev) = creation_event {
|
|
||||||
if let Some(fed) = ev.content().get("m.federate") {
|
|
||||||
fed == "true"
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Does the user who sent this member event have required power levels to do so.
|
/// Does the user who sent this member event have required power levels to do so.
|
||||||
///
|
///
|
||||||
/// * `user` - Information about the membership event and user making the request.
|
/// * `user` - Information about the membership event and user making the request.
|
||||||
@ -316,7 +308,7 @@ pub fn valid_membership_change(
|
|||||||
let state_key = if let Some(s) = user.state_key.as_ref() {
|
let state_key = if let Some(s) = user.state_key.as_ref() {
|
||||||
s
|
s
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::TempString("State event requires state_key".into()));
|
return Err(Error::InvalidPdu("State event requires state_key".into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
let content =
|
let content =
|
||||||
@ -324,8 +316,8 @@ pub fn valid_membership_change(
|
|||||||
|
|
||||||
let target_membership = content.membership;
|
let target_membership = content.membership;
|
||||||
|
|
||||||
let target_user_id =
|
let target_user_id = UserId::try_from(state_key.as_str())
|
||||||
UserId::try_from(state_key.as_str()).map_err(|e| Error::TempString(format!("{}", e)))?;
|
.map_err(|e| Error::ConversionError(format!("{}", e)))?;
|
||||||
|
|
||||||
let key = (EventType::RoomMember, Some(user.sender.to_string()));
|
let key = (EventType::RoomMember, Some(user.sender.to_string()));
|
||||||
let sender = auth_events.get(&key);
|
let sender = auth_events.get(&key);
|
||||||
@ -464,10 +456,10 @@ pub fn valid_membership_change(
|
|||||||
|
|
||||||
/// Is the event's sender in the room that they sent the event to.
|
/// Is the event's sender in the room that they sent the event to.
|
||||||
pub fn check_event_sender_in_room(
|
pub fn check_event_sender_in_room(
|
||||||
event: &StateEvent,
|
sender: &UserId,
|
||||||
auth_events: &StateMap<StateEvent>,
|
auth_events: &StateMap<StateEvent>,
|
||||||
) -> Option<bool> {
|
) -> Option<bool> {
|
||||||
let mem = auth_events.get(&(EventType::RoomMember, Some(event.sender().to_string())))?;
|
let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?;
|
||||||
// TODO this is check_membership a helper fn in synapse but it does this
|
// TODO this is check_membership a helper fn in synapse but it does this
|
||||||
Some(
|
Some(
|
||||||
mem.deserialize_content::<room::member::MemberEventContent>()
|
mem.deserialize_content::<room::member::MemberEventContent>()
|
||||||
@ -692,6 +684,20 @@ pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipStat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Can this room federate based on its m.room.create event.
|
||||||
|
pub fn can_federate(auth_events: &StateMap<StateEvent>) -> bool {
|
||||||
|
let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into())));
|
||||||
|
if let Some(ev) = creation_event {
|
||||||
|
if let Some(fed) = ev.content().get("m.federate") {
|
||||||
|
fed == "true"
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content.
|
/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content.
|
||||||
/// or return `default` if no power level event is found or zero if no field matches `name`.
|
/// or return `default` if no power level event is found or zero if no field matches `name`.
|
||||||
pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default: i64) -> i64 {
|
pub fn get_named_level(auth_events: &StateMap<StateEvent>, name: &str, default: i64) -> i64 {
|
||||||
|
20
src/lib.rs
20
src/lib.rs
@ -119,7 +119,7 @@ impl StateResolution {
|
|||||||
|
|
||||||
for event in event_map.values() {
|
for event in event_map.values() {
|
||||||
if event.room_id() != Some(room_id) {
|
if event.room_id() != Some(room_id) {
|
||||||
return Err(Error::TempString(format!(
|
return Err(Error::InvalidPdu(format!(
|
||||||
"resolving event {} in room {}, when correct room is {}",
|
"resolving event {} in room {}, when correct room is {}",
|
||||||
event.event_id(),
|
event.event_id(),
|
||||||
event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"),
|
event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"),
|
||||||
@ -288,16 +288,14 @@ impl StateResolution {
|
|||||||
|
|
||||||
tracing::debug!("calculating auth chain difference");
|
tracing::debug!("calculating auth chain difference");
|
||||||
|
|
||||||
store
|
store.auth_chain_diff(
|
||||||
.auth_chain_diff(
|
room_id,
|
||||||
room_id,
|
state_sets
|
||||||
state_sets
|
.iter()
|
||||||
.iter()
|
.map(|map| map.values().cloned().collect())
|
||||||
.map(|map| map.values().cloned().collect())
|
.dedup()
|
||||||
.dedup()
|
.collect::<Vec<_>>(),
|
||||||
.collect::<Vec<_>>(),
|
)
|
||||||
)
|
|
||||||
.map_err(Error::TempString)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Events are sorted from "earliest" to "latest". They are compared using
|
/// Events are sorted from "earliest" to "latest". They are compared using
|
||||||
|
@ -2,18 +2,14 @@ use std::collections::BTreeSet;
|
|||||||
|
|
||||||
use ruma::identifiers::{EventId, RoomId};
|
use ruma::identifiers::{EventId, RoomId};
|
||||||
|
|
||||||
use crate::StateEvent;
|
use crate::{Result, StateEvent};
|
||||||
|
|
||||||
pub trait StateStore {
|
pub trait StateStore {
|
||||||
/// Return a single event based on the EventId.
|
/// Return a single event based on the EventId.
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String>;
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent>;
|
||||||
|
|
||||||
/// Returns the events that correspond to the `event_ids` sorted in the same order.
|
/// Returns the events that correspond to the `event_ids` sorted in the same order.
|
||||||
fn get_events(
|
fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<StateEvent>> {
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<StateEvent>, String> {
|
|
||||||
let mut events = vec![];
|
let mut events = vec![];
|
||||||
for id in event_ids {
|
for id in event_ids {
|
||||||
events.push(self.get_event(room_id, id)?);
|
events.push(self.get_event(room_id, id)?);
|
||||||
@ -22,11 +18,7 @@ pub trait StateStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Vec of the related auth events to the given `event`.
|
/// Returns a Vec of the related auth events to the given `event`.
|
||||||
fn auth_event_ids(
|
fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<EventId>> {
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
let mut result = vec![];
|
let mut result = vec![];
|
||||||
let mut stack = event_ids.to_vec();
|
let mut stack = event_ids.to_vec();
|
||||||
|
|
||||||
@ -52,7 +44,7 @@ pub trait StateStore {
|
|||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
event_ids: Vec<Vec<EventId>>,
|
event_ids: Vec<Vec<EventId>>,
|
||||||
) -> Result<Vec<EventId>, String> {
|
) -> Result<Vec<EventId>> {
|
||||||
let mut chains = vec![];
|
let mut chains = vec![];
|
||||||
for ids in event_ids {
|
for ids in event_ids {
|
||||||
// TODO state store `auth_event_ids` returns self in the event ids list
|
// TODO state store `auth_event_ids` returns self in the event ids list
|
||||||
|
@ -18,7 +18,7 @@ use state_res::{
|
|||||||
// auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction,
|
// auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction,
|
||||||
valid_membership_change,
|
valid_membership_change,
|
||||||
},
|
},
|
||||||
Requester, StateEvent, StateMap, StateStore,
|
Requester, StateEvent, StateMap, StateStore, Result, Error
|
||||||
};
|
};
|
||||||
use tracing_subscriber as tracer;
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
@ -75,12 +75,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{StateEvent, StateMap, StateStore};
|
use state_res::{Error, Result, StateEvent, StateMap, StateStore};
|
||||||
use tracing_subscriber as tracer;
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
@ -57,12 +57,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,7 +14,9 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore};
|
use state_res::{
|
||||||
|
Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore,
|
||||||
|
};
|
||||||
use tracing_subscriber as tracer;
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
static LOGGER: Once = Once::new();
|
static LOGGER: Once = Once::new();
|
||||||
@ -200,12 +202,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,9 +1,4 @@
|
|||||||
use std::{
|
use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH};
|
||||||
cell::RefCell,
|
|
||||||
collections::{BTreeMap, BTreeSet},
|
|
||||||
convert::TryFrom,
|
|
||||||
time::UNIX_EPOCH,
|
|
||||||
};
|
|
||||||
|
|
||||||
use maplit::btreemap;
|
use maplit::btreemap;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
@ -18,7 +13,9 @@ use ruma::{
|
|||||||
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
identifiers::{EventId, RoomId, RoomVersionId, UserId},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore};
|
use state_res::{
|
||||||
|
Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore,
|
||||||
|
};
|
||||||
use tracing_subscriber as tracer;
|
use tracing_subscriber as tracer;
|
||||||
|
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
@ -768,83 +765,12 @@ pub struct TestStore(RefCell<BTreeMap<EventId, StateEvent>>);
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
impl StateStore for TestStore {
|
impl StateStore for TestStore {
|
||||||
fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result<Vec<StateEvent>, String> {
|
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent> {
|
||||||
Ok(self
|
|
||||||
.0
|
|
||||||
.borrow()
|
|
||||||
.iter()
|
|
||||||
.filter(|e| events.contains(e.0))
|
|
||||||
.map(|(_, s)| s)
|
|
||||||
.cloned()
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<StateEvent, String> {
|
|
||||||
self.0
|
self.0
|
||||||
.borrow()
|
.borrow()
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or(format!("{} not found", event_id.to_string()))
|
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string())))
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_event_ids(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: &[EventId],
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
let mut result = vec![];
|
|
||||||
let mut stack = event_ids.to_vec();
|
|
||||||
|
|
||||||
// DFS for auth event chain
|
|
||||||
while !stack.is_empty() {
|
|
||||||
let ev_id = stack.pop().unwrap();
|
|
||||||
if result.contains(&ev_id) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.push(ev_id.clone());
|
|
||||||
|
|
||||||
let event = self.get_event(room_id, &ev_id).unwrap();
|
|
||||||
|
|
||||||
stack.extend(event.auth_events());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_chain_diff(
|
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event_ids: Vec<Vec<EventId>>,
|
|
||||||
) -> Result<Vec<EventId>, String> {
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
let mut chains = vec![];
|
|
||||||
for ids in event_ids {
|
|
||||||
// TODO state store `auth_event_ids` returns self in the event ids list
|
|
||||||
// when an event returns `auth_event_ids` self is not contained
|
|
||||||
let chain = self
|
|
||||||
.auth_event_ids(room_id, &ids)?
|
|
||||||
.into_iter()
|
|
||||||
.collect::<BTreeSet<_>>();
|
|
||||||
chains.push(chain);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(chain) = chains.first() {
|
|
||||||
let rest = chains.iter().skip(1).flatten().cloned().collect();
|
|
||||||
let common = chain.intersection(&rest).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
Ok(chains
|
|
||||||
.iter()
|
|
||||||
.flatten()
|
|
||||||
.filter(|id| !common.contains(&id))
|
|
||||||
.cloned()
|
|
||||||
.collect::<BTreeSet<_>>()
|
|
||||||
.into_iter()
|
|
||||||
.collect())
|
|
||||||
} else {
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user