Make event_map &mut and take fields in membership_change
This commit is contained in:
parent
94be5b0fef
commit
b0ee71e401
@ -48,11 +48,12 @@ fn resolution_shallow_auth_chain(c: &mut Criterion) {
|
|||||||
let (state_at_bob, state_at_charlie, _) = store.set_up();
|
let (state_at_bob, state_at_charlie, _) = store.set_up();
|
||||||
|
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
|
let mut ev_map = state_res::EventMap::default();
|
||||||
let _resolved = match StateResolution::resolve(
|
let _resolved = match StateResolution::resolve(
|
||||||
&room_id(),
|
&room_id(),
|
||||||
&RoomVersionId::Version6,
|
&RoomVersionId::Version6,
|
||||||
&[state_at_bob.clone(), state_at_charlie.clone()],
|
&[state_at_bob.clone(), state_at_charlie.clone()],
|
||||||
None,
|
&mut ev_map,
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(state) => state,
|
Ok(state) => state,
|
||||||
@ -102,7 +103,7 @@ fn resolve_deeper_event_set(c: &mut Criterion) {
|
|||||||
&room_id(),
|
&room_id(),
|
||||||
&RoomVersionId::Version6,
|
&RoomVersionId::Version6,
|
||||||
&[state_set_a.clone(), state_set_b.clone()],
|
&[state_set_a.clone(), state_set_b.clone()],
|
||||||
Some(inner.clone()),
|
&mut inner,
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(state) => state,
|
Ok(state) => state,
|
||||||
|
@ -193,7 +193,9 @@ pub fn auth_check<E: Event>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !valid_membership_change(
|
if !valid_membership_change(
|
||||||
incoming_event,
|
incoming_event.state_key().as_deref(),
|
||||||
|
incoming_event.sender(),
|
||||||
|
incoming_event.content(),
|
||||||
prev_event,
|
prev_event,
|
||||||
current_third_party_invite,
|
current_third_party_invite,
|
||||||
&auth_events,
|
&auth_events,
|
||||||
@ -271,25 +273,27 @@ pub fn auth_check<E: Event>(
|
|||||||
/// this is generated by calling `auth_types_for_event` with the membership event and
|
/// this is generated by calling `auth_types_for_event` with the membership event and
|
||||||
/// the current State.
|
/// the current State.
|
||||||
pub fn valid_membership_change<E: Event>(
|
pub fn valid_membership_change<E: Event>(
|
||||||
user: &Arc<E>,
|
user_state_key: Option<&str>,
|
||||||
|
user_sender: &UserId,
|
||||||
|
content: serde_json::Value,
|
||||||
prev_event: Option<Arc<E>>,
|
prev_event: Option<Arc<E>>,
|
||||||
current_third_party_invite: Option<Arc<E>>,
|
current_third_party_invite: Option<Arc<E>>,
|
||||||
auth_events: &StateMap<Arc<E>>,
|
auth_events: &StateMap<Arc<E>>,
|
||||||
) -> Result<bool> {
|
) -> Result<bool> {
|
||||||
let state_key = if let Some(s) = user.state_key() {
|
let state_key = if let Some(s) = user_state_key {
|
||||||
s
|
s
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::InvalidPdu("State event requires state_key".into()));
|
return Err(Error::InvalidPdu("State event requires state_key".into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
let content = serde_json::from_value::<room::member::MemberEventContent>(user.content())?;
|
let content = serde_json::from_value::<room::member::MemberEventContent>(content)?;
|
||||||
|
|
||||||
let target_membership = content.membership;
|
let target_membership = content.membership;
|
||||||
|
|
||||||
let target_user_id = UserId::try_from(state_key.as_str())
|
let target_user_id =
|
||||||
.map_err(|e| Error::ConversionError(format!("{}", e)))?;
|
UserId::try_from(state_key).map_err(|e| Error::ConversionError(format!("{}", e)))?;
|
||||||
|
|
||||||
let key = (EventType::RoomMember, Some(user.sender().to_string()));
|
let key = (EventType::RoomMember, Some(user_sender.to_string()));
|
||||||
let sender = auth_events.get(&key);
|
let sender = auth_events.get(&key);
|
||||||
let sender_membership =
|
let sender_membership =
|
||||||
sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
|
sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
|
||||||
@ -333,7 +337,7 @@ pub fn valid_membership_change<E: Event>(
|
|||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let sender_power = power_levels.users.get(&user.sender()).map_or_else(
|
let sender_power = power_levels.users.get(user_sender).map_or_else(
|
||||||
|| {
|
|| {
|
||||||
if sender_membership != member::MembershipState::Join {
|
if sender_membership != member::MembershipState::Join {
|
||||||
None
|
None
|
||||||
@ -372,7 +376,7 @@ pub fn valid_membership_change<E: Event>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(if target_membership == MembershipState::Join {
|
Ok(if target_membership == MembershipState::Join {
|
||||||
if user.sender() != &target_user_id {
|
if user_sender != &target_user_id {
|
||||||
false
|
false
|
||||||
} else if let MembershipState::Ban = current_membership {
|
} else if let MembershipState::Ban = current_membership {
|
||||||
false
|
false
|
||||||
@ -388,7 +392,12 @@ pub fn valid_membership_change<E: Event>(
|
|||||||
if current_membership == MembershipState::Ban {
|
if current_membership == MembershipState::Ban {
|
||||||
false
|
false
|
||||||
} else {
|
} else {
|
||||||
verify_third_party_invite(user, &tp_id, current_third_party_invite)
|
verify_third_party_invite(
|
||||||
|
Some(state_key),
|
||||||
|
user_sender,
|
||||||
|
&tp_id,
|
||||||
|
current_third_party_invite,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else if sender_membership != MembershipState::Join
|
} else if sender_membership != MembershipState::Join
|
||||||
|| current_membership == MembershipState::Join
|
|| current_membership == MembershipState::Join
|
||||||
@ -401,7 +410,7 @@ pub fn valid_membership_change<E: Event>(
|
|||||||
.is_some()
|
.is_some()
|
||||||
}
|
}
|
||||||
} else if target_membership == MembershipState::Leave {
|
} else if target_membership == MembershipState::Leave {
|
||||||
if user.sender() == &target_user_id {
|
if user_sender == &target_user_id {
|
||||||
current_membership == MembershipState::Join
|
current_membership == MembershipState::Join
|
||||||
|| current_membership == MembershipState::Invite
|
|| current_membership == MembershipState::Invite
|
||||||
} else if sender_membership != MembershipState::Join
|
} else if sender_membership != MembershipState::Join
|
||||||
@ -791,14 +800,15 @@ pub fn can_send_invite<E: Event>(event: &Arc<E>, auth_events: &StateMap<Arc<E>>)
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify_third_party_invite<E: Event>(
|
pub fn verify_third_party_invite<E: Event>(
|
||||||
event: &Arc<E>,
|
user_state_key: Option<&str>,
|
||||||
|
sender: &UserId,
|
||||||
tp_id: &member::ThirdPartyInvite,
|
tp_id: &member::ThirdPartyInvite,
|
||||||
current_third_party_invite: Option<Arc<E>>,
|
current_third_party_invite: Option<Arc<E>>,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
// 1. check for user being banned happens before this is called
|
// 1. check for user being banned happens before this is called
|
||||||
// checking for mxid and token keys is done by ruma when deserializing
|
// checking for mxid and token keys is done by ruma when deserializing
|
||||||
|
|
||||||
if event.state_key() != Some(tp_id.signed.mxid.to_string()) {
|
if user_state_key != Some(tp_id.signed.mxid.as_str()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -809,7 +819,7 @@ pub fn verify_third_party_invite<E: Event>(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.sender() != current_tpid.sender() {
|
if sender != current_tpid.sender() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
25
src/lib.rs
25
src/lib.rs
@ -46,17 +46,12 @@ impl StateResolution {
|
|||||||
room_version: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
incoming_event: Arc<E>,
|
incoming_event: Arc<E>,
|
||||||
current_state: &StateMap<EventId>,
|
current_state: &StateMap<EventId>,
|
||||||
event_map: Option<EventMap<Arc<E>>>,
|
event_map: &mut EventMap<Arc<E>>,
|
||||||
store: &dyn StateStore<E>,
|
store: &dyn StateStore<E>,
|
||||||
) -> Result<bool> {
|
) -> Result<bool> {
|
||||||
tracing::info!("Applying a single event, state resolution starting");
|
tracing::info!("Applying a single event, state resolution starting");
|
||||||
let ev = incoming_event;
|
let ev = incoming_event;
|
||||||
|
|
||||||
let mut event_map = if let Some(ev_map) = event_map {
|
|
||||||
ev_map
|
|
||||||
} else {
|
|
||||||
EventMap::new()
|
|
||||||
};
|
|
||||||
let prev_event = if let Some(id) = ev.prev_events().first() {
|
let prev_event = if let Some(id) = ev.prev_events().first() {
|
||||||
store.get_event(room_id, id).ok()
|
store.get_event(room_id, id).ok()
|
||||||
} else {
|
} else {
|
||||||
@ -69,7 +64,7 @@ impl StateResolution {
|
|||||||
{
|
{
|
||||||
if let Some(ev_id) = current_state.get(&key) {
|
if let Some(ev_id) = current_state.get(&key) {
|
||||||
if let Some(event) =
|
if let Some(event) =
|
||||||
StateResolution::get_or_load_event(room_id, ev_id, &mut event_map, store)
|
StateResolution::get_or_load_event(room_id, ev_id, event_map, store)
|
||||||
{
|
{
|
||||||
// TODO synapse checks `rejected_reason` is None here
|
// TODO synapse checks `rejected_reason` is None here
|
||||||
auth_events.insert(key.clone(), event);
|
auth_events.insert(key.clone(), event);
|
||||||
@ -102,17 +97,11 @@ impl StateResolution {
|
|||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
room_version: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
state_sets: &[StateMap<EventId>],
|
state_sets: &[StateMap<EventId>],
|
||||||
// TODO: make the `Option<&mut EventMap<Arc<ServerPdu>>>`
|
event_map: &mut EventMap<Arc<E>>,
|
||||||
event_map: Option<EventMap<Arc<E>>>,
|
|
||||||
store: &dyn StateStore<E>,
|
store: &dyn StateStore<E>,
|
||||||
) -> Result<StateMap<EventId>> {
|
) -> Result<StateMap<EventId>> {
|
||||||
tracing::info!("State resolution starting");
|
tracing::info!("State resolution starting");
|
||||||
|
|
||||||
let mut event_map = if let Some(ev_map) = event_map {
|
|
||||||
ev_map
|
|
||||||
} else {
|
|
||||||
EventMap::new()
|
|
||||||
};
|
|
||||||
// split non-conflicting and conflicting state
|
// split non-conflicting and conflicting state
|
||||||
let (clean, conflicting) = StateResolution::separate(&state_sets);
|
let (clean, conflicting) = StateResolution::separate(&state_sets);
|
||||||
|
|
||||||
@ -178,7 +167,7 @@ impl StateResolution {
|
|||||||
let mut sorted_control_levels = StateResolution::reverse_topological_power_sort(
|
let mut sorted_control_levels = StateResolution::reverse_topological_power_sort(
|
||||||
room_id,
|
room_id,
|
||||||
&control_events,
|
&control_events,
|
||||||
&mut event_map,
|
event_map,
|
||||||
store,
|
store,
|
||||||
&all_conflicted,
|
&all_conflicted,
|
||||||
);
|
);
|
||||||
@ -197,7 +186,7 @@ impl StateResolution {
|
|||||||
room_version,
|
room_version,
|
||||||
&sorted_control_levels,
|
&sorted_control_levels,
|
||||||
&clean,
|
&clean,
|
||||||
&mut event_map,
|
event_map,
|
||||||
store,
|
store,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -238,7 +227,7 @@ impl StateResolution {
|
|||||||
room_id,
|
room_id,
|
||||||
&events_to_resolve,
|
&events_to_resolve,
|
||||||
power_event,
|
power_event,
|
||||||
&mut event_map,
|
event_map,
|
||||||
store,
|
store,
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -255,7 +244,7 @@ impl StateResolution {
|
|||||||
room_version,
|
room_version,
|
||||||
&sorted_left_events,
|
&sorted_left_events,
|
||||||
&resolved_control, // The control events are added to the final resolved state
|
&resolved_control, // The control events are added to the final resolved state
|
||||||
&mut event_map,
|
event_map,
|
||||||
store,
|
store,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -36,7 +36,15 @@ fn test_ban_pass() {
|
|||||||
&vec![event_id("IMC")],
|
&vec![event_id("IMC")],
|
||||||
);
|
);
|
||||||
|
|
||||||
assert!(valid_membership_change(&requester, prev, None, &auth_events).unwrap())
|
assert!(valid_membership_change(
|
||||||
|
requester.state_key().as_deref(),
|
||||||
|
requester.sender(),
|
||||||
|
requester.content(),
|
||||||
|
prev,
|
||||||
|
None,
|
||||||
|
&auth_events
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -63,5 +71,13 @@ fn test_ban_fail() {
|
|||||||
&vec![event_id("IMC")],
|
&vec![event_id("IMC")],
|
||||||
);
|
);
|
||||||
|
|
||||||
assert!(!valid_membership_change(&requester, prev, None, &auth_events).unwrap())
|
assert!(!valid_membership_change(
|
||||||
|
requester.state_key().as_deref(),
|
||||||
|
requester.sender(),
|
||||||
|
requester.content(),
|
||||||
|
prev,
|
||||||
|
None,
|
||||||
|
&auth_events
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
}
|
}
|
||||||
|
@ -41,11 +41,17 @@ fn ban_with_auth_chains() {
|
|||||||
fn base_with_auth_chains() {
|
fn base_with_auth_chains() {
|
||||||
let store = TestStore(INITIAL_EVENTS());
|
let store = TestStore(INITIAL_EVENTS());
|
||||||
|
|
||||||
let resolved: BTreeMap<_, EventId> =
|
let mut ev_map = state_res::EventMap::default();
|
||||||
match StateResolution::resolve(&room_id(), &RoomVersionId::Version6, &[], None, &store) {
|
let resolved: BTreeMap<_, EventId> = match StateResolution::resolve(
|
||||||
Ok(state) => state,
|
&room_id(),
|
||||||
Err(e) => panic!("{}", e),
|
&RoomVersionId::Version6,
|
||||||
};
|
&[],
|
||||||
|
&mut ev_map,
|
||||||
|
&store,
|
||||||
|
) {
|
||||||
|
Ok(state) => state,
|
||||||
|
Err(e) => panic!("{}", e),
|
||||||
|
};
|
||||||
|
|
||||||
let resolved = resolved
|
let resolved = resolved
|
||||||
.values()
|
.values()
|
||||||
@ -105,11 +111,12 @@ fn ban_with_auth_chains2() {
|
|||||||
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone()))
|
.map(|ev| ((ev.kind(), ev.state_key()), ev.event_id().clone()))
|
||||||
.collect::<StateMap<_>>();
|
.collect::<StateMap<_>>();
|
||||||
|
|
||||||
|
let mut ev_map = state_res::EventMap::default();
|
||||||
let resolved: StateMap<EventId> = match StateResolution::resolve(
|
let resolved: StateMap<EventId> = match StateResolution::resolve(
|
||||||
&room_id(),
|
&room_id(),
|
||||||
&RoomVersionId::Version6,
|
&RoomVersionId::Version6,
|
||||||
&[state_set_a, state_set_b],
|
&[state_set_a, state_set_b],
|
||||||
None,
|
&mut ev_map,
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(state) => state,
|
Ok(state) => state,
|
||||||
|
@ -265,11 +265,12 @@ fn test_event_map_none() {
|
|||||||
// build up the DAG
|
// build up the DAG
|
||||||
let (state_at_bob, state_at_charlie, expected) = store.set_up();
|
let (state_at_bob, state_at_charlie, expected) = store.set_up();
|
||||||
|
|
||||||
|
let mut ev_map = state_res::EventMap::default();
|
||||||
let resolved = match StateResolution::resolve(
|
let resolved = match StateResolution::resolve(
|
||||||
&room_id(),
|
&room_id(),
|
||||||
&RoomVersionId::Version2,
|
&RoomVersionId::Version2,
|
||||||
&[state_at_bob, state_at_charlie],
|
&[state_at_bob, state_at_charlie],
|
||||||
None,
|
&mut ev_map,
|
||||||
&store,
|
&store,
|
||||||
) {
|
) {
|
||||||
Ok(state) => state,
|
Ok(state) => state,
|
||||||
|
@ -114,7 +114,7 @@ pub fn do_check(
|
|||||||
&room_id(),
|
&room_id(),
|
||||||
&RoomVersionId::Version6,
|
&RoomVersionId::Version6,
|
||||||
&state_sets,
|
&state_sets,
|
||||||
Some(event_map.clone()),
|
&mut event_map,
|
||||||
&store,
|
&store,
|
||||||
);
|
);
|
||||||
match resolved {
|
match resolved {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user