diff --git a/benches/state_res_bench.rs b/benches/state_res_bench.rs index 555677d6..fb25ef0a 100644 --- a/benches/state_res_bench.rs +++ b/benches/state_res_bench.rs @@ -57,7 +57,7 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) { b.iter(|| { let _resolved = match resolver.resolve( &room_id(), - &RoomVersionId::version_2(), + &RoomVersionId::Version2, &[state_at_bob.clone(), state_at_charlie.clone()], None, &store, @@ -91,13 +91,8 @@ fn resolve_deeper_event_set(c: &mut Criterion) { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| { - ( - (ev.kind(), ev.state_key().unwrap()), - ev.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) + .collect::>(); let state_set_b = [ inner.get(&event_id("CREATE")).unwrap(), @@ -109,18 +104,13 @@ fn resolve_deeper_event_set(c: &mut Criterion) { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| { - ( - (ev.kind(), ev.state_key().unwrap()), - ev.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) + .collect::>(); b.iter(|| { let _resolved = match resolver.resolve( &room_id(), - &RoomVersionId::version_2(), + &RoomVersionId::Version2, &[state_set_a.clone(), state_set_b.clone()], Some(inner.clone()), &store, @@ -302,23 +292,13 @@ impl TestStore { let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() - .map(|e| { - ( - (e.kind(), e.state_key().unwrap()), - e.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .collect::>(); let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] .iter() - .map(|e| { - ( - (e.kind(), e.state_key().unwrap()), - e.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .collect::>(); let expected = [ &create_event, @@ -328,13 +308,8 @@ impl TestStore { &charlie_mem, ] .iter() - .map(|e| { - ( - (e.kind(), e.state_key().unwrap()), - e.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .collect::>(); (state_at_bob, state_at_charlie, expected) } diff --git a/src/event_auth.rs b/src/event_auth.rs index c5a630eb..af1e68e4 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -22,35 +22,35 @@ pub enum RedactAllowed { No, } -pub fn auth_types_for_event(event: &StateEvent) -> Vec<(EventType, String)> { +pub fn auth_types_for_event(event: &StateEvent) -> Vec<(EventType, Option)> { if event.kind() == EventType::RoomCreate { return vec![]; } let mut auth_types = vec![ - (EventType::RoomPowerLevels, "".to_string()), - (EventType::RoomMember, event.sender().to_string()), - (EventType::RoomCreate, "".to_string()), + (EventType::RoomPowerLevels, Some("".to_string())), + (EventType::RoomMember, Some(event.sender().to_string())), + (EventType::RoomCreate, Some("".to_string())), ]; if event.kind() == EventType::RoomMember { if let Ok(content) = event.deserialize_content::() { if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) { - let key = (EventType::RoomJoinRules, "".into()); + let key = (EventType::RoomJoinRules, Some("".into())); if !auth_types.contains(&key) { auth_types.push(key) } } // TODO what when we don't find a state_key - let key = (EventType::RoomMember, event.state_key().unwrap()); + let key = (EventType::RoomMember, event.state_key()); if !auth_types.contains(&key) { auth_types.push(key) } if content.membership == MembershipState::Invite { if let Some(t_id) = content.third_party_invite { - let key = (EventType::RoomThirdPartyInvite, t_id.signed.token); + let key = (EventType::RoomThirdPartyInvite, Some(t_id.signed.token)); if !auth_types.contains(&key) { auth_types.push(key) } @@ -137,7 +137,7 @@ pub fn auth_check( // 3. If event does not have m.room.create in auth_events reject. if auth_events - .get(&(EventType::RoomCreate, "".into())) + .get(&(EventType::RoomCreate, Some("".into()))) .is_none() { tracing::warn!("no m.room.create event in auth chain"); @@ -238,7 +238,7 @@ pub fn auth_check( // synapse has an `event: &StateEvent` param but it's never used /// Can this room federate based on its m.room.create event. fn can_federate(auth_events: &StateMap) -> bool { - let creation_event = auth_events.get(&(EventType::RoomCreate, "".into())); + let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); if let Some(ev) = creation_event { if let Some(fed) = ev.content().get("m.federate") { fed == "true" @@ -263,7 +263,7 @@ fn is_membership_change_allowed( // check if this is the room creator joining if event.prev_event_ids().len() == 1 && membership == MembershipState::Join { - if let Some(create) = auth_events.get(&(EventType::RoomCreate, "".into())) { + if let Some(create) = auth_events.get(&(EventType::RoomCreate, Some("".into()))) { if let Ok(create_ev) = create.deserialize_content::() { if event.state_key() == Some(create_ev.creator.to_string()) { @@ -283,19 +283,19 @@ fn is_membership_change_allowed( return Some(false); } - let key = (EventType::RoomMember, event.sender().to_string()); + let key = (EventType::RoomMember, Some(event.sender().to_string())); let caller = auth_events.get(&key); let caller_in_room = caller.is_some() && check_membership(caller, MembershipState::Join); let caller_invited = caller.is_some() && check_membership(caller, MembershipState::Invite); - let key = (EventType::RoomMember, target_user_id.to_string()); + let key = (EventType::RoomMember, Some(target_user_id.to_string())); let target = auth_events.get(&key); let target_in_room = target.is_some() && check_membership(target, MembershipState::Join); let target_banned = target.is_some() && check_membership(target, MembershipState::Ban); - let key = (EventType::RoomJoinRules, "".to_string()); + let key = (EventType::RoomJoinRules, Some("".to_string())); let join_rules_event = auth_events.get(&key); let mut join_rule = JoinRule::Invite; @@ -436,7 +436,7 @@ fn check_event_sender_in_room( event: &StateEvent, auth_events: &StateMap, ) -> Option { - let mem = auth_events.get(&(EventType::RoomMember, event.sender().to_string()))?; + let mem = auth_events.get(&(EventType::RoomMember, Some(event.sender().to_string())))?; // TODO this is check_membership a helper fn in synapse but it does this Some( mem.deserialize_content::() @@ -448,7 +448,7 @@ fn check_event_sender_in_room( /// Is the user allowed to send a specific event. fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Option { - let ple = auth_events.get(&(EventType::RoomPowerLevels, "".into())); + let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); let send_level = get_send_level(event.kind(), event.state_key(), ple); let user_level = get_user_power_level(event.sender(), auth_events); @@ -480,7 +480,7 @@ fn check_power_levels( ) -> Option { use itertools::Itertools; - let key = (power_event.kind(), power_event.state_key().unwrap()); + let key = (power_event.kind(), power_event.state_key()); let current_state = if let Some(current_state) = auth_events.get(&key) { current_state @@ -629,7 +629,7 @@ fn check_redaction( return Some(RedactAllowed::CanRedact); } - if room_version.is_version_1() { + if let RoomVersionId::Version1 = room_version { if redaction_event.event_id() == redaction_event.redacts() { return Some(RedactAllowed::OwnEvent); } @@ -659,7 +659,7 @@ fn check_membership(member_event: Option<&StateEvent>, state: MembershipState) - } fn get_named_level(auth_events: &StateMap, name: &str, default: i64) -> i64 { - let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, "".into())); + let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); if let Some(pl) = power_level_event { // TODO do this the right way and deserialize if let Some(level) = pl.content().get(name) { @@ -673,7 +673,7 @@ fn get_named_level(auth_events: &StateMap, name: &str, default: i64) } fn get_user_power_level(user_id: &UserId, auth_events: &StateMap) -> i64 { - if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) { + if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))) { if let Ok(content) = pl.deserialize_content::() { if let Some(level) = content.users.get(user_id) { @@ -686,7 +686,7 @@ fn get_user_power_level(user_id: &UserId, auth_events: &StateMap) -> } } else { // if no power level event found the creator gets 100 everyone else gets 0 - let key = (EventType::RoomCreate, "".into()); + let key = (EventType::RoomCreate, Some("".into())); if let Some(create) = auth_events.get(&key) { if let Ok(c) = create.deserialize_content::() { if &c.creator == user_id { diff --git a/src/lib.rs b/src/lib.rs index bcf5ef67..e3085893 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,7 @@ pub enum ResolutionResult { } /// A mapping of event type and state_key to some value `T`, usually an `EventId`. -pub type StateMap = BTreeMap<(EventType, String), T>; +pub type StateMap = BTreeMap<(EventType, Option), T>; /// A mapping of `EventId` to `T`, usually a `StateEvent`. pub type EventMap = BTreeMap; @@ -191,7 +191,7 @@ impl StateResolution { .collect::>() ); - let power_event = resolved.get(&(EventType::RoomPowerLevels, "".into())); + let power_event = resolved.get(&(EventType::RoomPowerLevels, Some("".into()))); tracing::debug!("PL {:?}", power_event); @@ -526,7 +526,7 @@ impl StateResolution { if let Some(ev) = self._get_event(room_id, &aid, event_map, store) { // TODO what to do when no state_key is found ?? // TODO synapse check "rejected_reason", I'm guessing this is redacted_because for ruma ?? - auth_events.insert((ev.kind(), ev.state_key().unwrap()), ev); + auth_events.insert((ev.kind(), ev.state_key()), ev); } else { tracing::warn!("auth event id for {} is missing {}", aid, event_id); } @@ -547,7 +547,7 @@ impl StateResolution { .map_err(Error::TempString)? { // add event to resolved state map - resolved_state.insert((event.kind(), event.state_key().unwrap()), event_id.clone()); + resolved_state.insert((event.kind(), event.state_key()), event_id.clone()); } else { // synapse passes here on AuthError. We do not add this event to resolved_state. tracing::warn!( diff --git a/src/room_version.rs b/src/room_version.rs index 4c527713..1cf86b92 100644 --- a/src/room_version.rs +++ b/src/room_version.rs @@ -50,26 +50,20 @@ pub struct RoomVersion { impl RoomVersion { pub fn new(version: &RoomVersionId) -> Self { - if version.is_version_1() { - Self::version_1() - } else if version.is_version_2() { - Self::version_2() - } else if version.is_version_3() { - Self::version_3() - } else if version.is_version_4() { - Self::version_4() - } else if version.is_version_5() { - Self::version_5() - } else if version.is_version_6() { - Self::version_6() - } else { - panic!("this crate needs to be updated with ruma") + match version { + RoomVersionId::Version1 => Self::version_1(), + RoomVersionId::Version2 => Self::version_2(), + RoomVersionId::Version3 => Self::version_3(), + RoomVersionId::Version4 => Self::version_4(), + RoomVersionId::Version5 => Self::version_5(), + RoomVersionId::Version6 => Self::version_6(), + _ => panic!("unspec'ed room version"), } } fn version_1() -> Self { Self { - version: RoomVersionId::version_1(), + version: RoomVersionId::Version1, disposition: RoomDisposition::Stable, event_format: EventFormatVersion::V1, state_res: StateResolutionVersion::V1, @@ -82,7 +76,7 @@ impl RoomVersion { fn version_2() -> Self { Self { - version: RoomVersionId::version_2(), + version: RoomVersionId::Version2, disposition: RoomDisposition::Stable, event_format: EventFormatVersion::V1, state_res: StateResolutionVersion::V2, @@ -95,7 +89,7 @@ impl RoomVersion { fn version_3() -> Self { Self { - version: RoomVersionId::version_3(), + version: RoomVersionId::Version3, disposition: RoomDisposition::Stable, event_format: EventFormatVersion::V2, state_res: StateResolutionVersion::V2, @@ -108,7 +102,7 @@ impl RoomVersion { fn version_4() -> Self { Self { - version: RoomVersionId::version_4(), + version: RoomVersionId::Version4, disposition: RoomDisposition::Stable, event_format: EventFormatVersion::V3, state_res: StateResolutionVersion::V2, @@ -121,7 +115,7 @@ impl RoomVersion { fn version_5() -> Self { Self { - version: RoomVersionId::version_5(), + version: RoomVersionId::Version5, disposition: RoomDisposition::Stable, event_format: EventFormatVersion::V3, state_res: StateResolutionVersion::V2, @@ -134,7 +128,7 @@ impl RoomVersion { fn version_6() -> Self { Self { - version: RoomVersionId::version_6(), + version: RoomVersionId::Version6, disposition: RoomDisposition::Stable, event_format: EventFormatVersion::V3, state_res: StateResolutionVersion::V2, diff --git a/src/state_event.rs b/src/state_event.rs index 0f8e7154..751098f8 100644 --- a/src/state_event.rs +++ b/src/state_event.rs @@ -101,6 +101,7 @@ impl StateEvent { } } pub fn event_id(&self) -> Option<&EventId> { + println!("{:?}", self); match self { Self::Full(ev) => match ev { Pdu::RoomV1Pdu(ev) => Some(&ev.event_id), diff --git a/tests/res_with_auth_ids.rs b/tests/res_with_auth_ids.rs index be1d8d7e..6d374a5b 100644 --- a/tests/res_with_auth_ids.rs +++ b/tests/res_with_auth_ids.rs @@ -102,14 +102,14 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .iter() .map(|map| map .iter() - .map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id)) + .map(|((ty, key), id)| format!("(({}{:?}), {})", ty, key, id)) .collect::>()) .collect::>() ); let resolved = resolver.resolve( &room_id(), - &RoomVersionId::version_1(), + &RoomVersionId::Version1, &state_sets, Some(event_map.clone()), &store, @@ -135,7 +135,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: if fake_event.state_key().is_some() { let ty = fake_event.kind().clone(); // we know there is a state_key unwrap OK - let key = fake_event.state_key().unwrap().clone(); + let key = fake_event.state_key().clone(); state_after.insert((ty, key), event_id.clone()); } @@ -173,7 +173,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: event_map.insert(event_id.clone(), event); } - let mut expected_state = BTreeMap::new(); + let mut expected_state = StateMap::new(); for node in expected_state_ids { let ev = event_map.get(&node).expect(&format!( "{} not found in {:?}", @@ -184,7 +184,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .collect::>(), )); - let key = (ev.kind(), ev.state_key().unwrap()); + let key = (ev.kind(), ev.state_key()); expected_state.insert(key, node); } @@ -573,7 +573,7 @@ fn base_with_auth_chains() { let store = TestStore(RefCell::new(INITIAL_EVENTS())); let resolved: BTreeMap<_, EventId> = - match resolver.resolve(&room_id(), &RoomVersionId::version_2(), &[], None, &store) { + match resolver.resolve(&room_id(), &RoomVersionId::Version2, &[], None, &store) { Ok(ResolutionResult::Resolved(state)) => state, Err(e) => panic!("{}", e), _ => panic!("conflicted state left"), @@ -627,12 +627,7 @@ fn ban_with_auth_chains2() { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| { - ( - (ev.kind(), ev.state_key().unwrap()), - ev.event_id().unwrap().clone(), - ) - }) + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) .collect::>(); let state_set_b = [ @@ -645,17 +640,12 @@ fn ban_with_auth_chains2() { inner.get(&event_id("PA")).unwrap(), ] .iter() - .map(|ev| { - ( - (ev.kind(), ev.state_key().unwrap()), - ev.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().unwrap().clone())) + .collect::>(); - let resolved: BTreeMap<_, EventId> = match resolver.resolve( + let resolved: StateMap = match resolver.resolve( &room_id(), - &RoomVersionId::version_2(), + &RoomVersionId::Version2, &[state_set_a, state_set_b], None, &store, @@ -669,7 +659,7 @@ fn ban_with_auth_chains2() { "{:#?}", resolved .iter() - .map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id)) + .map(|((ty, key), id)| format!("(({}{:?}), {})", ty, key, id)) .collect::>() ); diff --git a/tests/state_res.rs b/tests/state_res.rs index ab08b0cc..d62bc132 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -346,14 +346,14 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .iter() .map(|map| map .iter() - .map(|((ty, key), id)| format!("(({}{}), {})", ty, key, id)) + .map(|((ty, key), id)| format!("(({}{:?}), {})", ty, key, id)) .collect::>()) .collect::>() ); let resolved = resolver.resolve( &room_id(), - &RoomVersionId::version_1(), + &RoomVersionId::Version1, &state_sets, Some(event_map.clone()), &store, @@ -379,7 +379,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: if fake_event.state_key().is_some() { let ty = fake_event.kind().clone(); // we know there is a state_key unwrap OK - let key = fake_event.state_key().unwrap().clone(); + let key = fake_event.state_key().clone(); state_after.insert((ty, key), event_id.clone()); } @@ -417,7 +417,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: event_map.insert(event_id.clone(), event); } - let mut expected_state = BTreeMap::new(); + let mut expected_state = StateMap::new(); for node in expected_state_ids { let ev = event_map.get(&node).expect(&format!( "{} not found in {:?}", @@ -428,7 +428,7 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: .collect::>(), )); - let key = (ev.kind(), ev.state_key().unwrap()); + let key = (ev.kind(), ev.state_key()); expected_state.insert(key, node); } @@ -700,7 +700,7 @@ fn test_event_map_none() { let resolved = match resolver.resolve( &room_id(), - &RoomVersionId::version_2(), + &RoomVersionId::Version2, &[state_at_bob, state_at_charlie], None, &store, @@ -904,23 +904,13 @@ impl TestStore { let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] .iter() - .map(|e| { - ( - (e.kind(), e.state_key().unwrap()), - e.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .collect::>(); let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] .iter() - .map(|e| { - ( - (e.kind(), e.state_key().unwrap()), - e.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .collect::>(); let expected = [ &create_event, @@ -930,13 +920,8 @@ impl TestStore { &charlie_mem, ] .iter() - .map(|e| { - ( - (e.kind(), e.state_key().unwrap()), - e.event_id().unwrap().clone(), - ) - }) - .collect::>(); + .map(|e| ((e.kind(), e.state_key()), e.event_id().unwrap().clone())) + .collect::>(); (state_at_bob, state_at_charlie, expected) }