events: Don't filter out any responses if there is no end timestamp

This commit is contained in:
Kévin Commaille 2023-08-23 17:55:56 +02:00 committed by GitHub
parent c652461ae7
commit d809e6e365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 61 additions and 56 deletions

View File

@ -52,29 +52,9 @@ pub fn compile_poll_results<'a>(
responses: impl IntoIterator<Item = PollResponseData<'a>>,
end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now);
let users_selections = responses
.into_iter()
.filter(|ev| {
// Filter out responses after the end_timestamp.
ev.origin_server_ts <= end_ts
})
.fold(BTreeMap::new(), |mut acc, data| {
let response =
acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
// Only keep the latest selections for each user.
if response.0 < data.origin_server_ts {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
*response = (
data.origin_server_ts,
validate_selections(answer_ids, poll.max_selections, data.selections),
);
}
acc
});
let users_selections =
filter_selections(answer_ids, poll.max_selections, responses, end_timestamp);
aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections)
}
@ -95,36 +75,16 @@ pub fn compile_unstable_poll_results<'a>(
responses: impl IntoIterator<Item = PollResponseData<'a>>,
end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
let end_ts = end_timestamp.unwrap_or_else(MilliSecondsSinceUnixEpoch::now);
let users_selections = responses
.into_iter()
.filter(|ev| {
// Filter out responses after the end_timestamp.
ev.origin_server_ts <= end_ts
})
.fold(BTreeMap::new(), |mut acc, data| {
let response =
acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
// Only keep the latest selections for each user.
if response.0 < data.origin_server_ts {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
*response = (
data.origin_server_ts,
validate_selections(answer_ids, poll.max_selections, data.selections),
);
}
acc
});
let users_selections =
filter_selections(answer_ids, poll.max_selections, responses, end_timestamp);
aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections)
}
/// Validate the selections of a response.
fn validate_selections<'a>(
answer_ids: BTreeSet<&str>,
answer_ids: &BTreeSet<&str>,
max_selections: UInt,
selections: &'a [String],
) -> Option<impl Iterator<Item = &'a str>> {
@ -140,6 +100,34 @@ fn validate_selections<'a>(
Some(selections.iter().take(max_selections).map(Deref::deref))
}
fn filter_selections<'a>(
answer_ids: BTreeSet<&str>,
max_selections: UInt,
responses: impl IntoIterator<Item = PollResponseData<'a>>,
end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
) -> BTreeMap<&'a UserId, (MilliSecondsSinceUnixEpoch, Option<impl Iterator<Item = &'a str>>)> {
responses
.into_iter()
.filter(|ev| {
// Filter out responses after the end_timestamp.
end_timestamp.map_or(true, |end_ts| ev.origin_server_ts <= end_ts)
})
.fold(BTreeMap::new(), |mut acc, data| {
let response =
acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
// Only keep the latest selections for each user.
if response.0 < data.origin_server_ts {
*response = (
data.origin_server_ts,
validate_selections(&answer_ids, max_selections, data.selections),
);
}
acc
})
}
/// Aggregate the given selections by answer.
fn aggregate_results<'a>(
answers: impl Iterator<Item = &'a str>,

View File

@ -94,7 +94,7 @@ impl SelectionsContentBlock {
poll: &PollContentBlock,
) -> Option<impl Iterator<Item = &'a str>> {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
validate_selections(answer_ids, poll.max_selections, &self.0)
validate_selections(&answer_ids, poll.max_selections, &self.0)
}
}

View File

@ -18,7 +18,7 @@ use super::{
use crate::{
events::{message::TextContentBlock, room::message::Relation},
serde::StringEnum,
PrivOwnedStr,
MilliSecondsSinceUnixEpoch, PrivOwnedStr,
};
/// The payload for a poll start event.
@ -90,7 +90,11 @@ impl OriginalSyncPollStartEvent {
&'a self,
responses: impl IntoIterator<Item = PollResponseData<'a>>,
) -> PollEndEventContent {
let full_results = compile_poll_results(&self.content.poll, responses, None);
let full_results = compile_poll_results(
&self.content.poll,
responses,
Some(MilliSecondsSinceUnixEpoch::now()),
);
let results =
full_results.into_iter().map(|(id, users)| (id, users.len())).collect::<Vec<_>>();

View File

@ -86,7 +86,7 @@ impl UnstablePollResponseContentBlock {
poll: &UnstablePollStartContentBlock,
) -> Option<impl Iterator<Item = &'a str>> {
let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
validate_selections(answer_ids, poll.max_selections, &self.answers)
validate_selections(&answer_ids, poll.max_selections, &self.answers)
}
}

View File

@ -16,7 +16,7 @@ use super::{
unstable_end::UnstablePollEndEventContent,
PollResponseData,
};
use crate::events::room::message::Relation;
use crate::{events::room::message::Relation, MilliSecondsSinceUnixEpoch};
/// The payload for an unstable poll start event.
///
@ -73,7 +73,11 @@ impl OriginalSyncUnstablePollStartEvent {
&'a self,
responses: impl IntoIterator<Item = PollResponseData<'a>>,
) -> UnstablePollEndEventContent {
let full_results = compile_unstable_poll_results(&self.content.poll_start, responses, None);
let full_results = compile_unstable_poll_results(
&self.content.poll_start,
responses,
Some(MilliSecondsSinceUnixEpoch::now()),
);
let results =
full_results.into_iter().map(|(id, users)| (id, users.len())).collect::<Vec<_>>();

View File

@ -886,12 +886,13 @@ fn compute_results() {
assert_eq!(*results.get("italian").unwrap(), uint!(6));
assert_eq!(*results.get("wings").unwrap(), uint!(7));
// Response in the future is ignored.
let future_ts = MilliSecondsSinceUnixEpoch::now().0 + uint!(100_000);
// Response later than end_timestamp is ignored.
let now = MilliSecondsSinceUnixEpoch::now();
let future_ts = now.0 + uint!(100_000);
responses.push(new_poll_response("$future_event", changing_user_3, future_ts, &["pizza"]));
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), Some(now));
assert_eq!(counted.get("pizza").unwrap().len(), 6);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
@ -903,6 +904,14 @@ fn compute_results() {
assert_eq!(*results.get("poutine").unwrap(), uint!(8));
assert_eq!(*results.get("italian").unwrap(), uint!(6));
assert_eq!(*results.get("wings").unwrap(), uint!(7));
// Response in the future is not ignored if there is no end_timestamp.
let counted =
compile_poll_results(&poll.content.poll, responses.iter().map(|r| r.data()), None);
assert_eq!(counted.get("pizza").unwrap().len(), 7);
assert_eq!(counted.get("poutine").unwrap().len(), 8);
assert_eq!(counted.get("italian").unwrap().len(), 6);
assert_eq!(counted.get("wings").unwrap().len(), 6);
}
fn new_unstable_poll_response(