diff --git a/crates/ruma-state-res/src/event_auth.rs b/crates/ruma-state-res/src/event_auth.rs index 60d45709..2c7005f2 100644 --- a/crates/ruma-state-res/src/event_auth.rs +++ b/crates/ruma-state-res/src/event_auth.rs @@ -1,6 +1,9 @@ use std::{borrow::Borrow, collections::BTreeSet}; -use futures_util::Future; +use futures_util::{ + future::{join3, OptionFuture}, + Future, +}; use js_int::{int, Int}; use ruma_common::{ serde::{Base64, Raw}, @@ -224,7 +227,14 @@ where } */ - let room_create_event = match fetch_state(&StateEventType::RoomCreate, "").await { + let (room_create_event, power_levels_event, sender_member_event) = join3( + fetch_state(&StateEventType::RoomCreate, ""), + fetch_state(&StateEventType::RoomPowerLevels, ""), + fetch_state(&StateEventType::RoomMember, sender.as_str()), + ) + .await; + + let room_create_event = match room_create_event { None => { warn!("no m.room.create event in auth chain"); return Ok(false); @@ -273,9 +283,6 @@ where } // If type is m.room.member - let power_levels_event = fetch_state(&StateEventType::RoomPowerLevels, "").await; - let sender_member_event = fetch_state(&StateEventType::RoomMember, sender.as_str()).await; - if *incoming_event.event_type() == TimelineEventType::RoomMember { debug!("starting m.room.member check"); let state_key = match incoming_event.state_key() { @@ -298,27 +305,34 @@ where let user_for_join_auth = content.join_authorised_via_users_server.as_ref().and_then(|u| u.deserialize().ok()); - let user_for_join_auth_event = if let Some(auth_user) = user_for_join_auth.as_ref() { - fetch_state(&StateEventType::RoomMember, auth_user.as_str()).await - } else { - None - }; + let user_for_join_auth_event: OptionFuture<_> = user_for_join_auth + .as_ref() + .map(|auth_user| fetch_state(&StateEventType::RoomMember, auth_user.as_str())) + .into(); + + let target_user_member_event = + fetch_state(&StateEventType::RoomMember, target_user.as_str()); + + let join_rules_event = fetch_state(&StateEventType::RoomJoinRules, ""); + + let (join_rules_event, target_user_member_event, user_for_join_auth_event) = + join3(join_rules_event, target_user_member_event, user_for_join_auth_event).await; let user_for_join_auth_membership = user_for_join_auth_event - .and_then(|mem| from_json_str::(mem.content().get()).ok()) + .and_then(|mem| from_json_str::(mem?.content().get()).ok()) .map(|mem| mem.membership) .unwrap_or(MembershipState::Leave); if !valid_membership_change( room_version, target_user, - fetch_state(&StateEventType::RoomMember, target_user.as_str()).await.as_ref(), + target_user_member_event.as_ref(), sender, sender_member_event.as_ref(), &incoming_event, current_third_party_invite, power_levels_event.as_ref(), - fetch_state(&StateEventType::RoomJoinRules, "").await.as_ref(), + join_rules_event.as_ref(), user_for_join_auth.as_deref(), &user_for_join_auth_membership, room_create_event,