diff --git a/benches/state_res_bench.rs b/benches/state_res_bench.rs index d9cbb3ab..369750c6 100644 --- a/benches/state_res_bench.rs +++ b/benches/state_res_bench.rs @@ -3,12 +3,7 @@ // `cargo bench unknown option --save-baseline`. // To pass args to criterion, use this form // `cargo bench --bench -- --save-baseline `. -use std::{ - cell::RefCell, - collections::{BTreeMap, BTreeSet}, - convert::TryFrom, - time::UNIX_EPOCH, -}; +use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH}; use criterion::{criterion_group, criterion_main, Criterion}; use maplit::btreemap; @@ -24,7 +19,9 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; +use state_res::{ + Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore, +}; static mut SERVER_TIMESTAMP: i32 = 0; @@ -137,82 +134,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result, String> { - Ok(self - .0 - .borrow() - .iter() - .filter(|e| events.contains(e.0)) - .map(|(_, s)| s) - .cloned() - .collect()) - } - - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) - } - - fn auth_event_ids( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { - let mut result = vec![]; - let mut stack = event_ids.to_vec(); - - // DFS for auth event chain - while !stack.is_empty() { - let ev_id = stack.pop().unwrap(); - if result.contains(&ev_id) { - continue; - } - - result.push(ev_id.clone()); - - let event = self.get_event(room_id, &ev_id).unwrap(); - stack.extend(event.auth_events()); - } - - Ok(result) - } - - fn auth_chain_diff( - &self, - room_id: &RoomId, - event_ids: Vec>, - ) -> Result, String> { - use itertools::Itertools; - - let mut chains = vec![]; - for ids in event_ids { - // TODO state store `auth_event_ids` returns self in the event ids list - // when an event returns `auth_event_ids` self is not contained - let chain = self - .auth_event_ids(room_id, &ids)? - .into_iter() - .collect::>(); - chains.push(chain); - } - - if let Some(chain) = chains.first() { - let rest = chains.iter().skip(1).flatten().cloned().collect(); - let common = chain.intersection(&rest).collect::>(); - - Ok(chains - .iter() - .flatten() - .filter(|id| !common.contains(&id)) - .cloned() - .collect::>() - .into_iter() - .collect()) - } else { - Ok(vec![]) - } + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/src/error.rs b/src/error.rs index 4945f975..51ca8fcc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,7 +20,18 @@ pub enum Error { #[error("Not found error: {0}")] NotFound(String), - // TODO remove once the correct errors are used - #[error("an error occured {0}")] - TempString(String), + #[error("Invalid PDU: {0}")] + InvalidPdu(String), + + #[error("Conversion failed: {0}")] + ConversionError(String), + + #[error("{0}")] + Custom(Box), +} + +impl Error { + pub fn custom(e: E) -> Self { + Self::Custom(Box::new(e)) + } } diff --git a/src/event_auth.rs b/src/event_auth.rs index a904fb4d..54771408 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -83,26 +83,26 @@ pub fn auth_types_for_event( /// * then there are checks for specific event types pub fn auth_check( room_version: &RoomVersionId, - event: &StateEvent, + incoming_event: &StateEvent, prev_event: Option<&StateEvent>, auth_events: StateMap, do_sig_check: bool, ) -> Result { - tracing::info!("auth_check beginning for {}", event.event_id().as_str()); + tracing::info!("auth_check beginning for {}", incoming_event.kind()); // don't let power from other rooms be used for auth_event in auth_events.values() { - if auth_event.room_id() != event.room_id() { + if auth_event.room_id() != incoming_event.room_id() { tracing::warn!("found auth event that did not match event's room_id"); return Ok(false); } } if do_sig_check { - let sender_domain = event.sender().server_name(); + let sender_domain = incoming_event.sender().server_name(); - let is_invite_via_3pid = if event.kind() == EventType::RoomMember { - event + let is_invite_via_3pid = if incoming_event.kind() == EventType::RoomMember { + incoming_event .deserialize_content::() .map(|c| c.membership == MembershipState::Invite && c.third_party_invite.is_some()) .unwrap_or_default() @@ -111,15 +111,15 @@ pub fn auth_check( }; // check the event has been signed by the domain of the sender - if event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid { + if incoming_event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid { tracing::warn!("event not signed by sender's server"); return Ok(false); } - if event.room_version() == RoomVersionId::Version1 - && event + if incoming_event.room_version() == RoomVersionId::Version1 + && incoming_event .signatures() - .get(event.event_id().server_name().unwrap()) + .get(incoming_event.event_id().server_name().unwrap()) .is_none() { tracing::warn!("event not signed by event_id's server"); @@ -134,24 +134,26 @@ pub fn auth_check( // Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules // // 1. If type is m.room.create: - if event.kind() == EventType::RoomCreate { + if incoming_event.kind() == EventType::RoomCreate { tracing::info!("start m.room.create check"); // If it has any previous events, reject - if !event.prev_event_ids().is_empty() { + if !incoming_event.prev_event_ids().is_empty() { tracing::warn!("the room creation event had previous events"); return Ok(false); } // If the domain of the room_id does not match the domain of the sender, reject - if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) { + if incoming_event.room_id().map(|id| id.server_name()) + != Some(incoming_event.sender().server_name()) + { tracing::warn!("creation events server does not match sender"); return Ok(false); // creation events room id does not match senders } // If content.room_version is present and is not a recognized version, reject if serde_json::from_value::( - event + incoming_event .content() .get("room_version") .cloned() @@ -165,7 +167,7 @@ pub fn auth_check( } // If content has no creator field, reject - if event.content().get("creator").is_none() { + if incoming_event.content().get("creator").is_none() { tracing::warn!("no creator field found in room create content"); return Ok(false); } @@ -187,10 +189,10 @@ pub fn auth_check( // [synapse] checks for federation here // 4. if type is m.room.aliases - if event.kind() == EventType::RoomAliases { + if incoming_event.kind() == EventType::RoomAliases { tracing::info!("starting m.room.aliases check"); // TODO && room_version "special case aliases auth" ?? - if event.state_key().is_none() { + if incoming_event.state_key().is_none() { tracing::warn!("no state_key field found for event"); return Ok(false); // must have state_key } @@ -202,7 +204,9 @@ pub fn auth_check( // } // If sender's domain doesn't matches state_key, reject - if event.state_key().as_deref() != Some(event.sender().server_name().as_str()) { + if incoming_event.state_key().as_deref() + != Some(incoming_event.sender().server_name().as_str()) + { tracing::warn!("state_key does not match sender"); return Ok(false); } @@ -211,15 +215,15 @@ pub fn auth_check( return Ok(true); } - if event.kind() == EventType::RoomMember { + if incoming_event.kind() == EventType::RoomMember { tracing::info!("starting m.room.member check"); - if event.state_key().is_none() { + if incoming_event.state_key().is_none() { tracing::warn!("no state_key found for m.room.member event"); return Ok(false); } - if event + if incoming_event .deserialize_content::() .is_err() { @@ -227,7 +231,7 @@ pub fn auth_check( return Ok(false); } - if !valid_membership_change(event.to_requester(), prev_event, &auth_events)? { + if !valid_membership_change(incoming_event.to_requester(), prev_event, &auth_events)? { return Ok(false); } @@ -236,7 +240,7 @@ pub fn auth_check( } // If the sender's current membership state is not join, reject - match check_event_sender_in_room(event, &auth_events) { + match check_event_sender_in_room(incoming_event.sender(), &auth_events) { Some(true) => {} // sender in room Some(false) => { tracing::warn!("sender's membership is not join"); @@ -250,22 +254,24 @@ pub fn auth_check( // Special case to allow m.room.third_party_invite events where ever // a user is allowed to issue invites - if event.kind() == EventType::RoomThirdPartyInvite { + if incoming_event.kind() == EventType::RoomThirdPartyInvite { // TODO impl this unimplemented!("third party invite") } // If the event type's required power level is greater than the sender's power level, reject // If the event has a state_key that starts with an @ and does not match the sender, reject. - if !can_send_event(event, &auth_events)? { + if !can_send_event(incoming_event, &auth_events)? { tracing::warn!("user cannot send event"); return Ok(false); } - if event.kind() == EventType::RoomPowerLevels { + if incoming_event.kind() == EventType::RoomPowerLevels { tracing::info!("starting m.room.power_levels check"); - if let Some(required_pwr_lvl) = check_power_levels(room_version, event, &auth_events) { + if let Some(required_pwr_lvl) = + check_power_levels(room_version, incoming_event, &auth_events) + { if !required_pwr_lvl { tracing::warn!("power level was not allowed"); return Ok(false); @@ -277,8 +283,8 @@ pub fn auth_check( tracing::info!("power levels event allowed"); } - if event.kind() == EventType::RoomRedaction { - if let RedactAllowed::No = check_redaction(room_version, event, &auth_events)? { + if incoming_event.kind() == EventType::RoomRedaction { + if let RedactAllowed::No = check_redaction(room_version, incoming_event, &auth_events)? { return Ok(false); } } @@ -287,20 +293,6 @@ pub fn auth_check( Ok(true) } -/// Can this room federate based on its m.room.create event. -pub fn can_federate(auth_events: &StateMap) -> bool { - let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); - if let Some(ev) = creation_event { - if let Some(fed) = ev.content().get("m.federate") { - fed == "true" - } else { - false - } - } else { - false - } -} - /// Does the user who sent this member event have required power levels to do so. /// /// * `user` - Information about the membership event and user making the request. @@ -316,7 +308,7 @@ pub fn valid_membership_change( let state_key = if let Some(s) = user.state_key.as_ref() { s } else { - return Err(Error::TempString("State event requires state_key".into())); + return Err(Error::InvalidPdu("State event requires state_key".into())); }; let content = @@ -324,8 +316,8 @@ pub fn valid_membership_change( let target_membership = content.membership; - let target_user_id = - UserId::try_from(state_key.as_str()).map_err(|e| Error::TempString(format!("{}", e)))?; + let target_user_id = UserId::try_from(state_key.as_str()) + .map_err(|e| Error::ConversionError(format!("{}", e)))?; let key = (EventType::RoomMember, Some(user.sender.to_string())); let sender = auth_events.get(&key); @@ -464,10 +456,10 @@ pub fn valid_membership_change( /// Is the event's sender in the room that they sent the event to. pub fn check_event_sender_in_room( - event: &StateEvent, + sender: &UserId, auth_events: &StateMap, ) -> Option { - let mem = auth_events.get(&(EventType::RoomMember, Some(event.sender().to_string())))?; + let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?; // TODO this is check_membership a helper fn in synapse but it does this Some( mem.deserialize_content::() @@ -692,6 +684,20 @@ pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipStat } } +/// Can this room federate based on its m.room.create event. +pub fn can_federate(auth_events: &StateMap) -> bool { + let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); + if let Some(ev) = creation_event { + if let Some(fed) = ev.content().get("m.federate") { + fed == "true" + } else { + false + } + } else { + false + } +} + /// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. /// or return `default` if no power level event is found or zero if no field matches `name`. pub fn get_named_level(auth_events: &StateMap, name: &str, default: i64) -> i64 { diff --git a/src/lib.rs b/src/lib.rs index 42c9d197..375f22bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,7 +119,7 @@ impl StateResolution { for event in event_map.values() { if event.room_id() != Some(room_id) { - return Err(Error::TempString(format!( + return Err(Error::InvalidPdu(format!( "resolving event {} in room {}, when correct room is {}", event.event_id(), event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"), @@ -288,16 +288,14 @@ impl StateResolution { tracing::debug!("calculating auth chain difference"); - store - .auth_chain_diff( - room_id, - state_sets - .iter() - .map(|map| map.values().cloned().collect()) - .dedup() - .collect::>(), - ) - .map_err(Error::TempString) + store.auth_chain_diff( + room_id, + state_sets + .iter() + .map(|map| map.values().cloned().collect()) + .dedup() + .collect::>(), + ) } /// Events are sorted from "earliest" to "latest". They are compared using diff --git a/src/state_store.rs b/src/state_store.rs index 80621bf0..777913ed 100644 --- a/src/state_store.rs +++ b/src/state_store.rs @@ -2,18 +2,14 @@ use std::collections::BTreeSet; use ruma::identifiers::{EventId, RoomId}; -use crate::StateEvent; +use crate::{Result, StateEvent}; pub trait StateStore { /// Return a single event based on the EventId. - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result; + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result; /// Returns the events that correspond to the `event_ids` sorted in the same order. - fn get_events( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { + fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result> { let mut events = vec![]; for id in event_ids { events.push(self.get_event(room_id, id)?); @@ -22,11 +18,7 @@ pub trait StateStore { } /// Returns a Vec of the related auth events to the given `event`. - fn auth_event_ids( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { + fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result> { let mut result = vec![]; let mut stack = event_ids.to_vec(); @@ -52,7 +44,7 @@ pub trait StateStore { &self, room_id: &RoomId, event_ids: Vec>, - ) -> Result, String> { + ) -> Result> { let mut chains = vec![]; for ids in event_ids { // TODO state store `auth_event_ids` returns self in the event ids list diff --git a/tests/event_auth.rs b/tests/event_auth.rs index 6ae41778..4fb91005 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -18,7 +18,7 @@ use state_res::{ // auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction, valid_membership_change, }, - Requester, StateEvent, StateMap, StateStore, + Requester, StateEvent, StateMap, StateStore, Result, Error }; use tracing_subscriber as tracer; @@ -75,12 +75,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/tests/event_sorting.rs b/tests/event_sorting.rs index 4cd0317d..c9d9248a 100644 --- a/tests/event_sorting.rs +++ b/tests/event_sorting.rs @@ -12,7 +12,7 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{StateEvent, StateMap, StateStore}; +use state_res::{Error, Result, StateEvent, StateMap, StateStore}; use tracing_subscriber as tracer; use std::sync::Once; @@ -57,12 +57,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/tests/res_with_auth_ids.rs b/tests/res_with_auth_ids.rs index b4dbab52..7b2d9d35 100644 --- a/tests/res_with_auth_ids.rs +++ b/tests/res_with_auth_ids.rs @@ -14,7 +14,9 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; +use state_res::{ + Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore, +}; use tracing_subscriber as tracer; static LOGGER: Once = Once::new(); @@ -200,12 +202,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/tests/state_res.rs b/tests/state_res.rs index 40a89e4b..aaed38ea 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -1,9 +1,4 @@ -use std::{ - cell::RefCell, - collections::{BTreeMap, BTreeSet}, - convert::TryFrom, - time::UNIX_EPOCH, -}; +use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH}; use maplit::btreemap; use ruma::{ @@ -18,7 +13,9 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; +use state_res::{ + Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore, +}; use tracing_subscriber as tracer; use std::sync::Once; @@ -768,83 +765,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result, String> { - Ok(self - .0 - .borrow() - .iter() - .filter(|e| events.contains(e.0)) - .map(|(_, s)| s) - .cloned() - .collect()) - } - - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) - } - - fn auth_event_ids( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { - let mut result = vec![]; - let mut stack = event_ids.to_vec(); - - // DFS for auth event chain - while !stack.is_empty() { - let ev_id = stack.pop().unwrap(); - if result.contains(&ev_id) { - continue; - } - - result.push(ev_id.clone()); - - let event = self.get_event(room_id, &ev_id).unwrap(); - - stack.extend(event.auth_events()); - } - - Ok(result) - } - - fn auth_chain_diff( - &self, - room_id: &RoomId, - event_ids: Vec>, - ) -> Result, String> { - use itertools::Itertools; - - let mut chains = vec![]; - for ids in event_ids { - // TODO state store `auth_event_ids` returns self in the event ids list - // when an event returns `auth_event_ids` self is not contained - let chain = self - .auth_event_ids(room_id, &ids)? - .into_iter() - .collect::>(); - chains.push(chain); - } - - if let Some(chain) = chains.first() { - let rest = chains.iter().skip(1).flatten().cloned().collect(); - let common = chain.intersection(&rest).collect::>(); - - Ok(chains - .iter() - .flatten() - .filter(|id| !common.contains(&id)) - .cloned() - .collect::>() - .into_iter() - .collect()) - } else { - Ok(vec![]) - } + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } }