api-macros: Move ruma_api! trait impl generation into derive macros
This commit is contained in:
		
							parent
							
								
									fae75410a9
								
							
						
					
					
						commit
						696c9fba4e
					
				| @ -2,17 +2,27 @@ | ||||
| 
 | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::quote; | ||||
| use syn::Type; | ||||
| use syn::{ | ||||
|     braced, | ||||
|     parse::{Parse, ParseStream}, | ||||
|     Attribute, Field, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| pub(crate) mod attribute; | ||||
| pub(crate) mod metadata; | ||||
| pub(crate) mod parse; | ||||
| pub(crate) mod request; | ||||
| pub(crate) mod response; | ||||
| mod metadata; | ||||
| mod request; | ||||
| mod response; | ||||
| 
 | ||||
| use self::{metadata::Metadata, request::Request, response::Response}; | ||||
| use crate::util; | ||||
| 
 | ||||
| mod kw { | ||||
|     use syn::custom_keyword; | ||||
| 
 | ||||
|     custom_keyword!(error); | ||||
|     custom_keyword!(request); | ||||
|     custom_keyword!(response); | ||||
| } | ||||
| 
 | ||||
| /// The result of processing the `ruma_api` macro, ready for output back to source code.
 | ||||
| pub struct Api { | ||||
|     /// The `metadata` section of the macro.
 | ||||
| @ -28,65 +38,129 @@ pub struct Api { | ||||
|     error_ty: Option<Type>, | ||||
| } | ||||
| 
 | ||||
| pub fn expand_all(api: Api) -> syn::Result<TokenStream> { | ||||
|     let ruma_api = util::import_ruma_api(); | ||||
|     let http = quote! { #ruma_api::exports::http }; | ||||
| impl Api { | ||||
|     pub fn expand_all(self) -> TokenStream { | ||||
|         let ruma_api = util::import_ruma_api(); | ||||
|         let http = quote! { #ruma_api::exports::http }; | ||||
| 
 | ||||
|     let metadata = &api.metadata; | ||||
|     let description = &metadata.description; | ||||
|     let method = &metadata.method; | ||||
|     let name = &metadata.name; | ||||
|     let path = &metadata.path; | ||||
|     let rate_limited: TokenStream = metadata | ||||
|         .rate_limited | ||||
|         .iter() | ||||
|         .map(|r| { | ||||
|             let attrs = &r.attrs; | ||||
|             let value = &r.value; | ||||
|             quote! { | ||||
|                 #( #attrs )* | ||||
|                 rate_limited: #value, | ||||
|             } | ||||
|         }) | ||||
|         .collect(); | ||||
|     let authentication: TokenStream = api | ||||
|         .metadata | ||||
|         .authentication | ||||
|         .iter() | ||||
|         .map(|r| { | ||||
|             let attrs = &r.attrs; | ||||
|             let value = &r.value; | ||||
|             quote! { | ||||
|                 #( #attrs )* | ||||
|                 authentication: #ruma_api::AuthScheme::#value, | ||||
|             } | ||||
|         }) | ||||
|         .collect(); | ||||
|         let metadata = &self.metadata; | ||||
|         let description = &metadata.description; | ||||
|         let method = &metadata.method; | ||||
|         let name = &metadata.name; | ||||
|         let path = &metadata.path; | ||||
|         let rate_limited: TokenStream = metadata | ||||
|             .rate_limited | ||||
|             .iter() | ||||
|             .map(|r| { | ||||
|                 let attrs = &r.attrs; | ||||
|                 let value = &r.value; | ||||
|                 quote! { | ||||
|                     #( #attrs )* | ||||
|                     rate_limited: #value, | ||||
|                 } | ||||
|             }) | ||||
|             .collect(); | ||||
|         let authentication: TokenStream = self | ||||
|             .metadata | ||||
|             .authentication | ||||
|             .iter() | ||||
|             .map(|r| { | ||||
|                 let attrs = &r.attrs; | ||||
|                 let value = &r.value; | ||||
|                 quote! { | ||||
|                     #( #attrs )* | ||||
|                     authentication: #ruma_api::AuthScheme::#value, | ||||
|                 } | ||||
|             }) | ||||
|             .collect(); | ||||
| 
 | ||||
|     let error_ty = api | ||||
|         .error_ty | ||||
|         .map_or_else(|| quote! { #ruma_api::error::MatrixError }, |err_ty| quote! { #err_ty }); | ||||
|         let error_ty = self | ||||
|             .error_ty | ||||
|             .map_or_else(|| quote! { #ruma_api::error::MatrixError }, |err_ty| quote! { #err_ty }); | ||||
| 
 | ||||
|     let request = api.request.map(|req| req.expand(metadata, &error_ty, &ruma_api)); | ||||
|     let response = api.response.map(|res| res.expand(metadata, &error_ty, &ruma_api)); | ||||
|         let request = self.request.map(|req| req.expand(metadata, &error_ty, &ruma_api)); | ||||
|         let response = self.response.map(|res| res.expand(metadata, &error_ty, &ruma_api)); | ||||
| 
 | ||||
|     let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); | ||||
|         let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); | ||||
| 
 | ||||
|     Ok(quote! { | ||||
|         #[doc = #metadata_doc] | ||||
|         pub const METADATA: #ruma_api::Metadata = #ruma_api::Metadata { | ||||
|             description: #description, | ||||
|             method: #http::Method::#method, | ||||
|             name: #name, | ||||
|             path: #path, | ||||
|             #rate_limited | ||||
|             #authentication | ||||
|         quote! { | ||||
|             #[doc = #metadata_doc] | ||||
|             pub const METADATA: #ruma_api::Metadata = #ruma_api::Metadata { | ||||
|                 description: #description, | ||||
|                 method: #http::Method::#method, | ||||
|                 name: #name, | ||||
|                 path: #path, | ||||
|                 #rate_limited | ||||
|                 #authentication | ||||
|             }; | ||||
| 
 | ||||
|             #request | ||||
|             #response | ||||
| 
 | ||||
|             #[cfg(not(any(feature = "client", feature = "server")))] | ||||
|             type _SilenceUnusedError = #error_ty; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Parse for Api { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         let metadata: Metadata = input.parse()?; | ||||
| 
 | ||||
|         let req_attrs = input.call(Attribute::parse_outer)?; | ||||
|         let (request, attributes) = if input.peek(kw::request) { | ||||
|             let request = parse_request(input, req_attrs)?; | ||||
|             let after_req_attrs = input.call(Attribute::parse_outer)?; | ||||
| 
 | ||||
|             (Some(request), after_req_attrs) | ||||
|         } else { | ||||
|             // There was no `request` field so the attributes are for `response`
 | ||||
|             (None, req_attrs) | ||||
|         }; | ||||
| 
 | ||||
|         #request | ||||
|         #response | ||||
|         let response = if input.peek(kw::response) { | ||||
|             Some(parse_response(input, attributes)?) | ||||
|         } else if !attributes.is_empty() { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 &attributes[0], | ||||
|                 "attributes are not supported on the error type", | ||||
|             )); | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
| 
 | ||||
|         #[cfg(not(any(feature = "client", feature = "server")))] | ||||
|         type _SilenceUnusedError = #error_ty; | ||||
|     }) | ||||
|         let error_ty = input | ||||
|             .peek(kw::error) | ||||
|             .then(|| { | ||||
|                 let _: kw::error = input.parse()?; | ||||
|                 let _: Token![:] = input.parse()?; | ||||
| 
 | ||||
|                 input.parse() | ||||
|             }) | ||||
|             .transpose()?; | ||||
| 
 | ||||
|         Ok(Self { metadata, request, response, error_ty }) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn parse_request(input: ParseStream<'_>, attributes: Vec<Attribute>) -> syn::Result<Request> { | ||||
|     let request_kw: kw::request = input.parse()?; | ||||
|     let _: Token![:] = input.parse()?; | ||||
|     let fields; | ||||
|     braced!(fields in input); | ||||
| 
 | ||||
|     let fields = fields.parse_terminated::<_, Token![,]>(Field::parse_named)?; | ||||
| 
 | ||||
|     Ok(Request { request_kw, attributes, fields }) | ||||
| } | ||||
| 
 | ||||
| fn parse_response(input: ParseStream<'_>, attributes: Vec<Attribute>) -> syn::Result<Response> { | ||||
|     let response_kw: kw::response = input.parse()?; | ||||
|     let _: Token![:] = input.parse()?; | ||||
|     let fields; | ||||
|     braced!(fields in input); | ||||
| 
 | ||||
|     let fields = fields.parse_terminated::<_, Token![,]>(Field::parse_named)?; | ||||
| 
 | ||||
|     Ok(Response { attributes, fields, response_kw }) | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| //! Details of the `metadata` section of the procedural macro.
 | ||||
| 
 | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::ToTokens; | ||||
| use syn::{ | ||||
|     braced, | ||||
| @ -8,7 +7,7 @@ use syn::{ | ||||
|     Attribute, Ident, LitBool, LitStr, Token, | ||||
| }; | ||||
| 
 | ||||
| use crate::util; | ||||
| use crate::{auth_scheme::AuthScheme, util}; | ||||
| 
 | ||||
| mod kw { | ||||
|     syn::custom_keyword!(metadata); | ||||
| @ -18,11 +17,6 @@ mod kw { | ||||
|     syn::custom_keyword!(path); | ||||
|     syn::custom_keyword!(rate_limited); | ||||
|     syn::custom_keyword!(authentication); | ||||
| 
 | ||||
|     syn::custom_keyword!(None); | ||||
|     syn::custom_keyword!(AccessToken); | ||||
|     syn::custom_keyword!(ServerSignatures); | ||||
|     syn::custom_keyword!(QueryOnlyAccessToken); | ||||
| } | ||||
| 
 | ||||
| /// A field of Metadata that contains attribute macros
 | ||||
| @ -124,42 +118,6 @@ impl Parse for Metadata { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub enum AuthScheme { | ||||
|     None(kw::None), | ||||
|     AccessToken(kw::AccessToken), | ||||
|     ServerSignatures(kw::ServerSignatures), | ||||
|     QueryOnlyAccessToken(kw::QueryOnlyAccessToken), | ||||
| } | ||||
| 
 | ||||
| impl Parse for AuthScheme { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         let lookahead = input.lookahead1(); | ||||
| 
 | ||||
|         if lookahead.peek(kw::None) { | ||||
|             input.parse().map(Self::None) | ||||
|         } else if lookahead.peek(kw::AccessToken) { | ||||
|             input.parse().map(Self::AccessToken) | ||||
|         } else if lookahead.peek(kw::ServerSignatures) { | ||||
|             input.parse().map(Self::ServerSignatures) | ||||
|         } else if lookahead.peek(kw::QueryOnlyAccessToken) { | ||||
|             input.parse().map(Self::QueryOnlyAccessToken) | ||||
|         } else { | ||||
|             Err(lookahead.error()) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl ToTokens for AuthScheme { | ||||
|     fn to_tokens(&self, tokens: &mut TokenStream) { | ||||
|         match self { | ||||
|             AuthScheme::None(kw) => kw.to_tokens(tokens), | ||||
|             AuthScheme::AccessToken(kw) => kw.to_tokens(tokens), | ||||
|             AuthScheme::ServerSignatures(kw) => kw.to_tokens(tokens), | ||||
|             AuthScheme::QueryOnlyAccessToken(kw) => kw.to_tokens(tokens), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| enum Field { | ||||
|     Description, | ||||
|     Method, | ||||
|  | ||||
| @ -1,350 +0,0 @@ | ||||
| use std::{collections::BTreeSet, mem}; | ||||
| 
 | ||||
| use syn::{ | ||||
|     braced, | ||||
|     parse::{Parse, ParseStream}, | ||||
|     spanned::Spanned, | ||||
|     visit::Visit, | ||||
|     Attribute, Field, Ident, Lifetime, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| use super::{ | ||||
|     attribute::{Meta, MetaNameValue}, | ||||
|     request::{RequestField, RequestFieldKind, RequestLifetimes}, | ||||
|     response::{ResponseField, ResponseFieldKind}, | ||||
|     Api, Metadata, Request, Response, | ||||
| }; | ||||
| 
 | ||||
| mod kw { | ||||
|     use syn::custom_keyword; | ||||
| 
 | ||||
|     custom_keyword!(error); | ||||
|     custom_keyword!(request); | ||||
|     custom_keyword!(response); | ||||
| } | ||||
| 
 | ||||
| impl Parse for Api { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         let metadata: Metadata = input.parse()?; | ||||
| 
 | ||||
|         let req_attrs = input.call(Attribute::parse_outer)?; | ||||
|         let (request, attributes) = if input.peek(kw::request) { | ||||
|             let request = parse_request(input, req_attrs)?; | ||||
|             let after_req_attrs = input.call(Attribute::parse_outer)?; | ||||
| 
 | ||||
|             (Some(request), after_req_attrs) | ||||
|         } else { | ||||
|             // There was no `request` field so the attributes are for `response`
 | ||||
|             (None, req_attrs) | ||||
|         }; | ||||
| 
 | ||||
|         let response = if input.peek(kw::response) { | ||||
|             Some(parse_response(input, attributes)?) | ||||
|         } else if !attributes.is_empty() { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 &attributes[0], | ||||
|                 "attributes are not supported on the error type", | ||||
|             )); | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
| 
 | ||||
|         let error_ty = input | ||||
|             .peek(kw::error) | ||||
|             .then(|| { | ||||
|                 let _: kw::error = input.parse()?; | ||||
|                 let _: Token![:] = input.parse()?; | ||||
| 
 | ||||
|                 input.parse() | ||||
|             }) | ||||
|             .transpose()?; | ||||
| 
 | ||||
|         if let Some(req) = &request { | ||||
|             let newtype_body_field = req.newtype_body_field(); | ||||
|             if metadata.method == "GET" && (req.has_body_fields() || newtype_body_field.is_some()) { | ||||
|                 let mut combined_error: Option<syn::Error> = None; | ||||
|                 let mut add_error = |field| { | ||||
|                     let error = | ||||
|                         syn::Error::new_spanned(field, "GET endpoints can't have body fields"); | ||||
|                     if let Some(combined_error_ref) = &mut combined_error { | ||||
|                         combined_error_ref.combine(error); | ||||
|                     } else { | ||||
|                         combined_error = Some(error); | ||||
|                     } | ||||
|                 }; | ||||
| 
 | ||||
|                 for field in req.body_fields() { | ||||
|                     add_error(field); | ||||
|                 } | ||||
| 
 | ||||
|                 if let Some(field) = newtype_body_field { | ||||
|                     add_error(field); | ||||
|                 } | ||||
| 
 | ||||
|                 return Err(combined_error.unwrap()); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         Ok(Self { metadata, request, response, error_ty }) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn parse_request(input: ParseStream<'_>, attributes: Vec<Attribute>) -> syn::Result<Request> { | ||||
|     let request_kw: kw::request = input.parse()?; | ||||
|     let _: Token![:] = input.parse()?; | ||||
|     let fields; | ||||
|     braced!(fields in input); | ||||
| 
 | ||||
|     let mut newtype_body_field = None; | ||||
|     let mut query_map_field = None; | ||||
|     let mut lifetimes = RequestLifetimes::default(); | ||||
| 
 | ||||
|     let fields: Vec<_> = fields | ||||
|         .parse_terminated::<Field, Token![,]>(Field::parse_named)? | ||||
|         .into_iter() | ||||
|         .map(|mut field| { | ||||
|             let mut field_kind = None; | ||||
|             let mut header = None; | ||||
| 
 | ||||
|             for attr in mem::take(&mut field.attrs) { | ||||
|                 let meta = match Meta::from_attribute(&attr)? { | ||||
|                     Some(m) => m, | ||||
|                     None => { | ||||
|                         field.attrs.push(attr); | ||||
|                         continue; | ||||
|                     } | ||||
|                 }; | ||||
| 
 | ||||
|                 if field_kind.is_some() { | ||||
|                     return Err(syn::Error::new_spanned( | ||||
|                         attr, | ||||
|                         "There can only be one field kind attribute", | ||||
|                     )); | ||||
|                 } | ||||
| 
 | ||||
|                 field_kind = Some(match meta { | ||||
|                     Meta::Word(ident) => match &ident.to_string()[..] { | ||||
|                         attr @ "body" | attr @ "raw_body" => req_res_meta_word( | ||||
|                             attr, | ||||
|                             &field, | ||||
|                             &mut newtype_body_field, | ||||
|                             RequestFieldKind::NewtypeBody, | ||||
|                             RequestFieldKind::NewtypeRawBody, | ||||
|                         )?, | ||||
|                         "path" => RequestFieldKind::Path, | ||||
|                         "query" => RequestFieldKind::Query, | ||||
|                         "query_map" => { | ||||
|                             if let Some(f) = &query_map_field { | ||||
|                                 let mut error = syn::Error::new_spanned( | ||||
|                                     field, | ||||
|                                     "There can only be one query map field", | ||||
|                                 ); | ||||
|                                 error.combine(syn::Error::new_spanned( | ||||
|                                     f, | ||||
|                                     "Previous query map field", | ||||
|                                 )); | ||||
|                                 return Err(error); | ||||
|                             } | ||||
| 
 | ||||
|                             query_map_field = Some(field.clone()); | ||||
|                             RequestFieldKind::QueryMap | ||||
|                         } | ||||
|                         _ => { | ||||
|                             return Err(syn::Error::new_spanned( | ||||
|                                 ident, | ||||
|                                 "Invalid #[ruma_api] argument, expected one of \ | ||||
|                                      `body`, `path`, `query`, `query_map`",
 | ||||
|                             )); | ||||
|                         } | ||||
|                     }, | ||||
|                     Meta::NameValue(MetaNameValue { name, value }) => { | ||||
|                         req_res_name_value(name, value, &mut header, RequestFieldKind::Header)? | ||||
|                     } | ||||
|                 }); | ||||
|             } | ||||
| 
 | ||||
|             match field_kind.unwrap_or(RequestFieldKind::Body) { | ||||
|                 RequestFieldKind::Header => { | ||||
|                     collect_lifetime_idents(&mut lifetimes.header, &field.ty) | ||||
|                 } | ||||
|                 RequestFieldKind::Body => collect_lifetime_idents(&mut lifetimes.body, &field.ty), | ||||
|                 RequestFieldKind::NewtypeBody => { | ||||
|                     collect_lifetime_idents(&mut lifetimes.body, &field.ty) | ||||
|                 } | ||||
|                 RequestFieldKind::NewtypeRawBody => { | ||||
|                     collect_lifetime_idents(&mut lifetimes.body, &field.ty) | ||||
|                 } | ||||
|                 RequestFieldKind::Path => collect_lifetime_idents(&mut lifetimes.path, &field.ty), | ||||
|                 RequestFieldKind::Query => collect_lifetime_idents(&mut lifetimes.query, &field.ty), | ||||
|                 RequestFieldKind::QueryMap => { | ||||
|                     collect_lifetime_idents(&mut lifetimes.query, &field.ty) | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             Ok(RequestField::new(field_kind.unwrap_or(RequestFieldKind::Body), field, header)) | ||||
|         }) | ||||
|         .collect::<syn::Result<_>>()?; | ||||
| 
 | ||||
|     if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) { | ||||
|         // TODO: highlight conflicting fields,
 | ||||
|         return Err(syn::Error::new_spanned( | ||||
|             request_kw, | ||||
|             "Can't have both a newtype body field and regular body fields", | ||||
|         )); | ||||
|     } | ||||
| 
 | ||||
|     if query_map_field.is_some() && fields.iter().any(|f| f.is_query()) { | ||||
|         return Err(syn::Error::new_spanned( | ||||
|             // TODO: raw,
 | ||||
|             request_kw, | ||||
|             "Can't have both a query map field and regular query fields", | ||||
|         )); | ||||
|     } | ||||
| 
 | ||||
|     // TODO when/if `&[(&str, &str)]` is supported remove this
 | ||||
|     if query_map_field.is_some() && !lifetimes.query.is_empty() { | ||||
|         return Err(syn::Error::new_spanned( | ||||
|             request_kw, | ||||
|             "Lifetimes are not allowed for query_map fields", | ||||
|         )); | ||||
|     } | ||||
| 
 | ||||
|     Ok(Request { attributes, fields, lifetimes }) | ||||
| } | ||||
| 
 | ||||
| fn parse_response(input: ParseStream<'_>, attributes: Vec<Attribute>) -> syn::Result<Response> { | ||||
|     let response_kw: kw::response = input.parse()?; | ||||
|     let _: Token![:] = input.parse()?; | ||||
|     let fields; | ||||
|     braced!(fields in input); | ||||
| 
 | ||||
|     let mut newtype_body_field = None; | ||||
| 
 | ||||
|     let fields: Vec<_> = fields | ||||
|         .parse_terminated::<Field, Token![,]>(Field::parse_named)? | ||||
|         .into_iter() | ||||
|         .map(|mut field| { | ||||
|             if has_lifetime(&field.ty) { | ||||
|                 return Err(syn::Error::new( | ||||
|                     field.ident.span(), | ||||
|                     "Lifetimes on Response fields cannot be supported until GAT are stable", | ||||
|                 )); | ||||
|             } | ||||
| 
 | ||||
|             let mut field_kind = None; | ||||
|             let mut header = None; | ||||
| 
 | ||||
|             for attr in mem::take(&mut field.attrs) { | ||||
|                 let meta = match Meta::from_attribute(&attr)? { | ||||
|                     Some(m) => m, | ||||
|                     None => { | ||||
|                         field.attrs.push(attr); | ||||
|                         continue; | ||||
|                     } | ||||
|                 }; | ||||
| 
 | ||||
|                 if field_kind.is_some() { | ||||
|                     return Err(syn::Error::new_spanned( | ||||
|                         attr, | ||||
|                         "There can only be one field kind attribute", | ||||
|                     )); | ||||
|                 } | ||||
| 
 | ||||
|                 field_kind = Some(match meta { | ||||
|                     Meta::Word(ident) => match &ident.to_string()[..] { | ||||
|                         s @ "body" | s @ "raw_body" => req_res_meta_word( | ||||
|                             s, | ||||
|                             &field, | ||||
|                             &mut newtype_body_field, | ||||
|                             ResponseFieldKind::NewtypeBody, | ||||
|                             ResponseFieldKind::NewtypeRawBody, | ||||
|                         )?, | ||||
|                         _ => { | ||||
|                             return Err(syn::Error::new_spanned( | ||||
|                                 ident, | ||||
|                                 "Invalid #[ruma_api] argument with value, expected `body`", | ||||
|                             )); | ||||
|                         } | ||||
|                     }, | ||||
|                     Meta::NameValue(MetaNameValue { name, value }) => { | ||||
|                         req_res_name_value(name, value, &mut header, ResponseFieldKind::Header)? | ||||
|                     } | ||||
|                 }); | ||||
|             } | ||||
| 
 | ||||
|             Ok(match field_kind.unwrap_or(ResponseFieldKind::Body) { | ||||
|                 ResponseFieldKind::Body => ResponseField::Body(field), | ||||
|                 ResponseFieldKind::Header => { | ||||
|                     ResponseField::Header(field, header.expect("missing header name")) | ||||
|                 } | ||||
|                 ResponseFieldKind::NewtypeBody => ResponseField::NewtypeBody(field), | ||||
|                 ResponseFieldKind::NewtypeRawBody => ResponseField::NewtypeRawBody(field), | ||||
|             }) | ||||
|         }) | ||||
|         .collect::<syn::Result<_>>()?; | ||||
| 
 | ||||
|     if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) { | ||||
|         // TODO: highlight conflicting fields,
 | ||||
|         return Err(syn::Error::new_spanned( | ||||
|             response_kw, | ||||
|             "Can't have both a newtype body field and regular body fields", | ||||
|         )); | ||||
|     } | ||||
| 
 | ||||
|     Ok(Response { attributes, fields }) | ||||
| } | ||||
| 
 | ||||
| fn has_lifetime(ty: &Type) -> bool { | ||||
|     let mut lifetimes = BTreeSet::new(); | ||||
|     collect_lifetime_idents(&mut lifetimes, ty); | ||||
|     !lifetimes.is_empty() | ||||
| } | ||||
| 
 | ||||
| fn collect_lifetime_idents(lifetimes: &mut BTreeSet<Lifetime>, ty: &Type) { | ||||
|     struct Visitor<'lt>(&'lt mut BTreeSet<Lifetime>); | ||||
|     impl<'ast> Visit<'ast> for Visitor<'_> { | ||||
|         fn visit_lifetime(&mut self, lt: &'ast Lifetime) { | ||||
|             self.0.insert(lt.clone()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     Visitor(lifetimes).visit_type(ty) | ||||
| } | ||||
| 
 | ||||
| fn req_res_meta_word<T>( | ||||
|     attr_kind: &str, | ||||
|     field: &Field, | ||||
|     newtype_body_field: &mut Option<Field>, | ||||
|     body_field_kind: T, | ||||
|     raw_field_kind: T, | ||||
| ) -> syn::Result<T> { | ||||
|     if let Some(f) = &newtype_body_field { | ||||
|         let mut error = syn::Error::new_spanned(field, "There can only be one newtype body field"); | ||||
|         error.combine(syn::Error::new_spanned(f, "Previous newtype body field")); | ||||
|         return Err(error); | ||||
|     } | ||||
| 
 | ||||
|     *newtype_body_field = Some(field.clone()); | ||||
|     Ok(match attr_kind { | ||||
|         "body" => body_field_kind, | ||||
|         "raw_body" => raw_field_kind, | ||||
|         _ => unreachable!(), | ||||
|     }) | ||||
| } | ||||
| 
 | ||||
| fn req_res_name_value<T>( | ||||
|     name: Ident, | ||||
|     value: Ident, | ||||
|     header: &mut Option<Ident>, | ||||
|     field_kind: T, | ||||
| ) -> syn::Result<T> { | ||||
|     if name != "header" { | ||||
|         return Err(syn::Error::new_spanned( | ||||
|             name, | ||||
|             "Invalid #[ruma_api] argument with value, expected `header`", | ||||
|         )); | ||||
|     } | ||||
| 
 | ||||
|     *header = Some(value); | ||||
|     Ok(field_kind) | ||||
| } | ||||
| @ -1,129 +1,67 @@ | ||||
| //! Details of the `request` section of the procedural macro.
 | ||||
| 
 | ||||
| use std::collections::BTreeSet; | ||||
| use std::collections::btree_map::{BTreeMap, Entry}; | ||||
| 
 | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::quote; | ||||
| use syn::{Attribute, Field, Ident, Lifetime}; | ||||
| use syn::{ | ||||
|     parse_quote, punctuated::Punctuated, spanned::Spanned, visit::Visit, Attribute, Field, Ident, | ||||
|     Lifetime, Token, | ||||
| }; | ||||
| 
 | ||||
| use crate::util; | ||||
| 
 | ||||
| use super::metadata::Metadata; | ||||
| 
 | ||||
| mod incoming; | ||||
| mod outgoing; | ||||
| 
 | ||||
| #[derive(Default)] | ||||
| pub(super) struct RequestLifetimes { | ||||
|     pub body: BTreeSet<Lifetime>, | ||||
|     pub path: BTreeSet<Lifetime>, | ||||
|     pub query: BTreeSet<Lifetime>, | ||||
|     pub header: BTreeSet<Lifetime>, | ||||
| } | ||||
| use super::{kw, metadata::Metadata}; | ||||
| use crate::util::{all_cfgs, all_cfgs_expr, extract_cfg}; | ||||
| 
 | ||||
| /// The result of processing the `request` section of the macro.
 | ||||
| pub(crate) struct Request { | ||||
|     /// The `request` keyword
 | ||||
|     pub(super) request_kw: kw::request, | ||||
| 
 | ||||
|     /// The attributes that will be applied to the struct definition.
 | ||||
|     pub(super) attributes: Vec<Attribute>, | ||||
| 
 | ||||
|     /// The fields of the request.
 | ||||
|     pub(super) fields: Vec<RequestField>, | ||||
| 
 | ||||
|     /// The collected lifetime identifiers from the declared fields.
 | ||||
|     pub(super) lifetimes: RequestLifetimes, | ||||
|     pub(super) fields: Punctuated<Field, Token![,]>, | ||||
| } | ||||
| 
 | ||||
| impl Request { | ||||
|     /// Whether or not this request has any data in the HTTP body.
 | ||||
|     pub(super) fn has_body_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|field| field.is_body()) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request has any data in HTTP headers.
 | ||||
|     fn has_header_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|field| field.is_header()) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request has any data in the URL path.
 | ||||
|     fn has_path_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|field| field.is_path()) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request has any data in the query string.
 | ||||
|     fn has_query_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|field| field.is_query()) | ||||
|     } | ||||
| 
 | ||||
|     /// Produces an iterator over all the body fields.
 | ||||
|     pub(super) fn body_fields(&self) -> impl Iterator<Item = &Field> { | ||||
|         self.fields.iter().filter_map(|field| field.as_body_field()) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether any `body` field has a lifetime annotation.
 | ||||
|     fn has_body_lifetimes(&self) -> bool { | ||||
|         !self.lifetimes.body.is_empty() | ||||
|     } | ||||
| 
 | ||||
|     /// Whether any `query` field has a lifetime annotation.
 | ||||
|     fn has_query_lifetimes(&self) -> bool { | ||||
|         !self.lifetimes.query.is_empty() | ||||
|     } | ||||
| 
 | ||||
|     /// Whether any field has a lifetime.
 | ||||
|     fn contains_lifetimes(&self) -> bool { | ||||
|         !(self.lifetimes.body.is_empty() | ||||
|             && self.lifetimes.path.is_empty() | ||||
|             && self.lifetimes.query.is_empty() | ||||
|             && self.lifetimes.header.is_empty()) | ||||
|     } | ||||
| 
 | ||||
|     /// The combination of every fields unique lifetime annotation.
 | ||||
|     fn combine_lifetimes(&self) -> TokenStream { | ||||
|         util::unique_lifetimes_to_tokens( | ||||
|             [ | ||||
|                 &self.lifetimes.body, | ||||
|                 &self.lifetimes.path, | ||||
|                 &self.lifetimes.query, | ||||
|                 &self.lifetimes.header, | ||||
|             ] | ||||
|             .iter() | ||||
|             .flat_map(|set| set.iter()), | ||||
|         ) | ||||
|     } | ||||
|     fn all_lifetimes(&self) -> BTreeMap<Lifetime, Option<Attribute>> { | ||||
|         let mut lifetimes = BTreeMap::new(); | ||||
| 
 | ||||
|     /// The lifetimes on fields with the `query` attribute.
 | ||||
|     fn query_lifetimes(&self) -> TokenStream { | ||||
|         util::unique_lifetimes_to_tokens(&self.lifetimes.query) | ||||
|     } | ||||
|         struct Visitor<'lt> { | ||||
|             field_cfg: Option<Attribute>, | ||||
|             lifetimes: &'lt mut BTreeMap<Lifetime, Option<Attribute>>, | ||||
|         } | ||||
| 
 | ||||
|     /// The lifetimes on fields with the `body` attribute.
 | ||||
|     fn body_lifetimes(&self) -> TokenStream { | ||||
|         util::unique_lifetimes_to_tokens(&self.lifetimes.body) | ||||
|     } | ||||
|         impl<'ast> Visit<'ast> for Visitor<'_> { | ||||
|             fn visit_lifetime(&mut self, lt: &'ast Lifetime) { | ||||
|                 match self.lifetimes.entry(lt.clone()) { | ||||
|                     Entry::Vacant(v) => { | ||||
|                         v.insert(self.field_cfg.clone()); | ||||
|                     } | ||||
|                     Entry::Occupied(mut o) => { | ||||
|                         let lifetime_cfg = o.get_mut(); | ||||
| 
 | ||||
|     /// Produces an iterator over all the header fields.
 | ||||
|     fn header_fields(&self) -> impl Iterator<Item = &RequestField> { | ||||
|         self.fields.iter().filter(|field| field.is_header()) | ||||
|     } | ||||
|                         // If at least one field uses this lifetime and has no cfg attribute, we
 | ||||
|                         // don't need a cfg attribute for the lifetime either.
 | ||||
|                         *lifetime_cfg = Option::zip(lifetime_cfg.as_ref(), self.field_cfg.as_ref()) | ||||
|                             .map(|(a, b)| { | ||||
|                                 let expr_a = extract_cfg(a); | ||||
|                                 let expr_b = extract_cfg(b); | ||||
|                                 parse_quote! { #[cfg( any( #expr_a, #expr_b ) )] } | ||||
|                             }); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|     /// Gets the number of path fields.
 | ||||
|     fn path_field_count(&self) -> usize { | ||||
|         self.fields.iter().filter(|field| field.is_path()).count() | ||||
|     } | ||||
|         for field in &self.fields { | ||||
|             let field_cfg = if field.attrs.is_empty() { None } else { all_cfgs(&field.attrs) }; | ||||
|             Visitor { lifetimes: &mut lifetimes, field_cfg }.visit_type(&field.ty); | ||||
|         } | ||||
| 
 | ||||
|     /// Returns the body field.
 | ||||
|     pub fn newtype_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(RequestField::as_newtype_body_field) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the body field.
 | ||||
|     fn newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the query map field.
 | ||||
|     fn query_map_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(RequestField::as_query_map_field) | ||||
|         lifetimes | ||||
|     } | ||||
| 
 | ||||
|     pub(super) fn expand( | ||||
| @ -132,8 +70,8 @@ impl Request { | ||||
|         error_ty: &TokenStream, | ||||
|         ruma_api: &TokenStream, | ||||
|     ) -> TokenStream { | ||||
|         let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; | ||||
|         let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; | ||||
|         let serde = quote! { #ruma_api::exports::serde }; | ||||
| 
 | ||||
|         let docs = format!( | ||||
|             "Data for a request to the `{}` API endpoint.\n\n{}", | ||||
| @ -142,239 +80,44 @@ impl Request { | ||||
|         ); | ||||
|         let struct_attributes = &self.attributes; | ||||
| 
 | ||||
|         let request_body_struct = | ||||
|             if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { | ||||
|                 let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; | ||||
|                 // Though we don't track the difference between new type body and body
 | ||||
|                 // for lifetimes, the outer check and the macro failing if it encounters
 | ||||
|                 // an illegal combination of field attributes, is enough to guarantee
 | ||||
|                 // `body_lifetimes` correctness.
 | ||||
|                 let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { | ||||
|                     (TokenStream::new(), self.body_lifetimes()) | ||||
|                 } else { | ||||
|                     (quote! { #serde::Deserialize }, TokenStream::new()) | ||||
|                 }; | ||||
|         let method = &metadata.method; | ||||
|         let path = &metadata.path; | ||||
|         let auth_attributes = metadata.authentication.iter().map(|field| { | ||||
|             let cfg_expr = all_cfgs_expr(&field.attrs); | ||||
|             let value = &field.value; | ||||
| 
 | ||||
|                 Some((derive_deserialize, quote! { #lifetimes (#field); })) | ||||
|             } else if self.has_body_fields() { | ||||
|                 let fields = self.fields.iter().filter(|f| f.is_body()); | ||||
|                 let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { | ||||
|                     (TokenStream::new(), self.body_lifetimes()) | ||||
|                 } else { | ||||
|                     (quote! { #serde::Deserialize }, TokenStream::new()) | ||||
|                 }; | ||||
|                 let fields = fields.map(RequestField::field); | ||||
| 
 | ||||
|                 Some((derive_deserialize, quote! { #lifetimes { #(#fields),* } })) | ||||
|             } else { | ||||
|                 None | ||||
|             match cfg_expr { | ||||
|                 Some(expr) => quote! { #[cfg_attr(#expr, ruma_api(authentication = #value))] }, | ||||
|                 None => quote! { #[ruma_api(authentication = #value)] }, | ||||
|             } | ||||
|             .map(|(derive_deserialize, def)| { | ||||
|                 quote! { | ||||
|                     /// Data in the request body.
 | ||||
|                     #[derive(
 | ||||
|                         Debug, | ||||
|                         #ruma_serde::Outgoing, | ||||
|                         #serde::Serialize, | ||||
|                         #derive_deserialize | ||||
|                     )] | ||||
|                     struct RequestBody #def | ||||
|                 } | ||||
|             }); | ||||
|         }); | ||||
| 
 | ||||
|         let request_query_struct = if let Some(f) = self.query_map_field() { | ||||
|             let field = Field { ident: None, colon_token: None, ..f.clone() }; | ||||
|             let (derive_deserialize, lifetime) = if self.has_query_lifetimes() { | ||||
|                 (TokenStream::new(), self.query_lifetimes()) | ||||
|             } else { | ||||
|                 (quote! { #serde::Deserialize }, TokenStream::new()) | ||||
|             }; | ||||
| 
 | ||||
|             quote! { | ||||
|                 /// Data in the request's query string.
 | ||||
|                 #[derive(
 | ||||
|                     Debug, | ||||
|                     #ruma_serde::Outgoing, | ||||
|                     #serde::Serialize, | ||||
|                     #derive_deserialize | ||||
|                 )] | ||||
|                 struct RequestQuery #lifetime (#field); | ||||
|             } | ||||
|         } else if self.has_query_fields() { | ||||
|             let fields = self.fields.iter().filter_map(RequestField::as_query_field); | ||||
|             let (derive_deserialize, lifetime) = if self.has_query_lifetimes() { | ||||
|                 (TokenStream::new(), self.query_lifetimes()) | ||||
|             } else { | ||||
|                 (quote! { #serde::Deserialize }, TokenStream::new()) | ||||
|             }; | ||||
| 
 | ||||
|             quote! { | ||||
|                 /// Data in the request's query string.
 | ||||
|                 #[derive(
 | ||||
|                     Debug, | ||||
|                     #ruma_serde::Outgoing, | ||||
|                     #serde::Serialize, | ||||
|                     #derive_deserialize | ||||
|                 )] | ||||
|                 struct RequestQuery #lifetime { | ||||
|                     #(#fields),* | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|             TokenStream::new() | ||||
|         }; | ||||
| 
 | ||||
|         let lifetimes = self.combine_lifetimes(); | ||||
|         let fields = self.fields.iter().map(|request_field| request_field.field()); | ||||
| 
 | ||||
|         let outgoing_request_impl = self.expand_outgoing(metadata, error_ty, &lifetimes, ruma_api); | ||||
|         let incoming_request_impl = self.expand_incoming(metadata, error_ty, ruma_api); | ||||
|         let request_ident = Ident::new("Request", self.request_kw.span()); | ||||
|         let lifetimes = self.all_lifetimes(); | ||||
|         let lifetimes = lifetimes.iter().map(|(lt, attr)| quote! { #attr #lt }); | ||||
|         let fields = &self.fields; | ||||
| 
 | ||||
|         quote! { | ||||
|             #[doc = #docs] | ||||
|             #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] | ||||
|             #[derive(
 | ||||
|                 Clone, | ||||
|                 Debug, | ||||
|                 #ruma_api_macros::Request, | ||||
|                 #ruma_serde::Outgoing, | ||||
|                 #ruma_serde::_FakeDeriveSerde, | ||||
|             )] | ||||
|             #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] | ||||
|             #[incoming_derive(!Deserialize)] | ||||
|             #[incoming_derive(!Deserialize, #ruma_api_macros::_FakeDeriveRumaApi)] | ||||
|             #[ruma_api(
 | ||||
|                 method = #method, | ||||
|                 path = #path, | ||||
|                 error_ty = #error_ty, | ||||
|             )] | ||||
|             #( #auth_attributes )* | ||||
|             #( #struct_attributes )* | ||||
|             pub struct Request #lifetimes { | ||||
|                 #(#fields),* | ||||
|             pub struct #request_ident < #(#lifetimes),* > { | ||||
|                 #fields | ||||
|             } | ||||
| 
 | ||||
|             #request_body_struct | ||||
|             #request_query_struct | ||||
| 
 | ||||
|             #outgoing_request_impl | ||||
|             #incoming_request_impl | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a request can have.
 | ||||
| pub(crate) enum RequestField { | ||||
|     /// JSON data in the body of the request.
 | ||||
|     Body(Field), | ||||
| 
 | ||||
|     /// Data in an HTTP header.
 | ||||
|     Header(Field, Ident), | ||||
| 
 | ||||
|     /// A specific data type in the body of the request.
 | ||||
|     NewtypeBody(Field), | ||||
| 
 | ||||
|     /// Arbitrary bytes in the body of the request.
 | ||||
|     NewtypeRawBody(Field), | ||||
| 
 | ||||
|     /// Data that appears in the URL path.
 | ||||
|     Path(Field), | ||||
| 
 | ||||
|     /// Data that appears in the query string.
 | ||||
|     Query(Field), | ||||
| 
 | ||||
|     /// Data that appears in the query string as dynamic key-value pairs.
 | ||||
|     QueryMap(Field), | ||||
| } | ||||
| 
 | ||||
| impl RequestField { | ||||
|     /// Creates a new `RequestField`.
 | ||||
|     pub(super) fn new(kind: RequestFieldKind, field: Field, header: Option<Ident>) -> Self { | ||||
|         match kind { | ||||
|             RequestFieldKind::Body => RequestField::Body(field), | ||||
|             RequestFieldKind::Header => { | ||||
|                 RequestField::Header(field, header.expect("missing header name")) | ||||
|             } | ||||
|             RequestFieldKind::NewtypeBody => RequestField::NewtypeBody(field), | ||||
|             RequestFieldKind::NewtypeRawBody => RequestField::NewtypeRawBody(field), | ||||
|             RequestFieldKind::Path => RequestField::Path(field), | ||||
|             RequestFieldKind::Query => RequestField::Query(field), | ||||
|             RequestFieldKind::QueryMap => RequestField::QueryMap(field), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request field is a body kind.
 | ||||
|     pub(super) fn is_body(&self) -> bool { | ||||
|         matches!(self, RequestField::Body(..)) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request field is a header kind.
 | ||||
|     fn is_header(&self) -> bool { | ||||
|         matches!(self, RequestField::Header(..)) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request field is a newtype body kind.
 | ||||
|     fn is_newtype_body(&self) -> bool { | ||||
|         matches!(self, RequestField::NewtypeBody(..)) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request field is a path kind.
 | ||||
|     fn is_path(&self) -> bool { | ||||
|         matches!(self, RequestField::Path(..)) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request field is a query string kind.
 | ||||
|     pub(super) fn is_query(&self) -> bool { | ||||
|         matches!(self, RequestField::Query(..)) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a body kind.
 | ||||
|     fn as_body_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::Body) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a body kind.
 | ||||
|     fn as_newtype_body_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::NewtypeBody) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a raw body kind.
 | ||||
|     fn as_newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::NewtypeRawBody) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a query kind.
 | ||||
|     fn as_query_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::Query) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a query map kind.
 | ||||
|     fn as_query_map_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::QueryMap) | ||||
|     } | ||||
| 
 | ||||
|     /// Gets the inner `Field` value.
 | ||||
|     fn field(&self) -> &Field { | ||||
|         match self { | ||||
|             RequestField::Body(field) | ||||
|             | RequestField::Header(field, _) | ||||
|             | RequestField::NewtypeBody(field) | ||||
|             | RequestField::NewtypeRawBody(field) | ||||
|             | RequestField::Path(field) | ||||
|             | RequestField::Query(field) | ||||
|             | RequestField::QueryMap(field) => field, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Gets the inner `Field` value if it's of the provided kind.
 | ||||
|     fn field_of_kind(&self, kind: RequestFieldKind) -> Option<&Field> { | ||||
|         match (self, kind) { | ||||
|             (RequestField::Body(field), RequestFieldKind::Body) | ||||
|             | (RequestField::Header(field, _), RequestFieldKind::Header) | ||||
|             | (RequestField::NewtypeBody(field), RequestFieldKind::NewtypeBody) | ||||
|             | (RequestField::NewtypeRawBody(field), RequestFieldKind::NewtypeRawBody) | ||||
|             | (RequestField::Path(field), RequestFieldKind::Path) | ||||
|             | (RequestField::Query(field), RequestFieldKind::Query) | ||||
|             | (RequestField::QueryMap(field), RequestFieldKind::QueryMap) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a request can have, without their values.
 | ||||
| #[derive(Clone, Copy, PartialEq, Eq)] | ||||
| pub(crate) enum RequestFieldKind { | ||||
|     Body, | ||||
|     Header, | ||||
|     NewtypeBody, | ||||
|     NewtypeRawBody, | ||||
|     Path, | ||||
|     Query, | ||||
|     QueryMap, | ||||
| } | ||||
|  | ||||
| @ -2,76 +2,40 @@ | ||||
| 
 | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::quote; | ||||
| use syn::{Attribute, Field, Ident}; | ||||
| use syn::{punctuated::Punctuated, spanned::Spanned, Attribute, Field, Ident, Token}; | ||||
| 
 | ||||
| use super::metadata::Metadata; | ||||
| 
 | ||||
| mod incoming; | ||||
| mod outgoing; | ||||
| use super::{kw, metadata::Metadata}; | ||||
| 
 | ||||
| /// The result of processing the `response` section of the macro.
 | ||||
| pub(crate) struct Response { | ||||
|     /// The `response` keyword
 | ||||
|     pub(super) response_kw: kw::response, | ||||
| 
 | ||||
|     /// The attributes that will be applied to the struct definition.
 | ||||
|     pub attributes: Vec<Attribute>, | ||||
| 
 | ||||
|     /// The fields of the response.
 | ||||
|     pub fields: Vec<ResponseField>, | ||||
|     pub fields: Punctuated<Field, Token![,]>, | ||||
| } | ||||
| 
 | ||||
| impl Response { | ||||
|     /// Whether or not this response has any data in the HTTP body.
 | ||||
|     fn has_body_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|field| field.is_body()) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this response has any data in HTTP headers.
 | ||||
|     fn has_header_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|field| field.is_header()) | ||||
|     } | ||||
| 
 | ||||
|     /// Gets the newtype body field, if this response has one.
 | ||||
|     fn newtype_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(ResponseField::as_newtype_body_field) | ||||
|     } | ||||
| 
 | ||||
|     /// Gets the newtype raw body field, if this response has one.
 | ||||
|     fn newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(ResponseField::as_newtype_raw_body_field) | ||||
|     } | ||||
| 
 | ||||
|     pub(super) fn expand( | ||||
|         &self, | ||||
|         metadata: &Metadata, | ||||
|         error_ty: &TokenStream, | ||||
|         ruma_api: &TokenStream, | ||||
|     ) -> TokenStream { | ||||
|         let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; | ||||
|         let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; | ||||
|         let serde = quote! { #ruma_api::exports::serde }; | ||||
| 
 | ||||
|         let docs = | ||||
|             format!("Data in the response from the `{}` API endpoint.", metadata.name.value()); | ||||
|         let struct_attributes = &self.attributes; | ||||
| 
 | ||||
|         let def = if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { | ||||
|             let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; | ||||
|             quote! { (#field); } | ||||
|         } else if self.has_body_fields() { | ||||
|             let fields = self.fields.iter().filter(|f| f.is_body()).map(ResponseField::field); | ||||
|             quote! { { #(#fields),* } } | ||||
|         } else { | ||||
|             quote! { {} } | ||||
|         }; | ||||
| 
 | ||||
|         let response_body_struct = quote! { | ||||
|             /// Data in the response body.
 | ||||
|             #[derive(Debug, #ruma_serde::Outgoing, #serde::Deserialize, #serde::Serialize)] | ||||
|             struct ResponseBody #def | ||||
|         }; | ||||
| 
 | ||||
|         let has_test_exhaustive_field = self | ||||
|             .fields | ||||
|             .iter() | ||||
|             .filter_map(|f| f.field().ident.as_ref()) | ||||
|             .filter_map(|f| f.ident.as_ref()) | ||||
|             .any(|ident| ident == "__test_exhaustive"); | ||||
| 
 | ||||
|         let non_exhaustive_attr = if has_test_exhaustive_field { | ||||
| @ -80,99 +44,24 @@ impl Response { | ||||
|             quote! { #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] } | ||||
|         }; | ||||
| 
 | ||||
|         let fields = self.fields.iter().map(|response_field| response_field.field()); | ||||
| 
 | ||||
|         let outgoing_response_impl = self.expand_outgoing(ruma_api); | ||||
|         let incoming_response_impl = self.expand_incoming(error_ty, ruma_api); | ||||
| 
 | ||||
|         let response_ident = Ident::new("Response", self.response_kw.span()); | ||||
|         let fields = &self.fields; | ||||
|         quote! { | ||||
|             #[doc = #docs] | ||||
|             #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] | ||||
|             #[derive(
 | ||||
|                 Clone, | ||||
|                 Debug, | ||||
|                 #ruma_api_macros::Response, | ||||
|                 #ruma_serde::Outgoing, | ||||
|                 #ruma_serde::_FakeDeriveSerde, | ||||
|             )] | ||||
|             #non_exhaustive_attr | ||||
|             #[incoming_derive(!Deserialize)] | ||||
|             #[incoming_derive(!Deserialize, #ruma_api_macros::_FakeDeriveRumaApi)] | ||||
|             #[ruma_api(error_ty = #error_ty)] | ||||
|             #( #struct_attributes )* | ||||
|             pub struct Response { | ||||
|                 #(#fields),* | ||||
|             pub struct #response_ident { | ||||
|                 #fields | ||||
|             } | ||||
| 
 | ||||
|             #response_body_struct | ||||
| 
 | ||||
|             #outgoing_response_impl | ||||
|             #incoming_response_impl | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a response can have.
 | ||||
| pub(crate) enum ResponseField { | ||||
|     /// JSON data in the body of the response.
 | ||||
|     Body(Field), | ||||
| 
 | ||||
|     /// Data in an HTTP header.
 | ||||
|     Header(Field, Ident), | ||||
| 
 | ||||
|     /// A specific data type in the body of the response.
 | ||||
|     NewtypeBody(Field), | ||||
| 
 | ||||
|     /// Arbitrary bytes in the body of the response.
 | ||||
|     NewtypeRawBody(Field), | ||||
| } | ||||
| 
 | ||||
| impl ResponseField { | ||||
|     /// Gets the inner `Field` value.
 | ||||
|     fn field(&self) -> &Field { | ||||
|         match self { | ||||
|             ResponseField::Body(field) | ||||
|             | ResponseField::Header(field, _) | ||||
|             | ResponseField::NewtypeBody(field) | ||||
|             | ResponseField::NewtypeRawBody(field) => field, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this response field is a body kind.
 | ||||
|     pub(super) fn is_body(&self) -> bool { | ||||
|         self.as_body_field().is_some() | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this response field is a header kind.
 | ||||
|     fn is_header(&self) -> bool { | ||||
|         matches!(self, ResponseField::Header(..)) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this response field is a newtype body kind.
 | ||||
|     fn is_newtype_body(&self) -> bool { | ||||
|         self.as_newtype_body_field().is_some() | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this response field is a body kind.
 | ||||
|     fn as_body_field(&self) -> Option<&Field> { | ||||
|         match self { | ||||
|             ResponseField::Body(field) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this response field is a newtype body kind.
 | ||||
|     fn as_newtype_body_field(&self) -> Option<&Field> { | ||||
|         match self { | ||||
|             ResponseField::NewtypeBody(field) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this response field is a newtype raw body kind.
 | ||||
|     fn as_newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         match self { | ||||
|             ResponseField::NewtypeRawBody(field) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a response can have, without their values.
 | ||||
| pub(crate) enum ResponseFieldKind { | ||||
|     Body, | ||||
|     Header, | ||||
|     NewtypeBody, | ||||
|     NewtypeRawBody, | ||||
| } | ||||
|  | ||||
| @ -2,17 +2,42 @@ | ||||
| 
 | ||||
| use syn::{ | ||||
|     parse::{Parse, ParseStream}, | ||||
|     Ident, Token, | ||||
|     Ident, Lit, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| /// Value type used for request and response struct attributes
 | ||||
| #[allow(clippy::large_enum_variant)] | ||||
| pub enum MetaValue { | ||||
|     Lit(Lit), | ||||
|     Type(Type), | ||||
| } | ||||
| 
 | ||||
| impl Parse for MetaValue { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         if input.peek(Lit) { | ||||
|             input.parse().map(Self::Lit) | ||||
|         } else { | ||||
|             input.parse().map(Self::Type) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Like syn::MetaNameValue, but expects an identifier as the value. Also, we don't care about the
 | ||||
| /// the span of the equals sign, so we don't have the `eq_token` field from syn::MetaNameValue.
 | ||||
| pub struct MetaNameValue { | ||||
| pub struct MetaNameValue<V> { | ||||
|     /// The part left of the equals sign
 | ||||
|     pub name: Ident, | ||||
| 
 | ||||
|     /// The part right of the equals sign
 | ||||
|     pub value: Ident, | ||||
|     pub value: V, | ||||
| } | ||||
| 
 | ||||
| impl<V: Parse> Parse for MetaNameValue<V> { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         let ident = input.parse()?; | ||||
|         let _: Token![=] = input.parse()?; | ||||
|         Ok(MetaNameValue { name: ident, value: input.parse()? }) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Like syn::Meta, but only parses ruma_api attributes
 | ||||
| @ -21,7 +46,7 @@ pub enum Meta { | ||||
|     Word(Ident), | ||||
| 
 | ||||
|     /// A name-value pair, like `header = CONTENT_TYPE` in `#[ruma_api(header = CONTENT_TYPE)]`
 | ||||
|     NameValue(MetaNameValue), | ||||
|     NameValue(MetaNameValue<Ident>), | ||||
| } | ||||
| 
 | ||||
| impl Meta { | ||||
							
								
								
									
										46
									
								
								crates/ruma-api-macros/src/auth_scheme.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								crates/ruma-api-macros/src/auth_scheme.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,46 @@ | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::ToTokens; | ||||
| use syn::parse::{Parse, ParseStream}; | ||||
| 
 | ||||
| mod kw { | ||||
|     syn::custom_keyword!(None); | ||||
|     syn::custom_keyword!(AccessToken); | ||||
|     syn::custom_keyword!(ServerSignatures); | ||||
|     syn::custom_keyword!(QueryOnlyAccessToken); | ||||
| } | ||||
| 
 | ||||
| pub enum AuthScheme { | ||||
|     None(kw::None), | ||||
|     AccessToken(kw::AccessToken), | ||||
|     ServerSignatures(kw::ServerSignatures), | ||||
|     QueryOnlyAccessToken(kw::QueryOnlyAccessToken), | ||||
| } | ||||
| 
 | ||||
| impl Parse for AuthScheme { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         let lookahead = input.lookahead1(); | ||||
| 
 | ||||
|         if lookahead.peek(kw::None) { | ||||
|             input.parse().map(Self::None) | ||||
|         } else if lookahead.peek(kw::AccessToken) { | ||||
|             input.parse().map(Self::AccessToken) | ||||
|         } else if lookahead.peek(kw::ServerSignatures) { | ||||
|             input.parse().map(Self::ServerSignatures) | ||||
|         } else if lookahead.peek(kw::QueryOnlyAccessToken) { | ||||
|             input.parse().map(Self::QueryOnlyAccessToken) | ||||
|         } else { | ||||
|             Err(lookahead.error()) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl ToTokens for AuthScheme { | ||||
|     fn to_tokens(&self, tokens: &mut TokenStream) { | ||||
|         match self { | ||||
|             AuthScheme::None(kw) => kw.to_tokens(tokens), | ||||
|             AuthScheme::AccessToken(kw) => kw.to_tokens(tokens), | ||||
|             AuthScheme::ServerSignatures(kw) => kw.to_tokens(tokens), | ||||
|             AuthScheme::QueryOnlyAccessToken(kw) => kw.to_tokens(tokens), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -11,15 +11,44 @@ | ||||
| #![recursion_limit = "256"] | ||||
| 
 | ||||
| use proc_macro::TokenStream; | ||||
| use syn::parse_macro_input; | ||||
| 
 | ||||
| use self::api::Api; | ||||
| use syn::{parse_macro_input, DeriveInput}; | ||||
| 
 | ||||
| mod api; | ||||
| mod attribute; | ||||
| mod auth_scheme; | ||||
| mod request; | ||||
| mod response; | ||||
| mod util; | ||||
| 
 | ||||
| use api::Api; | ||||
| use request::expand_derive_request; | ||||
| use response::expand_derive_response; | ||||
| 
 | ||||
| #[proc_macro] | ||||
| pub fn ruma_api(input: TokenStream) -> TokenStream { | ||||
|     let api = parse_macro_input!(input as Api); | ||||
|     api::expand_all(api).unwrap_or_else(syn::Error::into_compile_error).into() | ||||
|     api.expand_all().into() | ||||
| } | ||||
| 
 | ||||
| /// Internal helper taking care of the request-specific parts of `ruma_api!`.
 | ||||
| #[proc_macro_derive(Request, attributes(ruma_api))] | ||||
| pub fn derive_request(input: TokenStream) -> TokenStream { | ||||
|     let input = parse_macro_input!(input as DeriveInput); | ||||
|     expand_derive_request(input).unwrap_or_else(syn::Error::into_compile_error).into() | ||||
| } | ||||
| 
 | ||||
| /// Internal helper taking care of the response-specific parts of `ruma_api!`.
 | ||||
| #[proc_macro_derive(Response, attributes(ruma_api))] | ||||
| pub fn derive_response(input: TokenStream) -> TokenStream { | ||||
|     let input = parse_macro_input!(input as DeriveInput); | ||||
|     expand_derive_response(input).unwrap_or_else(syn::Error::into_compile_error).into() | ||||
| } | ||||
| 
 | ||||
| /// A derive macro that generates no code, but registers the ruma_api attribute so both
 | ||||
| /// `#[ruma_api(...)]` and `#[cfg_attr(..., ruma_api(...))]` are accepted on the type, its fields
 | ||||
| /// and (in case the input is an enum) variants fields.
 | ||||
| #[doc(hidden)] | ||||
| #[proc_macro_derive(_FakeDeriveRumaApi, attributes(ruma_api))] | ||||
| pub fn fake_derive_ruma_api(_input: TokenStream) -> TokenStream { | ||||
|     TokenStream::new() | ||||
| } | ||||
|  | ||||
							
								
								
									
										491
									
								
								crates/ruma-api-macros/src/request.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										491
									
								
								crates/ruma-api-macros/src/request.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,491 @@ | ||||
| use std::{ | ||||
|     collections::BTreeSet, | ||||
|     convert::{TryFrom, TryInto}, | ||||
|     mem, | ||||
| }; | ||||
| 
 | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::{quote, ToTokens}; | ||||
| use syn::{ | ||||
|     parse::{Parse, ParseStream}, | ||||
|     parse_quote, | ||||
|     punctuated::Punctuated, | ||||
|     DeriveInput, Field, Generics, Ident, Lifetime, Lit, LitStr, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| use crate::{ | ||||
|     attribute::{Meta, MetaNameValue, MetaValue}, | ||||
|     auth_scheme::AuthScheme, | ||||
|     util::{collect_lifetime_idents, import_ruma_api}, | ||||
| }; | ||||
| 
 | ||||
| mod incoming; | ||||
| mod outgoing; | ||||
| 
 | ||||
| pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> { | ||||
|     let fields = match input.data { | ||||
|         syn::Data::Struct(s) => s.fields, | ||||
|         _ => panic!("This derive macro only works on structs"), | ||||
|     }; | ||||
| 
 | ||||
|     let mut lifetimes = RequestLifetimes::default(); | ||||
|     let fields = fields | ||||
|         .into_iter() | ||||
|         .map(|f| { | ||||
|             let f = RequestField::try_from(f)?; | ||||
|             let ty = &f.field().ty; | ||||
| 
 | ||||
|             match &f { | ||||
|                 RequestField::Header(..) => collect_lifetime_idents(&mut lifetimes.header, ty), | ||||
|                 RequestField::Body(_) => collect_lifetime_idents(&mut lifetimes.body, ty), | ||||
|                 RequestField::NewtypeBody(_) => collect_lifetime_idents(&mut lifetimes.body, ty), | ||||
|                 RequestField::NewtypeRawBody(_) => collect_lifetime_idents(&mut lifetimes.body, ty), | ||||
|                 RequestField::Path(_) => collect_lifetime_idents(&mut lifetimes.path, ty), | ||||
|                 RequestField::Query(_) => collect_lifetime_idents(&mut lifetimes.query, ty), | ||||
|                 RequestField::QueryMap(_) => collect_lifetime_idents(&mut lifetimes.query, ty), | ||||
|             } | ||||
| 
 | ||||
|             Ok(f) | ||||
|         }) | ||||
|         .collect::<syn::Result<_>>()?; | ||||
| 
 | ||||
|     let mut authentication = None; | ||||
|     let mut error_ty = None; | ||||
|     let mut method = None; | ||||
|     let mut path = None; | ||||
| 
 | ||||
|     for attr in input.attrs { | ||||
|         if !attr.path.is_ident("ruma_api") { | ||||
|             continue; | ||||
|         } | ||||
| 
 | ||||
|         let meta = attr.parse_args_with(Punctuated::<_, Token![,]>::parse_terminated)?; | ||||
|         for MetaNameValue { name, value } in meta { | ||||
|             match value { | ||||
|                 MetaValue::Type(t) if name == "authentication" => { | ||||
|                     authentication = Some(parse_quote!(#t)); | ||||
|                 } | ||||
|                 MetaValue::Type(t) if name == "method" => { | ||||
|                     method = Some(parse_quote!(#t)); | ||||
|                 } | ||||
|                 MetaValue::Type(t) if name == "error_ty" => { | ||||
|                     error_ty = Some(t); | ||||
|                 } | ||||
|                 MetaValue::Lit(Lit::Str(s)) if name == "path" => { | ||||
|                     path = Some(s); | ||||
|                 } | ||||
|                 _ => unreachable!("invalid ruma_api({}) attribute", name), | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     let request = Request { | ||||
|         ident: input.ident, | ||||
|         generics: input.generics, | ||||
|         fields, | ||||
|         lifetimes, | ||||
|         authentication: authentication.expect("missing authentication attribute"), | ||||
|         method: method.expect("missing method attribute"), | ||||
|         path: path.expect("missing path attribute"), | ||||
|         error_ty: error_ty.expect("missing error_ty attribute"), | ||||
|     }; | ||||
| 
 | ||||
|     request.check()?; | ||||
|     Ok(request.expand_all()) | ||||
| } | ||||
| 
 | ||||
| #[derive(Default)] | ||||
| struct RequestLifetimes { | ||||
|     pub body: BTreeSet<Lifetime>, | ||||
|     pub path: BTreeSet<Lifetime>, | ||||
|     pub query: BTreeSet<Lifetime>, | ||||
|     pub header: BTreeSet<Lifetime>, | ||||
| } | ||||
| 
 | ||||
| struct Request { | ||||
|     ident: Ident, | ||||
|     generics: Generics, | ||||
|     lifetimes: RequestLifetimes, | ||||
|     fields: Vec<RequestField>, | ||||
| 
 | ||||
|     authentication: AuthScheme, | ||||
|     method: Ident, | ||||
|     path: LitStr, | ||||
|     error_ty: Type, | ||||
| } | ||||
| 
 | ||||
| impl Request { | ||||
|     fn body_fields(&self) -> impl Iterator<Item = &Field> { | ||||
|         self.fields.iter().filter_map(RequestField::as_body_field) | ||||
|     } | ||||
| 
 | ||||
|     fn query_fields(&self) -> impl Iterator<Item = &Field> { | ||||
|         self.fields.iter().filter_map(RequestField::as_query_field) | ||||
|     } | ||||
| 
 | ||||
|     fn has_body_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|f| matches!(f, RequestField::Body(..))) | ||||
|     } | ||||
| 
 | ||||
|     fn has_header_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|f| matches!(f, RequestField::Header(..))) | ||||
|     } | ||||
| 
 | ||||
|     fn has_path_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|f| matches!(f, RequestField::Path(..))) | ||||
|     } | ||||
| 
 | ||||
|     fn has_query_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|f| matches!(f, RequestField::Query(..))) | ||||
|     } | ||||
| 
 | ||||
|     fn has_lifetimes(&self) -> bool { | ||||
|         !(self.lifetimes.body.is_empty() | ||||
|             && self.lifetimes.path.is_empty() | ||||
|             && self.lifetimes.query.is_empty() | ||||
|             && self.lifetimes.header.is_empty()) | ||||
|     } | ||||
| 
 | ||||
|     fn header_fields(&self) -> impl Iterator<Item = &RequestField> { | ||||
|         self.fields.iter().filter(|f| matches!(f, RequestField::Header(..))) | ||||
|     } | ||||
| 
 | ||||
|     fn path_field_count(&self) -> usize { | ||||
|         self.fields.iter().filter(|f| matches!(f, RequestField::Path(..))).count() | ||||
|     } | ||||
| 
 | ||||
|     fn newtype_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(RequestField::as_newtype_body_field) | ||||
|     } | ||||
| 
 | ||||
|     fn newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) | ||||
|     } | ||||
| 
 | ||||
|     fn query_map_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(RequestField::as_query_map_field) | ||||
|     } | ||||
| 
 | ||||
|     fn expand_all(&self) -> TokenStream { | ||||
|         let ruma_api = import_ruma_api(); | ||||
|         let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; | ||||
|         let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; | ||||
|         let serde = quote! { #ruma_api::exports::serde }; | ||||
| 
 | ||||
|         let request_body_def = if let Some(body_field) = self.newtype_body_field() { | ||||
|             let field = Field { ident: None, colon_token: None, ..body_field.clone() }; | ||||
|             Some(quote! { (#field); }) | ||||
|         } else if self.has_body_fields() { | ||||
|             let fields = self.fields.iter().filter_map(RequestField::as_body_field); | ||||
|             Some(quote! { { #(#fields),* } }) | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
| 
 | ||||
|         let request_body_struct = request_body_def.map(|def| { | ||||
|             // Though we don't track the difference between newtype body and body
 | ||||
|             // for lifetimes, the outer check and the macro failing if it encounters
 | ||||
|             // an illegal combination of field attributes, is enough to guarantee
 | ||||
|             // `body_lifetimes` correctness.
 | ||||
|             let (derive_deserialize, generics) = if self.lifetimes.body.is_empty() { | ||||
|                 (quote! { #serde::Deserialize }, TokenStream::new()) | ||||
|             } else { | ||||
|                 let lifetimes = &self.lifetimes.body; | ||||
|                 (TokenStream::new(), quote! { < #(#lifetimes),* > }) | ||||
|             }; | ||||
| 
 | ||||
|             quote! { | ||||
|                 /// Data in the request body.
 | ||||
|                 #[derive(
 | ||||
|                     Debug, | ||||
|                     #ruma_api_macros::_FakeDeriveRumaApi, | ||||
|                     #ruma_serde::Outgoing, | ||||
|                     #serde::Serialize, | ||||
|                     #derive_deserialize | ||||
|                 )] | ||||
|                 struct RequestBody #generics #def | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         let request_query_def = if let Some(f) = self.query_map_field() { | ||||
|             let field = Field { ident: None, colon_token: None, ..f.clone() }; | ||||
|             Some(quote! { (#field); }) | ||||
|         } else if self.has_query_fields() { | ||||
|             let fields = self.fields.iter().filter_map(RequestField::as_query_field); | ||||
|             Some(quote! { { #(#fields),* } }) | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
| 
 | ||||
|         let request_query_struct = request_query_def.map(|def| { | ||||
|             let (derive_deserialize, generics) = if self.lifetimes.query.is_empty() { | ||||
|                 (quote! { #serde::Deserialize }, TokenStream::new()) | ||||
|             } else { | ||||
|                 let lifetimes = &self.lifetimes.query; | ||||
|                 (TokenStream::new(), quote! { < #(#lifetimes),* > }) | ||||
|             }; | ||||
| 
 | ||||
|             quote! { | ||||
|                 /// Data in the request's query string.
 | ||||
|                 #[derive(
 | ||||
|                     Debug, | ||||
|                     #ruma_api_macros::_FakeDeriveRumaApi, | ||||
|                     #ruma_serde::Outgoing, | ||||
|                     #serde::Serialize, | ||||
|                     #derive_deserialize | ||||
|                 )] | ||||
|                 struct RequestQuery #generics #def | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         let outgoing_request_impl = self.expand_outgoing(&ruma_api); | ||||
|         let incoming_request_impl = self.expand_incoming(&ruma_api); | ||||
| 
 | ||||
|         quote! { | ||||
|             #request_body_struct | ||||
|             #request_query_struct | ||||
| 
 | ||||
|             #outgoing_request_impl | ||||
|             #incoming_request_impl | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub(super) fn check(&self) -> syn::Result<()> { | ||||
|         // TODO: highlight problematic fields
 | ||||
| 
 | ||||
|         let newtype_body_fields = self.fields.iter().filter(|field| { | ||||
|             matches!(field, RequestField::NewtypeBody(_) | RequestField::NewtypeRawBody(_)) | ||||
|         }); | ||||
| 
 | ||||
|         let has_newtype_body_field = match newtype_body_fields.count() { | ||||
|             0 => false, | ||||
|             1 => true, | ||||
|             _ => { | ||||
|                 return Err(syn::Error::new_spanned( | ||||
|                     &self.ident, | ||||
|                     "Can't have more than one newtype body field", | ||||
|                 )) | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         let query_map_fields = | ||||
|             self.fields.iter().filter(|f| matches!(f, RequestField::QueryMap(_))); | ||||
|         let has_query_map_field = match query_map_fields.count() { | ||||
|             0 => false, | ||||
|             1 => true, | ||||
|             _ => { | ||||
|                 return Err(syn::Error::new_spanned( | ||||
|                     &self.ident, | ||||
|                     "Can't have more than one query_map field", | ||||
|                 )) | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         let has_body_fields = self.body_fields().count() > 0; | ||||
|         let has_query_fields = self.query_fields().count() > 0; | ||||
| 
 | ||||
|         if has_newtype_body_field && has_body_fields { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 &self.ident, | ||||
|                 "Can't have both a newtype body field and regular body fields", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         if has_query_map_field && has_query_fields { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 &self.ident, | ||||
|                 "Can't have both a query map field and regular query fields", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         // TODO when/if `&[(&str, &str)]` is supported remove this
 | ||||
|         if has_query_map_field && !self.lifetimes.query.is_empty() { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 &self.ident, | ||||
|                 "Lifetimes are not allowed for query_map fields", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         if self.method == "GET" && (has_body_fields || has_newtype_body_field) { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 &self.ident, | ||||
|                 "GET endpoints can't have body fields", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a request can have.
 | ||||
| enum RequestField { | ||||
|     /// JSON data in the body of the request.
 | ||||
|     Body(Field), | ||||
| 
 | ||||
|     /// Data in an HTTP header.
 | ||||
|     Header(Field, Ident), | ||||
| 
 | ||||
|     /// A specific data type in the body of the request.
 | ||||
|     NewtypeBody(Field), | ||||
| 
 | ||||
|     /// Arbitrary bytes in the body of the request.
 | ||||
|     NewtypeRawBody(Field), | ||||
| 
 | ||||
|     /// Data that appears in the URL path.
 | ||||
|     Path(Field), | ||||
| 
 | ||||
|     /// Data that appears in the query string.
 | ||||
|     Query(Field), | ||||
| 
 | ||||
|     /// Data that appears in the query string as dynamic key-value pairs.
 | ||||
|     QueryMap(Field), | ||||
| } | ||||
| 
 | ||||
| impl RequestField { | ||||
|     /// Creates a new `RequestField`.
 | ||||
|     fn new(kind: RequestFieldKind, field: Field, header: Option<Ident>) -> Self { | ||||
|         match kind { | ||||
|             RequestFieldKind::Body => RequestField::Body(field), | ||||
|             RequestFieldKind::Header => { | ||||
|                 RequestField::Header(field, header.expect("missing header name")) | ||||
|             } | ||||
|             RequestFieldKind::NewtypeBody => RequestField::NewtypeBody(field), | ||||
|             RequestFieldKind::NewtypeRawBody => RequestField::NewtypeRawBody(field), | ||||
|             RequestFieldKind::Path => RequestField::Path(field), | ||||
|             RequestFieldKind::Query => RequestField::Query(field), | ||||
|             RequestFieldKind::QueryMap => RequestField::QueryMap(field), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a body kind.
 | ||||
|     pub fn as_body_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::Body) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a body kind.
 | ||||
|     pub fn as_newtype_body_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::NewtypeBody) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a raw body kind.
 | ||||
|     pub fn as_newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::NewtypeRawBody) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a query kind.
 | ||||
|     pub fn as_query_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::Query) | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this request field is a query map kind.
 | ||||
|     pub fn as_query_map_field(&self) -> Option<&Field> { | ||||
|         self.field_of_kind(RequestFieldKind::QueryMap) | ||||
|     } | ||||
| 
 | ||||
|     /// Gets the inner `Field` value.
 | ||||
|     pub fn field(&self) -> &Field { | ||||
|         match self { | ||||
|             RequestField::Body(field) | ||||
|             | RequestField::Header(field, _) | ||||
|             | RequestField::NewtypeBody(field) | ||||
|             | RequestField::NewtypeRawBody(field) | ||||
|             | RequestField::Path(field) | ||||
|             | RequestField::Query(field) | ||||
|             | RequestField::QueryMap(field) => field, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Gets the inner `Field` value if it's of the provided kind.
 | ||||
|     fn field_of_kind(&self, kind: RequestFieldKind) -> Option<&Field> { | ||||
|         match (self, kind) { | ||||
|             (RequestField::Body(field), RequestFieldKind::Body) | ||||
|             | (RequestField::Header(field, _), RequestFieldKind::Header) | ||||
|             | (RequestField::NewtypeBody(field), RequestFieldKind::NewtypeBody) | ||||
|             | (RequestField::NewtypeRawBody(field), RequestFieldKind::NewtypeRawBody) | ||||
|             | (RequestField::Path(field), RequestFieldKind::Path) | ||||
|             | (RequestField::Query(field), RequestFieldKind::Query) | ||||
|             | (RequestField::QueryMap(field), RequestFieldKind::QueryMap) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl TryFrom<Field> for RequestField { | ||||
|     type Error = syn::Error; | ||||
| 
 | ||||
|     fn try_from(mut field: Field) -> syn::Result<Self> { | ||||
|         let mut field_kind = None; | ||||
|         let mut header = None; | ||||
| 
 | ||||
|         for attr in mem::take(&mut field.attrs) { | ||||
|             let meta = match Meta::from_attribute(&attr)? { | ||||
|                 Some(m) => m, | ||||
|                 None => { | ||||
|                     field.attrs.push(attr); | ||||
|                     continue; | ||||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|             if field_kind.is_some() { | ||||
|                 return Err(syn::Error::new_spanned( | ||||
|                     attr, | ||||
|                     "There can only be one field kind attribute", | ||||
|                 )); | ||||
|             } | ||||
| 
 | ||||
|             field_kind = Some(match meta { | ||||
|                 Meta::Word(ident) => match &ident.to_string()[..] { | ||||
|                     "body" => RequestFieldKind::Body, | ||||
|                     "raw_body" => RequestFieldKind::NewtypeRawBody, | ||||
|                     "path" => RequestFieldKind::Path, | ||||
|                     "query" => RequestFieldKind::Query, | ||||
|                     "query_map" => RequestFieldKind::QueryMap, | ||||
|                     _ => { | ||||
|                         return Err(syn::Error::new_spanned( | ||||
|                             ident, | ||||
|                             "Invalid #[ruma_api] argument, expected one of \ | ||||
|                             `body`, `raw_body`, `path`, `query`, `query_map`",
 | ||||
|                         )); | ||||
|                     } | ||||
|                 }, | ||||
|                 Meta::NameValue(MetaNameValue { name, value }) => { | ||||
|                     if name != "header" { | ||||
|                         return Err(syn::Error::new_spanned( | ||||
|                             name, | ||||
|                             "Invalid #[ruma_api] argument with value, expected `header`", | ||||
|                         )); | ||||
|                     } | ||||
| 
 | ||||
|                     header = Some(value); | ||||
|                     RequestFieldKind::Header | ||||
|                 } | ||||
|             }); | ||||
|         } | ||||
| 
 | ||||
|         Ok(RequestField::new(field_kind.unwrap_or(RequestFieldKind::Body), field, header)) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Parse for RequestField { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         input.call(Field::parse_named)?.try_into() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl ToTokens for RequestField { | ||||
|     fn to_tokens(&self, tokens: &mut TokenStream) { | ||||
|         self.field().to_tokens(tokens) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a request can have, without their values.
 | ||||
| #[derive(Clone, Copy, PartialEq, Eq)] | ||||
| enum RequestFieldKind { | ||||
|     Body, | ||||
|     Header, | ||||
|     NewtypeBody, | ||||
|     NewtypeRawBody, | ||||
|     Path, | ||||
|     Query, | ||||
|     QueryMap, | ||||
| } | ||||
| @ -2,23 +2,19 @@ use proc_macro2::{Ident, Span, TokenStream}; | ||||
| use quote::quote; | ||||
| 
 | ||||
| use super::{Request, RequestField, RequestFieldKind}; | ||||
| use crate::api::metadata::{AuthScheme, Metadata}; | ||||
| use crate::auth_scheme::AuthScheme; | ||||
| 
 | ||||
| impl Request { | ||||
|     pub fn expand_incoming( | ||||
|         &self, | ||||
|         metadata: &Metadata, | ||||
|         error_ty: &TokenStream, | ||||
|         ruma_api: &TokenStream, | ||||
|     ) -> TokenStream { | ||||
|     pub fn expand_incoming(&self, ruma_api: &TokenStream) -> TokenStream { | ||||
|         let http = quote! { #ruma_api::exports::http }; | ||||
|         let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; | ||||
|         let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; | ||||
|         let serde_json = quote! { #ruma_api::exports::serde_json }; | ||||
| 
 | ||||
|         let method = &metadata.method; | ||||
|         let method = &self.method; | ||||
|         let error_ty = &self.error_ty; | ||||
| 
 | ||||
|         let incoming_request_type = if self.contains_lifetimes() { | ||||
|         let incoming_request_type = if self.has_lifetimes() { | ||||
|             quote! { IncomingRequest } | ||||
|         } else { | ||||
|             quote! { Request } | ||||
| @ -28,7 +24,7 @@ impl Request { | ||||
|         // except this one. If we get errors about missing fields in IncomingRequest for
 | ||||
|         // a path field look here.
 | ||||
|         let (parse_request_path, path_vars) = if self.has_path_fields() { | ||||
|             let path_string = metadata.path.value(); | ||||
|             let path_string = self.path.value(); | ||||
| 
 | ||||
|             assert!(path_string.starts_with('/'), "path needs to start with '/'"); | ||||
|             assert!( | ||||
| @ -172,7 +168,7 @@ impl Request { | ||||
| 
 | ||||
|         let extract_body = | ||||
|             (self.has_body_fields() || self.newtype_body_field().is_some()).then(|| { | ||||
|                 let body_lifetimes = self.has_body_lifetimes().then(|| { | ||||
|                 let body_lifetimes = (!self.lifetimes.body.is_empty()).then(|| { | ||||
|                     // duplicate the anonymous lifetime as many times as needed
 | ||||
|                     let lifetimes = | ||||
|                         std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len()); | ||||
| @ -218,16 +214,12 @@ impl Request { | ||||
|             self.vars(RequestFieldKind::Body, quote! { request_body }) | ||||
|         }; | ||||
| 
 | ||||
|         let non_auth_impls = metadata.authentication.iter().filter_map(|auth| { | ||||
|             matches!(auth.value, AuthScheme::None(_)).then(|| { | ||||
|                 let attrs = &auth.attrs; | ||||
|                 quote! { | ||||
|                     #( #attrs )* | ||||
|                     #[automatically_derived] | ||||
|                     #[cfg(feature = "server")] | ||||
|                     impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} | ||||
|                 } | ||||
|             }) | ||||
|         let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { | ||||
|             quote! { | ||||
|                 #[automatically_derived] | ||||
|                 #[cfg(feature = "server")] | ||||
|                 impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         quote! { | ||||
| @ -265,7 +257,7 @@ impl Request { | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             #(#non_auth_impls)* | ||||
|             #non_auth_impl | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| @ -1,26 +1,21 @@ | ||||
| use proc_macro2::{Ident, Span, TokenStream}; | ||||
| use quote::quote; | ||||
| 
 | ||||
| use crate::api::metadata::{AuthScheme, Metadata}; | ||||
| use crate::auth_scheme::AuthScheme; | ||||
| 
 | ||||
| use super::{Request, RequestField, RequestFieldKind}; | ||||
| 
 | ||||
| impl Request { | ||||
|     pub fn expand_outgoing( | ||||
|         &self, | ||||
|         metadata: &Metadata, | ||||
|         error_ty: &TokenStream, | ||||
|         lifetimes: &TokenStream, | ||||
|         ruma_api: &TokenStream, | ||||
|     ) -> TokenStream { | ||||
|     pub fn expand_outgoing(&self, ruma_api: &TokenStream) -> TokenStream { | ||||
|         let bytes = quote! { #ruma_api::exports::bytes }; | ||||
|         let http = quote! { #ruma_api::exports::http }; | ||||
|         let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; | ||||
|         let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; | ||||
| 
 | ||||
|         let method = &metadata.method; | ||||
|         let method = &self.method; | ||||
|         let error_ty = &self.error_ty; | ||||
|         let request_path_string = if self.has_path_fields() { | ||||
|             let mut format_string = metadata.path.value(); | ||||
|             let mut format_string = self.path.value(); | ||||
|             let mut format_args = Vec::new(); | ||||
| 
 | ||||
|             while let Some(start_of_segment) = format_string.find(':') { | ||||
| @ -132,38 +127,31 @@ impl Request { | ||||
|             }) | ||||
|             .collect(); | ||||
| 
 | ||||
|         for auth in &metadata.authentication { | ||||
|             let attrs = &auth.attrs; | ||||
| 
 | ||||
|             let hdr_kv = match auth.value { | ||||
|                 AuthScheme::AccessToken(_) => quote! { | ||||
|                     #( #attrs )* | ||||
|         let hdr_kv = match self.authentication { | ||||
|             AuthScheme::AccessToken(_) => quote! { | ||||
|                 req_headers.insert( | ||||
|                     #http::header::AUTHORIZATION, | ||||
|                     ::std::convert::TryFrom::<_>::try_from(::std::format!( | ||||
|                         "Bearer {}", | ||||
|                         access_token | ||||
|                             .get_required_for_endpoint() | ||||
|                             .ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?, | ||||
|                     ))?, | ||||
|                 ); | ||||
|             }, | ||||
|             AuthScheme::None(_) => quote! { | ||||
|                 if let Some(access_token) = access_token.get_not_required_for_endpoint() { | ||||
|                     req_headers.insert( | ||||
|                         #http::header::AUTHORIZATION, | ||||
|                         ::std::convert::TryFrom::<_>::try_from(::std::format!( | ||||
|                             "Bearer {}", | ||||
|                             access_token | ||||
|                                 .get_required_for_endpoint() | ||||
|                                 .ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?, | ||||
|                         ))?, | ||||
|                         ::std::convert::TryFrom::<_>::try_from( | ||||
|                             ::std::format!("Bearer {}", access_token), | ||||
|                         )? | ||||
|                     ); | ||||
|                 }, | ||||
|                 AuthScheme::None(_) => quote! { | ||||
|                     if let Some(access_token) = access_token.get_not_required_for_endpoint() { | ||||
|                         #( #attrs )* | ||||
|                         req_headers.insert( | ||||
|                             #http::header::AUTHORIZATION, | ||||
|                             ::std::convert::TryFrom::<_>::try_from( | ||||
|                                 ::std::format!("Bearer {}", access_token), | ||||
|                             )? | ||||
|                         ); | ||||
|                     } | ||||
|                 }, | ||||
|                 AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {}, | ||||
|             }; | ||||
| 
 | ||||
|             header_kvs.extend(hdr_kv); | ||||
|         } | ||||
|                 } | ||||
|             }, | ||||
|             AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {}, | ||||
|         }; | ||||
|         header_kvs.extend(hdr_kv); | ||||
| 
 | ||||
|         let request_body = if let Some(field) = self.newtype_raw_body_field() { | ||||
|             let field_name = field.ident.as_ref().expect("expected field to have an identifier"); | ||||
| @ -185,22 +173,21 @@ impl Request { | ||||
|             quote! { <T as ::std::default::Default>::default() } | ||||
|         }; | ||||
| 
 | ||||
|         let non_auth_impls = metadata.authentication.iter().filter_map(|auth| { | ||||
|             matches!(auth.value, AuthScheme::None(_)).then(|| { | ||||
|                 let attrs = &auth.attrs; | ||||
|                 quote! { | ||||
|                     #( #attrs )* | ||||
|                     #[automatically_derived] | ||||
|                     #[cfg(feature = "client")] | ||||
|                     impl #lifetimes #ruma_api::OutgoingNonAuthRequest for Request #lifetimes {} | ||||
|                 } | ||||
|             }) | ||||
|         let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); | ||||
| 
 | ||||
|         let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { | ||||
|             quote! { | ||||
|                 #[automatically_derived] | ||||
|                 #[cfg(feature = "client")] | ||||
|                 impl #impl_generics #ruma_api::OutgoingNonAuthRequest | ||||
|                     for Request #ty_generics #where_clause {} | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         quote! { | ||||
|             #[automatically_derived] | ||||
|             #[cfg(feature = "client")] | ||||
|             impl #lifetimes #ruma_api::OutgoingRequest for Request #lifetimes { | ||||
|             impl #impl_generics #ruma_api::OutgoingRequest for Request #ty_generics #where_clause { | ||||
|                 type EndpointError = #error_ty; | ||||
|                 type IncomingResponse = <Response as #ruma_serde::Outgoing>::Incoming; | ||||
| 
 | ||||
| @ -236,7 +223,7 @@ impl Request { | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             #(#non_auth_impls)* | ||||
|             #non_auth_impl | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
							
								
								
									
										314
									
								
								crates/ruma-api-macros/src/response.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										314
									
								
								crates/ruma-api-macros/src/response.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,314 @@ | ||||
| use std::{ | ||||
|     convert::{TryFrom, TryInto}, | ||||
|     mem, | ||||
| }; | ||||
| 
 | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::{quote, ToTokens}; | ||||
| use syn::{ | ||||
|     parse::{Parse, ParseStream}, | ||||
|     punctuated::Punctuated, | ||||
|     visit::Visit, | ||||
|     DeriveInput, Field, Generics, Ident, Lifetime, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| use crate::{ | ||||
|     attribute::{Meta, MetaNameValue, MetaValue}, | ||||
|     util, | ||||
| }; | ||||
| 
 | ||||
| mod incoming; | ||||
| mod outgoing; | ||||
| 
 | ||||
| pub fn expand_derive_response(input: DeriveInput) -> syn::Result<TokenStream> { | ||||
|     let fields = match input.data { | ||||
|         syn::Data::Struct(s) => s.fields, | ||||
|         _ => panic!("This derive macro only works on structs"), | ||||
|     }; | ||||
| 
 | ||||
|     let fields = fields.into_iter().map(ResponseField::try_from).collect::<syn::Result<_>>()?; | ||||
|     let mut error_ty = None; | ||||
|     for attr in input.attrs { | ||||
|         if !attr.path.is_ident("ruma_api") { | ||||
|             continue; | ||||
|         } | ||||
| 
 | ||||
|         let meta = attr.parse_args_with(Punctuated::<_, Token![,]>::parse_terminated)?; | ||||
|         for MetaNameValue { name, value } in meta { | ||||
|             match value { | ||||
|                 MetaValue::Type(t) if name == "error_ty" => { | ||||
|                     error_ty = Some(t); | ||||
|                 } | ||||
|                 _ => unreachable!("invalid ruma_api({}) attribute", name), | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     let response = Response { | ||||
|         ident: input.ident, | ||||
|         generics: input.generics, | ||||
|         fields, | ||||
|         error_ty: error_ty.unwrap(), | ||||
|     }; | ||||
| 
 | ||||
|     response.check()?; | ||||
|     Ok(response.expand_all()) | ||||
| } | ||||
| 
 | ||||
| struct Response { | ||||
|     ident: Ident, | ||||
|     generics: Generics, | ||||
|     fields: Vec<ResponseField>, | ||||
|     error_ty: Type, | ||||
| } | ||||
| 
 | ||||
| impl Response { | ||||
|     /// Whether or not this request has any data in the HTTP body.
 | ||||
|     fn has_body_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|f| matches!(f, ResponseField::Body(_))) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the body field.
 | ||||
|     fn newtype_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(ResponseField::as_newtype_body_field) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the body field.
 | ||||
|     fn newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         self.fields.iter().find_map(ResponseField::as_newtype_raw_body_field) | ||||
|     } | ||||
| 
 | ||||
|     /// Whether or not this request has any data in the URL path.
 | ||||
|     fn has_header_fields(&self) -> bool { | ||||
|         self.fields.iter().any(|f| matches!(f, &ResponseField::Header(..))) | ||||
|     } | ||||
| 
 | ||||
|     fn expand_all(&self) -> TokenStream { | ||||
|         let ruma_api = util::import_ruma_api(); | ||||
|         let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; | ||||
|         let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; | ||||
|         let serde = quote! { #ruma_api::exports::serde }; | ||||
| 
 | ||||
|         let response_body_struct = | ||||
|             self.fields.iter().all(|f| !matches!(f, ResponseField::NewtypeRawBody(_))).then(|| { | ||||
|                 let newtype_body_field = | ||||
|                     self.fields.iter().find(|f| matches!(f, ResponseField::NewtypeBody(_))); | ||||
|                 let def = if let Some(body_field) = newtype_body_field { | ||||
|                     let field = | ||||
|                         Field { ident: None, colon_token: None, ..body_field.field().clone() }; | ||||
|                     quote! { (#field); } | ||||
|                 } else { | ||||
|                     let fields = self.fields.iter().filter_map(|f| f.as_body_field()); | ||||
|                     quote! { { #(#fields),* } } | ||||
|                 }; | ||||
| 
 | ||||
|                 quote! { | ||||
|                     /// Data in the response body.
 | ||||
|                     #[derive(
 | ||||
|                         Debug, | ||||
|                         #ruma_api_macros::_FakeDeriveRumaApi, | ||||
|                         #ruma_serde::Outgoing, | ||||
|                         #serde::Deserialize, | ||||
|                         #serde::Serialize, | ||||
|                     )] | ||||
|                     struct ResponseBody #def | ||||
|                 } | ||||
|             }); | ||||
| 
 | ||||
|         let outgoing_response_impl = self.expand_outgoing(&ruma_api); | ||||
|         let incoming_response_impl = self.expand_incoming(&self.error_ty, &ruma_api); | ||||
| 
 | ||||
|         quote! { | ||||
|             #response_body_struct | ||||
| 
 | ||||
|             #outgoing_response_impl | ||||
|             #incoming_response_impl | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn check(&self) -> syn::Result<()> { | ||||
|         // TODO: highlight problematic fields
 | ||||
| 
 | ||||
|         if !self.generics.params.is_empty() || self.generics.where_clause.is_some() { | ||||
|             panic!("This macro doesn't support generic types"); | ||||
|         } | ||||
| 
 | ||||
|         let newtype_body_fields = self.fields.iter().filter(|f| { | ||||
|             matches!(f, ResponseField::NewtypeBody(_) | ResponseField::NewtypeRawBody(_)) | ||||
|         }); | ||||
| 
 | ||||
|         let has_newtype_body_field = match newtype_body_fields.count() { | ||||
|             0 => false, | ||||
|             1 => true, | ||||
|             _ => { | ||||
|                 return Err(syn::Error::new_spanned( | ||||
|                     &self.ident, | ||||
|                     "Can't have more than one newtype body field", | ||||
|                 )) | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         let has_body_fields = self.fields.iter().any(|f| matches!(f, ResponseField::Body(_))); | ||||
|         if has_newtype_body_field && has_body_fields { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 &self.ident, | ||||
|                 "Can't have both a newtype body field and regular body fields", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a response can have.
 | ||||
| enum ResponseField { | ||||
|     /// JSON data in the body of the response.
 | ||||
|     Body(Field), | ||||
| 
 | ||||
|     /// Data in an HTTP header.
 | ||||
|     Header(Field, Ident), | ||||
| 
 | ||||
|     /// A specific data type in the body of the response.
 | ||||
|     NewtypeBody(Field), | ||||
| 
 | ||||
|     /// Arbitrary bytes in the body of the response.
 | ||||
|     NewtypeRawBody(Field), | ||||
| } | ||||
| 
 | ||||
| impl ResponseField { | ||||
|     /// Gets the inner `Field` value.
 | ||||
|     fn field(&self) -> &Field { | ||||
|         match self { | ||||
|             ResponseField::Body(field) | ||||
|             | ResponseField::Header(field, _) | ||||
|             | ResponseField::NewtypeBody(field) | ||||
|             | ResponseField::NewtypeRawBody(field) => field, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this response field is a body kind.
 | ||||
|     fn as_body_field(&self) -> Option<&Field> { | ||||
|         match self { | ||||
|             ResponseField::Body(field) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this response field is a newtype body kind.
 | ||||
|     fn as_newtype_body_field(&self) -> Option<&Field> { | ||||
|         match self { | ||||
|             ResponseField::NewtypeBody(field) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Return the contained field if this response field is a newtype raw body kind.
 | ||||
|     fn as_newtype_raw_body_field(&self) -> Option<&Field> { | ||||
|         match self { | ||||
|             ResponseField::NewtypeRawBody(field) => Some(field), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl TryFrom<Field> for ResponseField { | ||||
|     type Error = syn::Error; | ||||
| 
 | ||||
|     fn try_from(mut field: Field) -> syn::Result<Self> { | ||||
|         if has_lifetime(&field.ty) { | ||||
|             return Err(syn::Error::new_spanned( | ||||
|                 field.ident, | ||||
|                 "Lifetimes on Response fields cannot be supported until GAT are stable", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         let mut field_kind = None; | ||||
|         let mut header = None; | ||||
| 
 | ||||
|         for attr in mem::take(&mut field.attrs) { | ||||
|             let meta = match Meta::from_attribute(&attr)? { | ||||
|                 Some(m) => m, | ||||
|                 None => { | ||||
|                     field.attrs.push(attr); | ||||
|                     continue; | ||||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|             if field_kind.is_some() { | ||||
|                 return Err(syn::Error::new_spanned( | ||||
|                     attr, | ||||
|                     "There can only be one field kind attribute", | ||||
|                 )); | ||||
|             } | ||||
| 
 | ||||
|             field_kind = Some(match meta { | ||||
|                 Meta::Word(ident) => match &ident.to_string()[..] { | ||||
|                     "body" => ResponseFieldKind::NewtypeBody, | ||||
|                     "raw_body" => ResponseFieldKind::NewtypeRawBody, | ||||
|                     _ => { | ||||
|                         return Err(syn::Error::new_spanned( | ||||
|                             ident, | ||||
|                             "Invalid #[ruma_api] argument with value, expected `body`", | ||||
|                         )); | ||||
|                     } | ||||
|                 }, | ||||
|                 Meta::NameValue(MetaNameValue { name, value }) => { | ||||
|                     if name != "header" { | ||||
|                         return Err(syn::Error::new_spanned( | ||||
|                             name, | ||||
|                             "Invalid #[ruma_api] argument with value, expected `header`", | ||||
|                         )); | ||||
|                     } | ||||
| 
 | ||||
|                     header = Some(value); | ||||
|                     ResponseFieldKind::Header | ||||
|                 } | ||||
|             }); | ||||
|         } | ||||
| 
 | ||||
|         Ok(match field_kind.unwrap_or(ResponseFieldKind::Body) { | ||||
|             ResponseFieldKind::Body => ResponseField::Body(field), | ||||
|             ResponseFieldKind::Header => { | ||||
|                 ResponseField::Header(field, header.expect("missing header name")) | ||||
|             } | ||||
|             ResponseFieldKind::NewtypeBody => ResponseField::NewtypeBody(field), | ||||
|             ResponseFieldKind::NewtypeRawBody => ResponseField::NewtypeRawBody(field), | ||||
|         }) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Parse for ResponseField { | ||||
|     fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||||
|         input.call(Field::parse_named)?.try_into() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl ToTokens for ResponseField { | ||||
|     fn to_tokens(&self, tokens: &mut TokenStream) { | ||||
|         self.field().to_tokens(tokens) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The types of fields that a response can have, without their values.
 | ||||
| enum ResponseFieldKind { | ||||
|     Body, | ||||
|     Header, | ||||
|     NewtypeBody, | ||||
|     NewtypeRawBody, | ||||
| } | ||||
| 
 | ||||
| fn has_lifetime(ty: &Type) -> bool { | ||||
|     struct Visitor { | ||||
|         found_lifetime: bool, | ||||
|     } | ||||
| 
 | ||||
|     impl<'ast> Visit<'ast> for Visitor { | ||||
|         fn visit_lifetime(&mut self, _lt: &'ast Lifetime) { | ||||
|             self.found_lifetime = true; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     let mut vis = Visitor { found_lifetime: false }; | ||||
|     vis.visit_type(ty); | ||||
|     vis.found_lifetime | ||||
| } | ||||
| @ -1,10 +1,11 @@ | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::quote; | ||||
| use syn::Type; | ||||
| 
 | ||||
| use super::{Response, ResponseField}; | ||||
| 
 | ||||
| impl Response { | ||||
|     pub fn expand_incoming(&self, error_ty: &TokenStream, ruma_api: &TokenStream) -> TokenStream { | ||||
|     pub fn expand_incoming(&self, error_ty: &Type, ruma_api: &TokenStream) -> TokenStream { | ||||
|         let http = quote! { #ruma_api::exports::http }; | ||||
|         let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; | ||||
|         let serde_json = quote! { #ruma_api::exports::serde_json }; | ||||
| @ -5,25 +5,9 @@ use std::collections::BTreeSet; | ||||
| use proc_macro2::TokenStream; | ||||
| use proc_macro_crate::{crate_name, FoundCrate}; | ||||
| use quote::{format_ident, quote}; | ||||
| use syn::{AttrStyle, Attribute, Lifetime}; | ||||
| use syn::{parse_quote, visit::Visit, AttrStyle, Attribute, Lifetime, NestedMeta, Type}; | ||||
| 
 | ||||
| /// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`.
 | ||||
| pub(crate) fn unique_lifetimes_to_tokens<'a, I: IntoIterator<Item = &'a Lifetime>>( | ||||
|     lifetimes: I, | ||||
| ) -> TokenStream { | ||||
|     let lifetimes = lifetimes.into_iter().collect::<BTreeSet<_>>(); | ||||
|     if lifetimes.is_empty() { | ||||
|         TokenStream::new() | ||||
|     } else { | ||||
|         quote! { < #( #lifetimes ),* > } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub(crate) fn is_valid_endpoint_path(string: &str) -> bool { | ||||
|     string.as_bytes().iter().all(|b| (0x21..=0x7E).contains(b)) | ||||
| } | ||||
| 
 | ||||
| pub(crate) fn import_ruma_api() -> TokenStream { | ||||
| pub fn import_ruma_api() -> TokenStream { | ||||
|     if let Ok(FoundCrate::Name(name)) = crate_name("ruma-api") { | ||||
|         let import = format_ident!("{}", name); | ||||
|         quote! { ::#import } | ||||
| @ -41,6 +25,48 @@ pub(crate) fn import_ruma_api() -> TokenStream { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub(crate) fn is_cfg_attribute(attr: &Attribute) -> bool { | ||||
| pub fn is_valid_endpoint_path(string: &str) -> bool { | ||||
|     string.as_bytes().iter().all(|b| (0x21..=0x7E).contains(b)) | ||||
| } | ||||
| 
 | ||||
| pub fn collect_lifetime_idents(lifetimes: &mut BTreeSet<Lifetime>, ty: &Type) { | ||||
|     struct Visitor<'lt>(&'lt mut BTreeSet<Lifetime>); | ||||
|     impl<'ast> Visit<'ast> for Visitor<'_> { | ||||
|         fn visit_lifetime(&mut self, lt: &'ast Lifetime) { | ||||
|             self.0.insert(lt.clone()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     Visitor(lifetimes).visit_type(ty) | ||||
| } | ||||
| 
 | ||||
| pub fn is_cfg_attribute(attr: &Attribute) -> bool { | ||||
|     matches!(attr.style, AttrStyle::Outer) && attr.path.is_ident("cfg") | ||||
| } | ||||
| 
 | ||||
| pub fn all_cfgs_expr(cfgs: &[Attribute]) -> Option<TokenStream> { | ||||
|     let sub_cfgs: Vec<_> = cfgs.iter().filter_map(extract_cfg).collect(); | ||||
|     (!sub_cfgs.is_empty()).then(|| quote! { all( #(#sub_cfgs),* ) }) | ||||
| } | ||||
| 
 | ||||
| pub fn all_cfgs(cfgs: &[Attribute]) -> Option<Attribute> { | ||||
|     let cfg_expr = all_cfgs_expr(cfgs)?; | ||||
|     Some(parse_quote! { #[cfg( #cfg_expr )] }) | ||||
| } | ||||
| 
 | ||||
| pub fn extract_cfg(attr: &Attribute) -> Option<NestedMeta> { | ||||
|     if !attr.path.is_ident("cfg") { | ||||
|         return None; | ||||
|     } | ||||
| 
 | ||||
|     let meta = attr.parse_meta().expect("cfg attribute can be parsed to syn::Meta"); | ||||
|     let mut list = match meta { | ||||
|         syn::Meta::List(l) => l, | ||||
|         _ => panic!("unexpected cfg syntax"), | ||||
|     }; | ||||
| 
 | ||||
|     assert!(list.path.is_ident("cfg"), "expected cfg attributes only"); | ||||
|     assert_eq!(list.nested.len(), 1, "expected one item inside cfg()"); | ||||
| 
 | ||||
|     Some(list.nested.pop().unwrap().into_value()) | ||||
| } | ||||
|  | ||||
| @ -205,6 +205,7 @@ pub mod exports { | ||||
|     pub use bytes; | ||||
|     pub use http; | ||||
|     pub use percent_encoding; | ||||
|     pub use ruma_api_macros; | ||||
|     pub use ruma_serde; | ||||
|     pub use serde; | ||||
|     pub use serde_json; | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| #![allow(clippy::exhaustive_structs)] | ||||
| 
 | ||||
| #[derive(Copy, Clone, Debug, ruma_serde::Outgoing, serde::Serialize)] | ||||
| /*#[derive(Copy, Clone, Debug, ruma_serde::Outgoing, serde::Serialize)]
 | ||||
| pub struct OtherThing<'t> { | ||||
|     pub some: &'t str, | ||||
|     pub t: &'t [u8], | ||||
| @ -31,7 +31,7 @@ mod empty_response { | ||||
| 
 | ||||
|         response: {} | ||||
|     } | ||||
| } | ||||
| }*/ | ||||
| 
 | ||||
| mod nested_types { | ||||
|     use ruma_api::ruma_api; | ||||
| @ -59,7 +59,7 @@ mod nested_types { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| mod full_request_response { | ||||
| /*mod full_request_response {
 | ||||
|     use ruma_api::ruma_api; | ||||
| 
 | ||||
|     use super::{IncomingOtherThing, OtherThing}; | ||||
| @ -159,4 +159,4 @@ mod query_fields { | ||||
| 
 | ||||
|         response: {} | ||||
|     } | ||||
| } | ||||
| }*/ | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user