diff --git a/crates/ruma-common/src/serde/cow.rs b/crates/ruma-common/src/serde/cow.rs index f4fba913..eb6a70d3 100644 --- a/crates/ruma-common/src/serde/cow.rs +++ b/crates/ruma-common/src/serde/cow.rs @@ -1,23 +1,6 @@ use std::{borrow::Cow, str}; -use serde::de::{self, Deserialize, Deserializer, Unexpected, Visitor}; - -pub(crate) struct MyCowStr<'a>(Cow<'a, str>); - -impl<'a> MyCowStr<'a> { - pub(crate) fn get(self) -> Cow<'a, str> { - self.0 - } -} - -impl<'de> Deserialize<'de> for MyCowStr<'de> { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserialize_cow_str(deserializer).map(Self) - } -} +use serde::de::{self, Deserializer, Unexpected, Visitor}; /// Deserialize a `Cow<'de, str>`. /// diff --git a/crates/ruma-common/src/serde/raw.rs b/crates/ruma-common/src/serde/raw.rs index 2143f2fd..0cb4b45c 100644 --- a/crates/ruma-common/src/serde/raw.rs +++ b/crates/ruma-common/src/serde/raw.rs @@ -6,13 +6,11 @@ use std::{ }; use serde::{ - de::{Deserialize, Deserializer, IgnoredAny, MapAccess, Visitor}, + de::{self, Deserialize, DeserializeSeed, Deserializer, IgnoredAny, MapAccess, Visitor}, ser::{Serialize, Serializer}, }; use serde_json::value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue}; -use super::cow::MyCowStr; - /// A wrapper around `Box`, to be used in place of any type in the Matrix endpoint /// definition to allow request and response types to contain that said type represented by /// the generic argument `Ev`. @@ -97,6 +95,36 @@ impl Raw { where U: Deserialize<'a>, { + struct FieldVisitor<'b>(&'b str); + + impl<'b, 'de> Visitor<'de> for FieldVisitor<'b> { + type Value = bool; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "`{}`", self.0) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(value == self.0) + } + } + + struct Field<'b>(&'b str); + + impl<'b, 'de> DeserializeSeed<'de> for Field<'b> { + type Value = bool; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_identifier(FieldVisitor(self.0)) + } + } + struct SingleFieldVisitor<'b, T> { field_name: &'b str, _phantom: PhantomData, @@ -123,8 +151,8 @@ impl Raw { A: MapAccess<'de>, { let mut res = None; - while let Some(key) = map.next_key::>()? { - if key.get() == self.field_name { + while let Some(is_right_field) = map.next_key_seed(Field(self.field_name))? { + if is_right_field { res = Some(map.next_value()?); } else { map.next_value::()?;