diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index a2bfcb90..f0d6f847 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -173,6 +173,7 @@ impl Request { ruma_api: &TokenStream, ) -> TokenStream { let http = quote! { #ruma_api::exports::http }; + let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde = quote! { #ruma_api::exports::serde }; let serde_json = quote! { #ruma_api::exports::serde_json }; @@ -205,11 +206,131 @@ impl Request { TokenStream::new() }; - let (request_path_string, parse_request_path) = - self.path_string_and_parse(metadata, &ruma_api); + let (request_path_string, parse_request_path) = if self.has_path_fields() { + let path_string = metadata.path.value(); - let request_query_string = self.build_query_string(&ruma_api); - let extract_request_query = self.extract_request_query(&ruma_api); + assert!(path_string.starts_with('/'), "path needs to start with '/'"); + assert!( + path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(), + "number of declared path parameters needs to match amount of placeholders in path" + ); + + let format_call = { + let mut format_string = path_string.clone(); + let mut format_args = Vec::new(); + + while let Some(start_of_segment) = format_string.find(':') { + // ':' should only ever appear at the start of a segment + assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); + + let end_of_segment = match format_string[start_of_segment..].find('/') { + Some(rel_pos) => start_of_segment + rel_pos, + None => format_string.len(), + }; + + let path_var = Ident::new( + &format_string[start_of_segment + 1..end_of_segment], + Span::call_site(), + ); + format_args.push(quote! { + #percent_encoding::utf8_percent_encode( + &self.#path_var.to_string(), + #percent_encoding::NON_ALPHANUMERIC, + ) + }); + format_string.replace_range(start_of_segment..end_of_segment, "{}"); + } + + quote! { + format_args!(#format_string, #(#format_args),*) + } + }; + + let path_fields = + path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( + |(i, segment)| { + let path_var = &segment[1..]; + let path_var_ident = Ident::new(path_var, Span::call_site()); + quote! { + #path_var_ident: { + let segment = path_segments[#i].as_bytes(); + let decoded = + #percent_encoding::percent_decode(segment).decode_utf8()?; + + ::std::convert::TryFrom::try_from(&*decoded)? + } + } + }, + ); + + (format_call, quote! { #(#path_fields,)* }) + } else { + (quote! { metadata.path.to_owned() }, TokenStream::new()) + }; + + let request_query_string = if let Some(field) = self.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have identifier"); + + quote!({ + // This function exists so that the compiler will throw an error when the type of + // the field with the query_map attribute doesn't implement + // `IntoIterator`. + // + // This is necessary because the `ruma_serde::urlencoded::to_string` call will + // result in a runtime error when the type cannot be encoded as a list key-value + // pairs (?key1=value1&key2=value2). + // + // By asserting that it implements the iterator trait, we can ensure that it won't + // fail. + fn assert_trait_impl(_: &T) + where + T: ::std::iter::IntoIterator< + Item = (::std::string::String, ::std::string::String), + >, + {} + + let request_query = RequestQuery(self.#field_name); + assert_trait_impl(&request_query.0); + + format_args!( + "?{}", + #ruma_serde::urlencoded::to_string(request_query)? + ) + }) + } else if self.has_query_fields() { + let request_query_init_fields = + self.struct_init_fields(RequestFieldKind::Query, quote!(self)); + + quote!({ + let request_query = RequestQuery { + #request_query_init_fields + }; + + format_args!( + "?{}", + #ruma_serde::urlencoded::to_string(request_query)? + ) + }) + } else { + quote! { "" } + }; + + let extract_request_query = if self.query_map_field().is_some() { + quote! { + let request_query = #ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or(""), + )?; + } + } else if self.has_query_fields() { + quote! { + let request_query: ::Incoming = + #ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or("") + )?; + } + } else { + TokenStream::new() + }; let parse_request_query = if let Some(field) = self.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -593,158 +714,6 @@ impl Request { } } } - - /// Deserialize the query string. - fn extract_request_query(&self, ruma_api: &TokenStream) -> TokenStream { - let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - - if self.query_map_field().is_some() { - quote! { - let request_query = #ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or(""), - )?; - } - } else if self.has_query_fields() { - quote! { - let request_query: ::Incoming = - #ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or("") - )?; - } - } else { - TokenStream::new() - } - } - - /// The function determines the type of query string that needs to be built - /// and then builds it using `ruma_serde::urlencoded::to_string`. - fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream { - let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - - if let Some(field) = self.query_map_field() { - let field_name = field.ident.as_ref().expect("expected field to have identifier"); - - quote!({ - // This function exists so that the compiler will throw an error when the type of - // the field with the query_map attribute doesn't implement - // `IntoIterator`. - // - // This is necessary because the `ruma_serde::urlencoded::to_string` call will - // result in a runtime error when the type cannot be encoded as a list key-value - // pairs (?key1=value1&key2=value2). - // - // By asserting that it implements the iterator trait, we can ensure that it won't - // fail. - fn assert_trait_impl(_: &T) - where - T: ::std::iter::IntoIterator< - Item = (::std::string::String, ::std::string::String), - >, - {} - - let request_query = RequestQuery(self.#field_name); - assert_trait_impl(&request_query.0); - - format_args!( - "?{}", - #ruma_serde::urlencoded::to_string(request_query)? - ) - }) - } else if self.has_query_fields() { - let request_query_init_fields = - self.struct_init_fields(RequestFieldKind::Query, quote!(self)); - - quote!({ - let request_query = RequestQuery { - #request_query_init_fields - }; - - format_args!( - "?{}", - #ruma_serde::urlencoded::to_string(request_query)? - ) - }) - } else { - quote! { "" } - } - } - - /// The first item in the tuple generates code for the request path from the `Metadata` and - /// `Request` structs. The second item in the returned tuple is the code to generate a Request - /// struct field created from any segments of the path that start with ":". - /// - /// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is - /// used for implementing `TryFrom>>`, from path strings deserialized to - /// Ruma types. - fn path_string_and_parse( - &self, - metadata: &Metadata, - ruma_api: &TokenStream, - ) -> (TokenStream, TokenStream) { - let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; - - if self.has_path_fields() { - let path_string = metadata.path.value(); - - assert!(path_string.starts_with('/'), "path needs to start with '/'"); - assert!( - path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(), - "number of declared path parameters needs to match amount of placeholders in path" - ); - - let format_call = { - let mut format_string = path_string.clone(); - let mut format_args = Vec::new(); - - while let Some(start_of_segment) = format_string.find(':') { - // ':' should only ever appear at the start of a segment - assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); - - let end_of_segment = match format_string[start_of_segment..].find('/') { - Some(rel_pos) => start_of_segment + rel_pos, - None => format_string.len(), - }; - - let path_var = Ident::new( - &format_string[start_of_segment + 1..end_of_segment], - Span::call_site(), - ); - format_args.push(quote! { - #percent_encoding::utf8_percent_encode( - &self.#path_var.to_string(), - #percent_encoding::NON_ALPHANUMERIC, - ) - }); - format_string.replace_range(start_of_segment..end_of_segment, "{}"); - } - - quote! { - format_args!(#format_string, #(#format_args),*) - } - }; - - let path_fields = - path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( - |(i, segment)| { - let path_var = &segment[1..]; - let path_var_ident = Ident::new(path_var, Span::call_site()); - quote! { - #path_var_ident: { - let segment = path_segments[#i].as_bytes(); - let decoded = - #percent_encoding::percent_decode(segment).decode_utf8()?; - - ::std::convert::TryFrom::try_from(&*decoded)? - } - } - }, - ); - - (format_call, quote! { #(#path_fields,)* }) - } else { - (quote! { metadata.path.to_owned() }, TokenStream::new()) - } - } } /// The types of fields that a request can have.