From e383ae98eae8a3477110b94d718fca8e1583d96f Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 15 Nov 2019 20:42:01 +0100 Subject: [PATCH] Revert "Remove server-side functionality" This reverts commit 958a0a01c47c051eebf493234c314bc101609f63. --- ruma-api-macros/src/api.rs | 155 +++++++++++++++++++++++++++- ruma-api-macros/src/api/request.rs | 54 +++++++++- ruma-api-macros/src/api/response.rs | 66 +++++++++++- src/lib.rs | 45 +++++++- 4 files changed, 306 insertions(+), 14 deletions(-) diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index e854d332..9fbe21aa 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -90,7 +90,15 @@ impl ToTokens for Api { let response = &self.response; let response_types = quote! { #response }; - let set_request_path = if self.request.has_path_fields() { + let extract_request_path = if self.request.has_path_fields() { + quote! { + let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); + } + } else { + TokenStream::new() + }; + + let (set_request_path, parse_request_path) = if self.request.has_path_fields() { let path_str = path.value(); assert!(path_str.starts_with('/'), "path needs to start with '/'"); @@ -116,7 +124,7 @@ impl ToTokens for Api { } }); - quote! { + let set_tokens = quote! { let request_path = RequestPath { #request_path_init_fields }; @@ -126,11 +134,43 @@ impl ToTokens for Api { // the case for our placeholder url. let mut path_segments = url.path_segments_mut().unwrap(); #(#path_segment_push)* - } + }; + + let path_fields = path_segments + .enumerate() + .filter(|(_, s)| s.starts_with(':')) + .map(|(i, segment)| { + let path_var = &segment[1..]; + let path_var_ident = Ident::new(path_var, Span::call_site()); + let path_field = self + .request + .path_field(path_var) + .expect("expected request to have path field"); + let ty = &path_field.ty; + + quote! { + #path_var_ident: { + let segment = path_segments.get(#i).unwrap().as_bytes(); + let decoded = + ruma_api::exports::percent_encoding::percent_decode(segment) + .decode_utf8_lossy(); + #ty::deserialize(decoded.into_deserializer()) + .map_err(|e: ruma_api::exports::serde_json::error::Error| e)? + } + } + }); + + let parse_tokens = quote! { + #(#path_fields,)* + }; + + (set_tokens, parse_tokens) } else { - quote! { + let set_tokens = quote! { url.set_path(metadata.path); - } + }; + let parse_tokens = TokenStream::new(); + (set_tokens, parse_tokens) }; let set_request_query = if let Some(field) = self.request.query_map_field() { @@ -183,6 +223,21 @@ impl ToTokens for Api { TokenStream::new() }; + let extract_request_query = if self.request.has_query_fields() { + quote! { + let request_query: RequestQuery = + ruma_api::exports::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?; + } + } else { + TokenStream::new() + }; + + let parse_request_query = if self.request.has_query_fields() { + self.request.request_init_query_fields() + } else { + TokenStream::new() + }; + let add_headers_to_request = if self.request.has_header_fields() { let add_headers = self.request.add_headers_to_request(); quote! { @@ -193,6 +248,20 @@ impl ToTokens for Api { TokenStream::new() }; + let extract_request_headers = if self.request.has_header_fields() { + quote! { + let headers = request.headers(); + } + } else { + TokenStream::new() + }; + + let parse_request_headers = if self.request.has_header_fields() { + self.request.parse_headers_from_request() + } else { + TokenStream::new() + }; + let create_http_request = if let Some(field) = self.request.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -221,6 +290,36 @@ impl ToTokens for Api { } }; + let extract_request_body = if let Some(field) = self.request.newtype_body_field() { + let ty = &field.ty; + quote! { + let request_body: #ty = + ruma_api::exports::serde_json::from_slice(request.body().as_slice())?; + } + } else if self.request.has_body_fields() { + quote! { + let request_body: RequestBody = + ruma_api::exports::serde_json::from_slice(request.body().as_slice())?; + } + } else { + TokenStream::new() + }; + + let parse_request_body = if let Some(field) = self.request.newtype_body_field() { + let field_name = field + .ident + .as_ref() + .expect("expected field to have an identifier"); + + quote! { + #field_name: request_body, + } + } else if self.request.has_body_fields() { + self.request.request_init_body_fields() + } else { + TokenStream::new() + }; + let try_deserialize_response_body = if let Some(field) = self.response.newtype_body_field() { let field_type = &field.ty; @@ -256,6 +355,19 @@ impl ToTokens for Api { TokenStream::new() }; + let serialize_response_headers = self.response.apply_header_fields(); + + let try_serialize_response_body = if self.response.has_body() { + let body = self.response.to_body(); + quote! { + ruma_api::exports::serde_json::to_vec(&#body)? + } + } else { + quote! { + "{}".as_bytes().to_vec() + } + }; + let request_doc = format!( "Data for a request to the `{}` API endpoint.\n\n{}", name, @@ -273,6 +385,25 @@ impl ToTokens for Api { #[doc = #request_doc] #request_types + impl std::convert::TryFrom>> for Request { + type Error = ruma_api::Error; + + #[allow(unused_variables)] + fn try_from(request: ruma_api::exports::http::Request>) -> Result { + #extract_request_path + #extract_request_query + #extract_request_headers + #extract_request_body + + Ok(Request { + #parse_request_path + #parse_request_query + #parse_request_headers + #parse_request_body + }) + } + } + impl std::convert::TryFrom for ruma_api::exports::http::Request> { type Error = ruma_api::Error; @@ -304,6 +435,20 @@ impl ToTokens for Api { #[doc = #response_doc] #response_types + impl std::convert::TryFrom for ruma_api::exports::http::Response> { + type Error = ruma_api::Error; + + #[allow(unused_variables)] + fn try_from(response: Response) -> Result { + let response = ruma_api::exports::http::Response::builder() + .header(ruma_api::exports::http::header::CONTENT_TYPE, "application/json") + #serialize_response_headers + .body(#try_serialize_response_body) + .unwrap(); + Ok(response) + } + } + impl std::convert::TryFrom>> for Response { type Error = ruma_api::Error; diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 3a741289..1e8ff1d9 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -42,6 +42,30 @@ impl Request { } } + /// Produces code to extract fields from the HTTP headers in an `http::Request`. + pub fn parse_headers_from_request(&self) -> TokenStream { + let fields = self.header_fields().map(|request_field| { + let (field, header_name) = match request_field { + RequestField::Header(field, header_name) => (field, header_name), + _ => panic!("expected request field to be header variant"), + }; + + let field_name = &field.ident; + let header_name_string = header_name.to_string(); + + quote! { + #field_name: headers.get(ruma_api::exports::http::header::#header_name) + .and_then(|v| v.to_str().ok()) + .ok_or(ruma_api::exports::serde_json::Error::missing_field(#header_name_string))? + .to_owned() + } + }); + + quote! { + #(#fields,)* + } + } + /// Whether or not this request has any data in the HTTP body. pub fn has_body_fields(&self) -> bool { self.fields.iter().any(|field| field.is_body()) @@ -51,7 +75,6 @@ impl Request { pub fn has_header_fields(&self) -> bool { self.fields.iter().any(|field| field.is_header()) } - /// Whether or not this request has any data in the URL path. pub fn has_path_fields(&self) -> bool { self.fields.iter().any(|field| field.is_path()) @@ -77,6 +100,20 @@ impl Request { self.fields.iter().filter(|field| field.is_path()).count() } + /// Gets the path field with the given name. + pub fn path_field(&self, name: &str) -> Option<&Field> { + self.fields + .iter() + .flat_map(|f| f.field_of_kind(RequestFieldKind::Path)) + .find(|field| { + field + .ident + .as_ref() + .expect("expected field to have an identifier") + == name + }) + } + /// Returns the body field. pub fn newtype_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(RequestField::as_newtype_body_field) @@ -102,6 +139,17 @@ impl Request { self.struct_init_fields(RequestFieldKind::Query, quote!(request)) } + /// Produces code for a struct initializer for body fields on a variable named `request_body`. + pub fn request_init_body_fields(&self) -> TokenStream { + self.struct_init_fields(RequestFieldKind::Body, quote!(request_body)) + } + + /// Produces code for a struct initializer for query string fields on a variable named + /// `request_query`. + pub fn request_init_query_fields(&self) -> TokenStream { + self.struct_init_fields(RequestFieldKind::Query, quote!(request_query)) + } + /// Produces code for a struct initializer for the given field kind to be accessed through the /// given variable name. fn struct_init_fields( @@ -270,7 +318,7 @@ impl ToTokens for Request { quote_spanned! {span=> /// Data in the request body. - #[derive(Debug, ruma_api::exports::serde::Serialize)] + #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] struct RequestBody(#ty); } } else if self.has_body_fields() { @@ -278,7 +326,7 @@ impl ToTokens for Request { quote! { /// Data in the request body. - #[derive(Debug, ruma_api::exports::serde::Serialize)] + #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] struct RequestBody { #(#fields),* } diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index d1881297..c413d6c2 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -33,6 +33,11 @@ impl Response { self.fields.iter().any(|field| field.is_header()) } + /// Whether or not this response has any data in the HTTP body. + pub fn has_body(&self) -> bool { + self.fields.iter().any(|field| !field.is_header()) + } + /// Produces code for a response struct initializer. pub fn init_fields(&self) -> TokenStream { let fields = self.fields.iter().map(|response_field| match response_field { @@ -74,6 +79,63 @@ impl Response { } } + /// Produces code to add necessary HTTP headers to an `http::Response`. + pub fn apply_header_fields(&self) -> TokenStream { + let header_calls = self.fields.iter().filter_map(|response_field| { + if let ResponseField::Header(ref field, ref header_name) = *response_field { + let field_name = field + .ident + .as_ref() + .expect("expected field to have an identifier"); + let span = field.span(); + + Some(quote_spanned! {span=> + .header(ruma_api::exports::http::header::#header_name, response.#field_name) + }) + } else { + None + } + }); + + quote! { + #(#header_calls)* + } + } + + /// Produces code to initialize the struct that will be used to create the response body. + pub fn to_body(&self) -> TokenStream { + if let Some(field) = self.newtype_body_field() { + let field_name = field + .ident + .as_ref() + .expect("expected field to have an identifier"); + let span = field.span(); + quote_spanned!(span=> response.#field_name) + } else { + let fields = self.fields.iter().filter_map(|response_field| { + if let ResponseField::Body(ref field) = *response_field { + let field_name = field + .ident + .as_ref() + .expect("expected field to have an identifier"); + let span = field.span(); + + Some(quote_spanned! {span=> + #field_name: response.#field_name + }) + } else { + None + } + }); + + quote! { + ResponseBody { + #(#fields),* + } + } + } + } + /// Gets the newtype body field, if this response has one. pub fn newtype_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(ResponseField::as_newtype_body_field) @@ -195,7 +257,7 @@ impl ToTokens for Response { quote_spanned! {span=> /// Data in the response body. - #[derive(Debug, ruma_api::exports::serde::Deserialize)] + #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] struct ResponseBody(#ty); } } else if self.has_body_fields() { @@ -203,7 +265,7 @@ impl ToTokens for Response { quote! { /// Data in the response body. - #[derive(Debug, ruma_api::exports::serde::Deserialize)] + #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] struct ResponseBody { #(#fields),* } diff --git a/src/lib.rs b/src/lib.rs index 3f6867e1..b43ac198 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -206,6 +206,7 @@ pub use ruma_api_macros::ruma_api; /// It is not considered part of ruma-api's public API. pub mod exports { pub use http; + pub use percent_encoding; pub use serde; pub use serde_json; pub use serde_urlencoded; @@ -215,9 +216,12 @@ pub mod exports { /// A Matrix API endpoint. /// /// The type implementing this trait contains any data needed to make a request to the endpoint. -pub trait Endpoint: TryInto>, Error = Error> { +pub trait Endpoint: + TryFrom>, Error = Error> + TryInto>, Error = Error> +{ /// Data returned in a successful response from the endpoint. - type Response: TryFrom>, Error = Error>; + type Response: TryFrom>, Error = Error> + + TryInto>, Error = Error>; /// Metadata about the endpoint. const METADATA: Metadata; @@ -346,9 +350,10 @@ mod tests { pub mod create { use std::convert::TryFrom; - use http::{self, method::Method}; + use http::{self, header::CONTENT_TYPE, method::Method}; + use percent_encoding; use ruma_identifiers::{RoomAliasId, RoomId}; - use serde::{Deserialize, Serialize}; + use serde::{de::IntoDeserializer, Deserialize, Serialize}; use serde_json; use crate::{Endpoint, Error, Metadata}; @@ -395,6 +400,25 @@ mod tests { } } + impl TryFrom>> for Request { + type Error = Error; + + fn try_from(request: http::Request>) -> Result { + let request_body: RequestBody = + ::serde_json::from_slice(request.body().as_slice())?; + let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); + Ok(Request { + room_id: request_body.room_id, + room_alias: { + let segment = path_segments.get(5).unwrap().as_bytes(); + let decoded = percent_encoding::percent_decode(segment).decode_utf8_lossy(); + RoomAliasId::deserialize(decoded.into_deserializer()) + .map_err(|e: serde_json::error::Error| e)? + }, + }) + } + } + #[derive(Debug, Serialize, Deserialize)] struct RequestBody { room_id: RoomId, @@ -415,5 +439,18 @@ mod tests { } } } + + impl TryFrom for http::Response> { + type Error = Error; + + fn try_from(_: Response) -> Result>, Self::Error> { + let response = http::Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(b"{}".to_vec()) + .unwrap(); + + Ok(response) + } + } } }