diff --git a/src/api/mod.rs b/src/api/mod.rs index a7737f6d..09ff964c 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -64,7 +64,15 @@ impl ToTokens for Api { let response = &self.response; let response_types = quote! { #response }; - let set_request_path = if self.request.has_path_fields() { + let extract_request_path = if self.request.has_path_fields() { + quote! { + let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); + } + } else { + TokenStream::new() + }; + + let (set_request_path, parse_request_path) = if self.request.has_path_fields() { let path_str = path.as_str(); assert!(path_str.starts_with('/'), "path needs to start with '/'"); @@ -75,7 +83,7 @@ impl ToTokens for Api { let request_path_init_fields = self.request.request_path_init_fields(); - let mut tokens = quote! { + let mut set_tokens = quote! { let request_path = RequestPath { #request_path_init_fields }; @@ -86,8 +94,10 @@ impl ToTokens for Api { let mut path_segments = url.path_segments_mut().unwrap(); }; - for segment in path_str[1..].split('/') { - tokens.append_all(quote! { + let mut parse_tokens = TokenStream::new(); + + for (i, segment) in path_str[1..].split('/').into_iter().enumerate() { + set_tokens.append_all(quote! { path_segments.push }); @@ -95,21 +105,38 @@ impl ToTokens for Api { let path_var = &segment[1..]; let path_var_ident = Ident::new(path_var, Span::call_site()); - tokens.append_all(quote! { + set_tokens.append_all(quote! { (&request_path.#path_var_ident.to_string()); }); + + let path_field = self.request.path_field(path_var) + .expect("expected request to have path field"); + let ty = &path_field.ty; + + parse_tokens.append_all(quote! { + #path_var_ident: { + let segment = path_segments.get(#i).unwrap().as_bytes(); + let decoded = + ::url::percent_encoding::percent_decode(segment) + .decode_utf8_lossy(); + #ty::deserialize(decoded.into_deserializer()) + .map_err(|e: ::serde_json::error::Error| e)? + }, + }); } else { - tokens.append_all(quote! { + set_tokens.append_all(quote! { (#segment); }); } } - tokens + (set_tokens, parse_tokens) } else { - quote! { + let set_tokens = quote! { url.set_path(metadata.path); - } + }; + let parse_tokens = TokenStream::new(); + (set_tokens, parse_tokens) }; let set_request_query = if self.request.has_query_fields() { @@ -126,6 +153,21 @@ impl ToTokens for Api { TokenStream::new() }; + let extract_request_query = if self.request.has_query_fields() { + quote! { + let request_query: RequestQuery = + ::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?; + } + } else { + TokenStream::new() + }; + + let parse_request_query = if self.request.has_query_fields() { + self.request.request_init_query_fields() + } else { + TokenStream::new() + }; + let add_headers_to_request = if self.request.has_header_fields() { let mut header_tokens = quote! { let headers = http_request.headers_mut(); @@ -138,6 +180,20 @@ impl ToTokens for Api { TokenStream::new() }; + let extract_request_headers = if self.request.has_header_fields() { + quote! { + let headers = request.headers(); + } + } else { + TokenStream::new() + }; + + let parse_request_headers = if self.request.has_header_fields() { + self.request.parse_headers_from_request() + } else { + TokenStream::new() + }; + let create_http_request = if let Some(field) = self.request.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -162,6 +218,33 @@ impl ToTokens for Api { } }; + let extract_request_body = if let Some(field) = self.request.newtype_body_field() { + let ty = &field.ty; + quote! { + let request_body: #ty = + ::serde_json::from_slice(request.body().as_slice())?; + } + } else if self.request.has_body_fields() { + quote! { + let request_body: RequestBody = + ::serde_json::from_slice(request.body().as_slice())?; + } + } else { + TokenStream::new() + }; + + let parse_request_body = if let Some(field) = self.request.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + + quote! { + #field_name: request_body, + } + } else if self.request.has_body_fields() { + self.request.request_init_body_fields() + } else { + TokenStream::new() + }; + let deserialize_response_body = if let Some(field) = self.response.newtype_body_field() { let field_type = &field.ty; @@ -198,7 +281,7 @@ impl ToTokens for Api { } }; - let extract_headers = if self.response.has_header_fields() { + let extract_response_headers = if self.response.has_header_fields() { quote! { let mut headers = http_response.headers().clone(); } @@ -212,10 +295,27 @@ impl ToTokens for Api { TokenStream::new() }; + let serialize_response_headers = self.response.apply_header_fields(); + + let serialize_response_body = if self.response.has_body() { + let body = self.response.to_body(); + quote! { + .body(::hyper::Body::from(::serde_json::to_vec(&#body)?)) + } + } else { + quote! { + .body(::hyper::Body::from("{}".as_bytes().to_vec())) + } + }; + tokens.append_all(quote! { #[allow(unused_imports)] use ::futures::{Future as _Future, IntoFuture as _IntoFuture, Stream as _Stream}; use ::ruma_api::Endpoint as _RumaApiEndpoint; + use ::serde::Deserialize; + use ::serde::de::{Error as _SerdeError, IntoDeserializer}; + + use ::std::convert::{TryInto as _TryInto}; /// The API endpoint. #[derive(Debug)] @@ -223,6 +323,45 @@ impl ToTokens for Api { #request_types + impl ::std::convert::TryFrom<::http::Request>> for Request { + type Error = ::ruma_api::Error; + + #[allow(unused_variables)] + fn try_from(request: ::http::Request>) -> Result { + #extract_request_path + #extract_request_query + #extract_request_headers + #extract_request_body + + Ok(Request { + #parse_request_path + #parse_request_query + #parse_request_headers + #parse_request_body + }) + } + } + + impl ::futures::future::FutureFrom<::http::Request<::hyper::Body>> for Request { + type Future = Box<_Future>; + type Error = ::ruma_api::Error; + + #[allow(unused_variables)] + fn future_from(request: ::http::Request<::hyper::Body>) -> Self::Future { + let (parts, body) = request.into_parts(); + let future = body.from_err().fold(Vec::new(), |mut vec, chunk| { + vec.extend(chunk.iter()); + ::futures::future::ok::<_, Self::Error>(vec) + }).and_then(|body| { + ::http::Request::from_parts(parts, body) + .try_into() + .into_future() + .from_err() + }); + Box::new(future) + } + } + impl ::std::convert::TryFrom for ::http::Request<::hyper::Body> { type Error = ::ruma_api::Error; @@ -251,6 +390,20 @@ impl ToTokens for Api { #response_types + impl ::std::convert::TryFrom for ::http::Response<::hyper::Body> { + type Error = ::ruma_api::Error; + + #[allow(unused_variables)] + fn try_from(response: Response) -> Result { + let response = ::http::Response::builder() + .header(::http::header::CONTENT_TYPE, "application/json") + #serialize_response_headers + #serialize_response_body + .unwrap(); + Ok(response) + } + } + impl ::futures::future::FutureFrom<::http::Response<::hyper::Body>> for Response { type Future = Box<_Future>; type Error = ::ruma_api::Error; @@ -259,7 +412,7 @@ impl ToTokens for Api { fn future_from(http_response: ::http::Response<::hyper::Body>) -> Box<_Future> { if http_response.status().is_success() { - #extract_headers + #extract_response_headers #deserialize_response_body .and_then(move |response_body| { diff --git a/src/api/request.rs b/src/api/request.rs index d86a2129..55cbacbe 100644 --- a/src/api/request.rs +++ b/src/api/request.rs @@ -32,6 +32,27 @@ impl Request { }) } + pub fn parse_headers_from_request(&self) -> TokenStream { + self.header_fields().fold(TokenStream::new(), |mut header_tokens, request_field| { + let (field, header_name_string) = match request_field { + RequestField::Header(field, header_name_string) => (field, header_name_string), + _ => panic!("expected request field to be header variant"), + }; + + let field_name = &field.ident; + let header_name = Ident::new(header_name_string.as_ref(), Span::call_site()); + + header_tokens.append_all(quote! { + #field_name: headers.get(::http::header::#header_name) + .and_then(|v| v.to_str().ok()) + .ok_or(::serde_json::Error::missing_field(#header_name_string))? + .to_owned(), + }); + + header_tokens + }) + } + pub fn has_body_fields(&self) -> bool { self.fields.iter().any(|field| field.is_body()) } @@ -55,6 +76,17 @@ impl Request { self.fields.iter().filter(|field| field.is_path()).count() } + pub fn path_field(&self, name: &str) -> Option<&Field> { + self.fields.iter() + .flat_map(|f| f.field_(RequestFieldKind::Path)) + .find(|field| { + field.ident.as_ref() + .expect("expected field to have an identifier") + .to_string() + == name + }) + } + pub fn newtype_body_field(&self) -> Option<&Field> { for request_field in self.fields.iter() { match *request_field { @@ -69,18 +101,26 @@ impl Request { } pub fn request_body_init_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Body) + self.struct_init_fields(RequestFieldKind::Body, quote!(request)) } pub fn request_path_init_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Path) + self.struct_init_fields(RequestFieldKind::Path, quote!(request)) } pub fn request_query_init_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Query) + self.struct_init_fields(RequestFieldKind::Query, quote!(request)) } - fn struct_init_fields(&self, request_field_kind: RequestFieldKind) -> TokenStream { + pub fn request_init_body_fields(&self) -> TokenStream { + self.struct_init_fields(RequestFieldKind::Body, quote!(request_body)) + } + + pub fn request_init_query_fields(&self) -> TokenStream { + self.struct_init_fields(RequestFieldKind::Query, quote!(request_query)) + } + + fn struct_init_fields(&self, request_field_kind: RequestFieldKind, src: TokenStream) -> TokenStream { let mut tokens = TokenStream::new(); for field in self.fields.iter().flat_map(|f| f.field_(request_field_kind)) { @@ -88,7 +128,7 @@ impl Request { let span = field.span(); tokens.append_all(quote_spanned! {span=> - #field_name: request.#field_name, + #field_name: #src.#field_name, }); } @@ -211,7 +251,7 @@ impl ToTokens for Request { request_body_struct = quote_spanned! {span=> /// Data in the request body. - #[derive(Debug, Serialize)] + #[derive(Debug, Deserialize, Serialize)] struct RequestBody(#ty); }; } else if self.has_body_fields() { @@ -230,7 +270,7 @@ impl ToTokens for Request { request_body_struct = quote! { /// Data in the request body. - #[derive(Debug, Serialize)] + #[derive(Debug, Deserialize, Serialize)] struct RequestBody { #fields } @@ -257,7 +297,7 @@ impl ToTokens for Request { request_path_struct = quote! { /// Data in the request path. - #[derive(Debug, Serialize)] + #[derive(Debug, Deserialize, Serialize)] struct RequestPath { #fields } @@ -284,7 +324,7 @@ impl ToTokens for Request { request_query_struct = quote! { /// Data in the request's query string. - #[derive(Debug, Serialize)] + #[derive(Debug, Deserialize, Serialize)] struct RequestQuery { #fields } diff --git a/src/api/response.rs b/src/api/response.rs index 1974d6a9..3fdf380a 100644 --- a/src/api/response.rs +++ b/src/api/response.rs @@ -22,6 +22,10 @@ impl Response { self.fields.iter().any(|field| field.is_header()) } + pub fn has_body(&self) -> bool { + self.fields.iter().any(|field| !field.is_header()) + } + pub fn init_fields(&self) -> TokenStream { let mut tokens = TokenStream::new(); @@ -62,6 +66,47 @@ impl Response { tokens } + pub fn apply_header_fields(&self) -> TokenStream { + let mut tokens = TokenStream::new(); + + for response_field in self.fields.iter() { + if let ResponseField::Header(ref field, ref header) = *response_field { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + let header_name = Ident::new(header.as_ref(), Span::call_site()); + let span = field.span(); + + tokens.append_all(quote_spanned! {span=> + .header(::http::header::#header_name, response.#field_name) + }); + } + } + + tokens + } + + pub fn to_body(&self) -> TokenStream { + if let Some(ref field) = self.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + let span = field.span(); + quote_spanned!(span=> response.#field_name) + } else { + let fields = self.fields.iter().filter_map(|response_field| { + if let ResponseField::Body(ref field) = *response_field { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + let span = field.span(); + + Some(quote_spanned! {span=> + #field_name: response.#field_name + }) + } else { + None + } + }); + + quote!(ResponseBody{#(#fields),*}) + } + } + pub fn newtype_body_field(&self) -> Option<&Field> { for response_field in self.fields.iter() { match *response_field { @@ -194,7 +239,7 @@ impl ToTokens for Response { response_body_struct = quote_spanned! {span=> /// Data in the response body. - #[derive(Debug, Deserialize)] + #[derive(Debug, Deserialize, Serialize)] struct ResponseBody(#ty); }; } else if self.has_body_fields() { @@ -213,7 +258,7 @@ impl ToTokens for Response { response_body_struct = quote! { /// Data in the response body. - #[derive(Debug, Deserialize)] + #[derive(Debug, Deserialize, Serialize)] struct ResponseBody { #fields } diff --git a/src/lib.rs b/src/lib.rs index fa110b54..741e56fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,7 @@ mod api; /// This will generate a `ruma_api::Metadata` value to be used for the `ruma_api::Endpoint`'s /// associated constant, single `Request` and `Response` structs, and the necessary trait /// implementations to convert the request into a `http::Request` and to create a response from a -/// `http::Response`. +/// `http::Response` and vice versa. /// /// The details of each of the three sections of the macros are documented below. /// @@ -173,7 +173,7 @@ mod api; /// pub mod newtype_body_endpoint { /// use ruma_api_macros::ruma_api; /// -/// #[derive(Debug, Deserialize)] +/// #[derive(Debug, Deserialize, Serialize)] /// pub struct MyCustomType { /// pub foo: String, /// }