diff --git a/crates/ruma-common/src/events/poll.rs b/crates/ruma-common/src/events/poll.rs index af388019..1e74e708 100644 --- a/crates/ruma-common/src/events/poll.rs +++ b/crates/ruma-common/src/events/poll.rs @@ -4,16 +4,15 @@ //! //! [MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381 -use std::collections::{BTreeMap, BTreeSet}; +use std::{ + collections::{BTreeMap, BTreeSet}, + ops::Deref, +}; use indexmap::IndexMap; -use js_int::uint; +use js_int::{uint, UInt}; -use self::{ - response::OriginalSyncPollResponseEvent, start::PollContentBlock, - unstable_response::OriginalSyncUnstablePollResponseEvent, - unstable_start::UnstablePollStartContentBlock, -}; +use self::{start::PollContentBlock, unstable_start::UnstablePollStartContentBlock}; use crate::{MilliSecondsSinceUnixEpoch, UserId}; pub mod end; @@ -23,6 +22,20 @@ pub mod unstable_end; pub mod unstable_response; pub mod unstable_start; +/// The data from a poll response necessary to compile poll results. +#[derive(Debug, Clone, Copy)] +#[allow(clippy::exhaustive_structs)] +pub struct PollResponseData<'a> { + /// The sender of the response. + pub sender: &'a UserId, + + /// The time of creation of the response on the originating server. + pub origin_server_ts: MilliSecondsSinceUnixEpoch, + + /// The selections/answers of the response. + pub selections: &'a [String], +} + /// Generate the current results with the given poll and responses. /// /// If the `end_timestamp` is provided, any response with an `origin_server_ts` after that timestamp @@ -36,7 +49,7 @@ pub mod unstable_start; /// lowest. pub fn compile_poll_results<'a>( poll: &'a PollContentBlock, - responses: impl IntoIterator, + responses: impl IntoIterator>, end_timestamp: Option, ) -> IndexMap<&'a str, BTreeSet<&'a UserId>> { let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now); @@ -47,13 +60,17 @@ pub fn compile_poll_results<'a>( // Filter out responses after the end_timestamp. ev.origin_server_ts <= end_ts }) - .fold(BTreeMap::new(), |mut acc, ev| { + .fold(BTreeMap::new(), |mut acc, data| { let response = - acc.entry(&*ev.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None)); + acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None)); // Only keep the latest selections for each user. - if response.0 < ev.origin_server_ts { - *response = (ev.origin_server_ts, ev.content.selections.validate(poll)); + if response.0 < data.origin_server_ts { + let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); + *response = ( + data.origin_server_ts, + validate_selections(answer_ids, poll.max_selections, data.selections), + ); } acc @@ -75,7 +92,7 @@ pub fn compile_poll_results<'a>( /// lowest. pub fn compile_unstable_poll_results<'a>( poll: &'a UnstablePollStartContentBlock, - responses: impl IntoIterator, + responses: impl IntoIterator>, end_timestamp: Option, ) -> IndexMap<&'a str, BTreeSet<&'a UserId>> { let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now); @@ -86,13 +103,17 @@ pub fn compile_unstable_poll_results<'a>( // Filter out responses after the end_timestamp. ev.origin_server_ts <= end_ts }) - .fold(BTreeMap::new(), |mut acc, ev| { + .fold(BTreeMap::new(), |mut acc, data| { let response = - acc.entry(&*ev.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None)); + acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None)); // Only keep the latest selections for each user. - if response.0 < ev.origin_server_ts { - *response = (ev.origin_server_ts, ev.content.poll_response.validate(poll)); + if response.0 < data.origin_server_ts { + let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); + *response = ( + data.origin_server_ts, + validate_selections(answer_ids, poll.max_selections, data.selections), + ); } acc @@ -101,7 +122,25 @@ pub fn compile_unstable_poll_results<'a>( aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections) } -// Aggregate the given selections by answer. +/// Validate the selections of a response. +fn validate_selections<'a>( + answer_ids: BTreeSet<&str>, + max_selections: UInt, + selections: &'a [String], +) -> Option> { + // Vote is spoiled if any answer is unknown. + if selections.iter().any(|s| !answer_ids.contains(s.as_str())) { + return None; + } + + // Fallback to the maximum value for usize because we can't have more selections than that + // in memory. + let max_selections: usize = max_selections.try_into().unwrap_or(usize::MAX); + + Some(selections.iter().take(max_selections).map(Deref::deref)) +} + +/// Aggregate the given selections by answer. fn aggregate_results<'a>( answers: impl Iterator, users_selections: BTreeMap< @@ -164,13 +203,13 @@ fn generate_poll_end_fallback_text<'a>( // Construct the plain text representation. match top_answers_text.len() { - l if l > 1 => { + 0 => "The poll has closed with no top answer".to_owned(), + 1 => { + format!("The poll has closed. Top answer: {}", top_answers_text[0]) + } + _ => { let answers = top_answers_text.join(", "); format!("The poll has closed. Top answers: {answers}") } - l if l == 1 => { - format!("The poll has closed. Top answer: {}", top_answers_text[0]) - } - _ => "The poll has closed with no top answer".to_owned(), } } diff --git a/crates/ruma-common/src/events/poll/response.rs b/crates/ruma-common/src/events/poll/response.rs index e944e49c..d9649ce3 100644 --- a/crates/ruma-common/src/events/poll/response.rs +++ b/crates/ruma-common/src/events/poll/response.rs @@ -5,7 +5,7 @@ use std::{ops::Deref, vec}; use ruma_macros::EventContent; use serde::{Deserialize, Serialize}; -use super::start::PollContentBlock; +use super::{start::PollContentBlock, validate_selections, PollResponseData}; use crate::{events::relation::Reference, OwnedEventId}; /// The payload for a poll response event. @@ -52,6 +52,28 @@ impl PollResponseEventContent { } } +impl OriginalSyncPollResponseEvent { + /// Get the data from this response necessary to compile poll results. + pub fn data(&self) -> PollResponseData<'_> { + PollResponseData { + sender: &self.sender, + origin_server_ts: self.origin_server_ts, + selections: &self.content.selections, + } + } +} + +impl OriginalPollResponseEvent { + /// Get the data from this response necessary to compile poll results. + pub fn data(&self) -> PollResponseData<'_> { + PollResponseData { + sender: &self.sender, + origin_server_ts: self.origin_server_ts, + selections: &self.content.selections, + } + } +} + /// A block for selections content. #[derive(Clone, Debug, Serialize, Deserialize)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] @@ -67,17 +89,12 @@ impl SelectionsContentBlock { /// /// Returns the list of valid selections in this `SelectionsContentBlock`, or `None` if there is /// no valid selection. - pub fn validate(&self, poll: &PollContentBlock) -> Option> { - // Vote is spoiled if any answer is unknown. - if self.0.iter().any(|s| !poll.answers.iter().any(|a| a.id == *s)) { - return None; - } - - // Fallback to the maximum value for usize because we can't have more selections than that - // in memory. - let max_selections: usize = poll.max_selections.try_into().unwrap_or(usize::MAX); - - Some(self.0.iter().take(max_selections).map(Deref::deref)) + pub fn validate<'a>( + &'a self, + poll: &PollContentBlock, + ) -> Option> { + let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); + validate_selections(answer_ids, poll.max_selections, &self.0) } } diff --git a/crates/ruma-common/src/events/poll/start.rs b/crates/ruma-common/src/events/poll/start.rs index de052d77..273d6c02 100644 --- a/crates/ruma-common/src/events/poll/start.rs +++ b/crates/ruma-common/src/events/poll/start.rs @@ -13,8 +13,7 @@ use poll_answers_serde::PollAnswersDeHelper; use super::{ compile_poll_results, end::{PollEndEventContent, PollResultsContentBlock}, - generate_poll_end_fallback_text, - response::OriginalSyncPollResponseEvent, + generate_poll_end_fallback_text, PollResponseData, }; use crate::{ events::{message::TextContentBlock, room::message::Relation}, @@ -89,7 +88,7 @@ impl OriginalSyncPollStartEvent { /// This uses [`compile_poll_results()`] internally. pub fn compile_results<'a>( &'a self, - responses: impl IntoIterator, + responses: impl IntoIterator>, ) -> PollEndEventContent { let full_results = compile_poll_results(&self.content.poll, responses, None); let results = diff --git a/crates/ruma-common/src/events/poll/unstable_response.rs b/crates/ruma-common/src/events/poll/unstable_response.rs index 7a2f7580..d2d6df0a 100644 --- a/crates/ruma-common/src/events/poll/unstable_response.rs +++ b/crates/ruma-common/src/events/poll/unstable_response.rs @@ -1,12 +1,10 @@ //! Types for the `org.matrix.msc3381.poll.response` event, the unstable version of //! `m.poll.response`. -use std::ops::Deref; - use ruma_macros::EventContent; use serde::{Deserialize, Serialize}; -use super::unstable_start::UnstablePollStartContentBlock; +use super::{unstable_start::UnstablePollStartContentBlock, validate_selections, PollResponseData}; use crate::{events::relation::Reference, OwnedEventId}; /// The payload for an unstable poll response event. @@ -43,6 +41,28 @@ impl UnstablePollResponseEventContent { } } +impl OriginalSyncUnstablePollResponseEvent { + /// Get the data from this response necessary to compile poll results. + pub fn data(&self) -> PollResponseData<'_> { + PollResponseData { + sender: &self.sender, + origin_server_ts: self.origin_server_ts, + selections: &self.content.poll_response.answers, + } + } +} + +impl OriginalUnstablePollResponseEvent { + /// Get the data from this response necessary to compile poll results. + pub fn data(&self) -> PollResponseData<'_> { + PollResponseData { + sender: &self.sender, + origin_server_ts: self.origin_server_ts, + selections: &self.content.poll_response.answers, + } + } +} + /// An unstable block for poll response content. #[derive(Clone, Debug, Serialize, Deserialize)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] @@ -61,20 +81,12 @@ impl UnstablePollResponseContentBlock { /// /// Returns the list of valid selections in this `UnstablePollResponseContentBlock`, or `None` /// if there is no valid selection. - pub fn validate( - &self, + pub fn validate<'a>( + &'a self, poll: &UnstablePollStartContentBlock, - ) -> Option> { - // Vote is spoiled if any answer is unknown. - if self.answers.iter().any(|s| !poll.answers.iter().any(|a| a.id == *s)) { - return None; - } - - // Fallback to the maximum value for usize because we can't have more selections than that - // in memory. - let max_selections: usize = poll.max_selections.try_into().unwrap_or(usize::MAX); - - Some(self.answers.iter().take(max_selections).map(Deref::deref)) + ) -> Option> { + let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect(); + validate_selections(answer_ids, poll.max_selections, &self.answers) } } diff --git a/crates/ruma-common/src/events/poll/unstable_start.rs b/crates/ruma-common/src/events/poll/unstable_start.rs index e53d2359..53e3a538 100644 --- a/crates/ruma-common/src/events/poll/unstable_start.rs +++ b/crates/ruma-common/src/events/poll/unstable_start.rs @@ -14,7 +14,7 @@ use super::{ compile_unstable_poll_results, generate_poll_end_fallback_text, start::{PollAnswers, PollAnswersError, PollContentBlock, PollKind}, unstable_end::UnstablePollEndEventContent, - unstable_response::OriginalSyncUnstablePollResponseEvent, + PollResponseData, }; use crate::events::room::message::Relation; @@ -71,7 +71,7 @@ impl OriginalSyncUnstablePollStartEvent { /// This uses [`compile_unstable_poll_results()`] internally. pub fn compile_results<'a>( &'a self, - responses: impl IntoIterator, + responses: impl IntoIterator>, ) -> UnstablePollEndEventContent { let full_results = compile_unstable_poll_results(&self.content.poll_start, responses, None); let results = diff --git a/crates/ruma-common/tests/events/poll.rs b/crates/ruma-common/tests/events/poll.rs index ef4ff7e5..f0728b34 100644 --- a/crates/ruma-common/tests/events/poll.rs +++ b/crates/ruma-common/tests/events/poll.rs @@ -671,7 +671,8 @@ fn compute_results() { responses.extend(generate_poll_responses(10..15, &["italian"])); responses.extend(generate_poll_responses(15..20, &["wings"])); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 5); assert_eq!(counted.get("poutine").unwrap().len(), 5); assert_eq!(counted.get("italian").unwrap().len(), 5); @@ -683,7 +684,7 @@ fn compute_results() { assert_eq!(iter.next(), Some(&"wings")); assert_eq!(iter.next(), None); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(5)); assert_eq!(*results.get("poutine").unwrap(), uint!(5)); @@ -713,7 +714,8 @@ fn compute_results() { ), ]); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 5); assert_eq!(counted.get("poutine").unwrap().len(), 7); assert_eq!(counted.get("italian").unwrap().len(), 6); @@ -725,7 +727,7 @@ fn compute_results() { assert_eq!(iter.next(), Some(&"pizza")); assert_eq!(iter.next(), None); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(5)); assert_eq!(*results.get("poutine").unwrap(), uint!(7)); @@ -752,13 +754,14 @@ fn compute_results() { ), ]); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 6); assert_eq!(counted.get("wings").unwrap().len(), 6); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(6)); assert_eq!(*results.get("poutine").unwrap(), uint!(8)); @@ -775,7 +778,8 @@ fn compute_results() { new_poll_response("$valid_for_now_event_3", changing_user_3, uint!(4200), &["wings"]), ]); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 6); @@ -787,7 +791,7 @@ fn compute_results() { assert_eq!(iter.next(), Some(&"italian")); assert_eq!(iter.next(), None); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(6)); assert_eq!(*results.get("poutine").unwrap(), uint!(8)); @@ -807,13 +811,14 @@ fn compute_results() { &["italian"], )); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 7); assert_eq!(counted.get("wings").unwrap().len(), 8); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(6)); assert_eq!(*results.get("poutine").unwrap(), uint!(8)); @@ -828,13 +833,14 @@ fn compute_results() { &[], )); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 6); assert_eq!(counted.get("wings").unwrap().len(), 8); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(6)); assert_eq!(*results.get("poutine").unwrap(), uint!(8)); @@ -849,13 +855,14 @@ fn compute_results() { &["indian"], )); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 6); assert_eq!(counted.get("wings").unwrap().len(), 7); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(6)); assert_eq!(*results.get("poutine").unwrap(), uint!(8)); @@ -865,13 +872,14 @@ fn compute_results() { // Response older than most recent one is ignored. responses.push(new_poll_response("$past_event", changing_user_3, uint!(1), &["pizza"])); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 6); assert_eq!(counted.get("wings").unwrap().len(), 7); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(6)); assert_eq!(*results.get("poutine").unwrap(), uint!(8)); @@ -882,13 +890,14 @@ fn compute_results() { let future_ts = MilliSecondsSinceUnixEpoch::now().0 + uint!(100_000); responses.push(new_poll_response("$future_event", changing_user_3, future_ts, &["pizza"])); - let counted = compile_poll_results(&poll.content.poll, &responses, None); + let counted = + compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None); assert_eq!(counted.get("pizza").unwrap().len(), 6); assert_eq!(counted.get("poutine").unwrap().len(), 8); assert_eq!(counted.get("italian").unwrap().len(), 6); assert_eq!(counted.get("wings").unwrap().len(), 7); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); let results = poll_end.poll_results.unwrap(); assert_eq!(*results.get("pizza").unwrap(), uint!(6)); assert_eq!(*results.get("poutine").unwrap(), uint!(8)); @@ -968,7 +977,11 @@ fn compute_unstable_results() { responses.extend(generate_unstable_poll_responses(6..8, &["italian"])); responses.extend(generate_unstable_poll_responses(8..11, &["wings"])); - let counted = compile_unstable_poll_results(&poll.content.poll_start, &responses, None); + let counted = compile_unstable_poll_results( + &poll.content.poll_start, + responses.iter().map(|r| r.data()), + None, + ); assert_eq!(counted.get("pizza").unwrap().len(), 5); assert_eq!(counted.get("poutine").unwrap().len(), 1); assert_eq!(counted.get("italian").unwrap().len(), 2); @@ -980,6 +993,6 @@ fn compute_unstable_results() { assert_eq!(iter.next(), Some(&"poutine")); assert_eq!(iter.next(), None); - let poll_end = poll.compile_results(&responses); + let poll_end = poll.compile_results(responses.iter().map(|r| r.data())); assert_eq!(poll_end.text, "The poll has closed. Top answer: Pizza 🍕"); }