From d22d83522bdb36abdf2969b52c271a8bbd96c237 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Fri, 14 Aug 2020 07:39:30 -0400 Subject: [PATCH] Make auth_types_for_event take the ruma types instead of StateEvent --- src/event_auth.rs | 17 +++++++++++------ src/lib.rs | 7 ++++++- tests/res_with_auth_ids.rs | 7 ++++++- tests/state_res.rs | 7 ++++++- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/event_auth.rs b/src/event_auth.rs index ace434de..8143fb4d 100644 --- a/src/event_auth.rs +++ b/src/event_auth.rs @@ -26,19 +26,24 @@ pub enum RedactAllowed { No, } -pub fn auth_types_for_event(event: &StateEvent) -> Vec<(EventType, Option)> { - if event.kind() == EventType::RoomCreate { +pub fn auth_types_for_event( + kind: EventType, + sender: &UserId, + state_key: Option, + content: serde_json::Value, +) -> Vec<(EventType, Option)> { + if kind == EventType::RoomCreate { return vec![]; } let mut auth_types = vec![ (EventType::RoomPowerLevels, Some("".to_string())), - (EventType::RoomMember, Some(event.sender().to_string())), + (EventType::RoomMember, Some(sender.to_string())), (EventType::RoomCreate, Some("".to_string())), ]; - if event.kind() == EventType::RoomMember { - if let Ok(content) = event.deserialize_content::() { + if kind == EventType::RoomMember { + if let Ok(content) = serde_json::from_value::(content) { if [MembershipState::Join, MembershipState::Invite].contains(&content.membership) { let key = (EventType::RoomJoinRules, Some("".into())); if !auth_types.contains(&key) { @@ -47,7 +52,7 @@ pub fn auth_types_for_event(event: &StateEvent) -> Vec<(EventType, Option>, expected_state_ids: state_after.insert((ty, key), event_id.clone()); } - let auth_types = state_res::auth_types_for_event(fake_event); + let auth_types = state_res::auth_types_for_event( + fake_event.kind(), + fake_event.sender(), + fake_event.state_key(), + fake_event.content().clone(), + ); let mut auth_events = vec![]; for key in auth_types { diff --git a/tests/state_res.rs b/tests/state_res.rs index 399133d6..006792dd 100644 --- a/tests/state_res.rs +++ b/tests/state_res.rs @@ -399,7 +399,12 @@ fn do_check(events: &[StateEvent], edges: Vec>, expected_state_ids: state_after.insert((ty, key), event_id.clone()); } - let auth_types = state_res::auth_types_for_event(fake_event); + let auth_types = state_res::auth_types_for_event( + fake_event.kind(), + fake_event.sender(), + fake_event.state_key(), + fake_event.content().clone(), + ); let mut auth_events = vec![]; for key in auth_types {