diff --git a/src/urlencoded/de.rs b/src/urlencoded/de.rs index 0e09beb3..a20e0d9a 100644 --- a/src/urlencoded/de.rs +++ b/src/urlencoded/de.rs @@ -1,6 +1,10 @@ //! Deserialization support for the `application/x-www-form-urlencoded` format. -use std::{borrow::Cow, io::Read}; +use std::{ + borrow::Cow, + collections::btree_map::{self, BTreeMap}, + io::Read, +}; use serde::{ de::{self, value::MapDeserializer, Error as de_Error, IntoDeserializer}, @@ -11,14 +15,18 @@ use url::form_urlencoded::{parse, Parse as UrlEncodedParse}; #[doc(inline)] pub use serde::de::value::Error; +mod val_or_vec; + +use val_or_vec::ValOrVec; + /// Deserializes a `application/x-www-form-urlencoded` value from a `&[u8]`. /// /// ``` /// let meal = vec![ /// ("bread".to_owned(), "baguette".to_owned()), /// ("cheese".to_owned(), "comté".to_owned()), -/// ("meat".to_owned(), "ham".to_owned()), /// ("fat".to_owned(), "butter".to_owned()), +/// ("meat".to_owned(), "ham".to_owned()), /// ]; /// /// assert_eq!( @@ -39,8 +47,8 @@ where /// let meal = vec![ /// ("bread".to_owned(), "baguette".to_owned()), /// ("cheese".to_owned(), "comté".to_owned()), -/// ("meat".to_owned(), "ham".to_owned()), /// ("fat".to_owned(), "butter".to_owned()), +/// ("meat".to_owned(), "ham".to_owned()), /// ]; /// /// assert_eq!( @@ -79,14 +87,14 @@ where /// * Everything else but `deserialize_seq` and `deserialize_seq_fixed_size` /// defers to `deserialize`. pub struct Deserializer<'de> { - inner: MapDeserializer<'de, PartIterator<'de>, Error>, + inner: MapDeserializer<'de, EntryIterator<'de>, Error>, } impl<'de> Deserializer<'de> { /// Returns a new `Deserializer`. - pub fn new(parser: UrlEncodedParse<'de>) -> Self { + pub fn new(parse: UrlEncodedParse<'de>) -> Self { Deserializer { - inner: MapDeserializer::new(PartIterator(parser)), + inner: MapDeserializer::new(group_entries(parse).into_iter()), } } } @@ -152,16 +160,53 @@ impl<'de> de::Deserializer<'de> for Deserializer<'de> { } } -struct PartIterator<'de>(UrlEncodedParse<'de>); +fn group_entries<'de>( + parse: UrlEncodedParse<'de>, +) -> BTreeMap, ValOrVec>> { + use btree_map::Entry::*; -impl<'de> Iterator for PartIterator<'de> { - type Item = (Part<'de>, Part<'de>); + let mut res = BTreeMap::new(); - fn next(&mut self) -> Option { - self.0.next().map(|(k, v)| (Part(k), Part(v))) + for (key, value) in parse { + match res.entry(Part(key)) { + Vacant(v) => { + v.insert(ValOrVec::Val(Part(value))); + } + Occupied(mut o) => { + o.get_mut().push(Part(value)); + } + } } + + res } +/* +input: a=b&c=d&a=c + +vvvvv + +next(): a => Wrapper([b, c]) +next(): c => Wrapper([d]) + +struct Foo { + a: Vec, + c: Vec, +} + +struct Bar { + a: Vec, + c: String, +} + +struct Baz { + a: String, +} +*/ + +type EntryIterator<'de> = btree_map::IntoIter, ValOrVec>>; + +#[derive(PartialEq, PartialOrd, Eq, Ord)] struct Part<'de>(Cow<'de, str>); impl<'de> IntoDeserializer<'de> for Part<'de> { diff --git a/src/urlencoded/de/val_or_vec.rs b/src/urlencoded/de/val_or_vec.rs new file mode 100644 index 00000000..6214f4e0 --- /dev/null +++ b/src/urlencoded/de/val_or_vec.rs @@ -0,0 +1,229 @@ +use std::{iter, ptr, vec}; + +use serde::de::{ + self, + value::{Error, SeqDeserializer}, + Deserializer, IntoDeserializer, +}; + +#[derive(Debug)] +pub enum ValOrVec { + Val(T), + Vec(Vec), +} + +impl ValOrVec { + pub fn push(&mut self, new_val: T) { + match self { + Self::Val(val) => { + let mut vec = Vec::with_capacity(2); + // Safety: + // + // since the vec is pre-allocated, push can't panic, so there + // is no opportunity for a panic in the unsafe code. + unsafe { + let existing_val = ptr::read(val); + vec.push(existing_val); + vec.push(new_val); + ptr::write(self, Self::Vec(vec)) + } + } + Self::Vec(vec) => vec.push(new_val), + } + } +} + +impl IntoIterator for ValOrVec { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +pub enum IntoIter { + Val(iter::Once), + Vec(vec::IntoIter), +} + +impl IntoIter { + fn new(vv: ValOrVec) -> Self { + match vv { + ValOrVec::Val(val) => Self::Val(iter::once(val)), + ValOrVec::Vec(vec) => Self::Vec(vec.into_iter()), + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + match self { + Self::Val(iter) => iter.next(), + Self::Vec(iter) => iter.next(), + } + } +} + +impl<'de, T> IntoDeserializer<'de> for ValOrVec +where + T: IntoDeserializer<'de> + Deserializer<'de, Error = Error>, +{ + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +macro_rules! forward_to_part { + ($($method:ident,)*) => { + $( + fn $method(self, visitor: V) -> Result + where V: de::Visitor<'de> + { + match self { + Self::Val(val) => val.$method(visitor), + Self::Vec(_) => Err(de::Error::custom("TODO: Error message")), + } + } + )* + } +} + +impl<'de, T> Deserializer<'de> for ValOrVec +where + T: IntoDeserializer<'de> + Deserializer<'de, Error = Error>, +{ + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self { + Self::Val(val) => val.deserialize_any(visitor), + Self::Vec(_) => self.deserialize_seq(visitor), + } + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_seq(SeqDeserializer::new(self.into_iter())) + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self { + Self::Val(val) => val.deserialize_enum(name, variants, visitor), + Self::Vec(_) => Err(de::Error::custom("TODO: Error message")), + } + } + + fn deserialize_tuple( + self, + len: usize, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self { + Self::Val(val) => val.deserialize_tuple(len, visitor), + Self::Vec(_) => Err(de::Error::custom("TODO: Error message")), + } + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self { + Self::Val(val) => val.deserialize_struct(name, fields, visitor), + Self::Vec(_) => Err(de::Error::custom("TODO: Error message")), + } + } + + fn deserialize_unit_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self { + Self::Val(val) => val.deserialize_unit_struct(name, visitor), + Self::Vec(_) => Err(de::Error::custom("TODO: Error message")), + } + } + + fn deserialize_tuple_struct( + self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self { + Self::Val(val) => val.deserialize_tuple_struct(name, len, visitor), + Self::Vec(_) => Err(de::Error::custom("TODO: Error message")), + } + } + + fn deserialize_newtype_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self { + Self::Val(val) => val.deserialize_newtype_struct(name, visitor), + Self::Vec(_) => Err(de::Error::custom("TODO: Error message")), + } + } + + forward_to_part! { + deserialize_bool, + deserialize_char, + deserialize_str, + deserialize_string, + deserialize_bytes, + deserialize_byte_buf, + deserialize_unit, + deserialize_u8, + deserialize_u16, + deserialize_u32, + deserialize_u64, + deserialize_i8, + deserialize_i16, + deserialize_i32, + deserialize_i64, + deserialize_f32, + deserialize_f64, + deserialize_option, + deserialize_identifier, + deserialize_ignored_any, + deserialize_map, + } +} diff --git a/tests/url_deserialize.rs b/tests/url_deserialize.rs index ee421f25..4986b1b1 100644 --- a/tests/url_deserialize.rs +++ b/tests/url_deserialize.rs @@ -69,13 +69,13 @@ enum X { #[test] fn deserialize_unit_enum() { - let result = vec![ - ("one".to_owned(), X::A), - ("two".to_owned(), X::B), - ("three".to_owned(), X::C), - ]; + let result: Vec<(String, X)> = + urlencoded::from_str("one=A&two=B&three=C").unwrap(); - assert_eq!(urlencoded::from_str("one=A&two=B&three=C"), Ok(result)); + assert_eq!(result.len(), 3); + assert!(result.contains(&("one".to_owned(), X::A))); + assert!(result.contains(&("two".to_owned(), X::B))); + assert!(result.contains(&("three".to_owned(), X::C))); } #[test] @@ -129,7 +129,6 @@ struct ListStruct { } #[test] -#[ignore] fn deserialize_newstruct() { let de = NewStruct { list: vec!["hello", "world"], @@ -138,7 +137,6 @@ fn deserialize_newstruct() { } #[test] -#[ignore] fn deserialize_numlist() { let de = NumList { list: vec![1, 2, 3, 4], @@ -147,7 +145,6 @@ fn deserialize_numlist() { } #[test] -#[ignore] fn deserialize_vec_bool() { assert_eq!( urlencoded::from_str("item=true&item=false&item=false"), @@ -158,7 +155,6 @@ fn deserialize_vec_bool() { } #[test] -#[ignore] fn deserialize_vec_string() { assert_eq!( urlencoded::from_str("item=hello&item=matrix&item=hello"), @@ -173,7 +169,6 @@ fn deserialize_vec_string() { } #[test] -#[ignore] fn deserialize_struct_unit_enum() { let result = Wrapper { item: vec![X::A, X::B, X::C],