Add convertion to/from Request/Response from/to http::Request/Response

This commit is contained in:
Jonas Herzig 2018-09-08 11:06:20 +02:00
parent 116a6f44bc
commit e23eff151b
4 changed files with 262 additions and 24 deletions

View File

@ -64,7 +64,15 @@ impl ToTokens for Api {
let response = &self.response; let response = &self.response;
let response_types = quote! { #response }; let response_types = quote! { #response };
let set_request_path = if self.request.has_path_fields() { let extract_request_path = if self.request.has_path_fields() {
quote! {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
}
} else {
TokenStream::new()
};
let (set_request_path, parse_request_path) = if self.request.has_path_fields() {
let path_str = path.as_str(); let path_str = path.as_str();
assert!(path_str.starts_with('/'), "path needs to start with '/'"); assert!(path_str.starts_with('/'), "path needs to start with '/'");
@ -75,7 +83,7 @@ impl ToTokens for Api {
let request_path_init_fields = self.request.request_path_init_fields(); let request_path_init_fields = self.request.request_path_init_fields();
let mut tokens = quote! { let mut set_tokens = quote! {
let request_path = RequestPath { let request_path = RequestPath {
#request_path_init_fields #request_path_init_fields
}; };
@ -86,8 +94,10 @@ impl ToTokens for Api {
let mut path_segments = url.path_segments_mut().unwrap(); let mut path_segments = url.path_segments_mut().unwrap();
}; };
for segment in path_str[1..].split('/') { let mut parse_tokens = TokenStream::new();
tokens.append_all(quote! {
for (i, segment) in path_str[1..].split('/').into_iter().enumerate() {
set_tokens.append_all(quote! {
path_segments.push path_segments.push
}); });
@ -95,21 +105,38 @@ impl ToTokens for Api {
let path_var = &segment[1..]; let path_var = &segment[1..];
let path_var_ident = Ident::new(path_var, Span::call_site()); let path_var_ident = Ident::new(path_var, Span::call_site());
tokens.append_all(quote! { set_tokens.append_all(quote! {
(&request_path.#path_var_ident.to_string()); (&request_path.#path_var_ident.to_string());
}); });
let path_field = self.request.path_field(path_var)
.expect("expected request to have path field");
let ty = &path_field.ty;
parse_tokens.append_all(quote! {
#path_var_ident: {
let segment = path_segments.get(#i).unwrap().as_bytes();
let decoded =
::url::percent_encoding::percent_decode(segment)
.decode_utf8_lossy();
#ty::deserialize(decoded.into_deserializer())
.map_err(|e: ::serde_json::error::Error| e)?
},
});
} else { } else {
tokens.append_all(quote! { set_tokens.append_all(quote! {
(#segment); (#segment);
}); });
} }
} }
tokens (set_tokens, parse_tokens)
} else { } else {
quote! { let set_tokens = quote! {
url.set_path(metadata.path); url.set_path(metadata.path);
} };
let parse_tokens = TokenStream::new();
(set_tokens, parse_tokens)
}; };
let set_request_query = if self.request.has_query_fields() { let set_request_query = if self.request.has_query_fields() {
@ -126,6 +153,21 @@ impl ToTokens for Api {
TokenStream::new() TokenStream::new()
}; };
let extract_request_query = if self.request.has_query_fields() {
quote! {
let request_query: RequestQuery =
::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?;
}
} else {
TokenStream::new()
};
let parse_request_query = if self.request.has_query_fields() {
self.request.request_init_query_fields()
} else {
TokenStream::new()
};
let add_headers_to_request = if self.request.has_header_fields() { let add_headers_to_request = if self.request.has_header_fields() {
let mut header_tokens = quote! { let mut header_tokens = quote! {
let headers = http_request.headers_mut(); let headers = http_request.headers_mut();
@ -138,6 +180,20 @@ impl ToTokens for Api {
TokenStream::new() TokenStream::new()
}; };
let extract_request_headers = if self.request.has_header_fields() {
quote! {
let headers = request.headers();
}
} else {
TokenStream::new()
};
let parse_request_headers = if self.request.has_header_fields() {
self.request.parse_headers_from_request()
} else {
TokenStream::new()
};
let create_http_request = if let Some(field) = self.request.newtype_body_field() { let create_http_request = if let Some(field) = self.request.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
@ -162,6 +218,33 @@ impl ToTokens for Api {
} }
}; };
let extract_request_body = if let Some(field) = self.request.newtype_body_field() {
let ty = &field.ty;
quote! {
let request_body: #ty =
::serde_json::from_slice(request.body().as_slice())?;
}
} else if self.request.has_body_fields() {
quote! {
let request_body: RequestBody =
::serde_json::from_slice(request.body().as_slice())?;
}
} else {
TokenStream::new()
};
let parse_request_body = if let Some(field) = self.request.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_body,
}
} else if self.request.has_body_fields() {
self.request.request_init_body_fields()
} else {
TokenStream::new()
};
let deserialize_response_body = if let Some(field) = self.response.newtype_body_field() { let deserialize_response_body = if let Some(field) = self.response.newtype_body_field() {
let field_type = &field.ty; let field_type = &field.ty;
@ -198,7 +281,7 @@ impl ToTokens for Api {
} }
}; };
let extract_headers = if self.response.has_header_fields() { let extract_response_headers = if self.response.has_header_fields() {
quote! { quote! {
let mut headers = http_response.headers().clone(); let mut headers = http_response.headers().clone();
} }
@ -212,10 +295,27 @@ impl ToTokens for Api {
TokenStream::new() TokenStream::new()
}; };
let serialize_response_headers = self.response.apply_header_fields();
let serialize_response_body = if self.response.has_body() {
let body = self.response.to_body();
quote! {
.body(::hyper::Body::from(::serde_json::to_vec(&#body)?))
}
} else {
quote! {
.body(::hyper::Body::from("{}".as_bytes().to_vec()))
}
};
tokens.append_all(quote! { tokens.append_all(quote! {
#[allow(unused_imports)] #[allow(unused_imports)]
use ::futures::{Future as _Future, IntoFuture as _IntoFuture, Stream as _Stream}; use ::futures::{Future as _Future, IntoFuture as _IntoFuture, Stream as _Stream};
use ::ruma_api::Endpoint as _RumaApiEndpoint; use ::ruma_api::Endpoint as _RumaApiEndpoint;
use ::serde::Deserialize;
use ::serde::de::{Error as _SerdeError, IntoDeserializer};
use ::std::convert::{TryInto as _TryInto};
/// The API endpoint. /// The API endpoint.
#[derive(Debug)] #[derive(Debug)]
@ -223,6 +323,45 @@ impl ToTokens for Api {
#request_types #request_types
impl ::std::convert::TryFrom<::http::Request<Vec<u8>>> for Request {
type Error = ::ruma_api::Error;
#[allow(unused_variables)]
fn try_from(request: ::http::Request<Vec<u8>>) -> Result<Self, Self::Error> {
#extract_request_path
#extract_request_query
#extract_request_headers
#extract_request_body
Ok(Request {
#parse_request_path
#parse_request_query
#parse_request_headers
#parse_request_body
})
}
}
impl ::futures::future::FutureFrom<::http::Request<::hyper::Body>> for Request {
type Future = Box<_Future<Item = Self, Error = Self::Error>>;
type Error = ::ruma_api::Error;
#[allow(unused_variables)]
fn future_from(request: ::http::Request<::hyper::Body>) -> Self::Future {
let (parts, body) = request.into_parts();
let future = body.from_err().fold(Vec::new(), |mut vec, chunk| {
vec.extend(chunk.iter());
::futures::future::ok::<_, Self::Error>(vec)
}).and_then(|body| {
::http::Request::from_parts(parts, body)
.try_into()
.into_future()
.from_err()
});
Box::new(future)
}
}
impl ::std::convert::TryFrom<Request> for ::http::Request<::hyper::Body> { impl ::std::convert::TryFrom<Request> for ::http::Request<::hyper::Body> {
type Error = ::ruma_api::Error; type Error = ::ruma_api::Error;
@ -251,6 +390,20 @@ impl ToTokens for Api {
#response_types #response_types
impl ::std::convert::TryFrom<Response> for ::http::Response<::hyper::Body> {
type Error = ::ruma_api::Error;
#[allow(unused_variables)]
fn try_from(response: Response) -> Result<Self, Self::Error> {
let response = ::http::Response::builder()
.header(::http::header::CONTENT_TYPE, "application/json")
#serialize_response_headers
#serialize_response_body
.unwrap();
Ok(response)
}
}
impl ::futures::future::FutureFrom<::http::Response<::hyper::Body>> for Response { impl ::futures::future::FutureFrom<::http::Response<::hyper::Body>> for Response {
type Future = Box<_Future<Item = Self, Error = Self::Error>>; type Future = Box<_Future<Item = Self, Error = Self::Error>>;
type Error = ::ruma_api::Error; type Error = ::ruma_api::Error;
@ -259,7 +412,7 @@ impl ToTokens for Api {
fn future_from(http_response: ::http::Response<::hyper::Body>) fn future_from(http_response: ::http::Response<::hyper::Body>)
-> Box<_Future<Item = Self, Error = Self::Error>> { -> Box<_Future<Item = Self, Error = Self::Error>> {
if http_response.status().is_success() { if http_response.status().is_success() {
#extract_headers #extract_response_headers
#deserialize_response_body #deserialize_response_body
.and_then(move |response_body| { .and_then(move |response_body| {

View File

@ -32,6 +32,27 @@ impl Request {
}) })
} }
pub fn parse_headers_from_request(&self) -> TokenStream {
self.header_fields().fold(TokenStream::new(), |mut header_tokens, request_field| {
let (field, header_name_string) = match request_field {
RequestField::Header(field, header_name_string) => (field, header_name_string),
_ => panic!("expected request field to be header variant"),
};
let field_name = &field.ident;
let header_name = Ident::new(header_name_string.as_ref(), Span::call_site());
header_tokens.append_all(quote! {
#field_name: headers.get(::http::header::#header_name)
.and_then(|v| v.to_str().ok())
.ok_or(::serde_json::Error::missing_field(#header_name_string))?
.to_owned(),
});
header_tokens
})
}
pub fn has_body_fields(&self) -> bool { pub fn has_body_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_body()) self.fields.iter().any(|field| field.is_body())
} }
@ -55,6 +76,17 @@ impl Request {
self.fields.iter().filter(|field| field.is_path()).count() self.fields.iter().filter(|field| field.is_path()).count()
} }
pub fn path_field(&self, name: &str) -> Option<&Field> {
self.fields.iter()
.flat_map(|f| f.field_(RequestFieldKind::Path))
.find(|field| {
field.ident.as_ref()
.expect("expected field to have an identifier")
.to_string()
== name
})
}
pub fn newtype_body_field(&self) -> Option<&Field> { pub fn newtype_body_field(&self) -> Option<&Field> {
for request_field in self.fields.iter() { for request_field in self.fields.iter() {
match *request_field { match *request_field {
@ -69,18 +101,26 @@ impl Request {
} }
pub fn request_body_init_fields(&self) -> TokenStream { pub fn request_body_init_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Body) self.struct_init_fields(RequestFieldKind::Body, quote!(request))
} }
pub fn request_path_init_fields(&self) -> TokenStream { pub fn request_path_init_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Path) self.struct_init_fields(RequestFieldKind::Path, quote!(request))
} }
pub fn request_query_init_fields(&self) -> TokenStream { pub fn request_query_init_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Query) self.struct_init_fields(RequestFieldKind::Query, quote!(request))
} }
fn struct_init_fields(&self, request_field_kind: RequestFieldKind) -> TokenStream { pub fn request_init_body_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Body, quote!(request_body))
}
pub fn request_init_query_fields(&self) -> TokenStream {
self.struct_init_fields(RequestFieldKind::Query, quote!(request_query))
}
fn struct_init_fields(&self, request_field_kind: RequestFieldKind, src: TokenStream) -> TokenStream {
let mut tokens = TokenStream::new(); let mut tokens = TokenStream::new();
for field in self.fields.iter().flat_map(|f| f.field_(request_field_kind)) { for field in self.fields.iter().flat_map(|f| f.field_(request_field_kind)) {
@ -88,7 +128,7 @@ impl Request {
let span = field.span(); let span = field.span();
tokens.append_all(quote_spanned! {span=> tokens.append_all(quote_spanned! {span=>
#field_name: request.#field_name, #field_name: #src.#field_name,
}); });
} }
@ -211,7 +251,7 @@ impl ToTokens for Request {
request_body_struct = quote_spanned! {span=> request_body_struct = quote_spanned! {span=>
/// Data in the request body. /// Data in the request body.
#[derive(Debug, Serialize)] #[derive(Debug, Deserialize, Serialize)]
struct RequestBody(#ty); struct RequestBody(#ty);
}; };
} else if self.has_body_fields() { } else if self.has_body_fields() {
@ -230,7 +270,7 @@ impl ToTokens for Request {
request_body_struct = quote! { request_body_struct = quote! {
/// Data in the request body. /// Data in the request body.
#[derive(Debug, Serialize)] #[derive(Debug, Deserialize, Serialize)]
struct RequestBody { struct RequestBody {
#fields #fields
} }
@ -257,7 +297,7 @@ impl ToTokens for Request {
request_path_struct = quote! { request_path_struct = quote! {
/// Data in the request path. /// Data in the request path.
#[derive(Debug, Serialize)] #[derive(Debug, Deserialize, Serialize)]
struct RequestPath { struct RequestPath {
#fields #fields
} }
@ -284,7 +324,7 @@ impl ToTokens for Request {
request_query_struct = quote! { request_query_struct = quote! {
/// Data in the request's query string. /// Data in the request's query string.
#[derive(Debug, Serialize)] #[derive(Debug, Deserialize, Serialize)]
struct RequestQuery { struct RequestQuery {
#fields #fields
} }

View File

@ -22,6 +22,10 @@ impl Response {
self.fields.iter().any(|field| field.is_header()) self.fields.iter().any(|field| field.is_header())
} }
pub fn has_body(&self) -> bool {
self.fields.iter().any(|field| !field.is_header())
}
pub fn init_fields(&self) -> TokenStream { pub fn init_fields(&self) -> TokenStream {
let mut tokens = TokenStream::new(); let mut tokens = TokenStream::new();
@ -62,6 +66,47 @@ impl Response {
tokens tokens
} }
pub fn apply_header_fields(&self) -> TokenStream {
let mut tokens = TokenStream::new();
for response_field in self.fields.iter() {
if let ResponseField::Header(ref field, ref header) = *response_field {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let header_name = Ident::new(header.as_ref(), Span::call_site());
let span = field.span();
tokens.append_all(quote_spanned! {span=>
.header(::http::header::#header_name, response.#field_name)
});
}
}
tokens
}
pub fn to_body(&self) -> TokenStream {
if let Some(ref field) = self.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let span = field.span();
quote_spanned!(span=> response.#field_name)
} else {
let fields = self.fields.iter().filter_map(|response_field| {
if let ResponseField::Body(ref field) = *response_field {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let span = field.span();
Some(quote_spanned! {span=>
#field_name: response.#field_name
})
} else {
None
}
});
quote!(ResponseBody{#(#fields),*})
}
}
pub fn newtype_body_field(&self) -> Option<&Field> { pub fn newtype_body_field(&self) -> Option<&Field> {
for response_field in self.fields.iter() { for response_field in self.fields.iter() {
match *response_field { match *response_field {
@ -194,7 +239,7 @@ impl ToTokens for Response {
response_body_struct = quote_spanned! {span=> response_body_struct = quote_spanned! {span=>
/// Data in the response body. /// Data in the response body.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Serialize)]
struct ResponseBody(#ty); struct ResponseBody(#ty);
}; };
} else if self.has_body_fields() { } else if self.has_body_fields() {
@ -213,7 +258,7 @@ impl ToTokens for Response {
response_body_struct = quote! { response_body_struct = quote! {
/// Data in the response body. /// Data in the response body.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Serialize)]
struct ResponseBody { struct ResponseBody {
#fields #fields
} }

View File

@ -51,7 +51,7 @@ mod api;
/// This will generate a `ruma_api::Metadata` value to be used for the `ruma_api::Endpoint`'s /// This will generate a `ruma_api::Metadata` value to be used for the `ruma_api::Endpoint`'s
/// associated constant, single `Request` and `Response` structs, and the necessary trait /// associated constant, single `Request` and `Response` structs, and the necessary trait
/// implementations to convert the request into a `http::Request` and to create a response from a /// implementations to convert the request into a `http::Request` and to create a response from a
/// `http::Response`. /// `http::Response` and vice versa.
/// ///
/// The details of each of the three sections of the macros are documented below. /// The details of each of the three sections of the macros are documented below.
/// ///
@ -173,7 +173,7 @@ mod api;
/// pub mod newtype_body_endpoint { /// pub mod newtype_body_endpoint {
/// use ruma_api_macros::ruma_api; /// use ruma_api_macros::ruma_api;
/// ///
/// #[derive(Debug, Deserialize)] /// #[derive(Debug, Deserialize, Serialize)]
/// pub struct MyCustomType { /// pub struct MyCustomType {
/// pub foo: String, /// pub foo: String,
/// } /// }