From f558b5569260ece47d14f8ffc367dbfdfc3b6077 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Tue, 19 Nov 2019 21:11:17 +0100 Subject: [PATCH] Add SendRecv trait + derive macro to allow receiving requests, sending responses --- Cargo.toml | 3 + ruma-api-macros/Cargo.toml | 2 +- ruma-api-macros/src/api.rs | 89 ++++---- ruma-api-macros/src/api/request.rs | 98 ++++++--- ruma-api-macros/src/api/response.rs | 102 ++++++--- ruma-api-macros/src/lib.rs | 15 +- ruma-api-macros/src/send_recv.rs | 195 ++++++++++++++++++ .../src/send_recv/wrap_incoming.rs | 58 ++++++ src/lib.rs | 31 ++- tests/ruma_api_macros.rs | 5 + 10 files changed, 487 insertions(+), 111 deletions(-) create mode 100644 ruma-api-macros/src/send_recv.rs create mode 100644 ruma-api-macros/src/send_recv/wrap_incoming.rs diff --git a/Cargo.toml b/Cargo.toml index 7c656ed8..b77b7745 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,9 @@ serde_json = "1.0.41" serde_urlencoded = "0.6.1" url = { version = "2.1.0", optional = true } +[dev-dependencies] +ruma-events = "0.15.1" + [features] default = ["with-ruma-api-macros"] with-ruma-api-macros = ["percent-encoding", "ruma-api-macros", "serde", "url"] diff --git a/ruma-api-macros/Cargo.toml b/ruma-api-macros/Cargo.toml index 88b85b2d..3cfddf13 100644 --- a/ruma-api-macros/Cargo.toml +++ b/ruma-api-macros/Cargo.toml @@ -15,7 +15,7 @@ edition = "2018" [dependencies] proc-macro2 = "1.0.6" quote = "1.0.2" -syn = { version = "1.0.8", features = ["full"] } +syn = { version = "1.0.8", features = ["full", "extra-traits"] } [lib] proc-macro = true diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index 9fbe21aa..326576dd 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -85,10 +85,20 @@ impl ToTokens for Api { let rate_limited = &self.metadata.rate_limited; let requires_authentication = &self.metadata.requires_authentication; - let request = &self.request; - let request_types = quote! { #request }; - let response = &self.response; - let response_types = quote! { #response }; + let request_type = &self.request; + let response_type = &self.response; + + let request_try_from_type = if self.request.uses_wrap_incoming() { + quote!(IncomingRequest) + } else { + quote!(Request) + }; + + let response_try_from_type = if self.response.uses_wrap_incoming() { + quote!(IncomingResponse) + } else { + quote!(Response) + }; let extract_request_path = if self.request.has_path_fields() { quote! { @@ -110,7 +120,7 @@ impl ToTokens for Api { let request_path_init_fields = self.request.request_path_init_fields(); let path_segments = path_str[1..].split('/'); - let path_segment_push = path_segments.map(|segment| { + let path_segment_push = path_segments.clone().map(|segment| { let arg = if segment.starts_with(':') { let path_var = &segment[1..]; let path_var_ident = Ident::new(path_var, Span::call_site()); @@ -136,10 +146,8 @@ impl ToTokens for Api { #(#path_segment_push)* }; - let path_fields = path_segments - .enumerate() - .filter(|(_, s)| s.starts_with(':')) - .map(|(i, segment)| { + let path_fields = path_segments.enumerate().filter(|(_, s)| s.starts_with(':')).map( + |(i, segment)| { let path_var = &segment[1..]; let path_var_ident = Ident::new(path_var, Span::call_site()); let path_field = self @@ -158,7 +166,8 @@ impl ToTokens for Api { .map_err(|e: ruma_api::exports::serde_json::error::Error| e)? } } - }); + }, + ); let parse_tokens = quote! { #(#path_fields,)* @@ -223,7 +232,12 @@ impl ToTokens for Api { TokenStream::new() }; - let extract_request_query = if self.request.has_query_fields() { + let extract_request_query = if self.request.query_map_field().is_some() { + quote! { + let request_query = + ruma_api::exports::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?; + } + } else if self.request.has_query_fields() { quote! { let request_query: RequestQuery = ruma_api::exports::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?; @@ -232,7 +246,13 @@ impl ToTokens for Api { TokenStream::new() }; - let parse_request_query = if self.request.has_query_fields() { + let parse_request_query = if let Some(field) = self.request.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + + quote! { + #field_name: request_query + } + } else if self.request.has_query_fields() { self.request.request_init_query_fields() } else { TokenStream::new() @@ -290,15 +310,14 @@ impl ToTokens for Api { } }; - let extract_request_body = if let Some(field) = self.request.newtype_body_field() { - let ty = &field.ty; + let extract_request_body = if self.request.newtype_body_field().is_some() { quote! { - let request_body: #ty = + let request_body = ruma_api::exports::serde_json::from_slice(request.body().as_slice())?; } } else if self.request.has_body_fields() { quote! { - let request_body: RequestBody = + let request_body: ::Incoming = ruma_api::exports::serde_json::from_slice(request.body().as_slice())?; } } else { @@ -306,10 +325,7 @@ impl ToTokens for Api { }; let parse_request_body = if let Some(field) = self.request.newtype_body_field() { - let field_name = field - .ident - .as_ref() - .expect("expected field to have an identifier"); + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { #field_name: request_body, @@ -320,18 +336,15 @@ impl ToTokens for Api { TokenStream::new() }; - let try_deserialize_response_body = if let Some(field) = self.response.newtype_body_field() - { - let field_type = &field.ty; + let response_body_type_annotation = if self.response.has_body_fields() { + quote!(: ::Incoming) + } else { + TokenStream::new() + }; + let try_deserialize_response_body = if self.response.has_body() { quote! { - ruma_api::exports::serde_json::from_slice::<#field_type>( - http_response.into_body().as_slice(), - )? - } - } else if self.response.has_body_fields() { - quote! { - ruma_api::exports::serde_json::from_slice::( + ruma_api::exports::serde_json::from_slice( http_response.into_body().as_slice(), )? } @@ -383,9 +396,9 @@ impl ToTokens for Api { use std::convert::TryInto as _; #[doc = #request_doc] - #request_types + #request_type - impl std::convert::TryFrom>> for Request { + impl std::convert::TryFrom>> for #request_try_from_type { type Error = ruma_api::Error; #[allow(unused_variables)] @@ -395,7 +408,7 @@ impl ToTokens for Api { #extract_request_headers #extract_request_body - Ok(Request { + Ok(Self { #parse_request_path #parse_request_query #parse_request_headers @@ -433,7 +446,7 @@ impl ToTokens for Api { } #[doc = #response_doc] - #response_types + #response_type impl std::convert::TryFrom for ruma_api::exports::http::Response> { type Error = ruma_api::Error; @@ -449,7 +462,7 @@ impl ToTokens for Api { } } - impl std::convert::TryFrom>> for Response { + impl std::convert::TryFrom>> for #response_try_from_type { type Error = ruma_api::Error; #[allow(unused_variables)] @@ -459,8 +472,10 @@ impl ToTokens for Api { if http_response.status().is_success() { #extract_response_headers - let response_body = #try_deserialize_response_body; - Ok(Response { + let response_body #response_body_type_annotation = + #try_deserialize_response_body; + + Ok(Self { #response_init_fields }) } else { diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 1e8ff1d9..e6261780 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -90,6 +90,11 @@ impl Request { self.fields.iter().filter_map(|field| field.as_body_field()) } + /// Whether any field has a #[wrap_incoming] attribute. + pub fn uses_wrap_incoming(&self) -> bool { + self.fields.iter().any(|f| f.has_wrap_incoming_attr()) + } + /// Produces an iterator over all the header fields. pub fn header_fields(&self) -> impl Iterator { self.fields.iter().filter(|field| field.is_header()) @@ -102,16 +107,9 @@ impl Request { /// Gets the path field with the given name. pub fn path_field(&self, name: &str) -> Option<&Field> { - self.fields - .iter() - .flat_map(|f| f.field_of_kind(RequestFieldKind::Path)) - .find(|field| { - field - .ident - .as_ref() - .expect("expected field to have an identifier") - == name - }) + self.fields.iter().flat_map(|f| f.field_of_kind(RequestFieldKind::Path)).find(|field| { + field.ident.as_ref().expect("expected field to have an identifier") == name + }) } /// Returns the body field. @@ -273,8 +271,8 @@ impl TryFrom for Request { .collect::>>()?; if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) { + // TODO: highlight conflicting fields, return Err(syn::Error::new_spanned( - // TODO: raw, raw.request_kw, "Can't have both a newtype body field and regular body fields", )); @@ -295,7 +293,8 @@ impl TryFrom for Request { impl ToTokens for Request { fn to_tokens(&self, tokens: &mut TokenStream) { let request_struct_header = quote! { - #[derive(Debug, Clone)] + #[derive(Debug, Clone, ruma_api::SendRecv)] + #[incoming_no_deserialize] pub struct Request }; @@ -312,28 +311,51 @@ impl ToTokens for Request { } }; - let request_body_struct = if let Some(field) = self.newtype_body_field() { - let ty = &field.ty; - let span = field.span(); + let request_body_struct = + if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { + let field = body_field.field(); + let ty = &field.ty; + let span = field.span(); + let derive_deserialize = if body_field.has_wrap_incoming_attr() { + TokenStream::new() + } else { + quote!(ruma_api::exports::serde::Deserialize) + }; - quote_spanned! {span=> - /// Data in the request body. - #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] - struct RequestBody(#ty); - } - } else if self.has_body_fields() { - let fields = self.fields.iter().filter_map(RequestField::as_body_field); - - quote! { - /// Data in the request body. - #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] - struct RequestBody { - #(#fields),* + quote_spanned! {span=> + /// Data in the request body. + #[derive( + Debug, + ruma_api::SendRecv, + ruma_api::exports::serde::Serialize, + #derive_deserialize + )] + struct RequestBody(#ty); } - } - } else { - TokenStream::new() - }; + } else if self.has_body_fields() { + let fields = self.fields.iter().filter(|f| f.is_body()); + let derive_deserialize = if fields.clone().any(|f| f.has_wrap_incoming_attr()) { + TokenStream::new() + } else { + quote!(ruma_api::exports::serde::Deserialize) + }; + let fields = fields.map(RequestField::field); + + quote! { + /// Data in the request body. + #[derive( + Debug, + ruma_api::SendRecv, + ruma_api::exports::serde::Serialize, + #derive_deserialize + )] + struct RequestBody { + #(#fields),* + } + } + } else { + TokenStream::new() + }; let request_path_struct = if self.has_path_fields() { let fields = self.fields.iter().filter_map(RequestField::as_path_field); @@ -449,6 +471,11 @@ impl RequestField { self.kind() == RequestFieldKind::Header } + /// Whether or not this request field is a newtype body kind. + fn is_newtype_body(&self) -> bool { + self.kind() == RequestFieldKind::NewtypeBody + } + /// Whether or not this request field is a path kind. fn is_path(&self) -> bool { self.kind() == RequestFieldKind::Path @@ -504,6 +531,13 @@ impl RequestField { None } } + + /// Whether or not the request field has a #[wrap_incoming] attribute. + fn has_wrap_incoming_attr(&self) -> bool { + self.field().attrs.iter().any(|attr| { + attr.path.segments.len() == 1 && attr.path.segments[0].ident == "wrap_incoming" + }) + } } /// The types of fields that a request can have, without their values. diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index c413d6c2..eb0abe8b 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -38,6 +38,11 @@ impl Response { self.fields.iter().any(|field| !field.is_header()) } + /// Whether any field has a #[wrap_incoming] attribute. + pub fn uses_wrap_incoming(&self) -> bool { + self.fields.iter().any(|f| f.has_wrap_incoming_attr()) + } + /// Produces code for a response struct initializer. pub fn init_fields(&self) -> TokenStream { let fields = self.fields.iter().map(|response_field| match response_field { @@ -83,10 +88,8 @@ impl Response { pub fn apply_header_fields(&self) -> TokenStream { let header_calls = self.fields.iter().filter_map(|response_field| { if let ResponseField::Header(ref field, ref header_name) = *response_field { - let field_name = field - .ident - .as_ref() - .expect("expected field to have an identifier"); + let field_name = + field.ident.as_ref().expect("expected field to have an identifier"); let span = field.span(); Some(quote_spanned! {span=> @@ -105,19 +108,14 @@ impl Response { /// Produces code to initialize the struct that will be used to create the response body. pub fn to_body(&self) -> TokenStream { if let Some(field) = self.newtype_body_field() { - let field_name = field - .ident - .as_ref() - .expect("expected field to have an identifier"); + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let span = field.span(); quote_spanned!(span=> response.#field_name) } else { let fields = self.fields.iter().filter_map(|response_field| { if let ResponseField::Body(ref field) = *response_field { - let field_name = field - .ident - .as_ref() - .expect("expected field to have an identifier"); + let field_name = + field.ident.as_ref().expect("expected field to have an identifier"); let span = field.span(); Some(quote_spanned! {span=> @@ -220,8 +218,8 @@ impl TryFrom for Response { .collect::>>()?; if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) { + // TODO: highlight conflicting fields, return Err(syn::Error::new_spanned( - // TODO: raw, raw.response_kw, "Can't have both a newtype body field and regular body fields", )); @@ -234,7 +232,8 @@ impl TryFrom for Response { impl ToTokens for Response { fn to_tokens(&self, tokens: &mut TokenStream) { let response_struct_header = quote! { - #[derive(Debug, Clone)] + #[derive(Debug, Clone, ruma_api::SendRecv)] + #[incoming_no_deserialize] pub struct Response }; @@ -251,28 +250,51 @@ impl ToTokens for Response { } }; - let response_body_struct = if let Some(field) = self.newtype_body_field() { - let ty = &field.ty; - let span = field.span(); + let response_body_struct = + if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { + let field = body_field.field(); + let ty = &field.ty; + let span = field.span(); + let derive_deserialize = if body_field.has_wrap_incoming_attr() { + TokenStream::new() + } else { + quote!(ruma_api::exports::serde::Deserialize) + }; - quote_spanned! {span=> - /// Data in the response body. - #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] - struct ResponseBody(#ty); - } - } else if self.has_body_fields() { - let fields = self.fields.iter().filter_map(ResponseField::as_body_field); - - quote! { - /// Data in the response body. - #[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)] - struct ResponseBody { - #(#fields),* + quote_spanned! {span=> + /// Data in the response body. + #[derive( + Debug, + ruma_api::SendRecv, + ruma_api::exports::serde::Serialize, + #derive_deserialize + )] + struct ResponseBody(#ty); } - } - } else { - TokenStream::new() - }; + } else if self.has_body_fields() { + let fields = self.fields.iter().filter(|f| f.is_body()); + let derive_deserialize = if fields.clone().any(|f| f.has_wrap_incoming_attr()) { + TokenStream::new() + } else { + quote!(ruma_api::exports::serde::Deserialize) + }; + let fields = fields.map(ResponseField::field); + + quote! { + /// Data in the response body. + #[derive( + Debug, + ruma_api::SendRecv, + ruma_api::exports::serde::Serialize, + #derive_deserialize + )] + struct ResponseBody { + #(#fields),* + } + } + } else { + TokenStream::new() + }; let response = quote! { #response_struct_header @@ -317,6 +339,11 @@ impl ResponseField { } } + /// Whether or not this response field is a newtype body kind. + fn is_newtype_body(&self) -> bool { + self.as_newtype_body_field().is_some() + } + /// Return the contained field if this response field is a body kind. fn as_body_field(&self) -> Option<&Field> { match self { @@ -332,6 +359,13 @@ impl ResponseField { _ => None, } } + + /// Whether or not the reponse field has a #[wrap_incoming] attribute. + fn has_wrap_incoming_attr(&self) -> bool { + self.field().attrs.iter().any(|attr| { + attr.path.segments.len() == 1 && attr.path.segments[0].ident == "wrap_incoming" + }) + } } /// The types of fields that a response can have, without their values. diff --git a/ruma-api-macros/src/lib.rs b/ruma-api-macros/src/lib.rs index 94b90380..0e5d4fb5 100644 --- a/ruma-api-macros/src/lib.rs +++ b/ruma-api-macros/src/lib.rs @@ -17,16 +17,27 @@ use std::convert::TryFrom as _; use proc_macro::TokenStream; use quote::ToTokens; +use syn::{parse_macro_input, DeriveInput}; -use crate::api::{Api, RawApi}; +use self::{ + api::{Api, RawApi}, + send_recv::expand_send_recv, +}; mod api; +mod send_recv; #[proc_macro] pub fn ruma_api(input: TokenStream) -> TokenStream { - let raw_api = syn::parse_macro_input!(input as RawApi); + let raw_api = parse_macro_input!(input as RawApi); match Api::try_from(raw_api) { Ok(api) => api.into_token_stream().into(), Err(err) => err.to_compile_error().into(), } } + +#[proc_macro_derive(SendRecv, attributes(wrap_incoming, incoming_no_deserialize))] +pub fn derive_send_recv(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_send_recv(input).unwrap_or_else(|err| err.to_compile_error()).into() +} diff --git a/ruma-api-macros/src/send_recv.rs b/ruma-api-macros/src/send_recv.rs new file mode 100644 index 00000000..448e14e1 --- /dev/null +++ b/ruma-api-macros/src/send_recv.rs @@ -0,0 +1,195 @@ +use std::mem; + +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{ + parse_quote, punctuated::Pair, spanned::Spanned, Attribute, Data, DeriveInput, Fields, + GenericArgument, Path, PathArguments, Type, TypePath, +}; + +mod wrap_incoming; + +use wrap_incoming::Meta; + +pub fn expand_send_recv(input: DeriveInput) -> syn::Result { + let derive_deserialize = if no_deserialize_in_attrs(&input.attrs) { + TokenStream::new() + } else { + quote!(#[derive(ruma_api::exports::serde::Deserialize)]) + }; + + let mut fields: Vec<_> = match input.data { + Data::Enum(_) | Data::Union(_) => { + panic!("#[derive(SendRecv)] is only supported for structs") + } + Data::Struct(s) => match s.fields { + Fields::Named(fs) => fs.named.into_pairs().map(Pair::into_value).collect(), + Fields::Unnamed(fs) => fs.unnamed.into_pairs().map(Pair::into_value).collect(), + Fields::Unit => return Ok(impl_send_recv_incoming_self(input.ident)), + }, + }; + + let mut any_attribute = false; + + for field in &mut fields { + let mut field_meta = None; + + let mut remaining_attrs = Vec::new(); + for attr in mem::replace(&mut field.attrs, Vec::new()) { + if let Some(meta) = Meta::from_attribute(&attr)? { + if field_meta.is_some() { + return Err(syn::Error::new_spanned( + attr, + "duplicate #[wrap_incoming] attribute", + )); + } + field_meta = Some(meta); + any_attribute = true; + } else { + remaining_attrs.push(attr); + } + } + field.attrs = remaining_attrs; + + if let Some(attr) = field_meta { + if let Some(type_to_wrap) = attr.type_to_wrap { + wrap_generic_arg(&type_to_wrap, &mut field.ty, attr.wrapper_type.as_ref())?; + } else { + wrap_ty(&mut field.ty, attr.wrapper_type)?; + } + } + } + + if !any_attribute { + return Ok(impl_send_recv_incoming_self(input.ident)); + } + + let vis = input.vis; + let doc = format!("\"Incoming\" variant of [{ty}](struct.{ty}.html).", ty = input.ident); + let original_ident = input.ident; + let incoming_ident = Ident::new(&format!("Incoming{}", original_ident), Span::call_site()); + + Ok(quote! { + #[doc = #doc] + #derive_deserialize + #vis struct #incoming_ident { + #(#fields,)* + } + + impl ruma_api::SendRecv for #original_ident { + type Incoming = #incoming_ident; + } + }) +} + +fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool { + for attr in attrs { + match &attr.path { + Path { leading_colon: None, segments } + if segments.len() == 1 && segments[0].ident == "incoming_no_deserialize" => + { + return true + } + _ => {} + } + } + + false +} + +fn impl_send_recv_incoming_self(ident: Ident) -> TokenStream { + quote! { + impl ruma_api::SendRecv for #ident { + type Incoming = Self; + } + } +} + +fn wrap_ty(ty: &mut Type, path: Option) -> syn::Result<()> { + if let Some(wrap_ty) = path { + *ty = parse_quote!(#wrap_ty<#ty>); + } else { + match ty { + Type::Path(TypePath { path, .. }) => { + let ty_ident = &mut path.segments.last_mut().unwrap().ident; + let ident = Ident::new(&format!("Incoming{}", ty_ident), Span::call_site()); + *ty_ident = parse_quote!(#ident); + } + _ => return Err(syn::Error::new_spanned(ty, "Can't wrap this type")), + } + } + + Ok(()) +} + +fn wrap_generic_arg(type_to_wrap: &Type, of: &mut Type, with: Option<&Path>) -> syn::Result<()> { + let mut span = None; + wrap_generic_arg_impl(type_to_wrap, of, with, &mut span)?; + + if span.is_some() { + Ok(()) + } else { + Err(syn::Error::new_spanned( + of, + format!( + "Couldn't find generic argument `{}` in this type", + type_to_wrap.to_token_stream() + ), + )) + } +} + +fn wrap_generic_arg_impl( + type_to_wrap: &Type, + of: &mut Type, + with: Option<&Path>, + span: &mut Option, +) -> syn::Result<()> { + // TODO: Support things like array types? + let ty_path = match of { + Type::Path(TypePath { path, .. }) => path, + _ => return Ok(()), + }; + + let args = match &mut ty_path.segments.last_mut().unwrap().arguments { + PathArguments::AngleBracketed(ab) => &mut ab.args, + _ => return Ok(()), + }; + + for arg in args.iter_mut() { + let ty = match arg { + GenericArgument::Type(ty) => ty, + _ => continue, + }; + + if ty == type_to_wrap { + if let Some(s) = span { + let mut error = syn::Error::new( + *s, + format!( + "`{}` found multiple times, this is not currently supported", + type_to_wrap.to_token_stream() + ), + ); + error.combine(syn::Error::new_spanned(ty, "second occurrence")); + return Err(error); + } + + *span = Some(ty.span()); + + if let Some(wrapper_type) = with { + *ty = parse_quote!(#wrapper_type<#ty>); + } else if let Type::Path(TypePath { path, .. }) = ty { + let ty_ident = &mut path.segments.last_mut().unwrap().ident; + let ident = Ident::new(&format!("Incoming{}", ty_ident), Span::call_site()); + *ty_ident = parse_quote!(#ident); + } else { + return Err(syn::Error::new_spanned(ty, "Can't wrap this type")); + } + } else { + wrap_generic_arg_impl(type_to_wrap, ty, with, span)?; + } + } + + Ok(()) +} diff --git a/ruma-api-macros/src/send_recv/wrap_incoming.rs b/ruma-api-macros/src/send_recv/wrap_incoming.rs new file mode 100644 index 00000000..f13c6f3c --- /dev/null +++ b/ruma-api-macros/src/send_recv/wrap_incoming.rs @@ -0,0 +1,58 @@ +use syn::{ + parse::{Parse, ParseStream}, + Ident, Path, Type, +}; + +mod kw { + use syn::custom_keyword; + custom_keyword!(with); +} + +/// The inside of a `#[wrap_incoming]` attribute +#[derive(Default)] +pub struct Meta { + pub type_to_wrap: Option, + pub wrapper_type: Option, +} + +impl Meta { + /// Check if the given attribute is a wrap_incoming attribute. If it is, parse it. + pub fn from_attribute(attr: &syn::Attribute) -> syn::Result> { + if attr.path.is_ident("wrap_incoming") { + if attr.tokens.is_empty() { + Ok(Some(Self::default())) + } else { + attr.parse_args().map(Some) + } + } else { + Ok(None) + } + } +} + +impl Parse for Meta { + fn parse(input: ParseStream) -> syn::Result { + let mut type_to_wrap = None; + let mut wrapper_type = try_parse_wrapper_type(input)?; + + if wrapper_type.is_none() && input.peek(Ident) { + type_to_wrap = Some(input.parse()?); + wrapper_type = try_parse_wrapper_type(input)?; + } + + if input.is_empty() { + Ok(Self { type_to_wrap, wrapper_type }) + } else { + Err(input.error("expected end of attribute args")) + } + } +} + +fn try_parse_wrapper_type(input: ParseStream) -> syn::Result> { + if input.peek(kw::with) { + input.parse::()?; + Ok(Some(input.parse()?)) + } else { + Ok(None) + } +} diff --git a/src/lib.rs b/src/lib.rs index b43ac198..47f2bfc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -200,6 +200,9 @@ use serde_urlencoded; #[cfg(feature = "with-ruma-api-macros")] pub use ruma_api_macros::ruma_api; +#[cfg(feature = "with-ruma-api-macros")] +pub use ruma_api_macros::SendRecv; + #[cfg(feature = "with-ruma-api-macros")] #[doc(hidden)] /// This module is used to support the generated code from ruma-api-macros. @@ -213,15 +216,25 @@ pub mod exports { pub use url; } +/// A type that can be sent as well as received. Types that implement this trait have a +/// corresponding 'Incoming' type, which is either just `Self`, or another type that has the same +/// fields with some types exchanged by ones that allow fallible deserialization, e.g. `EventResult` +/// from ruma_events. +pub trait SendRecv { + /// The 'Incoming' variant of `Self`. + type Incoming; +} + /// A Matrix API endpoint. /// /// The type implementing this trait contains any data needed to make a request to the endpoint. -pub trait Endpoint: - TryFrom>, Error = Error> + TryInto>, Error = Error> +pub trait Endpoint: SendRecv + TryInto>, Error = Error> +where + ::Incoming: TryFrom>, Error = Error>, + ::Incoming: TryFrom>, Error = Error>, { /// Data returned in a successful response from the endpoint. - type Response: TryFrom>, Error = Error> - + TryInto>, Error = Error>; + type Response: SendRecv + TryInto>, Error = Error>; /// Metadata about the endpoint. const METADATA: Metadata; @@ -356,7 +369,7 @@ mod tests { use serde::{de::IntoDeserializer, Deserialize, Serialize}; use serde_json; - use crate::{Endpoint, Error, Metadata}; + use crate::{Endpoint, Error, Metadata, SendRecv}; /// A request to create a new room alias. #[derive(Debug)] @@ -365,6 +378,10 @@ mod tests { pub room_alias: RoomAliasId, // path } + impl SendRecv for Request { + type Incoming = Self; + } + impl Endpoint for Request { type Response = Response; @@ -428,6 +445,10 @@ mod tests { #[derive(Clone, Copy, Debug)] pub struct Response; + impl SendRecv for Response { + type Incoming = Self; + } + impl TryFrom>> for Response { type Error = Error; diff --git a/tests/ruma_api_macros.rs b/tests/ruma_api_macros.rs index a99d24dd..d43c9269 100644 --- a/tests/ruma_api_macros.rs +++ b/tests/ruma_api_macros.rs @@ -1,5 +1,6 @@ pub mod some_endpoint { use ruma_api::ruma_api; + use ruma_events::{tag::TagEventContent, EventResult}; ruma_api! { metadata { @@ -40,6 +41,10 @@ pub mod some_endpoint { // You can use serde attributes on any kind of field #[serde(skip_serializing_if = "Option::is_none")] pub optional_flag: Option, + + /// The user's tags for the room. + #[wrap_incoming(with EventResult)] + pub tags: TagEventContent, } } }