diff --git a/src/event_auth.rs b/src/event_auth.rs index e72c3c2d..7a62b6a2 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -26,35 +26,37 @@ pub fn auth_types_for_event( sender: &UserId, state_key: Option, content: serde_json::Value, -) -> Vec<(EventType, Option)> { +) -> Vec<(EventType, String)> { if kind == EventType::RoomCreate { return vec![]; } let mut auth_types = vec![ - (EventType::RoomPowerLevels, Some("".to_string())), - (EventType::RoomMember, Some(sender.to_string())), - (EventType::RoomCreate, Some("".to_string())), + (EventType::RoomPowerLevels, "".to_string()), + (EventType::RoomMember, sender.to_string()), + (EventType::RoomCreate, "".to_string()), ]; if kind == EventType::RoomMember { if let Ok(content) = serde_json::from_value::(content) { if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) { - let key = (EventType::RoomJoinRules, Some("".into())); + let key = (EventType::RoomJoinRules, "".into()); if !auth_types.contains(&key) { auth_types.push(key) } } // TODO what when we don't find a state_key - let key = (EventType::RoomMember, state_key); - if !auth_types.contains(&key) { - auth_types.push(key) + if let Some(state_key) = state_key { + let key = (EventType::RoomMember, state_key); + if !auth_types.contains(&key) { + auth_types.push(key) + } } if content.membership == MembershipState::Invite { if let Some(t_id) = content.third_party_invite { - let key = (EventType::RoomThirdPartyInvite, Some(t_id.signed.token)); + let key = (EventType::RoomThirdPartyInvite, t_id.signed.token); if !auth_types.contains(&key) { auth_types.push(key) } @@ -156,7 +158,7 @@ pub fn auth_check( // 3. If event does not have m.room.create in auth_events reject if auth_events - .get(&(EventType::RoomCreate, Some("".into()))) + .get(&(EventType::RoomCreate, "".into())) .is_none() { tracing::warn!("no m.room.create event in auth chain"); @@ -170,10 +172,6 @@ pub fn auth_check( if incoming_event.kind() == EventType::RoomAliases { tracing::info!("starting m.room.aliases check"); // [synapse] adds `&& room_version` "special case aliases auth" - if incoming_event.state_key().is_none() { - tracing::warn!("no state_key field found for event"); - return Ok(false); // must have state_key - } // [synapse] // if event.state_key().unwrap().is_empty() { @@ -182,8 +180,8 @@ pub fn auth_check( // } // If sender's domain doesn't matches state_key, reject - if incoming_event.state_key().as_deref() - != Some(incoming_event.sender().server_name().as_str()) + if incoming_event.state_key() + != incoming_event.sender().server_name().as_str() { tracing::warn!("state_key does not match sender"); return Ok(false); @@ -196,11 +194,6 @@ pub fn auth_check( if incoming_event.kind() == EventType::RoomMember { tracing::info!("starting m.room.member check"); - if incoming_event.state_key().is_none() { - tracing::warn!("no state_key found for m.room.member event"); - return Ok(false); - } - if incoming_event .deserialize_content::() .is_err() @@ -246,7 +239,7 @@ pub fn auth_check( // If the event type's required power level is greater than the sender's power level, reject // If the event has a state_key that starts with an @ and does not match the sender, reject. - if !can_send_event(&incoming_event, &auth_events)? { + if !can_send_event(&incoming_event, &auth_events) { tracing::warn!("user cannot send event"); return Ok(false); } @@ -307,7 +300,7 @@ pub fn valid_membership_change( let target_user_id = UserId::try_from(state_key.as_str()) .map_err(|e| Error::ConversionError(format!("{}", e)))?; - let key = (EventType::RoomMember, Some(user.sender.to_string())); + let key = (EventType::RoomMember, user.sender.to_string()); let sender = auth_events.get(&key); let sender_membership = sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { @@ -316,7 +309,7 @@ pub fn valid_membership_change( .membership) })?; - let key = (EventType::RoomMember, Some(target_user_id.to_string())); + let key = (EventType::RoomMember, target_user_id.to_string()); let current = auth_events.get(&key); let current_membership = current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { @@ -325,7 +318,7 @@ pub fn valid_membership_change( .membership) })?; - let key = (EventType::RoomPowerLevels, Some("".into())); + let key = (EventType::RoomPowerLevels, "".into()); let power_levels = auth_events.get(&key).map_or_else( || { Ok::<_, Error>(power_levels::PowerLevelsEventContent { @@ -373,7 +366,7 @@ pub fn valid_membership_change( Some, ); - let key = (EventType::RoomJoinRules, Some("".to_string())); + let key = (EventType::RoomJoinRules, "".to_string()); let join_rules_event = auth_events.get(&key); let mut join_rules = JoinRule::Invite; if let Some(jr) = join_rules_event { @@ -447,7 +440,7 @@ pub fn check_event_sender_in_room( sender: &UserId, auth_events: &StateMap>, ) -> Option { - let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?; + let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?; Some( mem.deserialize_content::() .ok()? @@ -458,10 +451,10 @@ pub fn check_event_sender_in_room( /// Is the user allowed to send a specific event based on the rooms power levels. Does the event /// have the correct userId as it's state_key if it's not the "" state_key. -pub fn can_send_event(event: &Arc, auth_events: &StateMap>) -> Result { - let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); +pub fn can_send_event(event: &Arc, auth_events: &StateMap>) -> bool { + let ple = auth_events.get(&(EventType::RoomPowerLevels, "".into())); - let event_type_power_level = get_send_level(event.kind(), event.state_key(), ple); + let event_type_power_level = get_send_level(event.kind(), Some(event.state_key()), ple); let user_level = get_user_power_level(event.sender(), auth_events); tracing::debug!( @@ -472,15 +465,14 @@ pub fn can_send_event(event: &Arc, auth_events: &StateMap>, state: Membership /// Can this room federate based on its m.room.create event. pub fn can_federate(auth_events: &StateMap>) -> bool { - let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); + let creation_event = auth_events.get(&(EventType::RoomCreate, "".into())); if let Some(ev) = creation_event { if let Some(fed) = ev.content().get("m.federate") { fed == "true" @@ -698,7 +690,7 @@ pub fn can_federate(auth_events: &StateMap>) -> bool { /// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. /// or return `default` if no power level event is found or zero if no field matches `name`. pub fn get_named_level(auth_events: &StateMap>, name: &str, default: i64) -> i64 { - let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); + let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, "".into())); if let Some(pl) = power_level_event { // TODO do this the right way and deserialize if let Some(level) = pl.content().get(name) { @@ -714,7 +706,7 @@ pub fn get_named_level(auth_events: &StateMap>, name: &str, defa /// Helper function to fetch a users default power level from a "m.room.power_level" event's `users` /// object. pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap>) -> i64 { - if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))) { + if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) { if let Ok(content) = pl.deserialize_content::() { if let Some(level) = content.users.get(user_id) { @@ -727,7 +719,7 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap() { if &c.creator == user_id { @@ -779,7 +771,7 @@ pub fn get_send_level( /// Check user can send invite. pub fn can_send_invite(event: &Requester<'_>, auth_events: &StateMap>) -> Result { let user_level = get_user_power_level(event.sender, auth_events); - let key = (EventType::RoomPowerLevels, Some("".into())); + let key = (EventType::RoomPowerLevels, "".into()); let invite_level = auth_events .get(&key) .map_or_else( @@ -811,7 +803,7 @@ pub fn verify_third_party_invite( // If there is no m.room.third_party_invite event in the current room state // with state_key matching token, reject if let Some(current_tpid) = current_third_party_invite { - if current_tpid.state_key() != Some(tp_id.signed.token.to_string()) { + if current_tpid.state_key() != tp_id.signed.token.to_string() { return false; } diff --git a/src/lib.rs b/src/lib.rs index 6004c3a8..3658640c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ pub use state_store::StateStore; const _YIELD_AFTER_ITERATIONS: usize = 100; /// A mapping of event type and state_key to some value `T`, usually an `EventId`. -pub type StateMap = BTreeMap<(EventType, Option), T>; +pub type StateMap = BTreeMap<(EventType, String), T>; /// A mapping of `EventId` to `T`, usually a `StateEvent`. pub type EventMap = BTreeMap; @@ -185,7 +185,7 @@ impl StateResolution { .collect::>() ); - let power_event = resolved_control.get(&(EventType::RoomPowerLevels, Some("".into()))); + let power_event = resolved_control.get(&(EventType::RoomPowerLevels, "".into())); tracing::debug!("PL {:?}", power_event); @@ -512,7 +512,7 @@ impl StateResolution { for key in event_auth::auth_types_for_event( event.kind(), event.sender(), - event.state_key(), + Some(event.state_key()), event.content().clone(), ) { if let Some(ev_id) = resolved_state.get(&key) { diff --git a/src/state_event.rs b/src/state_event.rs index 71032b29..bfbf2c72 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -34,7 +34,7 @@ impl StateEvent { prev_event_ids: self.prev_event_ids(), room_id: self.room_id().unwrap(), content: self.content(), - state_key: self.state_key(), + state_key: Some(self.state_key()), sender: self.sender(), } } @@ -175,7 +175,7 @@ impl StateEvent { }, } } - pub fn state_key(&self) -> Option { + pub fn state_key(&self) -> String { match self { Self::Full(ev) => match ev { Pdu::RoomV1Pdu(ev) => ev.state_key.clone(), @@ -185,7 +185,7 @@ impl StateEvent { PduStub::RoomV1PduStub(ev) => ev.state_key.clone(), PduStub::RoomV3PduStub(ev) => ev.state_key.clone(), }, - } + }.expect("All state events have a state key") } #[cfg(not(feature = "unstable-pre-spec"))]