From 36cec22cf326f6b3e0d8e4283f929ffa964c3fd2 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Wed, 26 Aug 2020 10:31:44 -0400 Subject: [PATCH 1/7] Follow spec for is_membership_change_allowed Add checks for caller in room and remove unspec'ed synapse check leave -> join with join_rule = invite --- src/error.rs | 3 + src/event_auth.rs | 167 +++++++++++++++++++++++----------------------- src/lib.rs | 5 +- 3 files changed, 87 insertions(+), 88 deletions(-) diff --git a/src/error.rs b/src/error.rs index 79f91750..4945f975 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,6 +17,9 @@ pub enum Error { #[error(transparent)] IntParseError(#[from] ParseIntError), + #[error("Not found error: {0}")] + NotFound(String), + // TODO remove once the correct errors are used #[error("an error occured {0}")] TempString(String), diff --git a/src/event_auth.rs b/src/event_auth.rs index ed9e182e..4b071599 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -13,7 +13,7 @@ use serde_json::json; use crate::{ room_version::RoomVersion, state_event::{Requester, StateEvent}, - StateMap, + Result, StateMap, }; /// Represents the 3 event redaction outcomes. @@ -82,14 +82,14 @@ pub fn auth_check( event: &StateEvent, auth_events: StateMap, do_sig_check: bool, -) -> Option { +) -> Result { tracing::info!("auth_check beginning for {}", event.event_id().as_str()); // don't let power from other rooms be used for auth_event in auth_events.values() { if auth_event.room_id() != event.room_id() { tracing::warn!("found auth event that did not match event's room_id"); - return Some(false); + return Ok(false); } } @@ -108,7 +108,7 @@ pub fn auth_check( // check the event has been signed by the domain of the sender if event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid { tracing::warn!("event not signed by sender's server"); - return Some(false); + return Ok(false); } if event.room_version() == RoomVersionId::Version1 @@ -118,7 +118,7 @@ pub fn auth_check( .is_none() { tracing::warn!("event not signed by event_id's server"); - return Some(false); + return Ok(false); } } @@ -135,7 +135,7 @@ pub fn auth_check( // domain of room_id must match domain of sender. if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) { tracing::warn!("creation events server does not match sender"); - return Some(false); // creation events room id does not match senders + return Ok(false); // creation events room id does not match senders } // if content.room_version is present and is not a valid version @@ -150,11 +150,11 @@ pub fn auth_check( .is_err() { tracing::warn!("invalid room version found in m.room.create event"); - return Some(false); + return Ok(false); } tracing::info!("m.room.create event was allowed"); - return Some(true); + return Ok(true); } // 3. If event does not have m.room.create in auth_events reject. @@ -164,7 +164,7 @@ pub fn auth_check( { tracing::warn!("no m.room.create event in auth chain"); - return Some(false); + return Ok(false); } // check for m.federate @@ -174,7 +174,7 @@ pub fn auth_check( if !can_federate(&auth_events) { tracing::warn!("federation not allowed"); - return Some(false); + return Ok(false); } } @@ -184,41 +184,41 @@ pub fn auth_check( // TODO && room_version "special case aliases auth" ?? if event.state_key().is_none() { tracing::warn!("no state_key field found for event"); - return Some(false); // must have state_key + return Ok(false); // must have state_key } if event.state_key().unwrap().is_empty() { tracing::warn!("state_key must be non-empty"); - return Some(false); // and be non-empty state_key (point to a user_id) + return Ok(false); // and be non-empty state_key (point to a user_id) } if event.state_key() != Some(event.sender().to_string()) { tracing::warn!("no state_key field found for event"); - return Some(false); + return Ok(false); } tracing::info!("m.room.aliases event was allowed"); - return Some(true); + return Ok(true); } if event.kind() == EventType::RoomMember { tracing::info!("starting m.room.member check"); if !is_membership_change_allowed(event.to_requester(), &auth_events)? { - return Some(false); + return Ok(false); } tracing::info!("m.room.member event was allowed"); - return Some(true); + return Ok(true); } - if let Some(in_room) = check_event_sender_in_room(event, &auth_events) { + if let Ok(in_room) = check_event_sender_in_room(event, &auth_events) { if !in_room { tracing::warn!("sender not in room"); - return Some(false); + return Ok(false); } } else { tracing::warn!("sender not in room"); - return Some(false); + return Ok(false); } // Special case to allow m.room.third_party_invite events where ever @@ -230,7 +230,7 @@ pub fn auth_check( if !can_send_event(event, &auth_events)? { tracing::warn!("user cannot send event"); - return Some(false); + return Ok(false); } if event.kind() == EventType::RoomPowerLevels { @@ -238,23 +238,23 @@ pub fn auth_check( if let Some(required_pwr_lvl) = check_power_levels(room_version, event, &auth_events) { if !required_pwr_lvl { tracing::warn!("power level was not allowed"); - return Some(false); + return Ok(false); } } else { tracing::warn!("power level was not allowed"); - return Some(false); + return Ok(false); } tracing::info!("power levels event allowed"); } if event.kind() == EventType::RoomRedaction { if let RedactAllowed::No = check_redaction(room_version, event, &auth_events)? { - return Some(false); + return Ok(false); } } tracing::info!("allowing event passed all checks"); - Some(true) + Ok(true) } // synapse has an `event: &StateEvent` param but it's never used @@ -272,38 +272,35 @@ pub fn can_federate(auth_events: &StateMap) -> bool { } } -/// Dose the user who sent this member event have required power levels to do so. +/// Does the user who sent this member event have required power levels to do so. +/// +/// If called on it's own the following must be true: +/// - there must be a valid state_key in `user` +/// - there must be a membership key in `user.content` i.e. the event is of type "m.room.member" pub fn is_membership_change_allowed( user: Requester<'_>, auth_events: &StateMap, -) -> Option { +) -> Result { let content = // TODO return error - serde_json::from_str::(&user.content.to_string()).ok()?; + serde_json::from_str::(&user.content.to_string())?; let membership = content.membership; - // check if this is the room creator joining + // If the only previous event is an m.room.create and the state_key is the creator, allow if user.prev_event_ids.len() == 1 && membership == MembershipState::Join { if let Some(create) = auth_events.get(&(EventType::RoomCreate, Some("".into()))) { if let Ok(create_ev) = create.deserialize_content::() { if user.state_key == Some(create_ev.creator.to_string()) { tracing::debug!("m.room.member event allowed via m.room.create"); - return Some(true); + return Ok(true); } } } } - let target_user_id = UserId::try_from(user.state_key.as_deref().unwrap()) - .ok() - .unwrap(); - // if the server_names are different and federation is NOT allowed - if user.room_id.server_name() != target_user_id.server_name() && !can_federate(auth_events) { - tracing::warn!("server cannot federate"); - return Some(false); - } + let target_user_id = UserId::try_from(user.state_key.as_deref().unwrap()).unwrap(); let key = (EventType::RoomMember, Some(user.sender.to_string())); let caller = auth_events.get(&key); @@ -323,8 +320,7 @@ pub fn is_membership_change_allowed( let mut join_rule = JoinRule::Invite; if let Some(jr) = join_rules_event { join_rule = jr - .deserialize_content::() - .ok()? // TODO these are errors? and should be treated as a DB failure? + .deserialize_content::()? .join_rule; } @@ -354,77 +350,79 @@ pub fn is_membership_change_allowed( // TODO this is unimpled if !verify_third_party_invite(&user, auth_events) { tracing::warn!("not invited to this room",); - return Some(false); + return Ok(false); } if target_banned { tracing::warn!("banned from this room",); - return Some(false); + return Ok(false); } tracing::info!("invite succeded"); - return Some(true); - } - - if membership != MembershipState::Join { - if caller_invited && membership == MembershipState::Leave && &target_user_id == user.sender - { - tracing::warn!("join event succeded"); - return Some(true); - } - - if !caller_in_room { - tracing::warn!("user is not in this room {}", user.room_id.as_str(),); - return Some(false); // caller is not joined - } + return Ok(true); } if membership == MembershipState::Invite { + if !caller_in_room { + tracing::warn!("invite sender not in room they are inviting user to"); + return Ok(false); + } + if target_banned { tracing::warn!("target has been banned"); - return Some(false); + return Ok(false); } else if target_in_room { tracing::warn!("already in room"); - return Some(false); // already in room + return Ok(false); // already in room } else { let invite_level = get_named_level(auth_events, "invite", 0); if user_level < invite_level { - return Some(false); + return Ok(false); } } } else if membership == MembershipState::Join { if user.sender != &target_user_id { tracing::warn!("cannot force another user to join"); - return Some(false); // cannot force another user to join + return Ok(false); // cannot force another user to join } else if target_banned { tracing::warn!("cannot join when banned"); - return Some(false); // cannot joined when banned + return Ok(false); // cannot joined when banned } else if join_rule == JoinRule::Public { tracing::info!("join rule public") // pass } else if join_rule == JoinRule::Invite { if !caller_in_room && !caller_invited { tracing::warn!("user has not been invited to this room"); - return Some(false); // you are not invited to this room + return Ok(false); // you are not invited to this room } } else { tracing::warn!("the join rule is Private or yet to be spec'ed by Matrix"); // synapse has 2 TODO's may_join list and private rooms // the join_rule is Private or Knock which means it is not yet spec'ed - return Some(false); + return Ok(false); } } else if membership == MembershipState::Leave { + if !caller_in_room { + tracing::warn!("sender not in room they are leaving"); + return Ok(false); + } + if target_banned && user_level < ban_level { tracing::warn!("not enough power to unban"); - return Some(false); // you cannot unban this user + return Ok(false); // you cannot unban this user } else if &target_user_id != user.sender { let kick_level = get_named_level(auth_events, "kick", 50); if user_level < kick_level || user_level <= target_level { tracing::warn!("not enough power to kick user"); - return Some(false); // you do not have the power to kick user + return Ok(false); // you do not have the power to kick user } } } else if membership == MembershipState::Ban { + if !caller_in_room { + tracing::warn!("ban sender not in room they are banning user from"); + return Ok(false); + } + tracing::debug!( "{} < {} || {} <= {}", user_level, @@ -432,17 +430,18 @@ pub fn is_membership_change_allowed( user_level, target_level ); + if user_level < ban_level || user_level <= target_level { tracing::warn!("not enough power to ban"); - return Some(false); + return Ok(false); } } else { tracing::warn!("unknown membership status"); // Unknown membership status - return Some(false); + return Ok(false); } - Some(true) + Ok(true) } /// Is the event's sender in the room that they sent the event to. @@ -451,19 +450,19 @@ pub fn is_membership_change_allowed( pub fn check_event_sender_in_room( event: &StateEvent, auth_events: &StateMap, -) -> Option { - let mem = auth_events.get(&(EventType::RoomMember, Some(event.sender().to_string())))?; +) -> Result { + let mem = auth_events + .get(&(EventType::RoomMember, Some(event.sender().to_string()))) + .ok_or_else(|| crate::Error::NotFound("Authe event was not found".into()))?; // TODO this is check_membership a helper fn in synapse but it does this - Some( - mem.deserialize_content::() - .ok()? - .membership - == MembershipState::Join, - ) + Ok(mem + .deserialize_content::()? + .membership + == MembershipState::Join) } /// Is the user allowed to send a specific event based on the rooms power levels. -pub fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Option { +pub fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Result { let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); let send_level = get_send_level(event.kind(), event.state_key(), ple); @@ -477,15 +476,15 @@ pub fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> ); if user_level < send_level { - return Some(false); + return Ok(false); } if let Some(sk) = event.state_key() { if sk.starts_with('@') && sk != event.sender().as_str() { - return Some(false); // permission required to post in this room + return Ok(false); // permission required to post in this room } } - Some(true) + Ok(true) } /// Confirm that the event sender has the required power levels. @@ -637,12 +636,12 @@ pub fn check_redaction( room_version: &RoomVersionId, redaction_event: &StateEvent, auth_events: &StateMap, -) -> Option { +) -> Result { let user_level = get_user_power_level(redaction_event.sender(), auth_events); let redact_level = get_named_level(auth_events, "redact", 50); if user_level >= redact_level { - return Some(RedactAllowed::CanRedact); + return Ok(RedactAllowed::CanRedact); } if let RoomVersionId::Version1 = room_version { @@ -650,14 +649,14 @@ pub fn check_redaction( if Some(redaction_event.event_id().server_name()) == redaction_event.redacts().map(|id| id.server_name()) { - return Some(RedactAllowed::OwnEvent); + return Ok(RedactAllowed::OwnEvent); } } else { // TODO synapse has this line also // event.internal_metadata.recheck_redaction = True - return Some(RedactAllowed::OwnEvent); + return Ok(RedactAllowed::OwnEvent); } - Some(RedactAllowed::No) + Ok(RedactAllowed::No) } /// Check that the member event matches `state`. diff --git a/src/lib.rs b/src/lib.rs index 0467ffa6..fe757aed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -553,10 +553,7 @@ impl StateResolution { tracing::debug!("event to check {:?}", event.event_id().to_string()); - if event_auth::auth_check(room_version, &event, auth_events, false) - .ok_or_else(|| "Auth check failed due to deserialization most likely".to_string()) - .map_err(Error::TempString)? - { + if event_auth::auth_check(room_version, &event, auth_events, false)? { // add event to resolved state map resolved_state.insert((event.kind(), event.state_key()), event_id.clone()); } else { From 025c2df752de4ef7f5e672630dff332c40235ba6 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Wed, 26 Aug 2020 11:04:30 -0400 Subject: [PATCH 2/7] Allow join room creator only if create event has no prev_events --- src/event_auth.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/event_auth.rs b/src/event_auth.rs index 4b071599..9e780860 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -292,7 +292,9 @@ pub fn is_membership_change_allowed( if let Some(create) = auth_events.get(&(EventType::RoomCreate, Some("".into()))) { if let Ok(create_ev) = create.deserialize_content::() { - if user.state_key == Some(create_ev.creator.to_string()) { + if user.state_key == Some(create_ev.creator.to_string()) + && create.prev_event_ids().is_empty() + { tracing::debug!("m.room.member event allowed via m.room.create"); return Ok(true); } From fbcd26c6d28dd28072003fc03644c55a0e3312d6 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Wed, 26 Aug 2020 20:08:48 -0400 Subject: [PATCH 3/7] All of event_auth follows the spec strictly, all the synapse-isms removed --- src/event_auth.rs | 170 +++++++++++++++++++++++++++++----------------- src/lib.rs | 12 +++- 2 files changed, 117 insertions(+), 65 deletions(-) diff --git a/src/event_auth.rs b/src/event_auth.rs index 9e780860..56172480 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -80,6 +80,7 @@ pub fn auth_types_for_event( pub fn auth_check( room_version: &RoomVersionId, event: &StateEvent, + redacted_event: Option<&StateEvent>, auth_events: StateMap, do_sig_check: bool, ) -> Result { @@ -132,19 +133,25 @@ pub fn auth_check( if event.kind() == EventType::RoomCreate { tracing::info!("start m.room.create check"); - // domain of room_id must match domain of sender. + // If it has any previous events, reject + if !event.prev_event_ids().is_empty() { + tracing::warn!("the room creation event had previous events"); + return Ok(false); + } + + // If the domain of the room_id does not match the domain of the sender, reject if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) { tracing::warn!("creation events server does not match sender"); return Ok(false); // creation events room id does not match senders } - // if content.room_version is present and is not a valid version + // If content.room_version is present and is not a recognized version, reject if serde_json::from_value::( event .content() .get("room_version") .cloned() - // synapse defaults to version 1 + // TODO synapse defaults to version 1 .unwrap_or_else(|| serde_json::json!("1")), ) .is_err() @@ -153,11 +160,17 @@ pub fn auth_check( return Ok(false); } + // If content has no creator field, reject + if event.content().get("creator").is_none() { + tracing::warn!("no creator field found in room create content"); + return Ok(false); + } + tracing::info!("m.room.create event was allowed"); return Ok(true); } - // 3. If event does not have m.room.create in auth_events reject. + // 3. If event does not have m.room.create in auth_events reject if auth_events .get(&(EventType::RoomCreate, Some("".into()))) .is_none() @@ -167,16 +180,7 @@ pub fn auth_check( return Ok(false); } - // check for m.federate - if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) { - tracing::info!("checking federation"); - - if !can_federate(&auth_events) { - tracing::warn!("federation not allowed"); - - return Ok(false); - } - } + // [synapse] checks for federation here // 4. if type is m.room.aliases if event.kind() == EventType::RoomAliases { @@ -186,13 +190,17 @@ pub fn auth_check( tracing::warn!("no state_key field found for event"); return Ok(false); // must have state_key } - if event.state_key().unwrap().is_empty() { - tracing::warn!("state_key must be non-empty"); - return Ok(false); // and be non-empty state_key (point to a user_id) - } + // TODO this is not part of the spec + // if event.state_key().unwrap().is_empty() { + // tracing::warn!("state_key must be non-empty"); + // return Ok(false); // and be non-empty state_key (point to a user_id) + // } + + // TODO what? "sender's domain doesn't matches" + // If sender's domain doesn't matches state_key, reject if event.state_key() != Some(event.sender().to_string()) { - tracing::warn!("no state_key field found for event"); + tracing::warn!("state_key does not match sender"); return Ok(false); } @@ -203,6 +211,19 @@ pub fn auth_check( if event.kind() == EventType::RoomMember { tracing::info!("starting m.room.member check"); + if event.state_key().is_none() { + tracing::warn!("no state_key found for m.room.member event"); + return Ok(false); + } + + if event + .deserialize_content::() + .is_err() + { + tracing::warn!("no membership filed found for m.room.member event content"); + return Ok(false); + } + if !is_membership_change_allowed(event.to_requester(), &auth_events)? { return Ok(false); } @@ -211,14 +232,17 @@ pub fn auth_check( return Ok(true); } - if let Ok(in_room) = check_event_sender_in_room(event, &auth_events) { - if !in_room { - tracing::warn!("sender not in room"); + // If the sender's current membership state is not join, reject + match check_event_sender_in_room(event, &auth_events) { + Some(true) => {} // sender in room + Some(false) => { + tracing::warn!("sender's membership is not join"); + return Ok(false); + } + None => { + tracing::warn!("sender not found in room"); return Ok(false); } - } else { - tracing::warn!("sender not in room"); - return Ok(false); } // Special case to allow m.room.third_party_invite events where ever @@ -228,6 +252,8 @@ pub fn auth_check( unimplemented!("third party invite") } + // If the event type's required power level is greater than the sender's power level, reject + // If the event has a state_key that starts with an @ and does not match the sender, reject. if !can_send_event(event, &auth_events)? { tracing::warn!("user cannot send event"); return Ok(false); @@ -235,6 +261,7 @@ pub fn auth_check( if event.kind() == EventType::RoomPowerLevels { tracing::info!("starting m.room.power_levels check"); + if let Some(required_pwr_lvl) = check_power_levels(room_version, event, &auth_events) { if !required_pwr_lvl { tracing::warn!("power level was not allowed"); @@ -248,7 +275,9 @@ pub fn auth_check( } if event.kind() == EventType::RoomRedaction { - if let RedactAllowed::No = check_redaction(room_version, event, &auth_events)? { + if let RedactAllowed::No = + check_redaction(room_version, event, redacted_event, &auth_events)? + { return Ok(false); } } @@ -282,7 +311,6 @@ pub fn is_membership_change_allowed( auth_events: &StateMap, ) -> Result { let content = - // TODO return error serde_json::from_str::(&user.content.to_string())?; let membership = content.membership; @@ -326,7 +354,7 @@ pub fn is_membership_change_allowed( .join_rule; } - let user_level = get_user_power_level(user.sender, auth_events); + let senders_level = get_user_power_level(user.sender, auth_events); let target_level = get_user_power_level(&target_user_id, auth_events); // synapse has a not "what to do for default here 50" @@ -363,11 +391,13 @@ pub fn is_membership_change_allowed( } if membership == MembershipState::Invite { + // if senders current membership is not join reject if !caller_in_room { tracing::warn!("invite sender not in room they are inviting user to"); return Ok(false); } + // If the targets current membership is ban or join if target_banned { tracing::warn!("target has been banned"); return Ok(false); @@ -376,11 +406,15 @@ pub fn is_membership_change_allowed( return Ok(false); // already in room } else { let invite_level = get_named_level(auth_events, "invite", 0); - if user_level < invite_level { + // If the sender's power level is greater than or equal to the invite level, allow. + if senders_level < invite_level { + tracing::warn!("invite sender does not have power to invite"); return Ok(false); } } + // we already check if the join event was the room creator } else if membership == MembershipState::Join { + // If the sender does not match state_key, reject. if user.sender != &target_user_id { tracing::warn!("cannot force another user to join"); return Ok(false); // cannot force another user to join @@ -403,23 +437,28 @@ pub fn is_membership_change_allowed( return Ok(false); } } else if membership == MembershipState::Leave { + if user.sender == &target_user_id && !(caller_in_room || caller_invited) {} + // if senders current membership is not join reject if !caller_in_room { tracing::warn!("sender not in room they are leaving"); return Ok(false); } - if target_banned && user_level < ban_level { + // If the target user's current membership state is ban, and the sender's power level is less than the ban level, reject + if target_banned && senders_level < ban_level { tracing::warn!("not enough power to unban"); return Ok(false); // you cannot unban this user - } else if &target_user_id != user.sender { - let kick_level = get_named_level(auth_events, "kick", 50); - if user_level < kick_level || user_level <= target_level { - tracing::warn!("not enough power to kick user"); - return Ok(false); // you do not have the power to kick user - } + // If the sender's power level is greater than or equal to the kick level, + // and the target user's power level is less than the sender's power level, allow + } else if senders_level <= get_named_level(auth_events, "kick", 50) + || target_level < senders_level + { + tracing::warn!("not enough power to kick user"); + return Ok(false); // you do not have the power to kick user } } else if membership == MembershipState::Ban { + // if senders current membership is not join reject if !caller_in_room { tracing::warn!("ban sender not in room they are banning user from"); return Ok(false); @@ -427,13 +466,13 @@ pub fn is_membership_change_allowed( tracing::debug!( "{} < {} || {} <= {}", - user_level, + senders_level, ban_level, - user_level, + senders_level, target_level ); - if user_level < ban_level || user_level <= target_level { + if senders_level < ban_level || senders_level <= target_level { tracing::warn!("not enough power to ban"); return Ok(false); } @@ -447,37 +486,36 @@ pub fn is_membership_change_allowed( } /// Is the event's sender in the room that they sent the event to. -/// -/// A return value of None is not a failure pub fn check_event_sender_in_room( event: &StateEvent, auth_events: &StateMap, -) -> Result { - let mem = auth_events - .get(&(EventType::RoomMember, Some(event.sender().to_string()))) - .ok_or_else(|| crate::Error::NotFound("Authe event was not found".into()))?; +) -> Option { + let mem = auth_events.get(&(EventType::RoomMember, Some(event.sender().to_string())))?; // TODO this is check_membership a helper fn in synapse but it does this - Ok(mem - .deserialize_content::()? - .membership - == MembershipState::Join) + Some( + mem.deserialize_content::() + .ok()? + .membership + == MembershipState::Join, + ) } -/// Is the user allowed to send a specific event based on the rooms power levels. +/// Is the user allowed to send a specific event based on the rooms power levels. Does the event +/// have the correct userId as it's state_key if it's not the "" state_key. pub fn can_send_event(event: &StateEvent, auth_events: &StateMap) -> Result { let ple = auth_events.get(&(EventType::RoomPowerLevels, Some("".into()))); - let send_level = get_send_level(event.kind(), event.state_key(), ple); + let event_type_power_level = get_send_level(event.kind(), event.state_key(), ple); let user_level = get_user_power_level(event.sender(), auth_events); tracing::debug!( - "{} snd {} usr {}", + "{} ev_type {} usr {}", event.event_id().to_string(), - send_level, + event_type_power_level, user_level ); - if user_level < send_level { + if user_level < event_type_power_level { return Ok(false); } @@ -495,17 +533,17 @@ pub fn check_power_levels( power_event: &StateEvent, auth_events: &StateMap, ) -> Option { - use itertools::Itertools; - let key = (power_event.kind(), power_event.state_key()); let current_state = if let Some(current_state) = auth_events.get(&key) { current_state } else { - // TODO synapse returns here, shouldn't this be an error ?? + // If there is no previous m.room.power_levels event in the room, allow return Some(true); }; + // If users key in content is not a dictionary with keys that are valid user IDs + // with values that are integers (or a string that is an integer), reject. let user_content = power_event .deserialize_content::() .unwrap(); @@ -521,7 +559,7 @@ pub fn check_power_levels( let mut user_levels_to_check = btreeset![]; let old_list = ¤t_content.users; let user_list = &user_content.users; - for user in old_list.keys().chain(user_list.keys()).dedup() { + for user in old_list.keys().chain(user_list.keys()) { let user: &UserId = user; user_levels_to_check.insert(user); } @@ -531,7 +569,7 @@ pub fn check_power_levels( let mut event_levels_to_check = btreeset![]; let old_list = ¤t_content.events; let new_list = &user_content.events; - for ev_id in old_list.keys().chain(new_list.keys()).dedup() { + for ev_id in old_list.keys().chain(new_list.keys()) { let ev_id: &EventType = ev_id; event_levels_to_check.insert(ev_id); } @@ -637,27 +675,31 @@ fn get_deserialize_levels( pub fn check_redaction( room_version: &RoomVersionId, redaction_event: &StateEvent, + redacted_event: Option<&StateEvent>, auth_events: &StateMap, ) -> Result { let user_level = get_user_power_level(redaction_event.sender(), auth_events); let redact_level = get_named_level(auth_events, "redact", 50); if user_level >= redact_level { + tracing::info!("redaction allowed via power levels"); return Ok(RedactAllowed::CanRedact); } if let RoomVersionId::Version1 = room_version { // are the redacter and redactee in the same domain - if Some(redaction_event.event_id().server_name()) - == redaction_event.redacts().map(|id| id.server_name()) + if Some(redaction_event.sender().server_name()) + == redaction_event.redacts().and_then(|id| id.server_name()) { + tracing::info!("redaction event allowed via room version 1 rules"); return Ok(RedactAllowed::OwnEvent); } - } else { - // TODO synapse has this line also - // event.internal_metadata.recheck_redaction = True + // redactions to events where the sender's domains match, allow + } else if redacted_event.map(|ev| ev.sender()) == Some(redaction_event.sender()) { + tracing::info!("redaction allowed via own redaction"); return Ok(RedactAllowed::OwnEvent); } + Ok(RedactAllowed::No) } diff --git a/src/lib.rs b/src/lib.rs index fe757aed..a7b09469 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -553,7 +553,17 @@ impl StateResolution { tracing::debug!("event to check {:?}", event.event_id().to_string()); - if event_auth::auth_check(room_version, &event, auth_events, false)? { + let redacted_event = event + .redacts() + .and_then(|id| StateResolution::get_or_load_event(room_id, id, event_map, store)); + + if event_auth::auth_check( + room_version, + &event, + redacted_event.as_ref(), + auth_events, + false, + )? { // add event to resolved state map resolved_state.insert((event.kind(), event.state_key()), event_id.clone()); } else { From 17958665f6592af3ef478024fd1d75c384a30e7f Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Wed, 26 Aug 2020 20:51:39 -0400 Subject: [PATCH 4/7] Update docs in event_auth and add first few event_auth tests --- src/event_auth.rs | 5 - tests/event_auth.rs | 280 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 5 deletions(-) diff --git a/src/event_auth.rs b/src/event_auth.rs index 56172480..e0fa6d0c 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -286,7 +286,6 @@ pub fn auth_check( Ok(true) } -// synapse has an `event: &StateEvent` param but it's never used /// Can this room federate based on its m.room.create event. pub fn can_federate(auth_events: &StateMap) -> bool { let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); @@ -302,10 +301,6 @@ pub fn can_federate(auth_events: &StateMap) -> bool { } /// Does the user who sent this member event have required power levels to do so. -/// -/// If called on it's own the following must be true: -/// - there must be a valid state_key in `user` -/// - there must be a membership key in `user.content` i.e. the event is of type "m.room.member" pub fn is_membership_change_allowed( user: Requester<'_>, auth_events: &StateMap, diff --git a/tests/event_auth.rs b/tests/event_auth.rs index e69de29b..f9d0c752 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -0,0 +1,280 @@ +use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom}; + +use ruma::{ + events::{ + pdu::EventHash, + room::{ + join_rules::JoinRule, + member::{MemberEventContent, MembershipState}, + }, + EventType, + }, + identifiers::{EventId, RoomId, RoomVersionId, UserId}, +}; +use serde_json::{json, Value as JsonValue}; +use state_res::{ + event_auth::{ + auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction, + is_membership_change_allowed, + }, + Requester, StateEvent, StateMap, StateStore, +}; +use tracing_subscriber as tracer; + +use std::sync::Once; + +static LOGGER: Once = Once::new(); + +static mut SERVER_TIMESTAMP: i32 = 0; + +fn event_id(id: &str) -> EventId { + if id.contains('$') { + return EventId::try_from(id).unwrap(); + } + EventId::try_from(format!("${}:foo", id)).unwrap() +} + +fn alice() -> UserId { + UserId::try_from("@alice:foo").unwrap() +} +fn bob() -> UserId { + UserId::try_from("@bob:foo").unwrap() +} +fn charlie() -> UserId { + UserId::try_from("@charlie:foo").unwrap() +} + +fn room_id() -> RoomId { + RoomId::try_from("!test:foo").unwrap() +} + +fn member_content_ban() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Ban, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +fn member_content_join() -> JsonValue { + serde_json::to_value(MemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + }) + .unwrap() +} + +pub struct TestStore(RefCell>); + +#[allow(unused)] +impl StateStore for TestStore { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + self.0 + .borrow() + .get(event_id) + .cloned() + .ok_or(format!("{} not found", event_id.to_string())) + } +} + +fn to_pdu_event( + id: &str, + sender: UserId, + ev_type: EventType, + state_key: Option<&str>, + content: JsonValue, + auth_events: &[S], + prev_events: &[S], +) -> StateEvent +where + S: AsRef, +{ + let ts = unsafe { + let ts = SERVER_TIMESTAMP; + // increment the "origin_server_ts" value + SERVER_TIMESTAMP += 1; + ts + }; + let id = if id.contains('$') { + id.to_string() + } else { + format!("${}:foo", id) + }; + let auth_events = auth_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .map(|id| { + ( + id, + EventHash { + sha256: "hello".into(), + }, + ) + }) + .collect::>(); + let prev_events = prev_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .map(|id| { + ( + id, + EventHash { + sha256: "hello".into(), + }, + ) + }) + .collect::>(); + + let json = if let Some(state_key) = state_key { + json!({ + "auth_events": auth_events, + "prev_events": prev_events, + "event_id": id, + "sender": sender, + "type": ev_type, + "state_key": state_key, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + } else { + json!({ + "auth_events": auth_events, + "prev_events": prev_events, + "event_id": id, + "sender": sender, + "type": ev_type, + "content": content, + "origin_server_ts": ts, + "room_id": room_id(), + "origin": "foo", + "depth": 0, + "hashes": { "sha256": "hello" }, + "signatures": {}, + }) + }; + serde_json::from_value(json).unwrap() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn INITIAL_EVENTS() -> BTreeMap { + // this is always called so we can init the logger here + let _ = LOGGER.call_once(|| { + tracer::fmt() + .with_env_filter(tracer::EnvFilter::from_default_env()) + .init() + }); + + vec![ + to_pdu_event::( + "CREATE", + alice(), + EventType::RoomCreate, + Some(""), + json!({ "creator": alice() }), + &[], + &[], + ), + to_pdu_event( + "IMA", + alice(), + EventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ), + to_pdu_event( + "IPOWER", + alice(), + EventType::RoomPowerLevels, + Some(""), + json!({"users": {alice().to_string(): 100}}), + &["CREATE", "IMA"], + &["IMA"], + ), + to_pdu_event( + "IJR", + alice(), + EventType::RoomJoinRules, + Some(""), + json!({ "join_rule": JoinRule::Public }), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ), + to_pdu_event( + "IMB", + bob(), + EventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IJR"], + ), + to_pdu_event( + "IMC", + charlie(), + EventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IMB"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id(), ev)) + .collect() +} + +#[test] +fn test_ban_pass() { + let events = INITIAL_EVENTS(); + + let auth_events = events + .values() + .map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) + .collect::>(); + + let requester = Requester { + prev_event_ids: vec![event_id("IMC")], + room_id: &room_id(), + content: &member_content_ban(), + state_key: Some(charlie().to_string()), + sender: &alice(), + }; + + assert!(is_membership_change_allowed(requester, &auth_events).unwrap()) +} + +#[test] +fn test_ban_fail() { + let events = INITIAL_EVENTS(); + + let auth_events = events + .values() + .map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) + .collect::>(); + + let requester = Requester { + prev_event_ids: vec![event_id("IMC")], + room_id: &room_id(), + content: &member_content_ban(), + state_key: Some(alice().to_string()), + sender: &charlie(), + }; + + assert!(!is_membership_change_allowed(requester, &auth_events).unwrap()) +} From aadccdee645d40610d443c2bf5098ad19bdabe63 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Thu, 27 Aug 2020 09:08:52 -0400 Subject: [PATCH 5/7] Fix DM room creator rejoining Check only the previous event is a RoomCreate event not that one exists --- src/event_auth.rs | 69 ++++++++++++++++++++------------------------- src/lib.rs | 10 ++++--- tests/event_auth.rs | 17 ++++++++--- 3 files changed, 49 insertions(+), 47 deletions(-) diff --git a/src/event_auth.rs b/src/event_auth.rs index e0fa6d0c..0f5f2cb3 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -80,7 +80,7 @@ pub fn auth_types_for_event( pub fn auth_check( room_version: &RoomVersionId, event: &StateEvent, - redacted_event: Option<&StateEvent>, + prev_event: Option<&StateEvent>, auth_events: StateMap, do_sig_check: bool, ) -> Result { @@ -197,9 +197,8 @@ pub fn auth_check( // return Ok(false); // and be non-empty state_key (point to a user_id) // } - // TODO what? "sender's domain doesn't matches" // If sender's domain doesn't matches state_key, reject - if event.state_key() != Some(event.sender().to_string()) { + if event.state_key().as_deref() != Some(event.sender().server_name().as_str()) { tracing::warn!("state_key does not match sender"); return Ok(false); } @@ -224,7 +223,7 @@ pub fn auth_check( return Ok(false); } - if !is_membership_change_allowed(event.to_requester(), &auth_events)? { + if !is_membership_change_allowed(event.to_requester(), prev_event, &auth_events)? { return Ok(false); } @@ -275,9 +274,7 @@ pub fn auth_check( } if event.kind() == EventType::RoomRedaction { - if let RedactAllowed::No = - check_redaction(room_version, event, redacted_event, &auth_events)? - { + if let RedactAllowed::No = check_redaction(room_version, event, &auth_events)? { return Ok(false); } } @@ -303,6 +300,7 @@ pub fn can_federate(auth_events: &StateMap) -> bool { /// Does the user who sent this member event have required power levels to do so. pub fn is_membership_change_allowed( user: Requester<'_>, + prev_event: Option<&StateEvent>, auth_events: &StateMap, ) -> Result { let content = @@ -312,7 +310,7 @@ pub fn is_membership_change_allowed( // If the only previous event is an m.room.create and the state_key is the creator, allow if user.prev_event_ids.len() == 1 && membership == MembershipState::Join { - if let Some(create) = auth_events.get(&(EventType::RoomCreate, Some("".into()))) { + if let Some(create) = prev_event { if let Ok(create_ev) = create.deserialize_content::() { if user.state_key == Some(create_ev.creator.to_string()) @@ -371,6 +369,7 @@ pub fn is_membership_change_allowed( .unwrap(), ); + // we already check if the join event was the room creator if membership == MembershipState::Invite && content.third_party_invite.is_some() { // TODO this is unimpled if !verify_third_party_invite(&user, auth_events) { @@ -383,9 +382,30 @@ pub fn is_membership_change_allowed( } tracing::info!("invite succeded"); return Ok(true); - } + } else if membership == MembershipState::Join { + // If the sender does not match state_key, reject. + if user.sender != &target_user_id { + tracing::warn!("cannot force another user to join"); + return Ok(false); // cannot force another user to join + } else if target_banned { + tracing::warn!("cannot join when banned"); + return Ok(false); // cannot joined when banned + } else if join_rule == JoinRule::Invite { + if !caller_in_room && !caller_invited { + tracing::warn!("user has not been invited to this room"); + return Ok(false); // you are not invited to this room + } + } else if join_rule == JoinRule::Public { + tracing::info!("join rule public") + // pass + } else { + tracing::warn!("the join rule is Private or yet to be spec'ed by Matrix"); + // synapse has 2 TODO's may_join list and private rooms - if membership == MembershipState::Invite { + // the join_rule is Private or Knock which means it is not yet spec'ed + return Ok(false); + } + } else if membership == MembershipState::Invite { // if senders current membership is not join reject if !caller_in_room { tracing::warn!("invite sender not in room they are inviting user to"); @@ -407,30 +427,6 @@ pub fn is_membership_change_allowed( return Ok(false); } } - // we already check if the join event was the room creator - } else if membership == MembershipState::Join { - // If the sender does not match state_key, reject. - if user.sender != &target_user_id { - tracing::warn!("cannot force another user to join"); - return Ok(false); // cannot force another user to join - } else if target_banned { - tracing::warn!("cannot join when banned"); - return Ok(false); // cannot joined when banned - } else if join_rule == JoinRule::Public { - tracing::info!("join rule public") - // pass - } else if join_rule == JoinRule::Invite { - if !caller_in_room && !caller_invited { - tracing::warn!("user has not been invited to this room"); - return Ok(false); // you are not invited to this room - } - } else { - tracing::warn!("the join rule is Private or yet to be spec'ed by Matrix"); - // synapse has 2 TODO's may_join list and private rooms - - // the join_rule is Private or Knock which means it is not yet spec'ed - return Ok(false); - } } else if membership == MembershipState::Leave { if user.sender == &target_user_id && !(caller_in_room || caller_invited) {} // if senders current membership is not join reject @@ -670,7 +666,6 @@ fn get_deserialize_levels( pub fn check_redaction( room_version: &RoomVersionId, redaction_event: &StateEvent, - redacted_event: Option<&StateEvent>, auth_events: &StateMap, ) -> Result { let user_level = get_user_power_level(redaction_event.sender(), auth_events); @@ -689,10 +684,6 @@ pub fn check_redaction( tracing::info!("redaction event allowed via room version 1 rules"); return Ok(RedactAllowed::OwnEvent); } - // redactions to events where the sender's domains match, allow - } else if redacted_event.map(|ev| ev.sender()) == Some(redaction_event.sender()) { - tracing::info!("redaction allowed via own redaction"); - return Ok(RedactAllowed::OwnEvent); } Ok(RedactAllowed::No) diff --git a/src/lib.rs b/src/lib.rs index a7b09469..42c9d197 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -553,14 +553,16 @@ impl StateResolution { tracing::debug!("event to check {:?}", event.event_id().to_string()); - let redacted_event = event - .redacts() - .and_then(|id| StateResolution::get_or_load_event(room_id, id, event_map, store)); + let most_recent_prev_event = event + .prev_event_ids() + .iter() + .filter_map(|id| StateResolution::get_or_load_event(room_id, id, event_map, store)) + .next_back(); if event_auth::auth_check( room_version, &event, - redacted_event.as_ref(), + most_recent_prev_event.as_ref(), auth_events, false, )? { diff --git a/tests/event_auth.rs b/tests/event_auth.rs index f9d0c752..7ffa58e7 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -9,12 +9,13 @@ use ruma::{ }, EventType, }, - identifiers::{EventId, RoomId, RoomVersionId, UserId}, + identifiers::{EventId, RoomId, UserId}, }; use serde_json::{json, Value as JsonValue}; +#[rustfmt::skip] // this deletes the comments for some reason yay! use state_res::{ event_auth::{ - auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction, + // auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction, is_membership_change_allowed, }, Requester, StateEvent, StateMap, StateStore, @@ -243,6 +244,10 @@ fn INITIAL_EVENTS() -> BTreeMap { fn test_ban_pass() { let events = INITIAL_EVENTS(); + let prev = events + .values() + .find(|ev| ev.event_id().as_str().contains("IMC")); + let auth_events = events .values() .map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) @@ -256,13 +261,17 @@ fn test_ban_pass() { sender: &alice(), }; - assert!(is_membership_change_allowed(requester, &auth_events).unwrap()) + assert!(is_membership_change_allowed(requester, prev, &auth_events).unwrap()) } #[test] fn test_ban_fail() { let events = INITIAL_EVENTS(); + let prev = events + .values() + .find(|ev| ev.event_id().as_str().contains("IMC")); + let auth_events = events .values() .map(|ev| ((ev.kind(), ev.state_key()), ev.clone())) @@ -276,5 +285,5 @@ fn test_ban_fail() { sender: &charlie(), }; - assert!(!is_membership_change_allowed(requester, &auth_events).unwrap()) + assert!(!is_membership_change_allowed(requester, prev, &auth_events).unwrap()) } From b846aec94a0efd22dc25e2cb88e35d8efa330892 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Thu, 27 Aug 2020 15:46:36 -0400 Subject: [PATCH 6/7] Replace membership auth with timo's logic --- src/event_auth.rs | 300 +++++++++++++++++++++----------------------- tests/event_auth.rs | 6 +- 2 files changed, 146 insertions(+), 160 deletions(-) diff --git a/src/event_auth.rs b/src/event_auth.rs index 0f5f2cb3..a904fb4d 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -1,19 +1,23 @@ -use std::convert::TryFrom; +use std::{collections::BTreeMap, convert::TryFrom}; use maplit::btreeset; use ruma::{ events::{ - room::{self, join_rules::JoinRule, member::MembershipState}, + room::{ + self, + join_rules::JoinRule, + member::{self, MembershipState}, + power_levels::{self, PowerLevelsEventContent}, + }, EventType, }, identifiers::{RoomVersionId, UserId}, }; -use serde_json::json; use crate::{ room_version::RoomVersion, state_event::{Requester, StateEvent}, - Result, StateMap, + Error, Result, StateMap, }; /// Represents the 3 event redaction outcomes. @@ -223,7 +227,7 @@ pub fn auth_check( return Ok(false); } - if !is_membership_change_allowed(event.to_requester(), prev_event, &auth_events)? { + if !valid_membership_change(event.to_requester(), prev_event, &auth_events)? { return Ok(false); } @@ -298,182 +302,164 @@ pub fn can_federate(auth_events: &StateMap) -> bool { } /// Does the user who sent this member event have required power levels to do so. -pub fn is_membership_change_allowed( +/// +/// * `user` - Information about the membership event and user making the request. +/// * `prev_event` - The event that occurred immediately before the `user` event or None. +/// * `auth_events` - The set of auth events that relate to a membership event. +/// this is generated by calling `auth_types_for_event` with the membership event and +/// the current State. +pub fn valid_membership_change( user: Requester<'_>, prev_event: Option<&StateEvent>, auth_events: &StateMap, ) -> Result { + let state_key = if let Some(s) = user.state_key.as_ref() { + s + } else { + return Err(Error::TempString("State event requires state_key".into())); + }; + let content = serde_json::from_str::(&user.content.to_string())?; - let membership = content.membership; + let target_membership = content.membership; - // If the only previous event is an m.room.create and the state_key is the creator, allow - if user.prev_event_ids.len() == 1 && membership == MembershipState::Join { - if let Some(create) = prev_event { - if let Ok(create_ev) = create.deserialize_content::() - { - if user.state_key == Some(create_ev.creator.to_string()) - && create.prev_event_ids().is_empty() - { - tracing::debug!("m.room.member event allowed via m.room.create"); - return Ok(true); - } - } - } - } - - let target_user_id = UserId::try_from(user.state_key.as_deref().unwrap()).unwrap(); + let target_user_id = + UserId::try_from(state_key.as_str()).map_err(|e| Error::TempString(format!("{}", e)))?; let key = (EventType::RoomMember, Some(user.sender.to_string())); - let caller = auth_events.get(&key); - - let caller_in_room = caller.is_some() && check_membership(caller, MembershipState::Join); - let caller_invited = caller.is_some() && check_membership(caller, MembershipState::Invite); + let sender = auth_events.get(&key); + let sender_membership = + sender.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { + Ok(pdu + .deserialize_content::()? + .membership) + })?; let key = (EventType::RoomMember, Some(target_user_id.to_string())); - let target = auth_events.get(&key); + let current = auth_events.get(&key); + let current_membership = + current.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { + Ok(pdu + .deserialize_content::()? + .membership) + })?; - let target_in_room = target.is_some() && check_membership(target, MembershipState::Join); - let target_banned = target.is_some() && check_membership(target, MembershipState::Ban); + let key = (EventType::RoomPowerLevels, Some("".into())); + let power_levels = auth_events.get(&key).map_or_else( + || { + Ok::<_, Error>(power_levels::PowerLevelsEventContent { + ban: 50.into(), + events: BTreeMap::new(), + events_default: 0.into(), + invite: 50.into(), + kick: 50.into(), + redact: 50.into(), + state_default: 0.into(), + users: BTreeMap::new(), + users_default: 0.into(), + notifications: ruma::events::room::power_levels::NotificationPowerLevels { + room: 50.into(), + }, + }) + }, + |power_levels| { + power_levels + .deserialize_content::() + .map_err(Into::into) + }, + )?; + + let sender_power = power_levels.users.get(&user.sender).map_or_else( + || { + if sender_membership != member::MembershipState::Join { + None + } else { + Some(&power_levels.users_default) + } + }, + // If it's okay, wrap with Some(_) + Some, + ); + let target_power = power_levels.users.get(&target_user_id).map_or_else( + || { + if target_membership != member::MembershipState::Join { + None + } else { + Some(&power_levels.users_default) + } + }, + // If it's okay, wrap with Some(_) + Some, + ); let key = (EventType::RoomJoinRules, Some("".to_string())); let join_rules_event = auth_events.get(&key); - - let mut join_rule = JoinRule::Invite; + let mut join_rules = JoinRule::Invite; if let Some(jr) = join_rules_event { - join_rule = jr + join_rules = jr .deserialize_content::()? .join_rule; } - let senders_level = get_user_power_level(user.sender, auth_events); - let target_level = get_user_power_level(&target_user_id, auth_events); - - // synapse has a not "what to do for default here 50" - let ban_level = get_named_level(auth_events, "ban", 50); - - // TODO clean this up - tracing::debug!( - "_is_membership_change_allowed: {}", - serde_json::to_string_pretty(&json!({ - "caller_in_room": caller_in_room, - "caller_invited": caller_invited, - "target_banned": target_banned, - "target_in_room": target_in_room, - "membership": membership, - "join_rule": join_rule, - "target_user_id": target_user_id, - "event.user_id": user.sender, - })) - .unwrap(), - ); - - // we already check if the join event was the room creator - if membership == MembershipState::Invite && content.third_party_invite.is_some() { - // TODO this is unimpled - if !verify_third_party_invite(&user, auth_events) { - tracing::warn!("not invited to this room",); - return Ok(false); + if let Some(prev) = prev_event { + if prev.kind() == EventType::RoomCreate && prev.prev_event_ids().is_empty() { + return Ok(true); } - if target_banned { - tracing::warn!("banned from this room",); - return Ok(false); - } - tracing::info!("invite succeded"); - return Ok(true); - } else if membership == MembershipState::Join { - // If the sender does not match state_key, reject. - if user.sender != &target_user_id { - tracing::warn!("cannot force another user to join"); - return Ok(false); // cannot force another user to join - } else if target_banned { - tracing::warn!("cannot join when banned"); - return Ok(false); // cannot joined when banned - } else if join_rule == JoinRule::Invite { - if !caller_in_room && !caller_invited { - tracing::warn!("user has not been invited to this room"); - return Ok(false); // you are not invited to this room - } - } else if join_rule == JoinRule::Public { - tracing::info!("join rule public") - // pass - } else { - tracing::warn!("the join rule is Private or yet to be spec'ed by Matrix"); - // synapse has 2 TODO's may_join list and private rooms - - // the join_rule is Private or Knock which means it is not yet spec'ed - return Ok(false); - } - } else if membership == MembershipState::Invite { - // if senders current membership is not join reject - if !caller_in_room { - tracing::warn!("invite sender not in room they are inviting user to"); - return Ok(false); - } - - // If the targets current membership is ban or join - if target_banned { - tracing::warn!("target has been banned"); - return Ok(false); - } else if target_in_room { - tracing::warn!("already in room"); - return Ok(false); // already in room - } else { - let invite_level = get_named_level(auth_events, "invite", 0); - // If the sender's power level is greater than or equal to the invite level, allow. - if senders_level < invite_level { - tracing::warn!("invite sender does not have power to invite"); - return Ok(false); - } - } - } else if membership == MembershipState::Leave { - if user.sender == &target_user_id && !(caller_in_room || caller_invited) {} - // if senders current membership is not join reject - if !caller_in_room { - tracing::warn!("sender not in room they are leaving"); - return Ok(false); - } - - // If the target user's current membership state is ban, and the sender's power level is less than the ban level, reject - if target_banned && senders_level < ban_level { - tracing::warn!("not enough power to unban"); - return Ok(false); // you cannot unban this user - - // If the sender's power level is greater than or equal to the kick level, - // and the target user's power level is less than the sender's power level, allow - } else if senders_level <= get_named_level(auth_events, "kick", 50) - || target_level < senders_level - { - tracing::warn!("not enough power to kick user"); - return Ok(false); // you do not have the power to kick user - } - } else if membership == MembershipState::Ban { - // if senders current membership is not join reject - if !caller_in_room { - tracing::warn!("ban sender not in room they are banning user from"); - return Ok(false); - } - - tracing::debug!( - "{} < {} || {} <= {}", - senders_level, - ban_level, - senders_level, - target_level - ); - - if senders_level < ban_level || senders_level <= target_level { - tracing::warn!("not enough power to ban"); - return Ok(false); - } - } else { - tracing::warn!("unknown membership status"); - // Unknown membership status - return Ok(false); } - Ok(true) + Ok(if target_membership == MembershipState::Join { + if user.sender != &target_user_id { + false + } else if let MembershipState::Ban = current_membership { + false + } else { + join_rules == JoinRule::Invite + && (current_membership == MembershipState::Join + || current_membership == MembershipState::Invite) + || join_rules == JoinRule::Public + } + } else if target_membership == MembershipState::Invite { + if let Some(_tp_id) = content.third_party_invite { + if current_membership == MembershipState::Ban { + false + } else { + // TODO this is not filled out + verify_third_party_invite(&user, auth_events) + } + } else if sender_membership != MembershipState::Join + || current_membership == MembershipState::Join + || current_membership == MembershipState::Ban + { + false + } else { + sender_power + .filter(|&p| p >= &power_levels.invite) + .is_some() + } + } else if target_membership == MembershipState::Leave { + if user.sender == &target_user_id { + current_membership == MembershipState::Join + || current_membership == MembershipState::Invite + } else if sender_membership != MembershipState::Join + || current_membership == MembershipState::Ban + && sender_power.filter(|&p| p < &power_levels.ban).is_some() + { + false + } else { + sender_power.filter(|&p| p >= &power_levels.kick).is_some() + && target_power < sender_power + } + } else if target_membership == MembershipState::Ban { + if sender_membership != MembershipState::Join { + false + } else { + sender_power.filter(|&p| p >= &power_levels.ban).is_some() + && target_power < sender_power + } + } else { + false + }) } /// Is the event's sender in the room that they sent the event to. diff --git a/tests/event_auth.rs b/tests/event_auth.rs index 7ffa58e7..6ae41778 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -16,7 +16,7 @@ use serde_json::{json, Value as JsonValue}; use state_res::{ event_auth::{ // auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction, - is_membership_change_allowed, + valid_membership_change, }, Requester, StateEvent, StateMap, StateStore, }; @@ -261,7 +261,7 @@ fn test_ban_pass() { sender: &alice(), }; - assert!(is_membership_change_allowed(requester, prev, &auth_events).unwrap()) + assert!(valid_membership_change(requester, prev, &auth_events).unwrap()) } #[test] @@ -285,5 +285,5 @@ fn test_ban_fail() { sender: &charlie(), }; - assert!(!is_membership_change_allowed(requester, prev, &auth_events).unwrap()) + assert!(!valid_membership_change(requester, prev, &auth_events).unwrap()) } From 394d26744a6586ccdc01838964bb27dab289eee5 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Thu, 27 Aug 2020 19:32:32 -0400 Subject: [PATCH 7/7] Use own Error type for all errors --- benches/state_res_bench.rs | 85 +++---------------------------- src/error.rs | 17 +++++-- src/event_auth.rs | 102 ++++++++++++++++++++----------------- src/lib.rs | 20 ++++---- src/state_store.rs | 18 ++----- tests/event_auth.rs | 6 +-- tests/event_sorting.rs | 6 +-- tests/res_with_auth_ids.rs | 8 +-- tests/state_res.rs | 86 +++---------------------------- 9 files changed, 105 insertions(+), 243 deletions(-) diff --git a/benches/state_res_bench.rs b/benches/state_res_bench.rs index d9cbb3ab..369750c6 100644 --- a/benches/state_res_bench.rs +++ b/benches/state_res_bench.rs @@ -3,12 +3,7 @@ // `cargo bench unknown option --save-baseline`. // To pass args to criterion, use this form // `cargo bench --bench -- --save-baseline `. -use std::{ - cell::RefCell, - collections::{BTreeMap, BTreeSet}, - convert::TryFrom, - time::UNIX_EPOCH, -}; +use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH}; use criterion::{criterion_group, criterion_main, Criterion}; use maplit::btreemap; @@ -24,7 +19,9 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; +use state_res::{ + Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore, +}; static mut SERVER_TIMESTAMP: i32 = 0; @@ -137,82 +134,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result, String> { - Ok(self - .0 - .borrow() - .iter() - .filter(|e| events.contains(e.0)) - .map(|(_, s)| s) - .cloned() - .collect()) - } - - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) - } - - fn auth_event_ids( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { - let mut result = vec![]; - let mut stack = event_ids.to_vec(); - - // DFS for auth event chain - while !stack.is_empty() { - let ev_id = stack.pop().unwrap(); - if result.contains(&ev_id) { - continue; - } - - result.push(ev_id.clone()); - - let event = self.get_event(room_id, &ev_id).unwrap(); - stack.extend(event.auth_events()); - } - - Ok(result) - } - - fn auth_chain_diff( - &self, - room_id: &RoomId, - event_ids: Vec>, - ) -> Result, String> { - use itertools::Itertools; - - let mut chains = vec![]; - for ids in event_ids { - // TODO state store `auth_event_ids` returns self in the event ids list - // when an event returns `auth_event_ids` self is not contained - let chain = self - .auth_event_ids(room_id, &ids)? - .into_iter() - .collect::>(); - chains.push(chain); - } - - if let Some(chain) = chains.first() { - let rest = chains.iter().skip(1).flatten().cloned().collect(); - let common = chain.intersection(&rest).collect::>(); - - Ok(chains - .iter() - .flatten() - .filter(|id| !common.contains(&id)) - .cloned() - .collect::>() - .into_iter() - .collect()) - } else { - Ok(vec![]) - } + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/src/error.rs b/src/error.rs index 4945f975..51ca8fcc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,7 +20,18 @@ pub enum Error { #[error("Not found error: {0}")] NotFound(String), - // TODO remove once the correct errors are used - #[error("an error occured {0}")] - TempString(String), + #[error("Invalid PDU: {0}")] + InvalidPdu(String), + + #[error("Conversion failed: {0}")] + ConversionError(String), + + #[error("{0}")] + Custom(Box), +} + +impl Error { + pub fn custom(e: E) -> Self { + Self::Custom(Box::new(e)) + } } diff --git a/src/event_auth.rs b/src/event_auth.rs index a904fb4d..54771408 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -83,26 +83,26 @@ pub fn auth_types_for_event( /// * then there are checks for specific event types pub fn auth_check( room_version: &RoomVersionId, - event: &StateEvent, + incoming_event: &StateEvent, prev_event: Option<&StateEvent>, auth_events: StateMap, do_sig_check: bool, ) -> Result { - tracing::info!("auth_check beginning for {}", event.event_id().as_str()); + tracing::info!("auth_check beginning for {}", incoming_event.kind()); // don't let power from other rooms be used for auth_event in auth_events.values() { - if auth_event.room_id() != event.room_id() { + if auth_event.room_id() != incoming_event.room_id() { tracing::warn!("found auth event that did not match event's room_id"); return Ok(false); } } if do_sig_check { - let sender_domain = event.sender().server_name(); + let sender_domain = incoming_event.sender().server_name(); - let is_invite_via_3pid = if event.kind() == EventType::RoomMember { - event + let is_invite_via_3pid = if incoming_event.kind() == EventType::RoomMember { + incoming_event .deserialize_content::() .map(|c| c.membership == MembershipState::Invite && c.third_party_invite.is_some()) .unwrap_or_default() @@ -111,15 +111,15 @@ pub fn auth_check( }; // check the event has been signed by the domain of the sender - if event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid { + if incoming_event.signatures().get(sender_domain).is_none() && !is_invite_via_3pid { tracing::warn!("event not signed by sender's server"); return Ok(false); } - if event.room_version() == RoomVersionId::Version1 - && event + if incoming_event.room_version() == RoomVersionId::Version1 + && incoming_event .signatures() - .get(event.event_id().server_name().unwrap()) + .get(incoming_event.event_id().server_name().unwrap()) .is_none() { tracing::warn!("event not signed by event_id's server"); @@ -134,24 +134,26 @@ pub fn auth_check( // Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules // // 1. If type is m.room.create: - if event.kind() == EventType::RoomCreate { + if incoming_event.kind() == EventType::RoomCreate { tracing::info!("start m.room.create check"); // If it has any previous events, reject - if !event.prev_event_ids().is_empty() { + if !incoming_event.prev_event_ids().is_empty() { tracing::warn!("the room creation event had previous events"); return Ok(false); } // If the domain of the room_id does not match the domain of the sender, reject - if event.room_id().map(|id| id.server_name()) != Some(event.sender().server_name()) { + if incoming_event.room_id().map(|id| id.server_name()) + != Some(incoming_event.sender().server_name()) + { tracing::warn!("creation events server does not match sender"); return Ok(false); // creation events room id does not match senders } // If content.room_version is present and is not a recognized version, reject if serde_json::from_value::( - event + incoming_event .content() .get("room_version") .cloned() @@ -165,7 +167,7 @@ pub fn auth_check( } // If content has no creator field, reject - if event.content().get("creator").is_none() { + if incoming_event.content().get("creator").is_none() { tracing::warn!("no creator field found in room create content"); return Ok(false); } @@ -187,10 +189,10 @@ pub fn auth_check( // [synapse] checks for federation here // 4. if type is m.room.aliases - if event.kind() == EventType::RoomAliases { + if incoming_event.kind() == EventType::RoomAliases { tracing::info!("starting m.room.aliases check"); // TODO && room_version "special case aliases auth" ?? - if event.state_key().is_none() { + if incoming_event.state_key().is_none() { tracing::warn!("no state_key field found for event"); return Ok(false); // must have state_key } @@ -202,7 +204,9 @@ pub fn auth_check( // } // If sender's domain doesn't matches state_key, reject - if event.state_key().as_deref() != Some(event.sender().server_name().as_str()) { + if incoming_event.state_key().as_deref() + != Some(incoming_event.sender().server_name().as_str()) + { tracing::warn!("state_key does not match sender"); return Ok(false); } @@ -211,15 +215,15 @@ pub fn auth_check( return Ok(true); } - if event.kind() == EventType::RoomMember { + if incoming_event.kind() == EventType::RoomMember { tracing::info!("starting m.room.member check"); - if event.state_key().is_none() { + if incoming_event.state_key().is_none() { tracing::warn!("no state_key found for m.room.member event"); return Ok(false); } - if event + if incoming_event .deserialize_content::() .is_err() { @@ -227,7 +231,7 @@ pub fn auth_check( return Ok(false); } - if !valid_membership_change(event.to_requester(), prev_event, &auth_events)? { + if !valid_membership_change(incoming_event.to_requester(), prev_event, &auth_events)? { return Ok(false); } @@ -236,7 +240,7 @@ pub fn auth_check( } // If the sender's current membership state is not join, reject - match check_event_sender_in_room(event, &auth_events) { + match check_event_sender_in_room(incoming_event.sender(), &auth_events) { Some(true) => {} // sender in room Some(false) => { tracing::warn!("sender's membership is not join"); @@ -250,22 +254,24 @@ pub fn auth_check( // Special case to allow m.room.third_party_invite events where ever // a user is allowed to issue invites - if event.kind() == EventType::RoomThirdPartyInvite { + if incoming_event.kind() == EventType::RoomThirdPartyInvite { // TODO impl this unimplemented!("third party invite") } // If the event type's required power level is greater than the sender's power level, reject // If the event has a state_key that starts with an @ and does not match the sender, reject. - if !can_send_event(event, &auth_events)? { + if !can_send_event(incoming_event, &auth_events)? { tracing::warn!("user cannot send event"); return Ok(false); } - if event.kind() == EventType::RoomPowerLevels { + if incoming_event.kind() == EventType::RoomPowerLevels { tracing::info!("starting m.room.power_levels check"); - if let Some(required_pwr_lvl) = check_power_levels(room_version, event, &auth_events) { + if let Some(required_pwr_lvl) = + check_power_levels(room_version, incoming_event, &auth_events) + { if !required_pwr_lvl { tracing::warn!("power level was not allowed"); return Ok(false); @@ -277,8 +283,8 @@ pub fn auth_check( tracing::info!("power levels event allowed"); } - if event.kind() == EventType::RoomRedaction { - if let RedactAllowed::No = check_redaction(room_version, event, &auth_events)? { + if incoming_event.kind() == EventType::RoomRedaction { + if let RedactAllowed::No = check_redaction(room_version, incoming_event, &auth_events)? { return Ok(false); } } @@ -287,20 +293,6 @@ pub fn auth_check( Ok(true) } -/// Can this room federate based on its m.room.create event. -pub fn can_federate(auth_events: &StateMap) -> bool { - let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); - if let Some(ev) = creation_event { - if let Some(fed) = ev.content().get("m.federate") { - fed == "true" - } else { - false - } - } else { - false - } -} - /// Does the user who sent this member event have required power levels to do so. /// /// * `user` - Information about the membership event and user making the request. @@ -316,7 +308,7 @@ pub fn valid_membership_change( let state_key = if let Some(s) = user.state_key.as_ref() { s } else { - return Err(Error::TempString("State event requires state_key".into())); + return Err(Error::InvalidPdu("State event requires state_key".into())); }; let content = @@ -324,8 +316,8 @@ pub fn valid_membership_change( let target_membership = content.membership; - let target_user_id = - UserId::try_from(state_key.as_str()).map_err(|e| Error::TempString(format!("{}", e)))?; + let target_user_id = UserId::try_from(state_key.as_str()) + .map_err(|e| Error::ConversionError(format!("{}", e)))?; let key = (EventType::RoomMember, Some(user.sender.to_string())); let sender = auth_events.get(&key); @@ -464,10 +456,10 @@ pub fn valid_membership_change( /// Is the event's sender in the room that they sent the event to. pub fn check_event_sender_in_room( - event: &StateEvent, + sender: &UserId, auth_events: &StateMap, ) -> Option { - let mem = auth_events.get(&(EventType::RoomMember, Some(event.sender().to_string())))?; + let mem = auth_events.get(&(EventType::RoomMember, Some(sender.to_string())))?; // TODO this is check_membership a helper fn in synapse but it does this Some( mem.deserialize_content::() @@ -692,6 +684,20 @@ pub fn check_membership(member_event: Option<&StateEvent>, state: MembershipStat } } +/// Can this room federate based on its m.room.create event. +pub fn can_federate(auth_events: &StateMap) -> bool { + let creation_event = auth_events.get(&(EventType::RoomCreate, Some("".into()))); + if let Some(ev) = creation_event { + if let Some(fed) = ev.content().get("m.federate") { + fed == "true" + } else { + false + } + } else { + false + } +} + /// Helper function to fetch a field, `name`, from a "m.room.power_level" event's content. /// or return `default` if no power level event is found or zero if no field matches `name`. pub fn get_named_level(auth_events: &StateMap, name: &str, default: i64) -> i64 { diff --git a/src/lib.rs b/src/lib.rs index 42c9d197..375f22bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,7 +119,7 @@ impl StateResolution { for event in event_map.values() { if event.room_id() != Some(room_id) { - return Err(Error::TempString(format!( + return Err(Error::InvalidPdu(format!( "resolving event {} in room {}, when correct room is {}", event.event_id(), event.room_id().map(|id| id.as_str()).unwrap_or("`unknown`"), @@ -288,16 +288,14 @@ impl StateResolution { tracing::debug!("calculating auth chain difference"); - store - .auth_chain_diff( - room_id, - state_sets - .iter() - .map(|map| map.values().cloned().collect()) - .dedup() - .collect::>(), - ) - .map_err(Error::TempString) + store.auth_chain_diff( + room_id, + state_sets + .iter() + .map(|map| map.values().cloned().collect()) + .dedup() + .collect::>(), + ) } /// Events are sorted from "earliest" to "latest". They are compared using diff --git a/src/state_store.rs b/src/state_store.rs index 80621bf0..777913ed 100644 --- a/src/state_store.rs +++ b/src/state_store.rs @@ -2,18 +2,14 @@ use std::collections::BTreeSet; use ruma::identifiers::{EventId, RoomId}; -use crate::StateEvent; +use crate::{Result, StateEvent}; pub trait StateStore { /// Return a single event based on the EventId. - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result; + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result; /// Returns the events that correspond to the `event_ids` sorted in the same order. - fn get_events( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { + fn get_events(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result> { let mut events = vec![]; for id in event_ids { events.push(self.get_event(room_id, id)?); @@ -22,11 +18,7 @@ pub trait StateStore { } /// Returns a Vec of the related auth events to the given `event`. - fn auth_event_ids( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { + fn auth_event_ids(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result> { let mut result = vec![]; let mut stack = event_ids.to_vec(); @@ -52,7 +44,7 @@ pub trait StateStore { &self, room_id: &RoomId, event_ids: Vec>, - ) -> Result, String> { + ) -> Result> { let mut chains = vec![]; for ids in event_ids { // TODO state store `auth_event_ids` returns self in the event ids list diff --git a/tests/event_auth.rs b/tests/event_auth.rs index 6ae41778..4fb91005 100644 --- a/tests/event_auth.rs +++ b/tests/event_auth.rs @@ -18,7 +18,7 @@ use state_res::{ // auth_check, auth_types_for_event, can_federate, check_power_levels, check_redaction, valid_membership_change, }, - Requester, StateEvent, StateMap, StateStore, + Requester, StateEvent, StateMap, StateStore, Result, Error }; use tracing_subscriber as tracer; @@ -75,12 +75,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/tests/event_sorting.rs b/tests/event_sorting.rs index 4cd0317d..c9d9248a 100644 --- a/tests/event_sorting.rs +++ b/tests/event_sorting.rs @@ -12,7 +12,7 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{StateEvent, StateMap, StateStore}; +use state_res::{Error, Result, StateEvent, StateMap, StateStore}; use tracing_subscriber as tracer; use std::sync::Once; @@ -57,12 +57,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/tests/res_with_auth_ids.rs b/tests/res_with_auth_ids.rs index b4dbab52..7b2d9d35 100644 --- a/tests/res_with_auth_ids.rs +++ b/tests/res_with_auth_ids.rs @@ -14,7 +14,9 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; +use state_res::{ + Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore, +}; use tracing_subscriber as tracer; static LOGGER: Once = Once::new(); @@ -200,12 +202,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } } diff --git a/tests/state_res.rs b/tests/state_res.rs index 40a89e4b..aaed38ea 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -1,9 +1,4 @@ -use std::{ - cell::RefCell, - collections::{BTreeMap, BTreeSet}, - convert::TryFrom, - time::UNIX_EPOCH, -}; +use std::{cell::RefCell, collections::BTreeMap, convert::TryFrom, time::UNIX_EPOCH}; use maplit::btreemap; use ruma::{ @@ -18,7 +13,9 @@ use ruma::{ identifiers::{EventId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, Value as JsonValue}; -use state_res::{ResolutionResult, StateEvent, StateMap, StateResolution, StateStore}; +use state_res::{ + Error, ResolutionResult, Result, StateEvent, StateMap, StateResolution, StateStore, +}; use tracing_subscriber as tracer; use std::sync::Once; @@ -768,83 +765,12 @@ pub struct TestStore(RefCell>); #[allow(unused)] impl StateStore for TestStore { - fn get_events(&self, room_id: &RoomId, events: &[EventId]) -> Result, String> { - Ok(self - .0 - .borrow() - .iter() - .filter(|e| events.contains(e.0)) - .map(|(_, s)| s) - .cloned() - .collect()) - } - - fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.0 .borrow() .get(event_id) .cloned() - .ok_or(format!("{} not found", event_id.to_string())) - } - - fn auth_event_ids( - &self, - room_id: &RoomId, - event_ids: &[EventId], - ) -> Result, String> { - let mut result = vec![]; - let mut stack = event_ids.to_vec(); - - // DFS for auth event chain - while !stack.is_empty() { - let ev_id = stack.pop().unwrap(); - if result.contains(&ev_id) { - continue; - } - - result.push(ev_id.clone()); - - let event = self.get_event(room_id, &ev_id).unwrap(); - - stack.extend(event.auth_events()); - } - - Ok(result) - } - - fn auth_chain_diff( - &self, - room_id: &RoomId, - event_ids: Vec>, - ) -> Result, String> { - use itertools::Itertools; - - let mut chains = vec![]; - for ids in event_ids { - // TODO state store `auth_event_ids` returns self in the event ids list - // when an event returns `auth_event_ids` self is not contained - let chain = self - .auth_event_ids(room_id, &ids)? - .into_iter() - .collect::>(); - chains.push(chain); - } - - if let Some(chain) = chains.first() { - let rest = chains.iter().skip(1).flatten().cloned().collect(); - let common = chain.intersection(&rest).collect::>(); - - Ok(chains - .iter() - .flatten() - .filter(|id| !common.contains(&id)) - .cloned() - .collect::>() - .into_iter() - .collect()) - } else { - Ok(vec![]) - } + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id.to_string()))) } }