Add SendRecv trait + derive macro to allow receiving requests, sending responses
This commit is contained in:
parent
e383ae98ea
commit
f558b55692
@ -22,6 +22,9 @@ serde_json = "1.0.41"
|
||||
serde_urlencoded = "0.6.1"
|
||||
url = { version = "2.1.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
ruma-events = "0.15.1"
|
||||
|
||||
[features]
|
||||
default = ["with-ruma-api-macros"]
|
||||
with-ruma-api-macros = ["percent-encoding", "ruma-api-macros", "serde", "url"]
|
||||
|
@ -15,7 +15,7 @@ edition = "2018"
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0.6"
|
||||
quote = "1.0.2"
|
||||
syn = { version = "1.0.8", features = ["full"] }
|
||||
syn = { version = "1.0.8", features = ["full", "extra-traits"] }
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
@ -85,10 +85,20 @@ impl ToTokens for Api {
|
||||
let rate_limited = &self.metadata.rate_limited;
|
||||
let requires_authentication = &self.metadata.requires_authentication;
|
||||
|
||||
let request = &self.request;
|
||||
let request_types = quote! { #request };
|
||||
let response = &self.response;
|
||||
let response_types = quote! { #response };
|
||||
let request_type = &self.request;
|
||||
let response_type = &self.response;
|
||||
|
||||
let request_try_from_type = if self.request.uses_wrap_incoming() {
|
||||
quote!(IncomingRequest)
|
||||
} else {
|
||||
quote!(Request)
|
||||
};
|
||||
|
||||
let response_try_from_type = if self.response.uses_wrap_incoming() {
|
||||
quote!(IncomingResponse)
|
||||
} else {
|
||||
quote!(Response)
|
||||
};
|
||||
|
||||
let extract_request_path = if self.request.has_path_fields() {
|
||||
quote! {
|
||||
@ -110,7 +120,7 @@ impl ToTokens for Api {
|
||||
let request_path_init_fields = self.request.request_path_init_fields();
|
||||
|
||||
let path_segments = path_str[1..].split('/');
|
||||
let path_segment_push = path_segments.map(|segment| {
|
||||
let path_segment_push = path_segments.clone().map(|segment| {
|
||||
let arg = if segment.starts_with(':') {
|
||||
let path_var = &segment[1..];
|
||||
let path_var_ident = Ident::new(path_var, Span::call_site());
|
||||
@ -136,10 +146,8 @@ impl ToTokens for Api {
|
||||
#(#path_segment_push)*
|
||||
};
|
||||
|
||||
let path_fields = path_segments
|
||||
.enumerate()
|
||||
.filter(|(_, s)| s.starts_with(':'))
|
||||
.map(|(i, segment)| {
|
||||
let path_fields = path_segments.enumerate().filter(|(_, s)| s.starts_with(':')).map(
|
||||
|(i, segment)| {
|
||||
let path_var = &segment[1..];
|
||||
let path_var_ident = Ident::new(path_var, Span::call_site());
|
||||
let path_field = self
|
||||
@ -158,7 +166,8 @@ impl ToTokens for Api {
|
||||
.map_err(|e: ruma_api::exports::serde_json::error::Error| e)?
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
let parse_tokens = quote! {
|
||||
#(#path_fields,)*
|
||||
@ -223,7 +232,12 @@ impl ToTokens for Api {
|
||||
TokenStream::new()
|
||||
};
|
||||
|
||||
let extract_request_query = if self.request.has_query_fields() {
|
||||
let extract_request_query = if self.request.query_map_field().is_some() {
|
||||
quote! {
|
||||
let request_query =
|
||||
ruma_api::exports::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?;
|
||||
}
|
||||
} else if self.request.has_query_fields() {
|
||||
quote! {
|
||||
let request_query: RequestQuery =
|
||||
ruma_api::exports::serde_urlencoded::from_str(&request.uri().query().unwrap_or(""))?;
|
||||
@ -232,7 +246,13 @@ impl ToTokens for Api {
|
||||
TokenStream::new()
|
||||
};
|
||||
|
||||
let parse_request_query = if self.request.has_query_fields() {
|
||||
let parse_request_query = if let Some(field) = self.request.query_map_field() {
|
||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||
|
||||
quote! {
|
||||
#field_name: request_query
|
||||
}
|
||||
} else if self.request.has_query_fields() {
|
||||
self.request.request_init_query_fields()
|
||||
} else {
|
||||
TokenStream::new()
|
||||
@ -290,15 +310,14 @@ impl ToTokens for Api {
|
||||
}
|
||||
};
|
||||
|
||||
let extract_request_body = if let Some(field) = self.request.newtype_body_field() {
|
||||
let ty = &field.ty;
|
||||
let extract_request_body = if self.request.newtype_body_field().is_some() {
|
||||
quote! {
|
||||
let request_body: #ty =
|
||||
let request_body =
|
||||
ruma_api::exports::serde_json::from_slice(request.body().as_slice())?;
|
||||
}
|
||||
} else if self.request.has_body_fields() {
|
||||
quote! {
|
||||
let request_body: RequestBody =
|
||||
let request_body: <RequestBody as ruma_api::SendRecv>::Incoming =
|
||||
ruma_api::exports::serde_json::from_slice(request.body().as_slice())?;
|
||||
}
|
||||
} else {
|
||||
@ -306,10 +325,7 @@ impl ToTokens for Api {
|
||||
};
|
||||
|
||||
let parse_request_body = if let Some(field) = self.request.newtype_body_field() {
|
||||
let field_name = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("expected field to have an identifier");
|
||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||
|
||||
quote! {
|
||||
#field_name: request_body,
|
||||
@ -320,18 +336,15 @@ impl ToTokens for Api {
|
||||
TokenStream::new()
|
||||
};
|
||||
|
||||
let try_deserialize_response_body = if let Some(field) = self.response.newtype_body_field()
|
||||
{
|
||||
let field_type = &field.ty;
|
||||
let response_body_type_annotation = if self.response.has_body_fields() {
|
||||
quote!(: <ResponseBody as ruma_api::SendRecv>::Incoming)
|
||||
} else {
|
||||
TokenStream::new()
|
||||
};
|
||||
|
||||
let try_deserialize_response_body = if self.response.has_body() {
|
||||
quote! {
|
||||
ruma_api::exports::serde_json::from_slice::<#field_type>(
|
||||
http_response.into_body().as_slice(),
|
||||
)?
|
||||
}
|
||||
} else if self.response.has_body_fields() {
|
||||
quote! {
|
||||
ruma_api::exports::serde_json::from_slice::<ResponseBody>(
|
||||
ruma_api::exports::serde_json::from_slice(
|
||||
http_response.into_body().as_slice(),
|
||||
)?
|
||||
}
|
||||
@ -383,9 +396,9 @@ impl ToTokens for Api {
|
||||
use std::convert::TryInto as _;
|
||||
|
||||
#[doc = #request_doc]
|
||||
#request_types
|
||||
#request_type
|
||||
|
||||
impl std::convert::TryFrom<ruma_api::exports::http::Request<Vec<u8>>> for Request {
|
||||
impl std::convert::TryFrom<ruma_api::exports::http::Request<Vec<u8>>> for #request_try_from_type {
|
||||
type Error = ruma_api::Error;
|
||||
|
||||
#[allow(unused_variables)]
|
||||
@ -395,7 +408,7 @@ impl ToTokens for Api {
|
||||
#extract_request_headers
|
||||
#extract_request_body
|
||||
|
||||
Ok(Request {
|
||||
Ok(Self {
|
||||
#parse_request_path
|
||||
#parse_request_query
|
||||
#parse_request_headers
|
||||
@ -433,7 +446,7 @@ impl ToTokens for Api {
|
||||
}
|
||||
|
||||
#[doc = #response_doc]
|
||||
#response_types
|
||||
#response_type
|
||||
|
||||
impl std::convert::TryFrom<Response> for ruma_api::exports::http::Response<Vec<u8>> {
|
||||
type Error = ruma_api::Error;
|
||||
@ -449,7 +462,7 @@ impl ToTokens for Api {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<ruma_api::exports::http::Response<Vec<u8>>> for Response {
|
||||
impl std::convert::TryFrom<ruma_api::exports::http::Response<Vec<u8>>> for #response_try_from_type {
|
||||
type Error = ruma_api::Error;
|
||||
|
||||
#[allow(unused_variables)]
|
||||
@ -459,8 +472,10 @@ impl ToTokens for Api {
|
||||
if http_response.status().is_success() {
|
||||
#extract_response_headers
|
||||
|
||||
let response_body = #try_deserialize_response_body;
|
||||
Ok(Response {
|
||||
let response_body #response_body_type_annotation =
|
||||
#try_deserialize_response_body;
|
||||
|
||||
Ok(Self {
|
||||
#response_init_fields
|
||||
})
|
||||
} else {
|
||||
|
@ -90,6 +90,11 @@ impl Request {
|
||||
self.fields.iter().filter_map(|field| field.as_body_field())
|
||||
}
|
||||
|
||||
/// Whether any field has a #[wrap_incoming] attribute.
|
||||
pub fn uses_wrap_incoming(&self) -> bool {
|
||||
self.fields.iter().any(|f| f.has_wrap_incoming_attr())
|
||||
}
|
||||
|
||||
/// Produces an iterator over all the header fields.
|
||||
pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> {
|
||||
self.fields.iter().filter(|field| field.is_header())
|
||||
@ -102,15 +107,8 @@ impl Request {
|
||||
|
||||
/// Gets the path field with the given name.
|
||||
pub fn path_field(&self, name: &str) -> Option<&Field> {
|
||||
self.fields
|
||||
.iter()
|
||||
.flat_map(|f| f.field_of_kind(RequestFieldKind::Path))
|
||||
.find(|field| {
|
||||
field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("expected field to have an identifier")
|
||||
== name
|
||||
self.fields.iter().flat_map(|f| f.field_of_kind(RequestFieldKind::Path)).find(|field| {
|
||||
field.ident.as_ref().expect("expected field to have an identifier") == name
|
||||
})
|
||||
}
|
||||
|
||||
@ -273,8 +271,8 @@ impl TryFrom<RawRequest> for Request {
|
||||
.collect::<syn::Result<Vec<_>>>()?;
|
||||
|
||||
if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) {
|
||||
// TODO: highlight conflicting fields,
|
||||
return Err(syn::Error::new_spanned(
|
||||
// TODO: raw,
|
||||
raw.request_kw,
|
||||
"Can't have both a newtype body field and regular body fields",
|
||||
));
|
||||
@ -295,7 +293,8 @@ impl TryFrom<RawRequest> for Request {
|
||||
impl ToTokens for Request {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let request_struct_header = quote! {
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, ruma_api::SendRecv)]
|
||||
#[incoming_no_deserialize]
|
||||
pub struct Request
|
||||
};
|
||||
|
||||
@ -312,21 +311,44 @@ impl ToTokens for Request {
|
||||
}
|
||||
};
|
||||
|
||||
let request_body_struct = if let Some(field) = self.newtype_body_field() {
|
||||
let request_body_struct =
|
||||
if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) {
|
||||
let field = body_field.field();
|
||||
let ty = &field.ty;
|
||||
let span = field.span();
|
||||
let derive_deserialize = if body_field.has_wrap_incoming_attr() {
|
||||
TokenStream::new()
|
||||
} else {
|
||||
quote!(ruma_api::exports::serde::Deserialize)
|
||||
};
|
||||
|
||||
quote_spanned! {span=>
|
||||
/// Data in the request body.
|
||||
#[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)]
|
||||
#[derive(
|
||||
Debug,
|
||||
ruma_api::SendRecv,
|
||||
ruma_api::exports::serde::Serialize,
|
||||
#derive_deserialize
|
||||
)]
|
||||
struct RequestBody(#ty);
|
||||
}
|
||||
} else if self.has_body_fields() {
|
||||
let fields = self.fields.iter().filter_map(RequestField::as_body_field);
|
||||
let fields = self.fields.iter().filter(|f| f.is_body());
|
||||
let derive_deserialize = if fields.clone().any(|f| f.has_wrap_incoming_attr()) {
|
||||
TokenStream::new()
|
||||
} else {
|
||||
quote!(ruma_api::exports::serde::Deserialize)
|
||||
};
|
||||
let fields = fields.map(RequestField::field);
|
||||
|
||||
quote! {
|
||||
/// Data in the request body.
|
||||
#[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)]
|
||||
#[derive(
|
||||
Debug,
|
||||
ruma_api::SendRecv,
|
||||
ruma_api::exports::serde::Serialize,
|
||||
#derive_deserialize
|
||||
)]
|
||||
struct RequestBody {
|
||||
#(#fields),*
|
||||
}
|
||||
@ -449,6 +471,11 @@ impl RequestField {
|
||||
self.kind() == RequestFieldKind::Header
|
||||
}
|
||||
|
||||
/// Whether or not this request field is a newtype body kind.
|
||||
fn is_newtype_body(&self) -> bool {
|
||||
self.kind() == RequestFieldKind::NewtypeBody
|
||||
}
|
||||
|
||||
/// Whether or not this request field is a path kind.
|
||||
fn is_path(&self) -> bool {
|
||||
self.kind() == RequestFieldKind::Path
|
||||
@ -504,6 +531,13 @@ impl RequestField {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether or not the request field has a #[wrap_incoming] attribute.
|
||||
fn has_wrap_incoming_attr(&self) -> bool {
|
||||
self.field().attrs.iter().any(|attr| {
|
||||
attr.path.segments.len() == 1 && attr.path.segments[0].ident == "wrap_incoming"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// The types of fields that a request can have, without their values.
|
||||
|
@ -38,6 +38,11 @@ impl Response {
|
||||
self.fields.iter().any(|field| !field.is_header())
|
||||
}
|
||||
|
||||
/// Whether any field has a #[wrap_incoming] attribute.
|
||||
pub fn uses_wrap_incoming(&self) -> bool {
|
||||
self.fields.iter().any(|f| f.has_wrap_incoming_attr())
|
||||
}
|
||||
|
||||
/// Produces code for a response struct initializer.
|
||||
pub fn init_fields(&self) -> TokenStream {
|
||||
let fields = self.fields.iter().map(|response_field| match response_field {
|
||||
@ -83,10 +88,8 @@ impl Response {
|
||||
pub fn apply_header_fields(&self) -> TokenStream {
|
||||
let header_calls = self.fields.iter().filter_map(|response_field| {
|
||||
if let ResponseField::Header(ref field, ref header_name) = *response_field {
|
||||
let field_name = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("expected field to have an identifier");
|
||||
let field_name =
|
||||
field.ident.as_ref().expect("expected field to have an identifier");
|
||||
let span = field.span();
|
||||
|
||||
Some(quote_spanned! {span=>
|
||||
@ -105,19 +108,14 @@ impl Response {
|
||||
/// Produces code to initialize the struct that will be used to create the response body.
|
||||
pub fn to_body(&self) -> TokenStream {
|
||||
if let Some(field) = self.newtype_body_field() {
|
||||
let field_name = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("expected field to have an identifier");
|
||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||
let span = field.span();
|
||||
quote_spanned!(span=> response.#field_name)
|
||||
} else {
|
||||
let fields = self.fields.iter().filter_map(|response_field| {
|
||||
if let ResponseField::Body(ref field) = *response_field {
|
||||
let field_name = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("expected field to have an identifier");
|
||||
let field_name =
|
||||
field.ident.as_ref().expect("expected field to have an identifier");
|
||||
let span = field.span();
|
||||
|
||||
Some(quote_spanned! {span=>
|
||||
@ -220,8 +218,8 @@ impl TryFrom<RawResponse> for Response {
|
||||
.collect::<syn::Result<Vec<_>>>()?;
|
||||
|
||||
if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) {
|
||||
// TODO: highlight conflicting fields,
|
||||
return Err(syn::Error::new_spanned(
|
||||
// TODO: raw,
|
||||
raw.response_kw,
|
||||
"Can't have both a newtype body field and regular body fields",
|
||||
));
|
||||
@ -234,7 +232,8 @@ impl TryFrom<RawResponse> for Response {
|
||||
impl ToTokens for Response {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let response_struct_header = quote! {
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, ruma_api::SendRecv)]
|
||||
#[incoming_no_deserialize]
|
||||
pub struct Response
|
||||
};
|
||||
|
||||
@ -251,21 +250,44 @@ impl ToTokens for Response {
|
||||
}
|
||||
};
|
||||
|
||||
let response_body_struct = if let Some(field) = self.newtype_body_field() {
|
||||
let response_body_struct =
|
||||
if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) {
|
||||
let field = body_field.field();
|
||||
let ty = &field.ty;
|
||||
let span = field.span();
|
||||
let derive_deserialize = if body_field.has_wrap_incoming_attr() {
|
||||
TokenStream::new()
|
||||
} else {
|
||||
quote!(ruma_api::exports::serde::Deserialize)
|
||||
};
|
||||
|
||||
quote_spanned! {span=>
|
||||
/// Data in the response body.
|
||||
#[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)]
|
||||
#[derive(
|
||||
Debug,
|
||||
ruma_api::SendRecv,
|
||||
ruma_api::exports::serde::Serialize,
|
||||
#derive_deserialize
|
||||
)]
|
||||
struct ResponseBody(#ty);
|
||||
}
|
||||
} else if self.has_body_fields() {
|
||||
let fields = self.fields.iter().filter_map(ResponseField::as_body_field);
|
||||
let fields = self.fields.iter().filter(|f| f.is_body());
|
||||
let derive_deserialize = if fields.clone().any(|f| f.has_wrap_incoming_attr()) {
|
||||
TokenStream::new()
|
||||
} else {
|
||||
quote!(ruma_api::exports::serde::Deserialize)
|
||||
};
|
||||
let fields = fields.map(ResponseField::field);
|
||||
|
||||
quote! {
|
||||
/// Data in the response body.
|
||||
#[derive(Debug, ruma_api::exports::serde::Deserialize, ruma_api::exports::serde::Serialize)]
|
||||
#[derive(
|
||||
Debug,
|
||||
ruma_api::SendRecv,
|
||||
ruma_api::exports::serde::Serialize,
|
||||
#derive_deserialize
|
||||
)]
|
||||
struct ResponseBody {
|
||||
#(#fields),*
|
||||
}
|
||||
@ -317,6 +339,11 @@ impl ResponseField {
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether or not this response field is a newtype body kind.
|
||||
fn is_newtype_body(&self) -> bool {
|
||||
self.as_newtype_body_field().is_some()
|
||||
}
|
||||
|
||||
/// Return the contained field if this response field is a body kind.
|
||||
fn as_body_field(&self) -> Option<&Field> {
|
||||
match self {
|
||||
@ -332,6 +359,13 @@ impl ResponseField {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether or not the reponse field has a #[wrap_incoming] attribute.
|
||||
fn has_wrap_incoming_attr(&self) -> bool {
|
||||
self.field().attrs.iter().any(|attr| {
|
||||
attr.path.segments.len() == 1 && attr.path.segments[0].ident == "wrap_incoming"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// The types of fields that a response can have, without their values.
|
||||
|
@ -17,16 +17,27 @@ use std::convert::TryFrom as _;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
use quote::ToTokens;
|
||||
use syn::{parse_macro_input, DeriveInput};
|
||||
|
||||
use crate::api::{Api, RawApi};
|
||||
use self::{
|
||||
api::{Api, RawApi},
|
||||
send_recv::expand_send_recv,
|
||||
};
|
||||
|
||||
mod api;
|
||||
mod send_recv;
|
||||
|
||||
#[proc_macro]
|
||||
pub fn ruma_api(input: TokenStream) -> TokenStream {
|
||||
let raw_api = syn::parse_macro_input!(input as RawApi);
|
||||
let raw_api = parse_macro_input!(input as RawApi);
|
||||
match Api::try_from(raw_api) {
|
||||
Ok(api) => api.into_token_stream().into(),
|
||||
Err(err) => err.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_derive(SendRecv, attributes(wrap_incoming, incoming_no_deserialize))]
|
||||
pub fn derive_send_recv(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
expand_send_recv(input).unwrap_or_else(|err| err.to_compile_error()).into()
|
||||
}
|
||||
|
195
ruma-api-macros/src/send_recv.rs
Normal file
195
ruma-api-macros/src/send_recv.rs
Normal file
@ -0,0 +1,195 @@
|
||||
use std::mem;
|
||||
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::{quote, ToTokens};
|
||||
use syn::{
|
||||
parse_quote, punctuated::Pair, spanned::Spanned, Attribute, Data, DeriveInput, Fields,
|
||||
GenericArgument, Path, PathArguments, Type, TypePath,
|
||||
};
|
||||
|
||||
mod wrap_incoming;
|
||||
|
||||
use wrap_incoming::Meta;
|
||||
|
||||
pub fn expand_send_recv(input: DeriveInput) -> syn::Result<TokenStream> {
|
||||
let derive_deserialize = if no_deserialize_in_attrs(&input.attrs) {
|
||||
TokenStream::new()
|
||||
} else {
|
||||
quote!(#[derive(ruma_api::exports::serde::Deserialize)])
|
||||
};
|
||||
|
||||
let mut fields: Vec<_> = match input.data {
|
||||
Data::Enum(_) | Data::Union(_) => {
|
||||
panic!("#[derive(SendRecv)] is only supported for structs")
|
||||
}
|
||||
Data::Struct(s) => match s.fields {
|
||||
Fields::Named(fs) => fs.named.into_pairs().map(Pair::into_value).collect(),
|
||||
Fields::Unnamed(fs) => fs.unnamed.into_pairs().map(Pair::into_value).collect(),
|
||||
Fields::Unit => return Ok(impl_send_recv_incoming_self(input.ident)),
|
||||
},
|
||||
};
|
||||
|
||||
let mut any_attribute = false;
|
||||
|
||||
for field in &mut fields {
|
||||
let mut field_meta = None;
|
||||
|
||||
let mut remaining_attrs = Vec::new();
|
||||
for attr in mem::replace(&mut field.attrs, Vec::new()) {
|
||||
if let Some(meta) = Meta::from_attribute(&attr)? {
|
||||
if field_meta.is_some() {
|
||||
return Err(syn::Error::new_spanned(
|
||||
attr,
|
||||
"duplicate #[wrap_incoming] attribute",
|
||||
));
|
||||
}
|
||||
field_meta = Some(meta);
|
||||
any_attribute = true;
|
||||
} else {
|
||||
remaining_attrs.push(attr);
|
||||
}
|
||||
}
|
||||
field.attrs = remaining_attrs;
|
||||
|
||||
if let Some(attr) = field_meta {
|
||||
if let Some(type_to_wrap) = attr.type_to_wrap {
|
||||
wrap_generic_arg(&type_to_wrap, &mut field.ty, attr.wrapper_type.as_ref())?;
|
||||
} else {
|
||||
wrap_ty(&mut field.ty, attr.wrapper_type)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !any_attribute {
|
||||
return Ok(impl_send_recv_incoming_self(input.ident));
|
||||
}
|
||||
|
||||
let vis = input.vis;
|
||||
let doc = format!("\"Incoming\" variant of [{ty}](struct.{ty}.html).", ty = input.ident);
|
||||
let original_ident = input.ident;
|
||||
let incoming_ident = Ident::new(&format!("Incoming{}", original_ident), Span::call_site());
|
||||
|
||||
Ok(quote! {
|
||||
#[doc = #doc]
|
||||
#derive_deserialize
|
||||
#vis struct #incoming_ident {
|
||||
#(#fields,)*
|
||||
}
|
||||
|
||||
impl ruma_api::SendRecv for #original_ident {
|
||||
type Incoming = #incoming_ident;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool {
|
||||
for attr in attrs {
|
||||
match &attr.path {
|
||||
Path { leading_colon: None, segments }
|
||||
if segments.len() == 1 && segments[0].ident == "incoming_no_deserialize" =>
|
||||
{
|
||||
return true
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn impl_send_recv_incoming_self(ident: Ident) -> TokenStream {
|
||||
quote! {
|
||||
impl ruma_api::SendRecv for #ident {
|
||||
type Incoming = Self;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_ty(ty: &mut Type, path: Option<Path>) -> syn::Result<()> {
|
||||
if let Some(wrap_ty) = path {
|
||||
*ty = parse_quote!(#wrap_ty<#ty>);
|
||||
} else {
|
||||
match ty {
|
||||
Type::Path(TypePath { path, .. }) => {
|
||||
let ty_ident = &mut path.segments.last_mut().unwrap().ident;
|
||||
let ident = Ident::new(&format!("Incoming{}", ty_ident), Span::call_site());
|
||||
*ty_ident = parse_quote!(#ident);
|
||||
}
|
||||
_ => return Err(syn::Error::new_spanned(ty, "Can't wrap this type")),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn wrap_generic_arg(type_to_wrap: &Type, of: &mut Type, with: Option<&Path>) -> syn::Result<()> {
|
||||
let mut span = None;
|
||||
wrap_generic_arg_impl(type_to_wrap, of, with, &mut span)?;
|
||||
|
||||
if span.is_some() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
of,
|
||||
format!(
|
||||
"Couldn't find generic argument `{}` in this type",
|
||||
type_to_wrap.to_token_stream()
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_generic_arg_impl(
|
||||
type_to_wrap: &Type,
|
||||
of: &mut Type,
|
||||
with: Option<&Path>,
|
||||
span: &mut Option<Span>,
|
||||
) -> syn::Result<()> {
|
||||
// TODO: Support things like array types?
|
||||
let ty_path = match of {
|
||||
Type::Path(TypePath { path, .. }) => path,
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
let args = match &mut ty_path.segments.last_mut().unwrap().arguments {
|
||||
PathArguments::AngleBracketed(ab) => &mut ab.args,
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
for arg in args.iter_mut() {
|
||||
let ty = match arg {
|
||||
GenericArgument::Type(ty) => ty,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if ty == type_to_wrap {
|
||||
if let Some(s) = span {
|
||||
let mut error = syn::Error::new(
|
||||
*s,
|
||||
format!(
|
||||
"`{}` found multiple times, this is not currently supported",
|
||||
type_to_wrap.to_token_stream()
|
||||
),
|
||||
);
|
||||
error.combine(syn::Error::new_spanned(ty, "second occurrence"));
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
*span = Some(ty.span());
|
||||
|
||||
if let Some(wrapper_type) = with {
|
||||
*ty = parse_quote!(#wrapper_type<#ty>);
|
||||
} else if let Type::Path(TypePath { path, .. }) = ty {
|
||||
let ty_ident = &mut path.segments.last_mut().unwrap().ident;
|
||||
let ident = Ident::new(&format!("Incoming{}", ty_ident), Span::call_site());
|
||||
*ty_ident = parse_quote!(#ident);
|
||||
} else {
|
||||
return Err(syn::Error::new_spanned(ty, "Can't wrap this type"));
|
||||
}
|
||||
} else {
|
||||
wrap_generic_arg_impl(type_to_wrap, ty, with, span)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
58
ruma-api-macros/src/send_recv/wrap_incoming.rs
Normal file
58
ruma-api-macros/src/send_recv/wrap_incoming.rs
Normal file
@ -0,0 +1,58 @@
|
||||
use syn::{
|
||||
parse::{Parse, ParseStream},
|
||||
Ident, Path, Type,
|
||||
};
|
||||
|
||||
mod kw {
|
||||
use syn::custom_keyword;
|
||||
custom_keyword!(with);
|
||||
}
|
||||
|
||||
/// The inside of a `#[wrap_incoming]` attribute
|
||||
#[derive(Default)]
|
||||
pub struct Meta {
|
||||
pub type_to_wrap: Option<Type>,
|
||||
pub wrapper_type: Option<Path>,
|
||||
}
|
||||
|
||||
impl Meta {
|
||||
/// Check if the given attribute is a wrap_incoming attribute. If it is, parse it.
|
||||
pub fn from_attribute(attr: &syn::Attribute) -> syn::Result<Option<Self>> {
|
||||
if attr.path.is_ident("wrap_incoming") {
|
||||
if attr.tokens.is_empty() {
|
||||
Ok(Some(Self::default()))
|
||||
} else {
|
||||
attr.parse_args().map(Some)
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for Meta {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
let mut type_to_wrap = None;
|
||||
let mut wrapper_type = try_parse_wrapper_type(input)?;
|
||||
|
||||
if wrapper_type.is_none() && input.peek(Ident) {
|
||||
type_to_wrap = Some(input.parse()?);
|
||||
wrapper_type = try_parse_wrapper_type(input)?;
|
||||
}
|
||||
|
||||
if input.is_empty() {
|
||||
Ok(Self { type_to_wrap, wrapper_type })
|
||||
} else {
|
||||
Err(input.error("expected end of attribute args"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_parse_wrapper_type(input: ParseStream) -> syn::Result<Option<Path>> {
|
||||
if input.peek(kw::with) {
|
||||
input.parse::<kw::with>()?;
|
||||
Ok(Some(input.parse()?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
31
src/lib.rs
31
src/lib.rs
@ -200,6 +200,9 @@ use serde_urlencoded;
|
||||
#[cfg(feature = "with-ruma-api-macros")]
|
||||
pub use ruma_api_macros::ruma_api;
|
||||
|
||||
#[cfg(feature = "with-ruma-api-macros")]
|
||||
pub use ruma_api_macros::SendRecv;
|
||||
|
||||
#[cfg(feature = "with-ruma-api-macros")]
|
||||
#[doc(hidden)]
|
||||
/// This module is used to support the generated code from ruma-api-macros.
|
||||
@ -213,15 +216,25 @@ pub mod exports {
|
||||
pub use url;
|
||||
}
|
||||
|
||||
/// A type that can be sent as well as received. Types that implement this trait have a
|
||||
/// corresponding 'Incoming' type, which is either just `Self`, or another type that has the same
|
||||
/// fields with some types exchanged by ones that allow fallible deserialization, e.g. `EventResult`
|
||||
/// from ruma_events.
|
||||
pub trait SendRecv {
|
||||
/// The 'Incoming' variant of `Self`.
|
||||
type Incoming;
|
||||
}
|
||||
|
||||
/// A Matrix API endpoint.
|
||||
///
|
||||
/// The type implementing this trait contains any data needed to make a request to the endpoint.
|
||||
pub trait Endpoint:
|
||||
TryFrom<http::Request<Vec<u8>>, Error = Error> + TryInto<http::Request<Vec<u8>>, Error = Error>
|
||||
pub trait Endpoint: SendRecv + TryInto<http::Request<Vec<u8>>, Error = Error>
|
||||
where
|
||||
<Self as SendRecv>::Incoming: TryFrom<http::Request<Vec<u8>>, Error = Error>,
|
||||
<Self::Response as SendRecv>::Incoming: TryFrom<http::Response<Vec<u8>>, Error = Error>,
|
||||
{
|
||||
/// Data returned in a successful response from the endpoint.
|
||||
type Response: TryFrom<http::Response<Vec<u8>>, Error = Error>
|
||||
+ TryInto<http::Response<Vec<u8>>, Error = Error>;
|
||||
type Response: SendRecv + TryInto<http::Response<Vec<u8>>, Error = Error>;
|
||||
|
||||
/// Metadata about the endpoint.
|
||||
const METADATA: Metadata;
|
||||
@ -356,7 +369,7 @@ mod tests {
|
||||
use serde::{de::IntoDeserializer, Deserialize, Serialize};
|
||||
use serde_json;
|
||||
|
||||
use crate::{Endpoint, Error, Metadata};
|
||||
use crate::{Endpoint, Error, Metadata, SendRecv};
|
||||
|
||||
/// A request to create a new room alias.
|
||||
#[derive(Debug)]
|
||||
@ -365,6 +378,10 @@ mod tests {
|
||||
pub room_alias: RoomAliasId, // path
|
||||
}
|
||||
|
||||
impl SendRecv for Request {
|
||||
type Incoming = Self;
|
||||
}
|
||||
|
||||
impl Endpoint for Request {
|
||||
type Response = Response;
|
||||
|
||||
@ -428,6 +445,10 @@ mod tests {
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Response;
|
||||
|
||||
impl SendRecv for Response {
|
||||
type Incoming = Self;
|
||||
}
|
||||
|
||||
impl TryFrom<http::Response<Vec<u8>>> for Response {
|
||||
type Error = Error;
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
pub mod some_endpoint {
|
||||
use ruma_api::ruma_api;
|
||||
use ruma_events::{tag::TagEventContent, EventResult};
|
||||
|
||||
ruma_api! {
|
||||
metadata {
|
||||
@ -40,6 +41,10 @@ pub mod some_endpoint {
|
||||
// You can use serde attributes on any kind of field
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub optional_flag: Option<bool>,
|
||||
|
||||
/// The user's tags for the room.
|
||||
#[wrap_incoming(with EventResult)]
|
||||
pub tags: TagEventContent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user