diff --git a/ruma-serde/src/strings.rs b/ruma-serde/src/strings.rs index 30d6c17b..53f3728a 100644 --- a/ruma-serde/src/strings.rs +++ b/ruma-serde/src/strings.rs @@ -1,11 +1,8 @@ -use std::{ - collections::BTreeMap, - convert::{TryFrom, TryInto}, -}; +use std::{collections::BTreeMap, convert::TryInto, fmt}; use js_int::Int; use serde::{ - de::{Deserializer, Error as _, IntoDeserializer as _}, + de::{self, Deserializer, IntoDeserializer as _, Visitor}, Deserialize, }; @@ -33,25 +30,6 @@ where } } -// Helper type for deserialize_int_or_string_to_int -#[derive(Deserialize)] -#[serde(untagged)] -enum IntOrString<'a> { - Num(Int), - Str(&'a str), -} - -impl TryFrom> for Int { - type Error = js_int::ParseIntError; - - fn try_from(input: IntOrString) -> Result { - match input { - IntOrString::Num(n) => Ok(n), - IntOrString::Str(string) => string.parse(), - } - } -} - /// Take either an integer number or a string and deserialize to an integer number. /// /// To be used like this: @@ -60,7 +38,61 @@ pub fn int_or_string_to_int<'de, D>(de: D) -> Result where D: Deserializer<'de>, { - IntOrString::deserialize(de)?.try_into().map_err(D::Error::custom) + struct IntOrStringVisitor; + + impl<'de> Visitor<'de> for IntOrStringVisitor { + type Value = Int; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an integer or a string") + } + + fn visit_i8(self, v: i8) -> Result { + Ok(v.into()) + } + + fn visit_i16(self, v: i16) -> Result { + Ok(v.into()) + } + + fn visit_i32(self, v: i32) -> Result { + Ok(v.into()) + } + + fn visit_i64(self, v: i64) -> Result { + v.try_into().map_err(E::custom) + } + + fn visit_i128(self, v: i128) -> Result { + v.try_into().map_err(E::custom) + } + + fn visit_u8(self, v: u8) -> Result { + Ok(v.into()) + } + + fn visit_u16(self, v: u16) -> Result { + Ok(v.into()) + } + + fn visit_u32(self, v: u32) -> Result { + Ok(v.into()) + } + + fn visit_u64(self, v: u64) -> Result { + v.try_into().map_err(E::custom) + } + + fn visit_u128(self, v: u128) -> Result { + v.try_into().map_err(E::custom) + } + + fn visit_str(self, v: &str) -> Result { + v.parse().map_err(E::custom) + } + } + + de.deserialize_any(IntOrStringVisitor) } /// Take a BTreeMap with values of either an integer number or a string and deserialize @@ -73,8 +105,42 @@ where D: Deserializer<'de>, T: Deserialize<'de> + Ord, { - BTreeMap::::deserialize(de)? - .into_iter() - .map(|(k, v)| v.try_into().map(|n| (k, n)).map_err(D::Error::custom)) - .collect() + #[repr(transparent)] + struct IntWrap(Int); + + impl<'de> Deserialize<'de> for IntWrap { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + int_or_string_to_int(deserializer).map(IntWrap) + } + } + + Ok(BTreeMap::::deserialize(de)?.into_iter().map(|(k, IntWrap(v))| (k, v)).collect()) +} + +#[cfg(test)] +mod tests { + use js_int::{int, Int}; + use matches::assert_matches; + use serde::Deserialize; + + use super::int_or_string_to_int; + + #[test] + fn int_or_string() -> Result<(), serde_json::Error> { + #[derive(Debug, Deserialize)] + struct Test { + #[serde(deserialize_with = "int_or_string_to_int")] + num: Int, + } + + assert_matches!( + serde_json::from_value::(serde_json::json!({ "num": "0" }))?, + Test { num } if num == int!(0) + ); + + Ok(()) + } }