use proc_macro2::{Span, TokenStream}; use quote::{ToTokens, TokenStreamExt}; use syn::{braced, Field, FieldValue, Ident, Meta, Token}; use syn::parse::{Parse, ParseStream, Result}; mod metadata; mod request; mod response; use self::metadata::Metadata; use self::request::Request; use self::response::Response; pub fn strip_serde_attrs(field: &Field) -> Field { let mut field = field.clone(); field.attrs = field.attrs.into_iter().filter(|attr| { let meta = attr.interpret_meta() .expect("ruma_api! could not parse field attributes"); let meta_list = match meta { Meta::List(meta_list) => meta_list, _ => return true, }; if &meta_list.ident.to_string() == "serde" { return false; } true }).collect(); field } pub struct Api { metadata: Metadata, request: Request, response: Response, } impl From for Api { fn from(raw_api: RawApi) -> Self { Api { metadata: raw_api.metadata.into(), request: raw_api.request.into(), response: raw_api.response.into(), } } } impl ToTokens for Api { fn to_tokens(&self, tokens: &mut TokenStream) { let description = &self.metadata.description; let method = Ident::new(self.metadata.method.as_ref(), Span::call_site()); let name = &self.metadata.name; let path = &self.metadata.path; let rate_limited = &self.metadata.rate_limited; let requires_authentication = &self.metadata.requires_authentication; let request = &self.request; let request_types = quote! { #request }; let response = &self.response; let response_types = quote! { #response }; let extract_request_path = if self.request.has_path_fields() { quote! { let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); } } else { TokenStream::new() }; let (set_request_path, parse_request_path) = if self.request.has_path_fields() { let path_str = path.as_str(); assert!(path_str.starts_with('/'), "path needs to start with '/'"); assert!( path_str.chars().filter(|c| *c == ':').count() == self.request.path_field_count(), "number of declared path parameters needs to match amount of placeholders in path" ); let request_path_init_fields = self.request.request_path_init_fields(); let mut set_tokens = quote! { let request_path = RequestPath { #request_path_init_fields }; // This `unwrap()` can only fail when the url is a // cannot-be-base url like `mailto:` or `data:`, which is not // the case for our placeholder url. let mut path_segments = url.path_segments_mut().unwrap(); }; let mut parse_tokens = TokenStream::new(); for (i, segment) in path_str[1..].split('/').into_iter().enumerate() { set_tokens.append_all(quote! { path_segments.push }); if segment.starts_with(':') { let path_var = &segment[1..]; let path_var_ident = Ident::new(path_var, Span::call_site()); set_tokens.append_all(quote! { (&request_path.#path_var_ident.to_string()); }); let path_field = self.request.path_field(path_var) .expect("expected request to have path field"); let ty = &path_field.ty; parse_tokens.append_all(quote! { #path_var_ident: { let segment = path_segments.get(#i).unwrap().as_bytes(); let decoded = ::url::percent_encoding::percent_decode(segment) .decode_utf8_lossy(); #ty::deserialize(decoded.into_deserializer()) .map_err(|e: ::serde_json::error::Error| e)? }, }); } else { set_tokens.append_all(quote! { (#segment); }); } } (set_tokens, parse_tokens) } else { let set_tokens = quote! { url.set_path(metadata.path); }; let parse_tokens = TokenStream::new(); (set_tokens, parse_tokens) }; let set_request_query = if self.request.has_query_fields() { let request_query_init_fields = self.request.request_query_init_fields(); quote! { let request_query = RequestQuery { #request_query_init_fields }; url.set_query(Some(&::serde_urlencoded::to_string(request_query)?)); } } else { TokenStream::new() }; let extract_request_query = if self.request.has_query_fields() { quote! { let request_query: RequestQuery = ::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?; } } else { TokenStream::new() }; let parse_request_query = if self.request.has_query_fields() { self.request.request_init_query_fields() } else { TokenStream::new() }; let add_headers_to_request = if self.request.has_header_fields() { let mut header_tokens = quote! { let headers = http_request.headers_mut(); }; header_tokens.append_all(self.request.add_headers_to_request()); header_tokens } else { TokenStream::new() }; let extract_request_headers = if self.request.has_header_fields() { quote! { let headers = request.headers(); } } else { TokenStream::new() }; let parse_request_headers = if self.request.has_header_fields() { self.request.parse_headers_from_request() } else { TokenStream::new() }; let create_http_request = if let Some(field) = self.request.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { let request_body = RequestBody(request.#field_name); let mut http_request = ::http::Request::new(::serde_json::to_vec(&request_body)?.into()); } } else if self.request.has_body_fields() { let request_body_init_fields = self.request.request_body_init_fields(); quote! { let request_body = RequestBody { #request_body_init_fields }; let mut http_request = ::http::Request::new(::serde_json::to_vec(&request_body)?.into()); } } else { quote! { let mut http_request = ::http::Request::new(::hyper::Body::empty()); } }; let extract_request_body = if let Some(field) = self.request.newtype_body_field() { let ty = &field.ty; quote! { let request_body: #ty = ::serde_json::from_slice(request.body().as_slice())?; } } else if self.request.has_body_fields() { quote! { let request_body: RequestBody = ::serde_json::from_slice(request.body().as_slice())?; } } else { TokenStream::new() }; let parse_request_body = if let Some(field) = self.request.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { #field_name: request_body, } } else if self.request.has_body_fields() { self.request.request_init_body_fields() } else { TokenStream::new() }; let deserialize_response_body = if let Some(field) = self.response.newtype_body_field() { let field_type = &field.ty; quote! { let future_response = http_response.into_body() .fold(Vec::new(), |mut vec, chunk| { vec.extend(chunk.iter()); ::futures::future::ok::<_, ::hyper::Error>(vec) }) .map_err(::ruma_api::Error::from) .and_then(|data| ::serde_json::from_slice::<#field_type>(data.as_slice()) .map_err(::ruma_api::Error::from) .into_future() ) } } else if self.response.has_body_fields() { quote! { let future_response = http_response.into_body() .fold(Vec::new(), |mut vec, chunk| { vec.extend(chunk.iter()); ::futures::future::ok::<_, ::hyper::Error>(vec) }) .map_err(::ruma_api::Error::from) .and_then(|data| ::serde_json::from_slice::(data.as_slice()) .map_err(::ruma_api::Error::from) .into_future() ) } } else { quote! { let future_response = ::futures::future::ok(()) } }; let extract_response_headers = if self.response.has_header_fields() { quote! { let mut headers = http_response.headers().clone(); } } else { TokenStream::new() }; let response_init_fields = if self.response.has_fields() { self.response.init_fields() } else { TokenStream::new() }; let serialize_response_headers = self.response.apply_header_fields(); let serialize_response_body = if self.response.has_body() { let body = self.response.to_body(); quote! { .body(::hyper::Body::from(::serde_json::to_vec(&#body)?)) } } else { quote! { .body(::hyper::Body::from("{}".as_bytes().to_vec())) } }; tokens.append_all(quote! { #[allow(unused_imports)] use ::futures::{Future as _Future, IntoFuture as _IntoFuture, Stream as _Stream}; use ::ruma_api::Endpoint as _RumaApiEndpoint; use ::serde::Deserialize; use ::serde::de::{Error as _SerdeError, IntoDeserializer}; use ::std::convert::{TryInto as _TryInto}; /// The API endpoint. #[derive(Debug)] pub struct Endpoint; #request_types impl ::std::convert::TryFrom<::http::Request>> for Request { type Error = ::ruma_api::Error; #[allow(unused_variables)] fn try_from(request: ::http::Request>) -> Result { #extract_request_path #extract_request_query #extract_request_headers #extract_request_body Ok(Request { #parse_request_path #parse_request_query #parse_request_headers #parse_request_body }) } } impl ::futures::future::FutureFrom<::http::Request<::hyper::Body>> for Request { type Future = Box<_Future + Send>; type Error = ::ruma_api::Error; #[allow(unused_variables)] fn future_from(request: ::http::Request<::hyper::Body>) -> Self::Future { let (parts, body) = request.into_parts(); let future = body.from_err().fold(Vec::new(), |mut vec, chunk| { vec.extend(chunk.iter()); ::futures::future::ok::<_, Self::Error>(vec) }).and_then(|body| { ::http::Request::from_parts(parts, body) .try_into() .into_future() .from_err() }); Box::new(future) } } impl ::std::convert::TryFrom for ::http::Request<::hyper::Body> { type Error = ::ruma_api::Error; #[allow(unused_mut, unused_variables)] fn try_from(request: Request) -> Result { let metadata = Endpoint::METADATA; // Use dummy homeserver url which has to be overwritten in // the calling code. Previously (with http::Uri) this was // not required, but Url::parse only accepts absolute urls. let mut url = ::url::Url::parse("http://invalid-host-please-change/").unwrap(); { #set_request_path } { #set_request_query } #create_http_request *http_request.method_mut() = ::http::Method::#method; *http_request.uri_mut() = url.into_string().parse().unwrap(); { #add_headers_to_request } Ok(http_request) } } #response_types impl ::std::convert::TryFrom for ::http::Response<::hyper::Body> { type Error = ::ruma_api::Error; #[allow(unused_variables)] fn try_from(response: Response) -> Result { let response = ::http::Response::builder() .header(::http::header::CONTENT_TYPE, "application/json") #serialize_response_headers #serialize_response_body .unwrap(); Ok(response) } } impl ::futures::future::FutureFrom<::http::Response<::hyper::Body>> for Response { type Future = Box<_Future + Send>; type Error = ::ruma_api::Error; #[allow(unused_variables)] fn future_from(http_response: ::http::Response<::hyper::Body>) -> Self::Future { if http_response.status().is_success() { #extract_response_headers #deserialize_response_body .and_then(move |response_body| { let response = Response { #response_init_fields }; Ok(response) }); Box::new(future_response) } else { Box::new(::futures::future::err(::ruma_api::Error::StatusCode(http_response.status().clone()))) } } } impl ::ruma_api::Endpoint for Endpoint { type Request = Request; type Response = Response; const METADATA: ::ruma_api::Metadata = ::ruma_api::Metadata { description: #description, method: ::http::Method::#method, name: #name, path: #path, rate_limited: #rate_limited, requires_authentication: #requires_authentication, }; } }); } } mod kw { use syn::custom_keyword; custom_keyword!(metadata); custom_keyword!(request); custom_keyword!(response); } pub struct RawApi { pub metadata: Vec, pub request: Vec, pub response: Vec, } impl Parse for RawApi { fn parse(input: ParseStream) -> Result { input.parse::()?; let metadata; braced!(metadata in input); input.parse::()?; let request; braced!(request in input); input.parse::()?; let response; braced!(response in input); Ok(RawApi { metadata: metadata .parse_terminated::(FieldValue::parse)? .into_iter() .collect(), request: request .parse_terminated::(Field::parse_named)? .into_iter() .collect(), response: response .parse_terminated::(Field::parse_named)? .into_iter() .collect(), }) } }