diff --git a/Cargo.toml b/Cargo.toml index 8b0239a7..2e823d57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ name = "ruma-api-macros" version = "0.1.0" [dependencies] -hyper = "0.11" quote = "0.3.15" synom = "0.11.3" @@ -18,9 +17,12 @@ version = "0.11.11" [dev-dependencies] futures = "0.1.14" +hyper = "0.11" serde = "1.0.8" serde_derive = "1.0.8" serde_json = "1.0.2" +serde_urlencoded = "0.5.1" +url = "1.5.1" [lib] doctest = false diff --git a/src/api/mod.rs b/src/api/mod.rs index 64ee55ff..f206a51e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -36,6 +36,20 @@ impl ToTokens for Api { tokens }; + let set_request_query = if self.request.has_query_fields() { + let request_query_init_fields = self.request.request_query_init_fields(); + + quote! { + let request_query = RequestQuery { + #request_query_init_fields + }; + + url.set_query(Some(&::serde_urlencoded::to_string(request_query)?)); + } + } else { + Tokens::new() + }; + let add_body_to_request = if let Some(field) = self.request.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected body field to have a name"); @@ -141,9 +155,15 @@ impl ToTokens for Api { fn try_from(request: Request) -> Result { let metadata = Endpoint::METADATA; + // The homeserver url has to be overwritten in the calling code. + let mut url = ::url::Url::parse("http://invalid-host-please-change/").unwrap(); + url.set_path(metadata.path); + #set_request_query + let mut hyper_request = ::hyper::Request::new( metadata.method, - metadata.path.parse()?, + // Every valid URL is a valid URI + url.into_string().parse().unwrap(), ); #add_body_to_request diff --git a/src/api/request.rs b/src/api/request.rs index 43f7fcd6..6bce3eb7 100644 --- a/src/api/request.rs +++ b/src/api/request.rs @@ -11,6 +11,10 @@ impl Request { self.fields.iter().any(|field| field.is_body()) } + pub fn has_query_fields(&self) -> bool { + self.fields.iter().any(|field| field.is_query()) + } + pub fn newtype_body_field(&self) -> Option<&Field> { for request_field in self.fields.iter() { match *request_field { @@ -44,9 +48,32 @@ impl Request { tokens } + pub fn request_query_init_fields(&self) -> Tokens { + let mut tokens = Tokens::new(); + + for query_field in self.query_fields() { + let field = match *query_field { + RequestField::Query(ref field) => field, + _ => panic!("expected query field"), + }; + + let field_name = field.ident.as_ref().expect("expected query field to have a name"); + + tokens.append(quote! { + #field_name: request.#field_name, + }); + } + + tokens + } + fn body_fields(&self) -> RequestBodyFields { RequestBodyFields::new(&self.fields) } + + fn query_fields(&self) -> RequestQueryFields { + RequestQueryFields::new(&self.fields) + } } impl From> for Request { @@ -187,6 +214,29 @@ impl ToTokens for Request { tokens.append("}"); } + + if self.has_query_fields() { + tokens.append(quote! { + /// Data in the request url's query parameters + #[derive(Debug, Serialize)] + struct RequestQuery + }); + + tokens.append("{"); + + for request_field in self.fields.iter() { + match *request_field { + RequestField::Query(ref field) => { + field.to_tokens(&mut tokens); + + tokens.append(","); + } + _ => {} + } + } + + tokens.append("}"); + } } } @@ -206,6 +256,13 @@ impl RequestField { _ => false, } } + + fn is_query(&self) -> bool { + match *self { + RequestField::Query(_) => true, + _ => false, + } + } } enum RequestFieldKind { @@ -246,3 +303,34 @@ impl<'a> Iterator for RequestBodyFields<'a> { None } } + +#[derive(Debug)] +pub struct RequestQueryFields<'a> { + fields: &'a [RequestField], + index: usize, +} + +impl<'a> RequestQueryFields<'a> { + pub fn new(fields: &'a [RequestField]) -> Self { + RequestQueryFields { + fields, + index: 0, + } + } +} + +impl<'a> Iterator for RequestQueryFields<'a> { + type Item = &'a RequestField; + + fn next(&mut self) -> Option<&'a RequestField> { + while let Some(value) = self.fields.get(self.index) { + self.index += 1; + + if value.is_query() { + return Some(value); + } + } + + None + } +} diff --git a/tests/ruma_api_macros.rs b/tests/ruma_api_macros.rs index 9b477386..fc7c4913 100644 --- a/tests/ruma_api_macros.rs +++ b/tests/ruma_api_macros.rs @@ -7,6 +7,8 @@ extern crate ruma_api_macros; extern crate serde; #[macro_use] extern crate serde_derive; extern crate serde_json; +extern crate serde_urlencoded; +extern crate url; pub mod get_supported_versions { use ruma_api_macros::ruma_api;