diff --git a/src/event_auth.rs b/src/event_auth.rs index abf00573..ace434de 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -10,7 +10,11 @@ use ruma::{ }; use serde_json::json; -use crate::{room_version::RoomVersion, state_event::StateEvent, StateMap}; +use crate::{ + room_version::RoomVersion, + state_event::{Requester, StateEvent}, + StateMap, +}; /// Represents the 3 event redaction outcomes. pub enum RedactAllowed { @@ -178,7 +182,7 @@ pub fn auth_check( if event.kind() == EventType::RoomMember { tracing::info!("starting m.room.member check"); - if !is_membership_change_allowed(event, &auth_events)? { + if !is_membership_change_allowed(event.to_requester(), &auth_events)? { return Some(false); } @@ -249,21 +253,21 @@ pub fn can_federate(auth_events: &StateMap) -> bool { /// Dose the user who sent this member event have required power levels to do so. pub fn is_membership_change_allowed( - event: &StateEvent, + user: Requester<'_>, auth_events: &StateMap, ) -> Option { - let content = event - .deserialize_content::() - .ok() - .unwrap(); + let content = + // TODO return error + serde_json::from_str::(&user.content.to_string()).ok()?; + let membership = content.membership; // check if this is the room creator joining - if event.prev_event_ids().len() == 1 && membership == MembershipState::Join { + if user.prev_event_ids.len() == 1 && membership == MembershipState::Join { if let Some(create) = auth_events.get(&(EventType::RoomCreate, Some("".into()))) { if let Ok(create_ev) = create.deserialize_content::() { - if event.state_key() == Some(create_ev.creator.to_string()) { + if user.state_key == Some(create_ev.creator.to_string()) { tracing::debug!("m.room.member event allowed via m.room.create"); return Some(true); } @@ -271,16 +275,16 @@ pub fn is_membership_change_allowed( } } - let target_user_id = UserId::try_from(event.state_key().unwrap()).ok().unwrap(); + let target_user_id = UserId::try_from(user.state_key.as_deref().unwrap()) + .ok() + .unwrap(); // if the server_names are different and federation is NOT allowed - if event.room_id().unwrap().server_name() != target_user_id.server_name() - && !can_federate(auth_events) - { + if user.room_id.server_name() != target_user_id.server_name() && !can_federate(auth_events) { tracing::warn!("server cannot federate"); return Some(false); } - let key = (EventType::RoomMember, Some(event.sender().to_string())); + let key = (EventType::RoomMember, Some(user.sender.to_string())); let caller = auth_events.get(&key); let caller_in_room = caller.is_some() && check_membership(caller, MembershipState::Join); @@ -299,12 +303,11 @@ pub fn is_membership_change_allowed( if let Some(jr) = join_rules_event { join_rule = jr .deserialize_content::() - .ok() - .unwrap() // TODO these are errors? and should be treated as a DB failure? + .ok()? // TODO these are errors? and should be treated as a DB failure? .join_rule; } - let user_level = get_user_power_level(event.sender(), auth_events); + let user_level = get_user_power_level(user.sender, auth_events); let target_level = get_user_power_level(&target_user_id, auth_events); // synapse has a not "what to do for default here 50" @@ -321,14 +324,14 @@ pub fn is_membership_change_allowed( "membership": membership, "join_rule": join_rule, "target_user_id": target_user_id, - "event.user_id": event.sender(), + "event.user_id": user.sender, })) .unwrap(), ); if membership == MembershipState::Invite && content.third_party_invite.is_some() { // TODO this is unimpled - if !verify_third_party_invite(event, auth_events) { + if !verify_third_party_invite(&user, auth_events) { tracing::warn!("not invited to this room",); return Some(false); } @@ -341,19 +344,14 @@ pub fn is_membership_change_allowed( } if membership != MembershipState::Join { - if caller_invited - && membership == MembershipState::Leave - && &target_user_id == event.sender() + if caller_invited && membership == MembershipState::Leave && &target_user_id == user.sender { tracing::warn!("join event succeded"); return Some(true); } if !caller_in_room { - tracing::warn!( - "user is not in this room {}", - event.room_id().unwrap().as_str(), - ); + tracing::warn!("user is not in this room {}", user.room_id.as_str(),); return Some(false); // caller is not joined } } @@ -372,7 +370,7 @@ pub fn is_membership_change_allowed( } } } else if membership == MembershipState::Join { - if event.sender() != &target_user_id { + if user.sender != &target_user_id { tracing::warn!("cannot force another user to join"); return Some(false); // cannot force another user to join } else if target_banned { @@ -397,7 +395,7 @@ pub fn is_membership_change_allowed( if target_banned && user_level < ban_level { tracing::warn!("not enough power to unban"); return Some(false); // you cannot unban this user - } else if &target_user_id != event.sender() { + } else if &target_user_id != user.sender { let kick_level = get_named_level(auth_events, "kick", 50); if user_level < kick_level || user_level <= target_level { @@ -734,6 +732,9 @@ pub fn get_send_level( } /// TODO this is unimplemented -pub fn verify_third_party_invite(_event: &StateEvent, _auth_events: &StateMap) -> bool { +pub fn verify_third_party_invite( + _event: &Requester<'_>, + _auth_events: &StateMap, +) -> bool { unimplemented!("impl third party invites") } diff --git a/src/lib.rs b/src/lib.rs index 23571777..ad2f161f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,7 @@ mod state_store; pub use error::{Error, Result}; pub use event_auth::{auth_check, auth_types_for_event}; -pub use state_event::StateEvent; +pub use state_event::{Requester, StateEvent}; pub use state_store::StateStore; // We want to yield to the reactor occasionally during state res when dealing diff --git a/src/state_event.rs b/src/state_event.rs index 34465d14..7b94ae84 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -13,6 +13,14 @@ use serde::{de, Serialize}; use serde_json::value::RawValue as RawJsonValue; use std::time::SystemTime; +pub struct Requester<'a> { + pub prev_event_ids: Vec, + pub room_id: &'a RoomId, + pub content: &'a serde_json::Value, + pub state_key: Option, + pub sender: &'a UserId, +} + #[derive(Clone, Debug, Serialize)] #[serde(untagged)] pub enum StateEvent { @@ -21,6 +29,16 @@ pub enum StateEvent { } impl StateEvent { + pub fn to_requester(&self) -> Requester<'_> { + Requester { + prev_event_ids: self.prev_event_ids(), + room_id: self.room_id().unwrap(), + content: self.content(), + state_key: self.state_key(), + sender: self.sender(), + } + } + pub fn is_power_event(&self) -> bool { match self { Self::Full(any_event) => match any_event { diff --git a/tests/res_with_auth_ids.rs b/tests/res_with_auth_ids.rs index 94e35ba3..f590199b 100644 --- a/tests/res_with_auth_ids.rs +++ b/tests/res_with_auth_ids.rs @@ -28,7 +28,7 @@ static LOGGER: Once = Once::new(); static mut SERVER_TIMESTAMP: i32 = 0; fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: Vec) { - // to activate logging use `RUST_LOG=debug cargo t one_test_only` + // to activate logging use `RUST_LOG=debug cargo t` let _ = LOGGER.call_once(|| { tracer::fmt() .with_env_filter(tracer::EnvFilter::from_default_env())