events: Don't require whole poll response events to compute results

Co-authored-by: Jonas Platte <jplatte@matrix.org>
This commit is contained in:
Kévin Commaille 2023-08-23 09:18:37 +02:00 committed by GitHub
parent 533da2aded
commit f540004a0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 156 additions and 76 deletions

View File

@ -4,16 +4,15 @@
//!
//! [MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381
use std::collections::{BTreeMap, BTreeSet};
use std::{
collections::{BTreeMap, BTreeSet},
ops::Deref,
};
use indexmap::IndexMap;
use js_int::uint;
use js_int::{uint, UInt};
use self::{
response::OriginalSyncPollResponseEvent, start::PollContentBlock,
unstable_response::OriginalSyncUnstablePollResponseEvent,
unstable_start::UnstablePollStartContentBlock,
};
use self::{start::PollContentBlock, unstable_start::UnstablePollStartContentBlock};
use crate::{MilliSecondsSinceUnixEpoch, UserId};
pub mod end;
@ -23,6 +22,20 @@ pub mod unstable_end;
pub mod unstable_response;
pub mod unstable_start;
/// The data from a poll response necessary to compile poll results.
#[derive(Debug, Clone, Copy)]
#[allow(clippy::exhaustive_structs)]
pub struct PollResponseData<'a> {
/// The sender of the response.
pub sender: &'a UserId,
/// The time of creation of the response on the originating server.
pub origin_server_ts: MilliSecondsSinceUnixEpoch,
/// The selections/answers of the response.
pub selections: &'a [String],
}
/// Generate the current results with the given poll and responses.
///
/// If the `end_timestamp` is provided, any response with an `origin_server_ts` after that timestamp
@ -36,7 +49,7 @@ pub mod unstable_start;
/// lowest.
pub fn compile_poll_results<'a>(
poll: &'a PollContentBlock,
responses: impl IntoIterator<Item = &'a OriginalSyncPollResponseEvent>,
responses: impl IntoIterator<Item = PollResponseData<'a>>,
end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now);
@ -47,13 +60,17 @@ pub fn compile_poll_results<'a>(
// Filter out responses after the end_timestamp.
ev.origin_server_ts <= end_ts
})
.fold(BTreeMap::new(), |mut acc, ev| {
.fold(BTreeMap::new(), |mut acc, data| {
let response =
acc.entry(&*ev.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
// Only keep the latest selections for each user.
if response.0 < ev.origin_server_ts {
*response = (ev.origin_server_ts, ev.content.selections.validate(poll));
if response.0 < data.origin_server_ts {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
*response = (
data.origin_server_ts,
validate_selections(answer_ids, poll.max_selections, data.selections),
);
}
acc
@ -75,7 +92,7 @@ pub fn compile_poll_results<'a>(
/// lowest.
pub fn compile_unstable_poll_results<'a>(
poll: &'a UnstablePollStartContentBlock,
responses: impl IntoIterator<Item = &'a OriginalSyncUnstablePollResponseEvent>,
responses: impl IntoIterator<Item = PollResponseData<'a>>,
end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now);
@ -86,13 +103,17 @@ pub fn compile_unstable_poll_results<'a>(
// Filter out responses after the end_timestamp.
ev.origin_server_ts <= end_ts
})
.fold(BTreeMap::new(), |mut acc, ev| {
.fold(BTreeMap::new(), |mut acc, data| {
let response =
acc.entry(&*ev.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
// Only keep the latest selections for each user.
if response.0 < ev.origin_server_ts {
*response = (ev.origin_server_ts, ev.content.poll_response.validate(poll));
if response.0 < data.origin_server_ts {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
*response = (
data.origin_server_ts,
validate_selections(answer_ids, poll.max_selections, data.selections),
);
}
acc
@ -101,7 +122,25 @@ pub fn compile_unstable_poll_results<'a>(
aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections)
}
// Aggregate the given selections by answer.
/// Validate the selections of a response.
fn validate_selections<'a>(
answer_ids: BTreeSet<&str>,
max_selections: UInt,
selections: &'a [String],
) -> Option<impl Iterator<Item = &'a str>> {
// Vote is spoiled if any answer is unknown.
if selections.iter().any(|s| !answer_ids.contains(s.as_str())) {
return None;
}
// Fallback to the maximum value for usize because we can't have more selections than that
// in memory.
let max_selections: usize = max_selections.try_into().unwrap_or(usize::MAX);
Some(selections.iter().take(max_selections).map(Deref::deref))
}
/// Aggregate the given selections by answer.
fn aggregate_results<'a>(
answers: impl Iterator<Item = &'a str>,
users_selections: BTreeMap<
@ -164,13 +203,13 @@ fn generate_poll_end_fallback_text<'a>(
// Construct the plain text representation.
match top_answers_text.len() {
l if l > 1 => {
0 => "The poll has closed with no top answer".to_owned(),
1 => {
format!("The poll has closed. Top answer: {}", top_answers_text[0])
}
_ => {
let answers = top_answers_text.join(", ");
format!("The poll has closed. Top answers: {answers}")
}
l if l == 1 => {
format!("The poll has closed. Top answer: {}", top_answers_text[0])
}
_ => "The poll has closed with no top answer".to_owned(),
}
}

View File

@ -5,7 +5,7 @@ use std::{ops::Deref, vec};
use ruma_macros::EventContent;
use serde::{Deserialize, Serialize};
use super::start::PollContentBlock;
use super::{start::PollContentBlock, validate_selections, PollResponseData};
use crate::{events::relation::Reference, OwnedEventId};
/// The payload for a poll response event.
@ -52,6 +52,28 @@ impl PollResponseEventContent {
}
}
impl OriginalSyncPollResponseEvent {
/// Get the data from this response necessary to compile poll results.
pub fn data(&self) -> PollResponseData<'_> {
PollResponseData {
sender: &self.sender,
origin_server_ts: self.origin_server_ts,
selections: &self.content.selections,
}
}
}
impl OriginalPollResponseEvent {
/// Get the data from this response necessary to compile poll results.
pub fn data(&self) -> PollResponseData<'_> {
PollResponseData {
sender: &self.sender,
origin_server_ts: self.origin_server_ts,
selections: &self.content.selections,
}
}
}
/// A block for selections content.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
@ -67,17 +89,12 @@ impl SelectionsContentBlock {
///
/// Returns the list of valid selections in this `SelectionsContentBlock`, or `None` if there is
/// no valid selection.
pub fn validate(&self, poll: &PollContentBlock) -> Option<impl Iterator<Item = &str>> {
// Vote is spoiled if any answer is unknown.
if self.0.iter().any(|s| !poll.answers.iter().any(|a| a.id == *s)) {
return None;
}
// Fallback to the maximum value for usize because we can't have more selections than that
// in memory.
let max_selections: usize = poll.max_selections.try_into().unwrap_or(usize::MAX);
Some(self.0.iter().take(max_selections).map(Deref::deref))
pub fn validate<'a>(
&'a self,
poll: &PollContentBlock,
) -> Option<impl Iterator<Item = &'a str>> {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
validate_selections(answer_ids, poll.max_selections, &self.0)
}
}

View File

@ -13,8 +13,7 @@ use poll_answers_serde::PollAnswersDeHelper;
use super::{
compile_poll_results,
end::{PollEndEventContent, PollResultsContentBlock},
generate_poll_end_fallback_text,
response::OriginalSyncPollResponseEvent,
generate_poll_end_fallback_text, PollResponseData,
};
use crate::{
events::{message::TextContentBlock, room::message::Relation},
@ -89,7 +88,7 @@ impl OriginalSyncPollStartEvent {
/// This uses [`compile_poll_results()`] internally.
pub fn compile_results<'a>(
&'a self,
responses: impl IntoIterator<Item = &'a OriginalSyncPollResponseEvent>,
responses: impl IntoIterator<Item = PollResponseData<'a>>,
) -> PollEndEventContent {
let full_results = compile_poll_results(&self.content.poll, responses, None);
let results =

View File

@ -1,12 +1,10 @@
//! Types for the `org.matrix.msc3381.poll.response` event, the unstable version of
//! `m.poll.response`.
use std::ops::Deref;
use ruma_macros::EventContent;
use serde::{Deserialize, Serialize};
use super::unstable_start::UnstablePollStartContentBlock;
use super::{unstable_start::UnstablePollStartContentBlock, validate_selections, PollResponseData};
use crate::{events::relation::Reference, OwnedEventId};
/// The payload for an unstable poll response event.
@ -43,6 +41,28 @@ impl UnstablePollResponseEventContent {
}
}
impl OriginalSyncUnstablePollResponseEvent {
/// Get the data from this response necessary to compile poll results.
pub fn data(&self) -> PollResponseData<'_> {
PollResponseData {
sender: &self.sender,
origin_server_ts: self.origin_server_ts,
selections: &self.content.poll_response.answers,
}
}
}
impl OriginalUnstablePollResponseEvent {
/// Get the data from this response necessary to compile poll results.
pub fn data(&self) -> PollResponseData<'_> {
PollResponseData {
sender: &self.sender,
origin_server_ts: self.origin_server_ts,
selections: &self.content.poll_response.answers,
}
}
}
/// An unstable block for poll response content.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
@ -61,20 +81,12 @@ impl UnstablePollResponseContentBlock {
///
/// Returns the list of valid selections in this `UnstablePollResponseContentBlock`, or `None`
/// if there is no valid selection.
pub fn validate(
&self,
pub fn validate<'a>(
&'a self,
poll: &UnstablePollStartContentBlock,
) -> Option<impl Iterator<Item = &str>> {
// Vote is spoiled if any answer is unknown.
if self.answers.iter().any(|s| !poll.answers.iter().any(|a| a.id == *s)) {
return None;
}
// Fallback to the maximum value for usize because we can't have more selections than that
// in memory.
let max_selections: usize = poll.max_selections.try_into().unwrap_or(usize::MAX);
Some(self.answers.iter().take(max_selections).map(Deref::deref))
) -> Option<impl Iterator<Item = &'a str>> {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
validate_selections(answer_ids, poll.max_selections, &self.answers)
}
}

