Make event_map &mut and take fields in membership_change

This commit is contained in:
Devin Ragotzy 2020-12-31 15:53:08 -05:00
parent 94be5b0fef
commit b0ee71e401
7 changed files with 68 additions and 44 deletions

View File

@ -48,11 +48,12 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) {
let (state_at_bob, state_at_charlie, _) = store.set_up(); let (state_at_bob, state_at_charlie, _) = store.set_up();
b.iter(|| { b.iter(|| {
let mut ev_map = state_res::EventMap::default();
let _resolved = match StateResolution::resolve( let _resolved = match StateResolution::resolve(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version6,
&[state_at_bob.clone(), state_at_charlie.clone()], &[state_at_bob.clone(), state_at_charlie.clone()],
None, &mut ev_map,
&store, &store,
) { ) {
Ok(state) => state, Ok(state) => state,
@ -102,7 +103,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version6,
&[state_set_a.clone(), state_set_b.clone()], &[state_set_a.clone(), state_set_b.clone()],
Some(inner.clone()), &mut inner,
&store, &store,
) { ) {
Ok(state) => state, Ok(state) => state,

View File

@ -193,7 +193,9 @@ pub fn auth_check<E: Event>(
} }
if !valid_membership_change( if !valid_membership_change(
incoming_event, incoming_event.state_key().as_deref(),
incoming_event.sender(),
incoming_event.content(),
prev_event, prev_event,
current_third_party_invite, current_third_party_invite,
&auth_events, &auth_events,
@ -271,25 +273,27 @@ pub fn auth_check<E: Event>(
/// this is generated by calling `auth_types_for_event` with the membership event and /// this is generated by calling `auth_types_for_event` with the membership event and
/// the current State. /// the current State.
pub fn valid_membership_change<E: Event>( pub fn valid_membership_change<E: Event>(
user: &Arc<E>, user_state_key: Option<&str>,
user_sender: &UserId,
content: serde_json::Value,
prev_event: Option<Arc<E>>, prev_event: Option<Arc<E>>,
current_third_party_invite: Option<Arc<E>>, current_third_party_invite: Option<Arc<E>>,
auth_events: &StateMap<Arc<E>>, auth_events: &StateMap<Arc<E>>,
) -> Result<bool> { ) -> Result<bool> {
let state_key = if let Some(s) = user.state_key() { let state_key = if let Some(s) = user_state_key {
s s
} else { } else {
return Err(Error::InvalidPdu("State event requires state_key".into())); return Err(Error::InvalidPdu("State event requires state_key".into()));
}; };
let content = serde_json::from_value::<room::member::MemberEventContent>(user.content())?; let content = serde_json::from_value::<room::member::MemberEventContent>(content)?;
let target_membership = content.membership; let target_membership = content.membership;
let target_user_id = UserId::try_from(state_key.as_str()) let target_user_id =
.map_err(|e| Error::ConversionError(format!("{}", e)))?; UserId::try_from(state_key).map_err(|e| Error::ConversionError(format!("{}", e)))?;
let key = (EventType::RoomMember, Some(user.sender().to_string())); let key = (EventType::RoomMember, Some(user_sender.to_string()));
let sender = auth_events.get(&key); let sender = auth_events.get(&key);
let sender_membership = let sender_membership =
sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
@ -333,7 +337,7 @@ pub fn valid_membership_change<E: Event>(
}, },
)?; )?;
let sender_power = power_levels.users.get(&user.sender()).map_or_else( let sender_power = power_levels.users.get(user_sender).map_or_else(
|| { || {
if sender_membership != member::MembershipState::Join { if sender_membership != member::MembershipState::Join {
None None
@ -372,7 +376,7 @@ pub fn valid_membership_change<E: Event>(
} }
Ok(if target_membership == MembershipState::Join { Ok(if target_membership == MembershipState::Join {
if user.sender() != &target_user_id { if user_sender != &target_user_id {
false false
} else if let MembershipState::Ban = current_membership { } else if let MembershipState::Ban = current_membership {
false false
@ -388,7 +392,12 @@ pub fn valid_membership_change<E: Event>(
if current_membership == MembershipState::Ban { if current_membership == MembershipState::Ban {
false false
} else { } else {
verify_third_party_invite(user, &tp_id, current_third_party_invite) verify_third_party_invite(
Some(state_key),
user_sender,
&tp_id,
current_third_party_invite,
)
} }
} else if sender_membership != MembershipState::Join } else if sender_membership != MembershipState::Join
|| current_membership == MembershipState::Join || current_membership == MembershipState::Join
@ -401,7 +410,7 @@ pub fn valid_membership_change<E: Event>(
.is_some() .is_some()
} }
} else if target_membership == MembershipState::Leave { } else if target_membership == MembershipState::Leave {
if user.sender() == &target_user_id { if user_sender == &target_user_id {
current_membership == MembershipState::Join current_membership == MembershipState::Join
|| current_membership == MembershipState::Invite || current_membership == MembershipState::Invite
} else if sender_membership != MembershipState::Join } else if sender_membership != MembershipState::Join
@ -791,14 +800,15 @@ pub fn can_send_invite<E: Event>(event: &Arc<E>, auth_events: &StateMap<Arc<E>>)
} }
pub fn verify_third_party_invite<E: Event>( pub fn verify_third_party_invite<E: Event>(
event: &Arc<E>, user_state_key: Option<&str>,
sender: &UserId,
tp_id: &member::ThirdPartyInvite, tp_id: &member::ThirdPartyInvite,
current_third_party_invite: Option<Arc<E>>, current_third_party_invite: Option<Arc<E>>,
) -> bool { ) -> bool {
// 1. check for user being banned happens before this is called // 1. check for user being banned happens before this is called
// checking for mxid and token keys is done by ruma when deserializing // checking for mxid and token keys is done by ruma when deserializing
if event.state_key() != Some(tp_id.signed.mxid.to_string()) { if user_state_key != Some(tp_id.signed.mxid.as_str()) {
return false; return false;
} }
@ -809,7 +819,7 @@ pub fn verify_third_party_invite<E: Event>(
return false; return false;
} }
if event.sender() != current_tpid.sender() { if sender != current_tpid.sender() {
return false; return false;
} }

View File

@ -46,17 +46,12 @@ impl StateResolution {
room_version: &RoomVersionId, room_version: &RoomVersionId,
incoming_event: Arc<E>, incoming_event: Arc<E>,
current_state: &StateMap<EventId>, current_state: &StateMap<EventId>,
event_map: Option<EventMap<Arc<E>>>, event_map: &mut EventMap<Arc<E>>,
store: &dyn StateStore<E>, store: &dyn StateStore<E>,
) -> Result<bool> { ) -> Result<bool> {
tracing::info!("Applying a single event, state resolution starting"); tracing::info!("Applying a single event, state resolution starting");
let ev = incoming_event; let ev = incoming_event;
let mut event_map = if let Some(ev_map) = event_map {
ev_map
} else {
EventMap::new()
};
let prev_event = if let Some(id) = ev.prev_events().first() { let prev_event = if let Some(id) = ev.prev_events().first() {
store.get_event(room_id, id).ok() store.get_event(room_id, id).ok()
} else { } else {
@ -69,7 +64,7 @@ impl StateResolution {
{ {
if let Some(ev_id) = current_state.get(&key) { if let Some(ev_id) = current_state.get(&key) {
if let Some(event) = if let Some(event) =
StateResolution::get_or_load_event(room_id, ev_id, &mut event_map, store) StateResolution::get_or_load_event(room_id, ev_id, event_map, store)
{ {
// TODO synapse checks `rejected_reason` is None here // TODO synapse checks `rejected_reason` is None here
auth_events.insert(key.clone(), event); auth_events.insert(key.clone(), event);
@ -102,17 +97,11 @@ impl StateResolution {
room_id: &RoomId, room_id: &RoomId,
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: &[StateMap<EventId>], state_sets: &[StateMap<EventId>],
// TODO: make the `Option<&mut EventMap<Arc<ServerPdu>>>` event_map: &mut EventMap<Arc<E>>,
event_map: Option<EventMap<Arc<E>>>,
store: &dyn StateStore<E>, store: &dyn StateStore<E>,
) -> Result<StateMap<EventId>> { ) -> Result<StateMap<EventId>> {
tracing::info!("State resolution starting"); tracing::info!("State resolution starting");
let mut event_map = if let Some(ev_map) = event_map {
ev_map
} else {
EventMap::new()
};
// split non-conflicting and conflicting state // split non-conflicting and conflicting state
let (clean, conflicting) = StateResolution::separate(&state_sets); let (clean, conflicting) = StateResolution::separate(&state_sets);
@ -178,7 +167,7 @@ impl StateResolution {
let mut sorted_control_levels = StateResolution::reverse_topological_power_sort( let mut sorted_control_levels = StateResolution::reverse_topological_power_sort(
room_id, room_id,
&control_events, &control_events,
&mut event_map, event_map,
store, store,
&all_conflicted, &all_conflicted,
); );
@ -197,7 +186,7 @@ impl StateResolution {
room_version, room_version,
&sorted_control_levels, &sorted_control_levels,
&clean, &clean,
&mut event_map, event_map,
store, store,
)?; )?;
@ -238,7 +227,7 @@ impl StateResolution {
room_id, room_id,
&events_to_resolve, &events_to_resolve,
power_event, power_event,
&mut event_map, event_map,
store, store,
); );
@ -255,7 +244,7 @@ impl StateResolution {
room_version, room_version,
&sorted_left_events, &sorted_left_events,
&resolved_control, // The control events are added to the final resolved state &resolved_control, // The control events are added to the final resolved state
&mut event_map, event_map,
store, store,
)?; )?;

View File

@ -36,7 +36,15 @@ fn test_ban_pass() {
&vec![event_id("IMC")], &vec![event_id("IMC")],
); );
assert!(valid_membership_change(&requester, prev, None, &auth_events).unwrap()) assert!(valid_membership_change(
requester.state_key().as_deref(),
requester.sender(),
requester.content(),
prev,
None,
&auth_events
)
.unwrap())
} }
#[test] #[test]
@ -63,5 +71,13 @@ fn test_ban_fail() {
&vec![event_id("IMC")], &vec![event_id("IMC")],
); );
assert!(!valid_membership_change(&requester, prev, None, &auth_events).unwrap()) assert!(!valid_membership_change(
requester.state_key().as_deref(),
requester.sender(),
requester.content(),
prev,
None,
&auth_events
)
.unwrap())
} }

View File

@ -41,11 +41,17 @@ fn ban_with_auth_chains() {
fn base_with_auth_chains() { fn base_with_auth_chains() {
let store = TestStore(INITIAL_EVENTS()); let store = TestStore(INITIAL_EVENTS());
let resolved: BTreeMap<_, EventId> = let mut ev_map = state_res::EventMap::default();
match StateResolution::resolve(&room_id(), &RoomVersionId::Version6, &[], None, &store) { let resolved: BTreeMap<_, EventId> = match StateResolution::resolve(
Ok(state) => state, &room_id(),
Err(e) => panic!("{}", e), &RoomVersionId::Version6,
}; &[],
&mut ev_map,
&store,
) {
Ok(state) => state,
Err(e) => panic!("{}", e),
};
let resolved = resolved let resolved = resolved
.values() .values()
@ -105,11 +111,12 @@ fn ban_with_auth_chains2() {
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone())) .map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone()))
.collect::<StateMap<_>>(); .collect::<StateMap<_>>();
let mut ev_map = state_res::EventMap::default();
let resolved: StateMap<EventId> = match StateResolution::resolve( let resolved: StateMap<EventId> = match StateResolution::resolve(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version6,
&[state_set_a, state_set_b], &[state_set_a, state_set_b],
None, &mut ev_map,
&store, &store,
) { ) {
Ok(state) => state, Ok(state) => state,

View File

@ -265,11 +265,12 @@ fn test_event_map_none() {
// build up the DAG // build up the DAG
let (state_at_bob, state_at_charlie, expected) = store.set_up(); let (state_at_bob, state_at_charlie, expected) = store.set_up();
let mut ev_map = state_res::EventMap::default();
let resolved = match StateResolution::resolve( let resolved = match StateResolution::resolve(
&room_id(), &room_id(),
&RoomVersionId::Version2, &RoomVersionId::Version2,
&[state_at_bob, state_at_charlie], &[state_at_bob, state_at_charlie],
None, &mut ev_map,
&store, &store,
) { ) {
Ok(state) => state, Ok(state) => state,

View File

@ -114,7 +114,7 @@ pub fn do_check(
&room_id(), &room_id(),
&RoomVersionId::Version6, &RoomVersionId::Version6,
&state_sets, &state_sets,
Some(event_map.clone()), &mut event_map,
&store, &store,
); );
match resolved { match resolved {