Use ruma::ServerPdu instead of local type

This commit is contained in:
Devin Ragotzy 2020-12-22 14:28:48 -05:00
parent 282270ed4f
commit 5299679c21
4 changed files with 218 additions and 175 deletions

View File

@ -20,10 +20,10 @@ maplit = "1.0.2"
thiserror = "1.0.22" thiserror = "1.0.22"
[dependencies.ruma] [dependencies.ruma]
git = "https://github.com/ruma/ruma" git = "https://github.com/DevinR528/ruma"
# branch = "verified-export" branch = "server-pdu"
# path = "../__forks__/ruma/ruma" # path = "../__forks__/ruma/ruma"
rev = "45d01011554f9d07739e9a5edf5498d8ac16f273" # rev = "45d01011554f9d07739e9a5edf5498d8ac16f273"
features = ["client-api", "federation-api", "appservice-api", "unstable-pre-spec", "unstable-synapse-quirks"] features = ["client-api", "federation-api", "appservice-api", "unstable-pre-spec", "unstable-synapse-quirks"]
#[dependencies.ruma] #[dependencies.ruma]

View File

@ -3,6 +3,7 @@ use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
use maplit::btreeset; use maplit::btreeset;
use ruma::{ use ruma::{
events::{ events::{
pdu::ServerPdu,
room::{ room::{
self, self,
join_rules::JoinRule, join_rules::JoinRule,
@ -14,49 +15,44 @@ use ruma::{
identifiers::{RoomVersionId, UserId}, identifiers::{RoomVersionId, UserId},
}; };
use crate::{ use crate::{state_event::Requester, to_requester, Error, Result, StateMap};
state_event::{Requester, StateEvent},
Error, Result, StateMap,
};
/// For the given event `kind` what are the relevant auth events /// For the given event `kind` what are the relevant auth events
/// that are needed to authenticate this `content`. /// that are needed to authenticate this `content`.
pub fn auth_types_for_event( pub fn auth_types_for_event(
kind: EventType, kind: &EventType,
sender: &UserId, sender: &UserId,
state_key: Option<String>, state_key: Option<String>,
content: serde_json::Value, content: serde_json::Value,
) -> Vec<(EventType, String)> { ) -> Vec<(EventType, Option<String>)> {
if kind == EventType::RoomCreate { if kind == EventType::RoomCreate {
return vec![]; return vec![];
} }
let mut auth_types = vec![ let mut auth_types = vec![
(EventType::RoomPowerLevels, "".to_string()), (EventType::RoomPowerLevels, Some("".to_string())),
(EventType::RoomMember, sender.to_string()), (EventType::RoomMember, Some(sender.to_string())),
(EventType::RoomCreate, "".to_string()), (EventType::RoomCreate, Some("".to_string())),
]; ];
if kind == EventType::RoomMember { if kind == EventType::RoomMember {
if let Ok(content) = serde_json::from_value::<room::member::MemberEventContent>(content) { if let Ok(content) = serde_json::from_value::<room::member::MemberEventContent>(content) {
if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) { if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) {
let key = (EventType::RoomJoinRules, "".into()); let key = (EventType::RoomJoinRules, Some("".into()));
if !auth_types.contains(&key) { if !auth_types.contains(&key) {
auth_types.push(key) auth_types.push(key)
} }
} }
// TODO what when we don't find a state_key // TODO what when we don't find a state_key
if let Some(state_key) = state_key { let key = (EventType::RoomMember, state_key);
let key = (EventType::RoomMember, state_key); if !auth_types.contains(&key) {
if !auth_types.contains(&key) { auth_types.push(key)
auth_types.push(key)
}
} }
if content.membership == MembershipState::Invite { if content.membership == MembershipState::Invite {
if let Some(t_id) = content.third_party_invite { if let Some(t_id) = content.third_party_invite {
let key = (EventType::RoomThirdPartyInvite, t_id.signed.token); let key = (EventType::RoomThirdPartyInvite, Some(t_id.signed.token));
if !auth_types.contains(&key) { if !auth_types.contains(&key) {
auth_types.push(key) auth_types.push(key)
} }
@ -74,12 +70,12 @@ pub fn auth_types_for_event(
/// * then there are checks for specific event types /// * then there are checks for specific event types
pub fn auth_check( pub fn auth_check(
room_version: &RoomVersionId, room_version: &RoomVersionId,
incoming_event: &Arc<StateEvent>, incoming_event: &Arc<ServerPdu>,
prev_event: Option<Arc<StateEvent>>, prev_event: Option<Arc<ServerPdu>>,
auth_events: StateMap<Arc<StateEvent>>, auth_events: StateMap<Arc<ServerPdu>>,
current_third_party_invite: Option<Arc<StateEvent>>, current_third_party_invite: Option<Arc<ServerPdu>>,
) -> Result<bool> { ) -> Result<bool> {
tracing::info!("auth_check beginning for {}", incoming_event.kind()); tracing::info!("auth_check beginning for {}", incoming_event.kind);
// [synapse] check that all the events are in the same room as `incoming_event` // [synapse] check that all the events are in the same room as `incoming_event`
@ -92,17 +88,17 @@ pub fn auth_check(
// Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules // Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
// //
// 1. If type is m.room.create: // 1. If type is m.room.create:
if incoming_event.kind() == EventType::RoomCreate { if incoming_event.kind == EventType::RoomCreate {
tracing::info!("start m.room.create check"); tracing::info!("start m.room.create check");
// If it has any previous events, reject // If it has any previous events, reject
if !incoming_event.prev_event_ids().is_empty() { if !incoming_event.prev_events.is_empty() {
tracing::warn!("the room creation event had previous events"); tracing::warn!("the room creation event had previous events");
return Ok(false); return Ok(false);
} }
// If the domain of the room_id does not match the domain of the sender, reject // If the domain of the room_id does not match the domain of the sender, reject
if incoming_event.room_id().server_name() != incoming_event.sender().server_name() { if incoming_event.room_id.server_name() != incoming_event.sender.server_name() {
tracing::warn!("creation events server does not match sender"); tracing::warn!("creation events server does not match sender");
return Ok(false); // creation events room id does not match senders return Ok(false); // creation events room id does not match senders
} }
@ -110,7 +106,7 @@ pub fn auth_check(
// If content.room_version is present and is not a recognized version, reject // If content.room_version is present and is not a recognized version, reject
if serde_json::from_value::<RoomVersionId>( if serde_json::from_value::<RoomVersionId>(
incoming_event incoming_event
.content() .content
.get("room_version") .get("room_version")
.cloned() .cloned()
// TODO synapse defaults to version 1 // TODO synapse defaults to version 1
@ -123,7 +119,7 @@ pub fn auth_check(
} }
// If content has no creator field, reject // If content has no creator field, reject
if incoming_event.content().get("creator").is_none() { if incoming_event.content.get("creator").is_none() {
tracing::warn!("no creator field found in room create content"); tracing::warn!("no creator field found in room create content");
return Ok(false); return Ok(false);
} }
@ -137,9 +133,9 @@ pub fn auth_check(
// a. auth_events cannot have duplicate keys since it's a BTree // a. auth_events cannot have duplicate keys since it's a BTree
// b. All entries are valid auth events according to spec // b. All entries are valid auth events according to spec
let expected_auth = auth_types_for_event( let expected_auth = auth_types_for_event(
incoming_event.kind(), incoming_event.kind,
incoming_event.sender(), incoming_event.sender(),
incoming_event.state_key(), incoming_event.state_key,
incoming_event.content().clone(), incoming_event.content().clone(),
); );
@ -156,7 +152,7 @@ pub fn auth_check(
// 3. If event does not have m.room.create in auth_events reject // 3. If event does not have m.room.create in auth_events reject
if auth_events if auth_events
.get(&(EventType::RoomCreate, "".into())) .get(&(EventType::RoomCreate, Some("".into())))
.is_none() .is_none()
{ {
tracing::warn!("no m.room.create event in auth chain"); tracing::warn!("no m.room.create event in auth chain");
@ -167,18 +163,18 @@ pub fn auth_check(
// [synapse] checks for federation here // [synapse] checks for federation here
// 4. if type is m.room.aliases // 4. if type is m.room.aliases
if incoming_event.kind() == EventType::RoomAliases { if incoming_event.kind == EventType::RoomAliases {
tracing::info!("starting m.room.aliases check"); tracing::info!("starting m.room.aliases check");
// [synapse] adds `&& room_version` "special case aliases auth" // [synapse] adds `&& room_version` "special case aliases auth"
// [synapse] // [synapse]
// if event.state_key().unwrap().is_empty() { // if event.state_key.unwrap().is_empty() {
// tracing::warn!("state_key must be non-empty"); // tracing::warn!("state_key must be non-empty");
// return Ok(false); // and be non-empty state_key (point to a user_id) // return Ok(false); // and be non-empty state_key (point to a user_id)
// } // }
// If sender's domain doesn't matches state_key, reject // If sender's domain doesn't matches state_key, reject
if incoming_event.state_key() != incoming_event.sender().server_name().as_str() { if incoming_event.state_key != Some(incoming_event.sender.server_name().to_string()) {
tracing::warn!("state_key does not match sender"); tracing::warn!("state_key does not match sender");
return Ok(false); return Ok(false);
} }
@ -187,19 +183,20 @@ pub fn auth_check(
return Ok(true); return Ok(true);
} }
if incoming_event.kind() == EventType::RoomMember { if incoming_event.kind == EventType::RoomMember {
tracing::info!("starting m.room.member check"); tracing::info!("starting m.room.member check");
if incoming_event if serde_json::from_value::<room::member::MemberEventContent>(
.deserialize_content::<room::member::MemberEventContent>() incoming_event.content.clone(),
.is_err() )
.is_err()
{ {
tracing::warn!("no membership filed found for m.room.member event content"); tracing::warn!("no membership filed found for m.room.member event content");
return Ok(false); return Ok(false);
} }
if !valid_membership_change( if !valid_membership_change(
incoming_event.to_requester(), to_requester(incoming_event),
prev_event, prev_event,
current_third_party_invite, current_third_party_invite,
&auth_events, &auth_events,
@ -212,7 +209,7 @@ pub fn auth_check(
} }
// If the sender's current membership state is not join, reject // If the sender's current membership state is not join, reject
match check_event_sender_in_room(incoming_event.sender(), &auth_events) { match check_event_sender_in_room(&incoming_event.sender, &auth_events) {
Some(true) => {} // sender in room Some(true) => {} // sender in room
Some(false) => { Some(false) => {
tracing::warn!("sender's membership is not join"); tracing::warn!("sender's membership is not join");
@ -226,8 +223,8 @@ pub fn auth_check(
// Allow if and only if sender's current power level is greater than // Allow if and only if sender's current power level is greater than
// or equal to the invite level // or equal to the invite level
if incoming_event.kind() == EventType::RoomThirdPartyInvite if incoming_event.kind == EventType::RoomThirdPartyInvite
&& !can_send_invite(&incoming_event.to_requester(), &auth_events)? && !can_send_invite(&to_requester(incoming_event), &auth_events)?
{ {
tracing::warn!("sender's cannot send invites in this room"); tracing::warn!("sender's cannot send invites in this room");
return Ok(false); return Ok(false);
@ -240,7 +237,7 @@ pub fn auth_check(
return Ok(false); return Ok(false);
} }
if incoming_event.kind() == EventType::RoomPowerLevels { if incoming_event.kind == EventType::RoomPowerLevels {
tracing::info!("starting m.room.power_levels check"); tracing::info!("starting m.room.power_levels check");
if let Some(required_pwr_lvl) = if let Some(required_pwr_lvl) =
@ -257,7 +254,7 @@ pub fn auth_check(
tracing::info!("power levels event allowed"); tracing::info!("power levels event allowed");
} }
if incoming_event.kind() == EventType::RoomRedaction if incoming_event.kind == EventType::RoomRedaction
&& !check_redaction(room_version, incoming_event, &auth_events)? && !check_redaction(room_version, incoming_event, &auth_events)?
{ {
return Ok(false); return Ok(false);
@ -278,9 +275,9 @@ pub fn auth_check(
/// the current State. /// the current State.
pub fn valid_membership_change( pub fn valid_membership_change(
user: Requester<'_>, user: Requester<'_>,
prev_event: Option<Arc<StateEvent>>, prev_event: Option<Arc<ServerPdu>>,
current_third_party_invite: Option<Arc<StateEvent>>, current_third_party_invite: Option<Arc<ServerPdu>>,
auth_events: &StateMap<Arc<StateEvent>>, auth_events: &StateMap<Arc<ServerPdu>>,
) -> Result<bool> { ) -> Result<bool> {
let state_key = if let Some(s) = user.state_key.as_ref() { let state_key = if let Some(s) = user.state_key.as_ref() {
s s
@ -296,25 +293,27 @@ pub fn valid_membership_change(
let target_user_id = UserId::try_from(state_key.as_str()) let target_user_id = UserId::try_from(state_key.as_str())
.map_err(|e| Error::ConversionError(format!("{}", e)))?; .map_err(|e| Error::ConversionError(format!("{}", e)))?;
let key = (EventType::RoomMember, user.sender.to_string()); let key = (EventType::RoomMember, Some(user.sender.to_string()));
let sender = auth_events.get(&key); let sender = auth_events.get(&key);
let sender_membership = let sender_membership =
sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(pdu Ok(
.deserialize_content::<room::member::MemberEventContent>()? serde_json::from_value::<room::member::MemberEventContent>(pdu.content.clone())?
.membership) .membership,
)
})?; })?;
let key = (EventType::RoomMember, target_user_id.to_string()); let key = (EventType::RoomMember, Some(target_user_id.to_string()));
let current = auth_events.get(&key); let current = auth_events.get(&key);
let current_membership = let current_membership =
current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
Ok(pdu Ok(
.deserialize_content::<room::member::MemberEventContent>()? serde_json::from_value::<room::member::MemberEventContent>(pdu.content.clone())?
.membership) .membership,
)
})?; })?;
let key = (EventType::RoomPowerLevels, "".into()); let key = (EventType::RoomPowerLevels, Some("".into()));
let power_levels = auth_events.get(&key).map_or_else( let power_levels = auth_events.get(&key).map_or_else(
|| { || {
Ok::<_, Error>(power_levels::PowerLevelsEventContent { Ok::<_, Error>(power_levels::PowerLevelsEventContent {
@ -333,8 +332,7 @@ pub fn valid_membership_change(
}) })
}, },
|power_levels| { |power_levels| {
power_levels serde_json::from_value::<PowerLevelsEventContent>(power_levels.content)
.deserialize_content::<PowerLevelsEventContent>()
.map_err(Into::into) .map_err(Into::into)
}, },
)?; )?;
@ -362,17 +360,17 @@ pub fn valid_membership_change(
Some, Some,
); );
let key = (EventType::RoomJoinRules, "".to_string()); let key = (EventType::RoomJoinRules, Some("".into()));
let join_rules_event = auth_events.get(&key); let join_rules_event = auth_events.get(&key);
let mut join_rules = JoinRule::Invite; let mut join_rules = JoinRule::Invite;
if let Some(jr) = join_rules_event { if let Some(jr) = join_rules_event {
join_rules = jr join_rules =
.deserialize_content::<room::join_rules::JoinRulesEventContent>()? serde_json::from_value::<room::join_rules::JoinRulesEventContent>(jr.content.clone())?
.join_rule; .join_rule;
} }
if let Some(prev) = prev_event { if let Some(prev) = prev_event {
if prev.kind() == EventType::RoomCreate && prev.prev_event_ids().is_empty() { if prev.kind == EventType::RoomCreate && prev.prev_events.is_empty() {
return Ok(true); return Ok(true);
} }
} }
@ -434,11 +432,11 @@ pub fn valid_membership_change(
/// Is the event's sender in the room that they sent the event to. /// Is the event's sender in the room that they sent the event to.
pub fn check_event_sender_in_room( pub fn check_event_sender_in_room(
sender: &UserId, sender: &UserId,
auth_events: &StateMap<Arc<StateEvent>>, auth_events: &StateMap<Arc<ServerPdu>>,
) -> Option<bool> { ) -> Option<bool> {
let mem = auth_events.get(&(EventType::RoomMember, sender.to_string()))?; let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?;
Some( Some(
mem.deserialize_content::<room::member::MemberEventContent>() serde_json::from_value::<room::member::MemberEventContent>(mem.content.clone())
.ok()? .ok()?
.membership .membership
== MembershipState::Join, == MembershipState::Join,
@ -447,15 +445,15 @@ pub fn check_event_sender_in_room(
/// Is the user allowed to send a specific event based on the rooms power levels. Does the event /// Is the user allowed to send a specific event based on the rooms power levels. Does the event
/// have the correct userId as it's state_key if it's not the "" state_key. /// have the correct userId as it's state_key if it's not the "" state_key.
pub fn can_send_event(event: &Arc<StateEvent>, auth_events: &StateMap<Arc<StateEvent>>) -> bool { pub fn can_send_event(event: &Arc<ServerPdu>, auth_events: &StateMap<Arc<ServerPdu>>) -> bool {
let ple = auth_events.get(&(EventType::RoomPowerLevels, "".into())); let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
let event_type_power_level = get_send_level(event.kind(), Some(event.state_key()), ple); let event_type_power_level = get_send_level(event.kind, event.state_key, ple);
let user_level = get_user_power_level(event.sender(), auth_events); let user_level = get_user_power_level(&event.sender, auth_events);
tracing::debug!( tracing::debug!(
"{} ev_type {} usr {}", "{} ev_type {} usr {}",
event.event_id().to_string(), event.event_id.to_string(),
event_type_power_level, event_type_power_level,
user_level user_level
); );
@ -464,7 +462,9 @@ pub fn can_send_event(event: &Arc<StateEvent>, auth_events: &StateMap<Arc<StateE
return false; return false;
} }
if event.state_key().starts_with('@') && event.state_key() != event.sender().as_str() { if event.state_key.map_or(false, |k| k.starts_with('@'))
&& event.state_key.as_deref() != Some(event.sender.as_str())
{
return false; // permission required to post in this room return false; // permission required to post in this room
} }
@ -474,10 +474,10 @@ pub fn can_send_event(event: &Arc<StateEvent>, auth_events: &StateMap<Arc<StateE
/// Confirm that the event sender has the required power levels. /// Confirm that the event sender has the required power levels.
pub fn check_power_levels( pub fn check_power_levels(
_: &RoomVersionId, _: &RoomVersionId,
power_event: &Arc<StateEvent>, power_event: &Arc<ServerPdu>,
auth_events: &StateMap<Arc<StateEvent>>, auth_events: &StateMap<Arc<ServerPdu>>,
) -> Option<bool> { ) -> Option<bool> {
let key = (power_event.kind(), power_event.state_key()); let key = (power_event.kind, power_event.state_key);
let current_state = if let Some(current_state) = auth_events.get(&key) { let current_state = if let Some(current_state) = auth_events.get(&key) {
current_state current_state
} else { } else {
@ -487,17 +487,19 @@ pub fn check_power_levels(
// If users key in content is not a dictionary with keys that are valid user IDs // If users key in content is not a dictionary with keys that are valid user IDs
// with values that are integers (or a string that is an integer), reject. // with values that are integers (or a string that is an integer), reject.
let user_content = power_event let user_content = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(
.deserialize_content::<room::power_levels::PowerLevelsEventContent>() power_event.content.clone(),
.unwrap(); )
let current_content = current_state .unwrap();
.deserialize_content::<room::power_levels::PowerLevelsEventContent>() let current_content = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(
.unwrap(); current_state.content.clone(),
)
.unwrap();
// validation of users is done in Ruma, synapse for loops validating user_ids and integers here // validation of users is done in Ruma, synapse for loops validating user_ids and integers here
tracing::info!("validation of power event finished"); tracing::info!("validation of power event finished");
let user_level = get_user_power_level(power_event.sender(), auth_events); let user_level = get_user_power_level(&power_event.sender, auth_events);
let mut user_levels_to_check = btreeset![]; let mut user_levels_to_check = btreeset![];
let old_list = &current_content.users; let old_list = &current_content.users;
@ -547,7 +549,7 @@ pub fn check_power_levels(
} }
// If the current value is equal to the sender's current power level, reject // If the current value is equal to the sender's current power level, reject
if user != power_event.sender() && old_level.map(|int| (*int).into()) == Some(user_level) { if user != &power_event.sender && old_level.map(|int| (*int).into()) == Some(user_level) {
tracing::warn!("m.room.power_level cannot remove ops == to own"); tracing::warn!("m.room.power_level cannot remove ops == to own");
return Some(false); // cannot remove ops level == to own return Some(false); // cannot remove ops level == to own
} }
@ -620,10 +622,10 @@ fn get_deserialize_levels(
/// Does the event redacting come from a user with enough power to redact the given event. /// Does the event redacting come from a user with enough power to redact the given event.
pub fn check_redaction( pub fn check_redaction(
room_version: &RoomVersionId, room_version: &RoomVersionId,
redaction_event: &Arc<StateEvent>, redaction_event: &Arc<ServerPdu>,
auth_events: &StateMap<Arc<StateEvent>>, auth_events: &StateMap<Arc<ServerPdu>>,
) -> Result<bool> { ) -> Result<bool> {
let user_level = get_user_power_level(redaction_event.sender(), auth_events); let user_level = get_user_power_level(&redaction_event.sender, auth_events);
let redact_level = get_named_level(auth_events, "redact", 50); let redact_level = get_named_level(auth_events, "redact", 50);
if user_level >= redact_level { if user_level >= redact_level {
@ -641,8 +643,8 @@ pub fn check_redaction(
// version 1 check // version 1 check
if let RoomVersionId::Version1 = room_version { if let RoomVersionId::Version1 = room_version {
// If the domain of the event_id of the event being redacted is the same as the domain of the event_id of the m.room.redaction, allow // If the domain of the event_id of the event being redacted is the same as the domain of the event_id of the m.room.redaction, allow
if redaction_event.event_id().server_name() if redaction_event.event_id.server_name()
== redaction_event.redacts().and_then(|id| id.server_name()) == redaction_event.redacts.and_then(|id| id.server_name())
{ {
tracing::info!("redaction event allowed via room version 1 rules"); tracing::info!("redaction event allowed via room version 1 rules");
return Ok(true); return Ok(true);
@ -655,10 +657,10 @@ pub fn check_redaction(
/// Check that the member event matches `state`. /// Check that the member event matches `state`.
/// ///
/// This function returns false instead of failing when deserialization fails. /// This function returns false instead of failing when deserialization fails.
pub fn check_membership(member_event: Option<Arc<StateEvent>>, state: MembershipState) -> bool { pub fn check_membership(member_event: Option<Arc<ServerPdu>>, state: MembershipState) -> bool {
if let Some(event) = member_event { if let Some(event) = member_event {
if let Ok(content) = if let Ok(content) =
serde_json::from_value::<room::member::MemberEventContent>(event.content().clone()) serde_json::from_value::<room::member::MemberEventContent>(event.content.clone())
{ {
content.membership == state content.membership == state
} else { } else {
@ -670,10 +672,10 @@ pub fn check_membership(member_event: Option<Arc<StateEvent>>, state: Membership
} }
/// Can this room federate based on its m.room.create event. /// Can this room federate based on its m.room.create event.
pub fn can_federate(auth_events: &StateMap<Arc<StateEvent>>) -> bool { pub fn can_federate(auth_events: &StateMap<Arc<ServerPdu>>) -> bool {
let creation_event = auth_events.get(&(EventType::RoomCreate, "".into())); let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into())));
if let Some(ev) = creation_event { if let Some(ev) = creation_event {
if let Some(fed) = ev.content().get("m.federate") { if let Some(fed) = ev.content.get("m.federate") {
fed == "true" fed == "true"
} else { } else {
false false
@ -685,11 +687,11 @@ pub fn can_federate(auth_events: &StateMap<Arc<StateEvent>>) -> bool {
/// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. /// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content.
/// or return `default` if no power level event is found or zero if no field matches `name`. /// or return `default` if no power level event is found or zero if no field matches `name`.
pub fn get_named_level(auth_events: &StateMap<Arc<StateEvent>>, name: &str, default: i64) -> i64 { pub fn get_named_level(auth_events: &StateMap<Arc<ServerPdu>>, name: &str, default: i64) -> i64 {
let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, "".into())); let power_level_event = auth_events.get(&(EventType::RoomPowerLevels, Some("".into())));
if let Some(pl) = power_level_event { if let Some(pl) = power_level_event {
// TODO do this the right way and deserialize // TODO do this the right way and deserialize
if let Some(level) = pl.content().get(name) { if let Some(level) = pl.content.get(name) {
level.to_string().parse().unwrap_or(default) level.to_string().parse().unwrap_or(default)
} else { } else {
0 0
@ -701,10 +703,11 @@ pub fn get_named_level(auth_events: &StateMap<Arc<StateEvent>>, name: &str, defa
/// Helper function to fetch a users default power level from a "m.room.power_level" event's `users` /// Helper function to fetch a users default power level from a "m.room.power_level" event's `users`
/// object. /// object.
pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<Arc<StateEvent>>) -> i64 { pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<Arc<ServerPdu>>) -> i64 {
if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, "".into())) { if let Some(pl) = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))) {
if let Ok(content) = pl.deserialize_content::<room::power_levels::PowerLevelsEventContent>() if let Ok(content) = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(
{ pl.content.clone(),
) {
if let Some(level) = content.users.get(user_id) { if let Some(level) = content.users.get(user_id) {
(*level).into() (*level).into()
} else { } else {
@ -715,9 +718,11 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<Arc<StateEv
} }
} else { } else {
// if no power level event found the creator gets 100 everyone else gets 0 // if no power level event found the creator gets 100 everyone else gets 0
let key = (EventType::RoomCreate, "".into()); let key = (EventType::RoomCreate, Some("".into()));
if let Some(create) = auth_events.get(&key) { if let Some(create) = auth_events.get(&key) {
if let Ok(c) = create.deserialize_content::<room::create::CreateEventContent>() { if let Ok(c) =
serde_json::from_value::<room::create::CreateEventContent>(create.content.clone())
{
if &c.creator == user_id { if &c.creator == user_id {
100 100
} else { } else {
@ -737,12 +742,12 @@ pub fn get_user_power_level(user_id: &UserId, auth_events: &StateMap<Arc<StateEv
pub fn get_send_level( pub fn get_send_level(
e_type: EventType, e_type: EventType,
state_key: Option<String>, state_key: Option<String>,
power_lvl: Option<&Arc<StateEvent>>, power_lvl: Option<&Arc<ServerPdu>>,
) -> i64 { ) -> i64 {
tracing::debug!("{:?} {:?}", e_type, state_key); tracing::debug!("{:?} {:?}", e_type, state_key);
if let Some(ple) = power_lvl { if let Some(ple) = power_lvl {
if let Ok(content) = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>( if let Ok(content) = serde_json::from_value::<room::power_levels::PowerLevelsEventContent>(
ple.content().clone(), ple.content.clone(),
) { ) {
let mut lvl: i64 = content let mut lvl: i64 = content
.events .events
@ -767,17 +772,16 @@ pub fn get_send_level(
/// Check user can send invite. /// Check user can send invite.
pub fn can_send_invite( pub fn can_send_invite(
event: &Requester<'_>, event: &Requester<'_>,
auth_events: &StateMap<Arc<StateEvent>>, auth_events: &StateMap<Arc<ServerPdu>>,
) -> Result<bool> { ) -> Result<bool> {
let user_level = get_user_power_level(event.sender, auth_events); let user_level = get_user_power_level(event.sender, auth_events);
let key = (EventType::RoomPowerLevels, "".into()); let key = (EventType::RoomPowerLevels, Some("".into()));
let invite_level = auth_events let invite_level = auth_events
.get(&key) .get(&key)
.map_or_else( .map_or_else(
|| Ok::<_, Error>(ruma::int!(50)), || Ok::<_, Error>(ruma::int!(50)),
|power_levels| { |power_levels| {
power_levels serde_json::from_value::<PowerLevelsEventContent>(power_levels.content.clone())
.deserialize_content::<PowerLevelsEventContent>()
.map(|pl| pl.invite) .map(|pl| pl.invite)
.map_err(Into::into) .map_err(Into::into)
}, },
@ -790,7 +794,7 @@ pub fn can_send_invite(
pub fn verify_third_party_invite( pub fn verify_third_party_invite(
event: &Requester<'_>, event: &Requester<'_>,
tp_id: &member::ThirdPartyInvite, tp_id: &member::ThirdPartyInvite,
current_third_party_invite: Option<Arc<StateEvent>>, current_third_party_invite: Option<Arc<ServerPdu>>,
) -> bool { ) -> bool {
// 1. check for user being banned happens before this is called // 1. check for user being banned happens before this is called
// checking for mxid and token keys is done by ruma when deserializing // checking for mxid and token keys is done by ruma when deserializing
@ -802,18 +806,18 @@ pub fn verify_third_party_invite(
// If there is no m.room.third_party_invite event in the current room state // If there is no m.room.third_party_invite event in the current room state
// with state_key matching token, reject // with state_key matching token, reject
if let Some(current_tpid) = current_third_party_invite { if let Some(current_tpid) = current_third_party_invite {
if current_tpid.state_key() != tp_id.signed.token { if current_tpid.state_key != Some(tp_id.signed.token) {
return false; return false;
} }
if event.sender != current_tpid.sender() { if event.sender != &current_tpid.sender {
return false; return false;
} }
// If any signature in signed matches any public key in the m.room.third_party_invite event, allow // If any signature in signed matches any public key in the m.room.third_party_invite event, allow
if let Ok(tpid_ev) = serde_json::from_value::< if let Ok(tpid_ev) = serde_json::from_value::<
ruma::events::room::third_party_invite::ThirdPartyInviteEventContent, ruma::events::room::third_party_invite::ThirdPartyInviteEventContent,
>(current_tpid.content().clone()) >(current_tpid.content.clone())
{ {
// A list of public keys in the public_keys field // A list of public keys in the public_keys field
for key in tpid_ev.public_keys.unwrap_or_default() { for key in tpid_ev.public_keys.unwrap_or_default() {

View File

@ -7,7 +7,7 @@ use std::{
use maplit::btreeset; use maplit::btreeset;
use ruma::{ use ruma::{
events::EventType, events::{pdu::ServerPdu, EventType},
identifiers::{EventId, RoomId, RoomVersionId}, identifiers::{EventId, RoomId, RoomVersionId},
}; };
@ -19,7 +19,7 @@ mod state_store;
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use event_auth::{auth_check, auth_types_for_event}; pub use event_auth::{auth_check, auth_types_for_event};
pub use state_event::{Requester, StateEvent}; pub use state_event::Requester;
pub use state_store::StateStore; pub use state_store::StateStore;
// We want to yield to the reactor occasionally during state res when dealing // We want to yield to the reactor occasionally during state res when dealing
@ -28,9 +28,9 @@ pub use state_store::StateStore;
const _YIELD_AFTER_ITERATIONS: usize = 100; const _YIELD_AFTER_ITERATIONS: usize = 100;
/// A mapping of event type and state_key to some value `T`, usually an `EventId`. /// A mapping of event type and state_key to some value `T`, usually an `EventId`.
pub type StateMap<T> = BTreeMap<(EventType, String), T>; pub type StateMap<T> = BTreeMap<(EventType, Option<String>), T>;
/// A mapping of `EventId` to `T`, usually a `StateEvent`. /// A mapping of `EventId` to `T`, usually a `ServerPdu`.
pub type EventMap<T> = BTreeMap<EventId, T>; pub type EventMap<T> = BTreeMap<EventId, T>;
#[derive(Default)] #[derive(Default)]
@ -44,9 +44,9 @@ impl StateResolution {
pub fn apply_event( pub fn apply_event(
room_id: &RoomId, room_id: &RoomId,
room_version: &RoomVersionId, room_version: &RoomVersionId,
incoming_event: Arc<StateEvent>, incoming_event: Arc<ServerPdu>,
current_state: &StateMap<EventId>, current_state: &StateMap<EventId>,
event_map: Option<EventMap<Arc<StateEvent>>>, event_map: Option<EventMap<Arc<ServerPdu>>>,
store: &dyn StateStore, store: &dyn StateStore,
) -> Result<bool> { ) -> Result<bool> {
tracing::info!("Applying a single event, state resolution starting"); tracing::info!("Applying a single event, state resolution starting");
@ -57,19 +57,16 @@ impl StateResolution {
} else { } else {
EventMap::new() EventMap::new()
}; };
let prev_event = if let Some(id) = ev.prev_event_ids().first() { let prev_event = if let Some(id) = ev.prev_events.first() {
store.get_event(room_id, id).ok() store.get_event(room_id, id).ok()
} else { } else {
None None
}; };
let mut auth_events = StateMap::new(); let mut auth_events = StateMap::new();
for key in event_auth::auth_types_for_event( for key in
ev.kind(), event_auth::auth_types_for_event(ev.kind, &ev.sender, ev.state_key, ev.content.clone())
ev.sender(), {
Some(ev.state_key()),
ev.content().clone(),
) {
if let Some(ev_id) = current_state.get(&key) { if let Some(ev_id) = current_state.get(&key) {
if let Some(event) = if let Some(event) =
StateResolution::get_or_load_event(room_id, ev_id, &mut event_map, store) StateResolution::get_or_load_event(room_id, ev_id, &mut event_map, store)
@ -105,8 +102,8 @@ impl StateResolution {
room_id: &RoomId, room_id: &RoomId,
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: &[StateMap<EventId>], state_sets: &[StateMap<EventId>],
// TODO: make the `Option<&mut EventMap<Arc<StateEvent>>>` // TODO: make the `Option<&mut EventMap<Arc<ServerPdu>>>`
event_map: Option<EventMap<Arc<StateEvent>>>, event_map: Option<EventMap<Arc<ServerPdu>>>,
store: &dyn StateStore, store: &dyn StateStore,
) -> Result<StateMap<EventId>> { ) -> Result<StateMap<EventId>> {
tracing::info!("State resolution starting"); tracing::info!("State resolution starting");
@ -157,7 +154,7 @@ impl StateResolution {
.unwrap(); .unwrap();
// update event_map to include the fetched events // update event_map to include the fetched events
event_map.extend(events.into_iter().map(|ev| (ev.event_id(), ev))); event_map.extend(events.into_iter().map(|ev| (ev.event_id.clone(), ev)));
// at this point our event_map == store there should be no missing events // at this point our event_map == store there should be no missing events
tracing::debug!("event map size: {}", event_map.len()); tracing::debug!("event map size: {}", event_map.len());
@ -233,7 +230,7 @@ impl StateResolution {
); );
// This "epochs" power level event // This "epochs" power level event
let power_event = resolved_control.get(&(EventType::RoomPowerLevels, "".into())); let power_event = resolved_control.get(&(EventType::RoomPowerLevels, Some("".into())));
tracing::debug!("PL {:?}", power_event); tracing::debug!("PL {:?}", power_event);
@ -341,7 +338,7 @@ impl StateResolution {
pub fn reverse_topological_power_sort( pub fn reverse_topological_power_sort(
room_id: &RoomId, room_id: &RoomId,
events_to_sort: &[EventId], events_to_sort: &[EventId],
event_map: &mut EventMap<Arc<StateEvent>>, event_map: &mut EventMap<Arc<ServerPdu>>,
store: &dyn StateStore, store: &dyn StateStore,
auth_diff: &[EventId], auth_diff: &[EventId],
) -> Vec<EventId> { ) -> Vec<EventId> {
@ -381,12 +378,12 @@ impl StateResolution {
let ev = event_map.get(event_id).unwrap(); let ev = event_map.get(event_id).unwrap();
let pl = event_to_pl.get(event_id).unwrap(); let pl = event_to_pl.get(event_id).unwrap();
tracing::debug!("{:?}", (-*pl, *ev.origin_server_ts(), ev.event_id())); tracing::debug!("{:?}", (-*pl, ev.origin_server_ts, ev.event_id));
// This return value is the key used for sorting events, // This return value is the key used for sorting events,
// events are then sorted by power level, time, // events are then sorted by power level, time,
// and lexically by event_id. // and lexically by event_id.
(-*pl, *ev.origin_server_ts(), ev.event_id()) (-*pl, ev.origin_server_ts, ev.event_id)
}) })
} }
@ -467,7 +464,7 @@ impl StateResolution {
fn get_power_level_for_sender( fn get_power_level_for_sender(
room_id: &RoomId, room_id: &RoomId,
event_id: &EventId, event_id: &EventId,
event_map: &mut EventMap<Arc<StateEvent>>, event_map: &mut EventMap<Arc<ServerPdu>>,
store: &dyn StateStore, store: &dyn StateStore,
) -> i64 { ) -> i64 {
tracing::info!("fetch event ({}) senders power level", event_id.to_string()); tracing::info!("fetch event ({}) senders power level", event_id.to_string());
@ -479,11 +476,11 @@ impl StateResolution {
// event.auth_event_ids does not include its own event id ? // event.auth_event_ids does not include its own event id ?
for aid in event for aid in event
.as_ref() .as_ref()
.map(|pdu| pdu.auth_events()) .map(|pdu| pdu.auth_events)
.unwrap_or_default() .unwrap_or_default()
{ {
if let Some(aev) = StateResolution::get_or_load_event(room_id, &aid, event_map, store) { if let Some(aev) = StateResolution::get_or_load_event(room_id, &aid, event_map, store) {
if aev.is_type_and_key(EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, EventType::RoomPowerLevels, "") {
pl = Some(aev); pl = Some(aev);
break; break;
} }
@ -496,15 +493,16 @@ impl StateResolution {
if let Some(content) = pl if let Some(content) = pl
.map(|pl| { .map(|pl| {
pl.deserialize_content::<ruma::events::room::power_levels::PowerLevelsEventContent>( serde_json::from_value::<ruma::events::room::power_levels::PowerLevelsEventContent>(
pl.content.clone(),
) )
.ok() .ok()
}) })
.flatten() .flatten()
{ {
if let Some(ev) = event { if let Some(ev) = event {
if let Some(user) = content.users.get(ev.sender()) { if let Some(user) = content.users.get(&ev.sender) {
tracing::debug!("found {} at power_level {}", ev.sender().to_string(), user); tracing::debug!("found {} at power_level {}", ev.sender.to_string(), user);
return (*user).into(); return (*user).into();
} }
} }
@ -529,7 +527,7 @@ impl StateResolution {
room_version: &RoomVersionId, room_version: &RoomVersionId,
events_to_check: &[EventId], events_to_check: &[EventId],
unconflicted_state: &StateMap<EventId>, unconflicted_state: &StateMap<EventId>,
event_map: &mut EventMap<Arc<StateEvent>>, event_map: &mut EventMap<Arc<ServerPdu>>,
store: &dyn StateStore, store: &dyn StateStore,
) -> Result<StateMap<EventId>> { ) -> Result<StateMap<EventId>> {
tracing::info!("starting iterative auth check"); tracing::info!("starting iterative auth check");
@ -549,23 +547,23 @@ impl StateResolution {
StateResolution::get_or_load_event(room_id, event_id, event_map, store).unwrap(); StateResolution::get_or_load_event(room_id, event_id, event_map, store).unwrap();
let mut auth_events = BTreeMap::new(); let mut auth_events = BTreeMap::new();
for aid in event.auth_events() { for aid in &event.auth_events {
if let Some(ev) = if let Some(ev) =
StateResolution::get_or_load_event(room_id, &aid, event_map, store) StateResolution::get_or_load_event(room_id, &aid, event_map, store)
{ {
// TODO what to do when no state_key is found ?? // TODO what to do when no state_key is found ??
// TODO synapse check "rejected_reason", I'm guessing this is redacted_because in ruma ?? // TODO synapse check "rejected_reason", I'm guessing this is redacted_because in ruma ??
auth_events.insert((ev.kind(), ev.state_key()), ev); auth_events.insert((ev.kind.clone(), ev.state_key.clone()), ev);
} else { } else {
tracing::warn!("auth event id for {} is missing {}", aid, event_id); tracing::warn!("auth event id for {} is missing {}", aid, event_id);
} }
} }
for key in event_auth::auth_types_for_event( for key in event_auth::auth_types_for_event(
event.kind(), event.kind,
event.sender(), &event.sender,
Some(event.state_key()), event.state_key,
event.content().clone(), event.content.clone(),
) { ) {
if let Some(ev_id) = resolved_state.get(&key) { if let Some(ev_id) = resolved_state.get(&key) {
if let Some(event) = if let Some(event) =
@ -577,10 +575,10 @@ impl StateResolution {
} }
} }
tracing::debug!("event to check {:?}", event.event_id().as_str()); tracing::debug!("event to check {:?}", event.event_id.as_str());
let most_recent_prev_event = event let most_recent_prev_event = event
.prev_event_ids() .prev_events
.iter() .iter()
.filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store)) .filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store))
.next_back(); .next_back();
@ -588,7 +586,7 @@ impl StateResolution {
// The key for this is (eventType + a state_key of the signed token not sender) so search // The key for this is (eventType + a state_key of the signed token not sender) so search
// for it // for it
let current_third_party = auth_events.iter().find_map(|(_, pdu)| { let current_third_party = auth_events.iter().find_map(|(_, pdu)| {
if pdu.kind() == EventType::RoomThirdPartyInvite { if pdu.kind == EventType::RoomThirdPartyInvite {
Some(pdu.clone()) // TODO no clone, auth_events is borrowed while moved Some(pdu.clone()) // TODO no clone, auth_events is borrowed while moved
} else { } else {
None None
@ -603,7 +601,10 @@ impl StateResolution {
current_third_party, current_third_party,
)? { )? {
// add event to resolved state map // add event to resolved state map
resolved_state.insert((event.kind(), event.state_key()), event_id.clone()); resolved_state.insert(
(event.kind.clone(), event.state_key.clone()),
event_id.clone(),
);
} else { } else {
// synapse passes here on AuthError. We do not add this event to resolved_state. // synapse passes here on AuthError. We do not add this event to resolved_state.
tracing::warn!( tracing::warn!(
@ -632,7 +633,7 @@ impl StateResolution {
room_id: &RoomId, room_id: &RoomId,
to_sort: &[EventId], to_sort: &[EventId],
resolved_power_level: Option<&EventId>, resolved_power_level: Option<&EventId>,
event_map: &mut EventMap<Arc<StateEvent>>, event_map: &mut EventMap<Arc<ServerPdu>>,
store: &dyn StateStore, store: &dyn StateStore,
) -> Vec<EventId> { ) -> Vec<EventId> {
tracing::debug!("mainline sort of events"); tracing::debug!("mainline sort of events");
@ -649,12 +650,12 @@ impl StateResolution {
mainline.push(p.clone()); mainline.push(p.clone());
let event = StateResolution::get_or_load_event(room_id, &p, event_map, store).unwrap(); let event = StateResolution::get_or_load_event(room_id, &p, event_map, store).unwrap();
let auth_events = event.auth_events(); let auth_events = &event.auth_events;
pl = None; pl = None;
for aid in auth_events { for aid in auth_events {
let ev = let ev =
StateResolution::get_or_load_event(room_id, &aid, event_map, store).unwrap(); StateResolution::get_or_load_event(room_id, &aid, event_map, store).unwrap();
if ev.is_type_and_key(EventType::RoomPowerLevels, "") { if is_type_and_key(&ev, EventType::RoomPowerLevels, "") {
pl = Some(aid.clone()); pl = Some(aid.clone());
break; break;
} }
@ -690,10 +691,7 @@ impl StateResolution {
ev_id, ev_id,
( (
depth, depth,
event_map event_map.get(ev_id).map(|ev| ev.origin_server_ts),
.get(ev_id)
.map(|ev| ev.origin_server_ts())
.cloned(),
ev_id, // TODO should this be a &str to sort lexically?? ev_id, // TODO should this be a &str to sort lexically??
), ),
); );
@ -719,26 +717,26 @@ impl StateResolution {
/// that has an associated mainline depth. /// that has an associated mainline depth.
fn get_mainline_depth( fn get_mainline_depth(
room_id: &RoomId, room_id: &RoomId,
mut event: Option<Arc<StateEvent>>, mut event: Option<Arc<ServerPdu>>,
mainline_map: &EventMap<usize>, mainline_map: &EventMap<usize>,
event_map: &mut EventMap<Arc<StateEvent>>, event_map: &mut EventMap<Arc<ServerPdu>>,
store: &dyn StateStore, store: &dyn StateStore,
) -> Result<usize> { ) -> Result<usize> {
while let Some(sort_ev) = event { while let Some(sort_ev) = event {
tracing::debug!("mainline event_id {}", sort_ev.event_id().to_string()); tracing::debug!("mainline event_id {}", sort_ev.event_id.to_string());
let id = sort_ev.event_id(); let id = &sort_ev.event_id;
if let Some(depth) = mainline_map.get(&id) { if let Some(depth) = mainline_map.get(&id) {
return Ok(*depth); return Ok(*depth);
} }
// dbg!(&sort_ev); // dbg!(&sort_ev);
let auth_events = sort_ev.auth_events(); let auth_events = &sort_ev.auth_events;
event = None; event = None;
for aid in auth_events { for aid in auth_events {
// dbg!(&aid); // dbg!(&aid);
let aev = StateResolution::get_or_load_event(room_id, &aid, event_map, store) let aev = StateResolution::get_or_load_event(room_id, &aid, event_map, store)
.ok_or_else(|| Error::NotFound("Auth event not found".to_owned()))?; .ok_or_else(|| Error::NotFound("Auth event not found".to_owned()))?;
if aev.is_type_and_key(EventType::RoomPowerLevels, "") { if is_type_and_key(&aev, EventType::RoomPowerLevels, "") {
event = Some(aev); event = Some(aev);
break; break;
} }
@ -752,7 +750,7 @@ impl StateResolution {
room_id: &RoomId, room_id: &RoomId,
graph: &mut BTreeMap<EventId, Vec<EventId>>, graph: &mut BTreeMap<EventId, Vec<EventId>>,
event_id: &EventId, event_id: &EventId,
event_map: &mut EventMap<Arc<StateEvent>>, event_map: &mut EventMap<Arc<ServerPdu>>,
store: &dyn StateStore, store: &dyn StateStore,
auth_diff: &[EventId], auth_diff: &[EventId],
) { ) {
@ -763,9 +761,9 @@ impl StateResolution {
graph.entry(eid.clone()).or_insert_with(Vec::new); graph.entry(eid.clone()).or_insert_with(Vec::new);
// prefer the store to event as the store filters dedups the events // prefer the store to event as the store filters dedups the events
// otherwise it seems we can loop forever // otherwise it seems we can loop forever
for aid in StateResolution::get_or_load_event(room_id, &eid, event_map, store) for aid in &StateResolution::get_or_load_event(room_id, &eid, event_map, store)
.unwrap() .unwrap()
.auth_events() .auth_events
{ {
if auth_diff.contains(&aid) { if auth_diff.contains(&aid) {
if !graph.contains_key(&aid) { if !graph.contains_key(&aid) {
@ -788,9 +786,9 @@ impl StateResolution {
fn get_or_load_event( fn get_or_load_event(
room_id: &RoomId, room_id: &RoomId,
ev_id: &EventId, ev_id: &EventId,
event_map: &mut EventMap<Arc<StateEvent>>, event_map: &mut EventMap<Arc<ServerPdu>>,
store: &dyn StateStore, store: &dyn StateStore,
) -> Option<Arc<StateEvent>> { ) -> Option<Arc<ServerPdu>> {
if let Some(e) = event_map.get(ev_id) { if let Some(e) = event_map.get(ev_id) {
return Some(Arc::clone(e)); return Some(Arc::clone(e));
} }
@ -803,9 +801,47 @@ impl StateResolution {
} }
} }
pub fn is_power_event(event_id: &EventId, event_map: &EventMap<Arc<StateEvent>>) -> bool { pub fn is_power_event(event_id: &EventId, event_map: &EventMap<Arc<ServerPdu>>) -> bool {
match event_map.get(event_id) { match event_map.get(event_id) {
Some(state) => state.is_power_event(), Some(state) => _is_power_event(state),
_ => false, _ => false,
} }
} }
pub fn is_type_and_key(&ev: &Arc<ServerPdu>, ev_type: EventType, state_key: &str) -> bool {
ev.kind == ev_type && ev.state_key.as_deref() == Some(state_key)
}
fn _is_power_event(&event: &Arc<ServerPdu>) -> bool {
use ruma::events::room::member::{MemberEventContent, MembershipState};
match event.kind {
EventType::RoomPowerLevels | EventType::RoomJoinRules | EventType::RoomCreate => {
event.state_key == Some("".into())
}
EventType::RoomMember => {
if let Ok(content) =
// TODO fix clone
serde_json::from_value::<MemberEventContent>(event.content.clone())
{
if [MembershipState::Leave, MembershipState::Ban].contains(&content.membership) {
return event.sender.as_str()
// TODO is None here a failure
!= event.state_key.as_deref().unwrap_or("NOT A STATE KEY");
}
}
false
}
_ => false,
}
}
pub fn to_requester(event: &Arc<ServerPdu>) -> Requester<'_> {
Requester {
prev_event_ids: event.prev_events,
room_id: &event.room_id,
content: &event.content,
state_key: event.state_key.clone(),
sender: &event.sender,
}
}

View File

@ -1,15 +1,18 @@
use std::{collections::BTreeSet, sync::Arc}; use std::{collections::BTreeSet, sync::Arc};
use ruma::identifiers::{EventId, RoomId}; use ruma::{
events::pdu::ServerPdu,
identifiers::{EventId, RoomId},
};
use crate::{Result, StateEvent}; use crate::Result;
pub trait StateStore { pub trait StateStore {
/// Return a single event based on the EventId. /// Return a single event based on the EventId.
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<StateEvent>>; fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<Arc<ServerPdu>>;
/// Returns the events that correspond to the `event_ids` sorted in the same order. /// Returns the events that correspond to the `event_ids` sorted in the same order.
fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<StateEvent>>> { fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<Vec<Arc<ServerPdu>>> {
let mut events = vec![]; let mut events = vec![];
for id in event_ids { for id in event_ids {
events.push(self.get_event(room_id, id)?); events.push(self.get_event(room_id, id)?);
@ -33,7 +36,7 @@ pub trait StateStore {
let event = self.get_event(room_id, &ev_id)?; let event = self.get_event(room_id, &ev_id)?;
stack.extend(event.auth_events()); stack.extend(event.auth_events.clone());
} }
Ok(result) Ok(result)