View File

@ -14,7 +14,7 @@ use super::{
compile_unstable_poll_results, generate_poll_end_fallback_text,
start::{PollAnswers, PollAnswersError, PollContentBlock, PollKind},
unstable_end::UnstablePollEndEventContent,
unstable_response::OriginalSyncUnstablePollResponseEvent,
PollResponseData,
};
use crate::events::room::message::Relation;
@ -71,7 +71,7 @@ impl OriginalSyncUnstablePollStartEvent {
/// This uses [`compile_unstable_poll_results()`] internally.
pub fn compile_results<'a>(
&'a self,
responses: impl IntoIterator<Item = &'a OriginalSyncUnstablePollResponseEvent>,
responses: impl IntoIterator<Item = PollResponseData<'a>>,
) -> UnstablePollEndEventContent {
let full_results = compile_unstable_poll_results(&self.content.poll_start, responses, None);
let results =

View File

@ -671,7 +671,8 @@ fn compute_results() {
responses.extend(generate_poll_responses(10..15, &["italian"]));
responses.extend(generate_poll_responses(15..20, &["wings"]));
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 5);
assert_eq!(counted.get("poutine").unwrap().len(), 5);
assert_eq!(counted.get("italian").unwrap().len(), 5);
@ -683,7 +684,7 @@ fn compute_results() {
assert_eq!(iter.next(), Some(&"wings"));
assert_eq!(iter.next(), None);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(5));
assert_eq!(*results.get("poutine").unwrap(), uint!(5));
@ -713,7 +714,8 @@ fn compute_results() {
),
]);
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 5);
assert_eq!(counted.get("poutine").unwrap().len(), 7);
assert_eq!(counted.get("italian").unwrap().len(), 6);
@ -725,7 +727,7 @@ fn compute_results() {
assert_eq!(iter.next(), Some(&"pizza"));
assert_eq!(iter.next(), None);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(5));
assert_eq!(*results.get("poutine").unwrap(), uint!(7));
@ -752,13 +754,14 @@ fn compute_results() {
),
]);
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
assert_eq!(counted.get("wings").unwrap().len(), 6);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(6));
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
@ -775,7 +778,8 @@ fn compute_results() {
new_poll_response("$valid_for_now_event_3", changing_user_3, uint!(4200), &["wings"]),
]);
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
@ -787,7 +791,7 @@ fn compute_results() {
assert_eq!(iter.next(), Some(&"italian"));
assert_eq!(iter.next(), None);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(6));
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
@ -807,13 +811,14 @@ fn compute_results() {
&["italian"],
));
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 7);
assert_eq!(counted.get("wings").unwrap().len(), 8);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(6));
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
@ -828,13 +833,14 @@ fn compute_results() {
&[],
));
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
assert_eq!(counted.get("wings").unwrap().len(), 8);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(6));
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
@ -849,13 +855,14 @@ fn compute_results() {
&["indian"],
));
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
assert_eq!(counted.get("wings").unwrap().len(), 7);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(6));
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
@ -865,13 +872,14 @@ fn compute_results() {
// Response older than most recent one is ignored.
responses.push(new_poll_response("$past_event", changing_user_3, uint!(1), &["pizza"]));
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
assert_eq!(counted.get("wings").unwrap().len(), 7);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(6));
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
@ -882,13 +890,14 @@ fn compute_results() {
let future_ts = MilliSecondsSinceUnixEpoch::now().0 + uint!(100_000);
responses.push(new_poll_response("$future_event", changing_user_3, future_ts, &["pizza"]));
let counted = compile_poll_results(&poll.content.poll, &responses, None);
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
assert_eq!(counted.get("wings").unwrap().len(), 7);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
let results = poll_end.poll_results.unwrap();
assert_eq!(*results.get("pizza").unwrap(), uint!(6));
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
@ -968,7 +977,11 @@ fn compute_unstable_results() {
responses.extend(generate_unstable_poll_responses(6..8, &["italian"]));
responses.extend(generate_unstable_poll_responses(8..11, &["wings"]));
let counted = compile_unstable_poll_results(&poll.content.poll_start, &responses, None);
let counted = compile_unstable_poll_results(
&poll.content.poll_start,
responses.iter().map(|r| r.data()),
None,
);
assert_eq!(counted.get("pizza").unwrap().len(), 5);
assert_eq!(counted.get("poutine").unwrap().len(), 1);
assert_eq!(counted.get("italian").unwrap().len(), 2);
@ -980,6 +993,6 @@ fn compute_unstable_results() {
assert_eq!(iter.next(), Some(&"poutine"));
assert_eq!(iter.next(), None);
let poll_end = poll.compile_results(&responses);
let poll_end = poll.compile_results(responses.iter().map(|r| r.data()));
assert_eq!(poll_end.text, "The poll has closed. Top answer: Pizza 🍕");
}