From d809e6e3659940f25446b9c7808685fa59d6e9aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= <76261501+zecakeh@users.noreply.github.com> Date: Wed, 23 Aug 2023 17:55:56 +0200 Subject: [PATCH] events: Don't filter out any responses if there is no end timestamp --- crates/ruma-common/src/events/poll.rs | 82 ++++++++----------- .../ruma-common/src/events/poll/response.rs | 2 +- crates/ruma-common/src/events/poll/start.rs | 8 +- .../src/events/poll/unstable_response.rs | 2 +- .../src/events/poll/unstable_start.rs | 8 +- crates/ruma-common/tests/events/poll.rs | 15 +++- 6 files changed, 61 insertions(+), 56 deletions(-) diff --git a/crates/ruma-common/src/events/poll.rs b/crates/ruma-common/src/events/poll.rs index 1e74e708..5d5407e8 100644 --- a/crates/ruma-common/src/events/poll.rs +++ b/crates/ruma-common/src/events/poll.rs @@ -52,29 +52,9 @@ pub fn compile_poll_results<'a>( responses: impl IntoIterator>, end_timestamp: Option, ) -> IndexMap<&'a str, BTreeSet<&'a UserId>> { - let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now); - - let users_selections = responses - .into_iter() - .filter(|ev| { - // Filter out responses after the end_timestamp. - ev.origin_server_ts <= end_ts - }) - .fold(BTreeMap::new(), |mut acc, data| { - let response = - acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None)); - - // Only keep the latest selections for each user. - if response.0 < data.origin_server_ts { - let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); - *response = ( - data.origin_server_ts, - validate_selections(answer_ids, poll.max_selections, data.selections), - ); - } - - acc - }); + let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); + let users_selections = + filter_selections(answer_ids, poll.max_selections, responses, end_timestamp); aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections) } @@ -95,36 +75,16 @@ pub fn compile_unstable_poll_results<'a>( responses: impl IntoIterator>, end_timestamp: Option, ) -> IndexMap<&'a str, BTreeSet<&'a UserId>> { - let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now); - - let users_selections = responses - .into_iter() - .filter(|ev| { - // Filter out responses after the end_timestamp. - ev.origin_server_ts <= end_ts - }) - .fold(BTreeMap::new(), |mut acc, data| { - let response = - acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None)); - - // Only keep the latest selections for each user. - if response.0 < data.origin_server_ts { - let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); - *response = ( - data.origin_server_ts, - validate_selections(answer_ids, poll.max_selections, data.selections), - ); - } - - acc - }); + let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); + let users_selections = + filter_selections(answer_ids, poll.max_selections, responses, end_timestamp); aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections) } /// Validate the selections of a response. fn validate_selections<'a>( - answer_ids: BTreeSet<&str>, + answer_ids: &BTreeSet<&str>, max_selections: UInt, selections: &'a [String], ) -> Option> { @@ -140,6 +100,34 @@ fn validate_selections<'a>( Some(selections.iter().take(max_selections).map(Deref::deref)) } +fn filter_selections<'a>( + answer_ids: BTreeSet<&str>, + max_selections: UInt, + responses: impl IntoIterator>, + end_timestamp: Option, +) -> BTreeMap<&'a UserId, (MilliSecondsSinceUnixEpoch, Option>)> { + responses + .into_iter() + .filter(|ev| { + // Filter out responses after the end_timestamp. + end_timestamp.map_or(true, |end_ts| ev.origin_server_ts <= end_ts) + }) + .fold(BTreeMap::new(), |mut acc, data| { + let response = + acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None)); + + // Only keep the latest selections for each user. + if response.0 < data.origin_server_ts { + *response = ( + data.origin_server_ts, + validate_selections(&answer_ids, max_selections, data.selections), + ); + } + + acc + }) +} + /// Aggregate the given selections by answer. fn aggregate_results<'a>( answers: impl Iterator, diff --git a/crates/ruma-common/src/events/poll/response.rs b/crates/ruma-common/src/events/poll/response.rs index d9649ce3..4204cf9c 100644 --- a/crates/ruma-common/src/events/poll/response.rs +++ b/crates/ruma-common/src/events/poll/response.rs @@ -94,7 +94,7 @@ impl SelectionsContentBlock { poll: &PollContentBlock, ) -> Option> { let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); - validate_selections(answer_ids, poll.max_selections, &self.0) + validate_selections(&answer_ids, poll.max_selections, &self.0) } } diff --git a/crates/ruma-common/src/events/poll/start.rs b/crates/ruma-common/src/events/poll/start.rs index 273d6c02..450edde5 100644 --- a/crates/ruma-common/src/events/poll/start.rs +++ b/crates/ruma-common/src/events/poll/start.rs @@ -18,7 +18,7 @@ use super::{ use crate::{ events::{message::TextContentBlock, room::message::Relation}, serde::StringEnum, - PrivOwnedStr, + MilliSecondsSinceUnixEpoch, PrivOwnedStr, }; /// The payload for a poll start event. @@ -90,7 +90,11 @@ impl OriginalSyncPollStartEvent { &'a self, responses: impl IntoIterator>, ) -> PollEndEventContent { - let full_results = compile_poll_results(&self.content.poll, responses, None); + let full_results = compile_poll_results( + &self.content.poll, + responses, + Some(MilliSecondsSinceUnixEpoch::now()), + ); let results = full_results.into_iter().map(|(id, users)| (id, users.len())).collect::>(); diff --git a/crates/ruma-common/src/events/poll/unstable_response.rs b/crates/ruma-common/src/events/poll/unstable_response.rs index d2d6df0a..e8ac5eaf 100644 --- a/crates/ruma-common/src/events/poll/unstable_response.rs +++ b/crates/ruma-common/src/events/poll/unstable_response.rs @@ -86,7 +86,7 @@ impl UnstablePollResponseContentBlock { poll: &UnstablePollStartContentBlock, ) -> Option> { let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); - validate_selections(answer_ids, poll.max_selections, &self.answers) + validate_selections(&answer_ids, poll.max_selections, &self.answers) } } diff --git a/crates/ruma-common/src/events/poll/unstable_start.rs b/crates/ruma-common/src/events/poll/unstable_start.rs index 53e3a538..f48ebf7f 100644 --- a/crates/ruma-common/src/events/poll/unstable_start.rs +++ b/crates/ruma-common/src/events/poll/unstable_start.rs @@ -16,7 +16,7 @@ use super::{ unstable_end::UnstablePollEndEventContent, PollResponseData, }; -use crate::events::room::message::Relation; +use crate::{events::room::message::Relation, MilliSecondsSinceUnixEpoch}; /// The payload for an unstable poll start event. /// @@ -73,7 +73,11 @@ impl OriginalSyncUnstablePollStartEvent { &'a self, responses: impl IntoIterator>, ) -> UnstablePollEndEventContent { - let full_results = compile_unstable_poll_results(&self.content.poll_start, responses, None); + let full_results = compile_unstable_poll_results( + &self.content.poll_start, + responses, + Some(MilliSecondsSinceUnixEpoch::now()), + ); let results = full_results.into_iter().map(|(id, users)| (id, users.len())).collect::>(); diff --git a/crates/ruma-common/tests/events/poll.rs b/crates/ruma-common/tests/events/poll.rs index f0728b34..92eb3161 100644 --- a/crates/ruma-common/tests/events/poll.rs +++ b/crates/ruma-common/tests/events/poll.rs @@ -886,12 +886,13 @@ fn compute_results() { assert_eq!(*results.get("italian").unwrap(), uint!(6)); assert_eq!(*results.get("wings").unwrap(), uint!(7)); - // Response in the future is ignored. - let future_ts = MilliSecondsSinceUnixEpoch::now().0 + uint!(100_000); + // Response later than end_timestamp is ignored. + let now = MilliSecondsSinceUnixEpoch::now(); + let future_ts = now.0 + uint!(100_000); responses.push(new_poll_response("$future_event", changing_user_3, future_ts, &["pizza"])); let counted = - compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), Some(now)); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 6); @@ -903,6 +904,14 @@ fn compute_results() { assert_eq!(*results.get("poutine").unwrap(), uint!(8)); assert_eq!(*results.get("italian").unwrap(), uint!(6)); assert_eq!(*results.get("wings").unwrap(), uint!(7)); + + // Response in the future is not ignored if there is no end_timestamp. + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); + assert_eq!(counted.get("pizza").unwrap().len(), 7); + assert_eq!(counted.get("poutine").unwrap().len(), 8); + assert_eq!(counted.get("italian").unwrap().len(), 6); + assert_eq!(counted.get("wings").unwrap().len(), 6); } fn new_unstable_poll_response(