diff --git a/benches/state_res_bench.rs b/benches/state_res_bench.rs index b543de02..dc8696c6 100644 --- a/benches/state_res_bench.rs +++ b/benches/state_res_bench.rs @@ -90,7 +90,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) + .map(|ev| ((ev.kind(), ev.state_key().unwrap()), ev.event_id().clone())) .collect::>(); let state_set_b = [ @@ -103,7 +103,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) + .map(|ev| ((ev.kind(), ev.state_key().unwrap()), ev.event_id().clone())) .collect::>(); b.iter(|| { @@ -220,12 +220,12 @@ impl TestStore { let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().clone())) + .map(|e| ((e.kind(), e.state_key().unwrap()), e.event_id().clone())) .collect::>(); let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().clone())) + .map(|e| ((e.kind(), e.state_key().unwrap()), e.event_id().clone())) .collect::>(); let expected = [ @@ -236,7 +236,7 @@ impl TestStore { &charlie_mem, ] .iter() - .map(|e| ((e.kind(), e.state_key()), e.event_id().clone())) + .map(|e| ((e.kind(), e.state_key().unwrap()), e.event_id().clone())) .collect::>(); (state_at_bob, state_at_charlie, expected) diff --git a/src/event_auth.rs b/src/event_auth.rs index e3b6df2e..4f2d95bf 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -23,36 +23,39 @@ pub fn auth_types_for_event( sender: &UserId, state_key: Option, content: serde_json::Value, -) -> Vec<(EventType, Option)> { +) -> Vec<(EventType, String)> { if kind == &EventType::RoomCreate { return vec![]; } let mut auth_types = vec![ - (EventType::RoomPowerLevels, Some("".to_string())), - (EventType::RoomMember, Some(sender.to_string())), - (EventType::RoomCreate, Some("".to_string())), + (EventType::RoomPowerLevels, "".to_string()), + (EventType::RoomMember, sender.to_string()), + (EventType::RoomCreate, "".to_string()), ]; if kind == &EventType::RoomMember { - if let Ok(content) = serde_json::from_value::(content) { - if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) { - let key = (EventType::RoomJoinRules, Some("".into())); + if let Some(state_key) = state_key { + if let Ok(content) = serde_json::from_value::(content) + { + if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) { + let key = (EventType::RoomJoinRules, "".to_string()); + if !auth_types.contains(&key) { + auth_types.push(key) + } + } + + let key = (EventType::RoomMember, state_key); if !auth_types.contains(&key) { auth_types.push(key) } - } - let key = (EventType::RoomMember, state_key); - if !auth_types.contains(&key) { - auth_types.push(key) - } - - if content.membership == MembershipState::Invite { - if let Some(t_id) = content.third_party_invite { - let key = (EventType::RoomThirdPartyInvite, Some(t_id.signed.token)); - if !auth_types.contains(&key) { - auth_types.push(key) + if content.membership == MembershipState::Invite { + if let Some(t_id) = content.third_party_invite { + let key = (EventType::RoomThirdPartyInvite, t_id.signed.token); + if !auth_types.contains(&key) { + auth_types.push(key) + } } } } @@ -157,7 +160,7 @@ pub fn auth_check( // 3. If event does not have m.room.create in auth_events reject if auth_events - .get(&(EventType::RoomCreate, Some("".into()))) + .get(&(EventType::RoomCreate, "".to_string())) .is_none() { log::warn!("no m.room.create event in auth chain"); @@ -172,12 +175,6 @@ pub fn auth_check( log::info!("starting m.room.aliases check"); // [synapse] adds `&& room_version` "special case aliases auth" - // [synapse] - // if event.state_key.unwrap().is_empty() { - // log::warn!("state_key must be non-empty"); - // return Ok(false); // and be non-empty state_key (point to a user_id) - // } - // If sender's domain doesn't matches state_key, reject if incoming_event.state_key() != Some(incoming_event.sender().server_name().to_string()) { log::warn!("state_key does not match sender"); @@ -190,6 +187,13 @@ pub fn auth_check( if incoming_event.kind() == EventType::RoomMember { log::info!("starting m.room.member check"); + let state_key = match incoming_event.state_key() { + None => { + log::warn!("no statekey in member event"); + return Ok(false); + } + Some(s) => s, + }; if serde_json::from_value::(incoming_event.content()) .is_err() @@ -199,7 +203,7 @@ pub fn auth_check( } if !valid_membership_change( - incoming_event.state_key().as_deref(), + &state_key, incoming_event.sender(), incoming_event.content(), prev_event, @@ -279,19 +283,13 @@ pub fn auth_check( /// this is generated by calling `auth_types_for_event` with the membership event and /// the current State. pub fn valid_membership_change( - user_state_key: Option<&str>, + state_key: &str, user_sender: &UserId, content: serde_json::Value, prev_event: Option>, current_third_party_invite: Option>, auth_events: &StateMap>, ) -> Result { - let state_key = if let Some(s) = user_state_key { - s - } else { - return Err(Error::InvalidPdu("State event requires state_key".into())); - }; - let content = serde_json::from_value::(content)?; let target_membership = content.membership; @@ -299,7 +297,7 @@ pub fn valid_membership_change( let target_user_id = UserId::try_from(state_key).map_err(|e| Error::ConversionError(format!("{}", e)))?; - let key = (EventType::RoomMember, Some(user_sender.to_string())); + let key = (EventType::RoomMember, user_sender.to_string()); let sender = auth_events.get(&key); let sender_membership = sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { @@ -309,7 +307,7 @@ pub fn valid_membership_change( ) })?; - let key = (EventType::RoomMember, Some(target_user_id.to_string())); + let key = (EventType::RoomMember, target_user_id.to_string()); let current = auth_events.get(&key); let current_membership = current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { @@ -319,7 +317,7 @@ pub fn valid_membership_change( ) })?; - let key = (EventType::RoomPowerLevels, Some("".into())); + let key = (EventType::RoomPowerLevels, "".into()); let power_levels = auth_events.get(&key).map_or_else( || Ok::<_, Error>(power_levels::PowerLevelsEventContent::default()), |power_levels| { @@ -351,7 +349,7 @@ pub fn valid_membership_change( Some, ); - let key = (EventType::RoomJoinRules, Some("".into())); + let key = (EventType::RoomJoinRules, "".into()); let join_rules_event = auth_events.get(&key); let mut join_rules = JoinRule::Invite; if let Some(jr) = join_rules_event { @@ -430,7 +428,7 @@ pub fn check_event_sender_in_room( sender: &UserId, auth_events: &StateMap>, ) -> Option { - let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?; + let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?; Some( serde_json::from_value::(mem.content()) .ok()? @@ -442,7 +440,7 @@ pub fn check_event_sender_in_room( /// Is the user allowed to send a specific event based on the rooms power levels. Does the event /// have the correct userId as it's state_key if it's not the "" state_key. pub fn can_send_event(event: &Arc, auth_events: &StateMap>) -> bool { - let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); + let ple = auth_events.get(&(EventType::RoomPowerLevels, "".into())); let event_type_power_level = get_send_level(&event.kind(), event.state_key(), ple); let user_level = get_user_power_level(&event.sender(), auth_events); @@ -476,7 +474,10 @@ pub fn check_power_levels( power_event: &Arc, auth_events: &StateMap>, ) -> Option { - let key = (power_event.kind(), power_event.state_key()); + let power_event_state_key = power_event + .state_key() + .expect("power events have state keys"); + let key = (power_event.kind(), power_event_state_key); let current_state = if let Some(current_state) = auth_events.get(&key) { current_state } else { @@ -620,7 +621,7 @@ fn get_deserialize_levels( /// Does the event redacting come from a user with enough power to redact the given event. pub fn check_redaction( - room_version: &RoomVersionId, + _room_version: &RoomVersionId, redaction_event: &Arc, auth_events: &StateMap>, ) -> Result { @@ -639,18 +640,15 @@ pub fn check_redaction( // Servers should only apply redaction's to events where the sender's domains match, // or the sender of the redaction has the appropriate permissions per the power levels. - // version 1 check - if let RoomVersionId::Version1 = room_version { - // If the domain of the event_id of the event being redacted is the same as the domain of the event_id of the m.room.redaction, allow - if redaction_event.event_id().server_name() - == redaction_event - .redacts() - .as_ref() - .and_then(|id| id.server_name()) - { - log::info!("redaction event allowed via room version 1 rules"); - return Ok(true); - } + // If the domain of the event_id of the event being redacted is the same as the domain of the event_id of the m.room.redaction, allow + if redaction_event.event_id().server_name() + == redaction_event + .redacts() + .as_ref() + .and_then(|id| id.server_name()) + { + log::info!("redaction event allowed via room version 1 rules"); + return Ok(true); } Ok(false) @@ -675,7 +673,7 @@ pub fn check_membership(member_event: Option>, state: Membershi /// Can this room federate based on its m.room.create event. pub fn can_federate(auth_events: &StateMap>) -> bool { - let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); + let creation_event = auth_events.get(&(EventType::RoomCreate, "".into())); if let Some(ev) = creation_event { if let Some(fed) = ev.content().get("m.federate") { fed == "true" @@ -690,7 +688,7 @@ pub fn can_federate(auth_events: &StateMap>) -> bool { /// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. /// or return `default` if no power level event is found or zero if no field matches `name`. pub fn get_named_level(auth_events: &StateMap>, name: &str, default: i64) -> i64 { - let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); + let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, "".into())); if let Some(pl) = power_level_event { // TODO do this the right way and deserialize if let Some(level) = pl.content().get(name) { @@ -706,7 +704,7 @@ pub fn get_named_level(auth_events: &StateMap>, name: &str, def /// Helper function to fetch a users default power level from a "m.room.power_level" event's `users` /// object. pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap>) -> i64 { - if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))) { + if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) { if let Ok(content) = serde_json::from_value::(pl.content()) { @@ -720,7 +718,7 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap(create.content()) @@ -758,7 +756,8 @@ pub fn get_send_level( content.events_default } }) - }).ok() + }) + .ok() }) .map(|int| i64::from(int)) .unwrap_or_else(|| if state_key.is_some() { 50 } else { 0 }) @@ -767,7 +766,7 @@ pub fn get_send_level( /// Check user can send invite. pub fn can_send_invite(event: &Arc, auth_events: &StateMap>) -> Result { let user_level = get_user_power_level(&event.sender(), auth_events); - let key = (EventType::RoomPowerLevels, Some("".into())); + let key = (EventType::RoomPowerLevels, "".into()); let invite_level = auth_events .get(&key) .map_or_else( diff --git a/src/lib.rs b/src/lib.rs index 99080d75..f41a0de7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ pub use state_store::StateStore; const _YIELD_AFTER_ITERATIONS: usize = 100; /// A mapping of event type and state_key to some value `T`, usually an `EventId`. -pub type StateMap = BTreeMap<(EventType, Option), T>; +pub type StateMap = BTreeMap<(EventType, String), T>; /// A mapping of `EventId` to `T`, usually a `ServerPdu`. pub type EventMap = BTreeMap; @@ -48,6 +48,8 @@ impl StateResolution { current_state: &StateMap, event_map: &EventMap>, ) -> Result { + let state_key = incoming_event.state_key().ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; + log::info!("Applying a single event, state resolution starting"); let ev = incoming_event; @@ -59,7 +61,7 @@ impl StateResolution { let mut auth_events = StateMap::new(); for key in - event_auth::auth_types_for_event(&ev.kind(), &ev.sender(), ev.state_key(), ev.content()) + event_auth::auth_types_for_event(&ev.kind(), &ev.sender(), Some(state_key), ev.content()) { if let Some(ev_id) = current_state.get(&key) { if let Ok(event) = StateResolution::get_or_load_event(room_id, ev_id, event_map) { @@ -183,7 +185,7 @@ impl StateResolution { ); // This "epochs" power level event - let power_event = resolved_control.get(&(EventType::RoomPowerLevels, Some("".into()))); + let power_event = resolved_control.get(&(EventType::RoomPowerLevels, "".into())); log::debug!("PL {:?}", power_event); @@ -498,12 +500,13 @@ impl StateResolution { for (idx, event_id) in events_to_check.iter().enumerate() { let event = StateResolution::get_or_load_event(room_id, event_id, event_map)?; + let state_key = event.state_key().ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; let mut auth_events = BTreeMap::new(); for aid in &event.auth_events() { if let Ok(ev) = StateResolution::get_or_load_event(room_id, &aid, event_map) { // TODO synapse check "rejected_reason", I'm guessing this is redacted_because in ruma ?? - auth_events.insert((ev.kind(), ev.state_key()), ev); + auth_events.insert((ev.kind(), state_key.clone()), ev); } else { log::warn!("auth event id for {} is missing {}", aid, event_id); } @@ -512,7 +515,7 @@ impl StateResolution { for key in event_auth::auth_types_for_event( &event.kind(), &event.sender(), - event.state_key(), + Some(state_key.clone()), event.content(), ) { if let Some(ev_id) = resolved_state.get(&key) { @@ -550,7 +553,7 @@ impl StateResolution { current_third_party, )? { // add event to resolved state map - resolved_state.insert((event.kind(), event.state_key()), event_id.clone()); + resolved_state.insert((event.kind(), state_key), event_id.clone()); } else { // synapse passes here on AuthError. We do not add this event to resolved_state. log::warn!( diff --git a/tests/event_auth.rs b/tests/event_auth.rs index 4d9c6fe7..46e14758 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -37,7 +37,7 @@ fn test_ban_pass() { ); assert!(valid_membership_change( - requester.state_key().as_deref(), + &requester.state_key(), requester.sender(), requester.content(), prev, @@ -72,7 +72,7 @@ fn test_ban_fail() { ); assert!(!valid_membership_change( - requester.state_key().as_deref(), + &requester.state_key(), requester.sender(), requester.content(), prev, diff --git a/tests/event_sorting.rs b/tests/event_sorting.rs index 78688322..a1a56904 100644 --- a/tests/event_sorting.rs +++ b/tests/event_sorting.rs @@ -61,7 +61,7 @@ fn test_event_sort() { shuffle(&mut events_to_sort); - let power_level = resolved_power.get(&(EventType::RoomPowerLevels, Some("".to_string()))); + let power_level = resolved_power.get(&(EventType::RoomPowerLevels, "".to_string())); let sorted_event_ids = state_res::StateResolution::mainline_sort( &room_id(), diff --git a/tests/utils.rs b/tests/utils.rs index 7423eb0e..3d365092 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -132,16 +132,14 @@ pub fn do_check( let mut state_after = state_before.clone(); - if fake_event.state_key().is_some() { - let ty = fake_event.kind(); - let key = fake_event.state_key(); - state_after.insert((ty, key), event_id.clone()); - } + let ty = fake_event.kind(); + let key = fake_event.state_key(); + state_after.insert((ty, key), event_id.clone()); let auth_types = state_res::auth_types_for_event( &fake_event.kind(), fake_event.sender(), - fake_event.state_key(), + Some(fake_event.state_key()), fake_event.content(), ); @@ -160,7 +158,7 @@ pub fn do_check( e.event_id().as_str(), e.sender().clone(), e.kind().clone(), - e.state_key().as_deref(), + Some(&e.state_key()), e.content(), &auth_events, prev_events, @@ -555,7 +553,7 @@ pub mod event { } fn state_key(&self) -> Option { - self.state_key() + Some(self.state_key()) } fn prev_events(&self) -> Vec { self.prev_event_ids() @@ -796,11 +794,11 @@ pub mod event { }, } } - pub fn state_key(&self) -> Option { + pub fn state_key(&self) -> String { match self { Self::Full(_, ev) => match ev { - Pdu::RoomV1Pdu(ev) => ev.state_key.clone(), - Pdu::RoomV3Pdu(ev) => ev.state_key.clone(), + Pdu::RoomV1Pdu(ev) => ev.state_key.clone().unwrap(), + Pdu::RoomV3Pdu(ev) => ev.state_key.clone().unwrap(), }, } }