diff --git a/crates/ruma-api-macros/src/api.rs b/crates/ruma-api-macros/src/api.rs index a104b0c1..561764d1 100644 --- a/crates/ruma-api-macros/src/api.rs +++ b/crates/ruma-api-macros/src/api.rs @@ -2,17 +2,27 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::Type; +use syn::{ + braced, + parse::{Parse, ParseStream}, + Attribute, Field, Token, Type, +}; -pub(crate) mod attribute; -pub(crate) mod metadata; -pub(crate) mod parse; -pub(crate) mod request; -pub(crate) mod response; +mod metadata; +mod request; +mod response; use self::{metadata::Metadata, request::Request, response::Response}; use crate::util; +mod kw { + use syn::custom_keyword; + + custom_keyword!(error); + custom_keyword!(request); + custom_keyword!(response); +} + /// The result of processing the `ruma_api` macro, ready for output back to source code. pub struct Api { /// The `metadata` section of the macro. @@ -28,65 +38,129 @@ pub struct Api { error_ty: Option, } -pub fn expand_all(api: Api) -> syn::Result { - let ruma_api = util::import_ruma_api(); - let http = quote! { #ruma_api::exports::http }; +impl Api { + pub fn expand_all(self) -> TokenStream { + let ruma_api = util::import_ruma_api(); + let http = quote! { #ruma_api::exports::http }; - let metadata = &api.metadata; - let description = &metadata.description; - let method = &metadata.method; - let name = &metadata.name; - let path = &metadata.path; - let rate_limited: TokenStream = metadata - .rate_limited - .iter() - .map(|r| { - let attrs = &r.attrs; - let value = &r.value; - quote! { - #( #attrs )* - rate_limited: #value, - } - }) - .collect(); - let authentication: TokenStream = api - .metadata - .authentication - .iter() - .map(|r| { - let attrs = &r.attrs; - let value = &r.value; - quote! { - #( #attrs )* - authentication: #ruma_api::AuthScheme::#value, - } - }) - .collect(); + let metadata = &self.metadata; + let description = &metadata.description; + let method = &metadata.method; + let name = &metadata.name; + let path = &metadata.path; + let rate_limited: TokenStream = metadata + .rate_limited + .iter() + .map(|r| { + let attrs = &r.attrs; + let value = &r.value; + quote! { + #( #attrs )* + rate_limited: #value, + } + }) + .collect(); + let authentication: TokenStream = self + .metadata + .authentication + .iter() + .map(|r| { + let attrs = &r.attrs; + let value = &r.value; + quote! { + #( #attrs )* + authentication: #ruma_api::AuthScheme::#value, + } + }) + .collect(); - let error_ty = api - .error_ty - .map_or_else(|| quote! { #ruma_api::error::MatrixError }, |err_ty| quote! { #err_ty }); + let error_ty = self + .error_ty + .map_or_else(|| quote! { #ruma_api::error::MatrixError }, |err_ty| quote! { #err_ty }); - let request = api.request.map(|req| req.expand(metadata, &error_ty, &ruma_api)); - let response = api.response.map(|res| res.expand(metadata, &error_ty, &ruma_api)); + let request = self.request.map(|req| req.expand(metadata, &error_ty, &ruma_api)); + let response = self.response.map(|res| res.expand(metadata, &error_ty, &ruma_api)); - let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); + let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); - Ok(quote! { - #[doc = #metadata_doc] - pub const METADATA: #ruma_api::Metadata = #ruma_api::Metadata { - description: #description, - method: #http::Method::#method, - name: #name, - path: #path, - #rate_limited - #authentication + quote! { + #[doc = #metadata_doc] + pub const METADATA: #ruma_api::Metadata = #ruma_api::Metadata { + description: #description, + method: #http::Method::#method, + name: #name, + path: #path, + #rate_limited + #authentication + }; + + #request + #response + + #[cfg(not(any(feature = "client", feature = "server")))] + type _SilenceUnusedError = #error_ty; + } + } +} + +impl Parse for Api { + fn parse(input: ParseStream<'_>) -> syn::Result { + let metadata: Metadata = input.parse()?; + + let req_attrs = input.call(Attribute::parse_outer)?; + let (request, attributes) = if input.peek(kw::request) { + let request = parse_request(input, req_attrs)?; + let after_req_attrs = input.call(Attribute::parse_outer)?; + + (Some(request), after_req_attrs) + } else { + // There was no `request` field so the attributes are for `response` + (None, req_attrs) }; - #request - #response + let response = if input.peek(kw::response) { + Some(parse_response(input, attributes)?) + } else if !attributes.is_empty() { + return Err(syn::Error::new_spanned( + &attributes[0], + "attributes are not supported on the error type", + )); + } else { + None + }; - #[cfg(not(any(feature = "client", feature = "server")))] - type _SilenceUnusedError = #error_ty; - }) + let error_ty = input + .peek(kw::error) + .then(|| { + let _: kw::error = input.parse()?; + let _: Token![:] = input.parse()?; + + input.parse() + }) + .transpose()?; + + Ok(Self { metadata, request, response, error_ty }) + } +} + +fn parse_request(input: ParseStream<'_>, attributes: Vec) -> syn::Result { + let request_kw: kw::request = input.parse()?; + let _: Token![:] = input.parse()?; + let fields; + braced!(fields in input); + + let fields = fields.parse_terminated::<_, Token![,]>(Field::parse_named)?; + + Ok(Request { request_kw, attributes, fields }) +} + +fn parse_response(input: ParseStream<'_>, attributes: Vec) -> syn::Result { + let response_kw: kw::response = input.parse()?; + let _: Token![:] = input.parse()?; + let fields; + braced!(fields in input); + + let fields = fields.parse_terminated::<_, Token![,]>(Field::parse_named)?; + + Ok(Response { attributes, fields, response_kw }) } diff --git a/crates/ruma-api-macros/src/api/metadata.rs b/crates/ruma-api-macros/src/api/metadata.rs index 39bd4d09..b97b1813 100644 --- a/crates/ruma-api-macros/src/api/metadata.rs +++ b/crates/ruma-api-macros/src/api/metadata.rs @@ -1,6 +1,5 @@ //! Details of the `metadata` section of the procedural macro. -use proc_macro2::TokenStream; use quote::ToTokens; use syn::{ braced, @@ -8,7 +7,7 @@ use syn::{ Attribute, Ident, LitBool, LitStr, Token, }; -use crate::util; +use crate::{auth_scheme::AuthScheme, util}; mod kw { syn::custom_keyword!(metadata); @@ -18,11 +17,6 @@ mod kw { syn::custom_keyword!(path); syn::custom_keyword!(rate_limited); syn::custom_keyword!(authentication); - - syn::custom_keyword!(None); - syn::custom_keyword!(AccessToken); - syn::custom_keyword!(ServerSignatures); - syn::custom_keyword!(QueryOnlyAccessToken); } /// A field of Metadata that contains attribute macros @@ -124,42 +118,6 @@ impl Parse for Metadata { } } -pub enum AuthScheme { - None(kw::None), - AccessToken(kw::AccessToken), - ServerSignatures(kw::ServerSignatures), - QueryOnlyAccessToken(kw::QueryOnlyAccessToken), -} - -impl Parse for AuthScheme { - fn parse(input: ParseStream<'_>) -> syn::Result { - let lookahead = input.lookahead1(); - - if lookahead.peek(kw::None) { - input.parse().map(Self::None) - } else if lookahead.peek(kw::AccessToken) { - input.parse().map(Self::AccessToken) - } else if lookahead.peek(kw::ServerSignatures) { - input.parse().map(Self::ServerSignatures) - } else if lookahead.peek(kw::QueryOnlyAccessToken) { - input.parse().map(Self::QueryOnlyAccessToken) - } else { - Err(lookahead.error()) - } - } -} - -impl ToTokens for AuthScheme { - fn to_tokens(&self, tokens: &mut TokenStream) { - match self { - AuthScheme::None(kw) => kw.to_tokens(tokens), - AuthScheme::AccessToken(kw) => kw.to_tokens(tokens), - AuthScheme::ServerSignatures(kw) => kw.to_tokens(tokens), - AuthScheme::QueryOnlyAccessToken(kw) => kw.to_tokens(tokens), - } - } -} - enum Field { Description, Method, diff --git a/crates/ruma-api-macros/src/api/parse.rs b/crates/ruma-api-macros/src/api/parse.rs deleted file mode 100644 index b1c7db4f..00000000 --- a/crates/ruma-api-macros/src/api/parse.rs +++ /dev/null @@ -1,350 +0,0 @@ -use std::{collections::BTreeSet, mem}; - -use syn::{ - braced, - parse::{Parse, ParseStream}, - spanned::Spanned, - visit::Visit, - Attribute, Field, Ident, Lifetime, Token, Type, -}; - -use super::{ - attribute::{Meta, MetaNameValue}, - request::{RequestField, RequestFieldKind, RequestLifetimes}, - response::{ResponseField, ResponseFieldKind}, - Api, Metadata, Request, Response, -}; - -mod kw { - use syn::custom_keyword; - - custom_keyword!(error); - custom_keyword!(request); - custom_keyword!(response); -} - -impl Parse for Api { - fn parse(input: ParseStream<'_>) -> syn::Result { - let metadata: Metadata = input.parse()?; - - let req_attrs = input.call(Attribute::parse_outer)?; - let (request, attributes) = if input.peek(kw::request) { - let request = parse_request(input, req_attrs)?; - let after_req_attrs = input.call(Attribute::parse_outer)?; - - (Some(request), after_req_attrs) - } else { - // There was no `request` field so the attributes are for `response` - (None, req_attrs) - }; - - let response = if input.peek(kw::response) { - Some(parse_response(input, attributes)?) - } else if !attributes.is_empty() { - return Err(syn::Error::new_spanned( - &attributes[0], - "attributes are not supported on the error type", - )); - } else { - None - }; - - let error_ty = input - .peek(kw::error) - .then(|| { - let _: kw::error = input.parse()?; - let _: Token![:] = input.parse()?; - - input.parse() - }) - .transpose()?; - - if let Some(req) = &request { - let newtype_body_field = req.newtype_body_field(); - if metadata.method == "GET" && (req.has_body_fields() || newtype_body_field.is_some()) { - let mut combined_error: Option = None; - let mut add_error = |field| { - let error = - syn::Error::new_spanned(field, "GET endpoints can't have body fields"); - if let Some(combined_error_ref) = &mut combined_error { - combined_error_ref.combine(error); - } else { - combined_error = Some(error); - } - }; - - for field in req.body_fields() { - add_error(field); - } - - if let Some(field) = newtype_body_field { - add_error(field); - } - - return Err(combined_error.unwrap()); - } - } - - Ok(Self { metadata, request, response, error_ty }) - } -} - -fn parse_request(input: ParseStream<'_>, attributes: Vec) -> syn::Result { - let request_kw: kw::request = input.parse()?; - let _: Token![:] = input.parse()?; - let fields; - braced!(fields in input); - - let mut newtype_body_field = None; - let mut query_map_field = None; - let mut lifetimes = RequestLifetimes::default(); - - let fields: Vec<_> = fields - .parse_terminated::(Field::parse_named)? - .into_iter() - .map(|mut field| { - let mut field_kind = None; - let mut header = None; - - for attr in mem::take(&mut field.attrs) { - let meta = match Meta::from_attribute(&attr)? { - Some(m) => m, - None => { - field.attrs.push(attr); - continue; - } - }; - - if field_kind.is_some() { - return Err(syn::Error::new_spanned( - attr, - "There can only be one field kind attribute", - )); - } - - field_kind = Some(match meta { - Meta::Word(ident) => match &ident.to_string()[..] { - attr @ "body" | attr @ "raw_body" => req_res_meta_word( - attr, - &field, - &mut newtype_body_field, - RequestFieldKind::NewtypeBody, - RequestFieldKind::NewtypeRawBody, - )?, - "path" => RequestFieldKind::Path, - "query" => RequestFieldKind::Query, - "query_map" => { - if let Some(f) = &query_map_field { - let mut error = syn::Error::new_spanned( - field, - "There can only be one query map field", - ); - error.combine(syn::Error::new_spanned( - f, - "Previous query map field", - )); - return Err(error); - } - - query_map_field = Some(field.clone()); - RequestFieldKind::QueryMap - } - _ => { - return Err(syn::Error::new_spanned( - ident, - "Invalid #[ruma_api] argument, expected one of \ - `body`, `path`, `query`, `query_map`", - )); - } - }, - Meta::NameValue(MetaNameValue { name, value }) => { - req_res_name_value(name, value, &mut header, RequestFieldKind::Header)? - } - }); - } - - match field_kind.unwrap_or(RequestFieldKind::Body) { - RequestFieldKind::Header => { - collect_lifetime_idents(&mut lifetimes.header, &field.ty) - } - RequestFieldKind::Body => collect_lifetime_idents(&mut lifetimes.body, &field.ty), - RequestFieldKind::NewtypeBody => { - collect_lifetime_idents(&mut lifetimes.body, &field.ty) - } - RequestFieldKind::NewtypeRawBody => { - collect_lifetime_idents(&mut lifetimes.body, &field.ty) - } - RequestFieldKind::Path => collect_lifetime_idents(&mut lifetimes.path, &field.ty), - RequestFieldKind::Query => collect_lifetime_idents(&mut lifetimes.query, &field.ty), - RequestFieldKind::QueryMap => { - collect_lifetime_idents(&mut lifetimes.query, &field.ty) - } - } - - Ok(RequestField::new(field_kind.unwrap_or(RequestFieldKind::Body), field, header)) - }) - .collect::>()?; - - if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) { - // TODO: highlight conflicting fields, - return Err(syn::Error::new_spanned( - request_kw, - "Can't have both a newtype body field and regular body fields", - )); - } - - if query_map_field.is_some() && fields.iter().any(|f| f.is_query()) { - return Err(syn::Error::new_spanned( - // TODO: raw, - request_kw, - "Can't have both a query map field and regular query fields", - )); - } - - // TODO when/if `&[(&str, &str)]` is supported remove this - if query_map_field.is_some() && !lifetimes.query.is_empty() { - return Err(syn::Error::new_spanned( - request_kw, - "Lifetimes are not allowed for query_map fields", - )); - } - - Ok(Request { attributes, fields, lifetimes }) -} - -fn parse_response(input: ParseStream<'_>, attributes: Vec) -> syn::Result { - let response_kw: kw::response = input.parse()?; - let _: Token![:] = input.parse()?; - let fields; - braced!(fields in input); - - let mut newtype_body_field = None; - - let fields: Vec<_> = fields - .parse_terminated::(Field::parse_named)? - .into_iter() - .map(|mut field| { - if has_lifetime(&field.ty) { - return Err(syn::Error::new( - field.ident.span(), - "Lifetimes on Response fields cannot be supported until GAT are stable", - )); - } - - let mut field_kind = None; - let mut header = None; - - for attr in mem::take(&mut field.attrs) { - let meta = match Meta::from_attribute(&attr)? { - Some(m) => m, - None => { - field.attrs.push(attr); - continue; - } - }; - - if field_kind.is_some() { - return Err(syn::Error::new_spanned( - attr, - "There can only be one field kind attribute", - )); - } - - field_kind = Some(match meta { - Meta::Word(ident) => match &ident.to_string()[..] { - s @ "body" | s @ "raw_body" => req_res_meta_word( - s, - &field, - &mut newtype_body_field, - ResponseFieldKind::NewtypeBody, - ResponseFieldKind::NewtypeRawBody, - )?, - _ => { - return Err(syn::Error::new_spanned( - ident, - "Invalid #[ruma_api] argument with value, expected `body`", - )); - } - }, - Meta::NameValue(MetaNameValue { name, value }) => { - req_res_name_value(name, value, &mut header, ResponseFieldKind::Header)? - } - }); - } - - Ok(match field_kind.unwrap_or(ResponseFieldKind::Body) { - ResponseFieldKind::Body => ResponseField::Body(field), - ResponseFieldKind::Header => { - ResponseField::Header(field, header.expect("missing header name")) - } - ResponseFieldKind::NewtypeBody => ResponseField::NewtypeBody(field), - ResponseFieldKind::NewtypeRawBody => ResponseField::NewtypeRawBody(field), - }) - }) - .collect::>()?; - - if newtype_body_field.is_some() && fields.iter().any(|f| f.is_body()) { - // TODO: highlight conflicting fields, - return Err(syn::Error::new_spanned( - response_kw, - "Can't have both a newtype body field and regular body fields", - )); - } - - Ok(Response { attributes, fields }) -} - -fn has_lifetime(ty: &Type) -> bool { - let mut lifetimes = BTreeSet::new(); - collect_lifetime_idents(&mut lifetimes, ty); - !lifetimes.is_empty() -} - -fn collect_lifetime_idents(lifetimes: &mut BTreeSet, ty: &Type) { - struct Visitor<'lt>(&'lt mut BTreeSet); - impl<'ast> Visit<'ast> for Visitor<'_> { - fn visit_lifetime(&mut self, lt: &'ast Lifetime) { - self.0.insert(lt.clone()); - } - } - - Visitor(lifetimes).visit_type(ty) -} - -fn req_res_meta_word( - attr_kind: &str, - field: &Field, - newtype_body_field: &mut Option, - body_field_kind: T, - raw_field_kind: T, -) -> syn::Result { - if let Some(f) = &newtype_body_field { - let mut error = syn::Error::new_spanned(field, "There can only be one newtype body field"); - error.combine(syn::Error::new_spanned(f, "Previous newtype body field")); - return Err(error); - } - - *newtype_body_field = Some(field.clone()); - Ok(match attr_kind { - "body" => body_field_kind, - "raw_body" => raw_field_kind, - _ => unreachable!(), - }) -} - -fn req_res_name_value( - name: Ident, - value: Ident, - header: &mut Option, - field_kind: T, -) -> syn::Result { - if name != "header" { - return Err(syn::Error::new_spanned( - name, - "Invalid #[ruma_api] argument with value, expected `header`", - )); - } - - *header = Some(value); - Ok(field_kind) -} diff --git a/crates/ruma-api-macros/src/api/request.rs b/crates/ruma-api-macros/src/api/request.rs index 20c8cc41..66fb3b76 100644 --- a/crates/ruma-api-macros/src/api/request.rs +++ b/crates/ruma-api-macros/src/api/request.rs @@ -1,129 +1,67 @@ //! Details of the `request` section of the procedural macro. -use std::collections::BTreeSet; +use std::collections::btree_map::{BTreeMap, Entry}; use proc_macro2::TokenStream; use quote::quote; -use syn::{Attribute, Field, Ident, Lifetime}; +use syn::{ + parse_quote, punctuated::Punctuated, spanned::Spanned, visit::Visit, Attribute, Field, Ident, + Lifetime, Token, +}; -use crate::util; - -use super::metadata::Metadata; - -mod incoming; -mod outgoing; - -#[derive(Default)] -pub(super) struct RequestLifetimes { - pub body: BTreeSet, - pub path: BTreeSet, - pub query: BTreeSet, - pub header: BTreeSet, -} +use super::{kw, metadata::Metadata}; +use crate::util::{all_cfgs, all_cfgs_expr, extract_cfg}; /// The result of processing the `request` section of the macro. pub(crate) struct Request { + /// The `request` keyword + pub(super) request_kw: kw::request, + /// The attributes that will be applied to the struct definition. pub(super) attributes: Vec, /// The fields of the request. - pub(super) fields: Vec, - - /// The collected lifetime identifiers from the declared fields. - pub(super) lifetimes: RequestLifetimes, + pub(super) fields: Punctuated, } impl Request { - /// Whether or not this request has any data in the HTTP body. - pub(super) fn has_body_fields(&self) -> bool { - self.fields.iter().any(|field| field.is_body()) - } - - /// Whether or not this request has any data in HTTP headers. - fn has_header_fields(&self) -> bool { - self.fields.iter().any(|field| field.is_header()) - } - - /// Whether or not this request has any data in the URL path. - fn has_path_fields(&self) -> bool { - self.fields.iter().any(|field| field.is_path()) - } - - /// Whether or not this request has any data in the query string. - fn has_query_fields(&self) -> bool { - self.fields.iter().any(|field| field.is_query()) - } - - /// Produces an iterator over all the body fields. - pub(super) fn body_fields(&self) -> impl Iterator { - self.fields.iter().filter_map(|field| field.as_body_field()) - } - - /// Whether any `body` field has a lifetime annotation. - fn has_body_lifetimes(&self) -> bool { - !self.lifetimes.body.is_empty() - } - - /// Whether any `query` field has a lifetime annotation. - fn has_query_lifetimes(&self) -> bool { - !self.lifetimes.query.is_empty() - } - - /// Whether any field has a lifetime. - fn contains_lifetimes(&self) -> bool { - !(self.lifetimes.body.is_empty() - && self.lifetimes.path.is_empty() - && self.lifetimes.query.is_empty() - && self.lifetimes.header.is_empty()) - } - /// The combination of every fields unique lifetime annotation. - fn combine_lifetimes(&self) -> TokenStream { - util::unique_lifetimes_to_tokens( - [ - &self.lifetimes.body, - &self.lifetimes.path, - &self.lifetimes.query, - &self.lifetimes.header, - ] - .iter() - .flat_map(|set| set.iter()), - ) - } + fn all_lifetimes(&self) -> BTreeMap> { + let mut lifetimes = BTreeMap::new(); - /// The lifetimes on fields with the `query` attribute. - fn query_lifetimes(&self) -> TokenStream { - util::unique_lifetimes_to_tokens(&self.lifetimes.query) - } + struct Visitor<'lt> { + field_cfg: Option, + lifetimes: &'lt mut BTreeMap>, + } - /// The lifetimes on fields with the `body` attribute. - fn body_lifetimes(&self) -> TokenStream { - util::unique_lifetimes_to_tokens(&self.lifetimes.body) - } + impl<'ast> Visit<'ast> for Visitor<'_> { + fn visit_lifetime(&mut self, lt: &'ast Lifetime) { + match self.lifetimes.entry(lt.clone()) { + Entry::Vacant(v) => { + v.insert(self.field_cfg.clone()); + } + Entry::Occupied(mut o) => { + let lifetime_cfg = o.get_mut(); - /// Produces an iterator over all the header fields. - fn header_fields(&self) -> impl Iterator { - self.fields.iter().filter(|field| field.is_header()) - } + // If at least one field uses this lifetime and has no cfg attribute, we + // don't need a cfg attribute for the lifetime either. + *lifetime_cfg = Option::zip(lifetime_cfg.as_ref(), self.field_cfg.as_ref()) + .map(|(a, b)| { + let expr_a = extract_cfg(a); + let expr_b = extract_cfg(b); + parse_quote! { #[cfg( any( #expr_a, #expr_b ) )] } + }); + } + } + } + } - /// Gets the number of path fields. - fn path_field_count(&self) -> usize { - self.fields.iter().filter(|field| field.is_path()).count() - } + for field in &self.fields { + let field_cfg = if field.attrs.is_empty() { None } else { all_cfgs(&field.attrs) }; + Visitor { lifetimes: &mut lifetimes, field_cfg }.visit_type(&field.ty); + } - /// Returns the body field. - pub fn newtype_body_field(&self) -> Option<&Field> { - self.fields.iter().find_map(RequestField::as_newtype_body_field) - } - - /// Returns the body field. - fn newtype_raw_body_field(&self) -> Option<&Field> { - self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) - } - - /// Returns the query map field. - fn query_map_field(&self) -> Option<&Field> { - self.fields.iter().find_map(RequestField::as_query_map_field) + lifetimes } pub(super) fn expand( @@ -132,8 +70,8 @@ impl Request { error_ty: &TokenStream, ruma_api: &TokenStream, ) -> TokenStream { + let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - let serde = quote! { #ruma_api::exports::serde }; let docs = format!( "Data for a request to the `{}` API endpoint.\n\n{}", @@ -142,239 +80,44 @@ impl Request { ); let struct_attributes = &self.attributes; - let request_body_struct = - if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { - let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; - // Though we don't track the difference between new type body and body - // for lifetimes, the outer check and the macro failing if it encounters - // an illegal combination of field attributes, is enough to guarantee - // `body_lifetimes` correctness. - let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { - (TokenStream::new(), self.body_lifetimes()) - } else { - (quote! { #serde::Deserialize }, TokenStream::new()) - }; + let method = &metadata.method; + let path = &metadata.path; + let auth_attributes = metadata.authentication.iter().map(|field| { + let cfg_expr = all_cfgs_expr(&field.attrs); + let value = &field.value; - Some((derive_deserialize, quote! { #lifetimes (#field); })) - } else if self.has_body_fields() { - let fields = self.fields.iter().filter(|f| f.is_body()); - let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { - (TokenStream::new(), self.body_lifetimes()) - } else { - (quote! { #serde::Deserialize }, TokenStream::new()) - }; - let fields = fields.map(RequestField::field); - - Some((derive_deserialize, quote! { #lifetimes { #(#fields),* } })) - } else { - None + match cfg_expr { + Some(expr) => quote! { #[cfg_attr(#expr, ruma_api(authentication = #value))] }, + None => quote! { #[ruma_api(authentication = #value)] }, } - .map(|(derive_deserialize, def)| { - quote! { - /// Data in the request body. - #[derive( - Debug, - #ruma_serde::Outgoing, - #serde::Serialize, - #derive_deserialize - )] - struct RequestBody #def - } - }); + }); - let request_query_struct = if let Some(f) = self.query_map_field() { - let field = Field { ident: None, colon_token: None, ..f.clone() }; - let (derive_deserialize, lifetime) = if self.has_query_lifetimes() { - (TokenStream::new(), self.query_lifetimes()) - } else { - (quote! { #serde::Deserialize }, TokenStream::new()) - }; - - quote! { - /// Data in the request's query string. - #[derive( - Debug, - #ruma_serde::Outgoing, - #serde::Serialize, - #derive_deserialize - )] - struct RequestQuery #lifetime (#field); - } - } else if self.has_query_fields() { - let fields = self.fields.iter().filter_map(RequestField::as_query_field); - let (derive_deserialize, lifetime) = if self.has_query_lifetimes() { - (TokenStream::new(), self.query_lifetimes()) - } else { - (quote! { #serde::Deserialize }, TokenStream::new()) - }; - - quote! { - /// Data in the request's query string. - #[derive( - Debug, - #ruma_serde::Outgoing, - #serde::Serialize, - #derive_deserialize - )] - struct RequestQuery #lifetime { - #(#fields),* - } - } - } else { - TokenStream::new() - }; - - let lifetimes = self.combine_lifetimes(); - let fields = self.fields.iter().map(|request_field| request_field.field()); - - let outgoing_request_impl = self.expand_outgoing(metadata, error_ty, &lifetimes, ruma_api); - let incoming_request_impl = self.expand_incoming(metadata, error_ty, ruma_api); + let request_ident = Ident::new("Request", self.request_kw.span()); + let lifetimes = self.all_lifetimes(); + let lifetimes = lifetimes.iter().map(|(lt, attr)| quote! { #attr #lt }); + let fields = &self.fields; quote! { #[doc = #docs] - #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] + #[derive( + Clone, + Debug, + #ruma_api_macros::Request, + #ruma_serde::Outgoing, + #ruma_serde::_FakeDeriveSerde, + )] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] - #[incoming_derive(!Deserialize)] + #[incoming_derive(!Deserialize, #ruma_api_macros::_FakeDeriveRumaApi)] + #[ruma_api( + method = #method, + path = #path, + error_ty = #error_ty, + )] + #( #auth_attributes )* #( #struct_attributes )* - pub struct Request #lifetimes { - #(#fields),* + pub struct #request_ident < #(#lifetimes),* > { + #fields } - - #request_body_struct - #request_query_struct - - #outgoing_request_impl - #incoming_request_impl } } } - -/// The types of fields that a request can have. -pub(crate) enum RequestField { - /// JSON data in the body of the request. - Body(Field), - - /// Data in an HTTP header. - Header(Field, Ident), - - /// A specific data type in the body of the request. - NewtypeBody(Field), - - /// Arbitrary bytes in the body of the request. - NewtypeRawBody(Field), - - /// Data that appears in the URL path. - Path(Field), - - /// Data that appears in the query string. - Query(Field), - - /// Data that appears in the query string as dynamic key-value pairs. - QueryMap(Field), -} - -impl RequestField { - /// Creates a new `RequestField`. - pub(super) fn new(kind: RequestFieldKind, field: Field, header: Option) -> Self { - match kind { - RequestFieldKind::Body => RequestField::Body(field), - RequestFieldKind::Header => { - RequestField::Header(field, header.expect("missing header name")) - } - RequestFieldKind::NewtypeBody => RequestField::NewtypeBody(field), - RequestFieldKind::NewtypeRawBody => RequestField::NewtypeRawBody(field), - RequestFieldKind::Path => RequestField::Path(field), - RequestFieldKind::Query => RequestField::Query(field), - RequestFieldKind::QueryMap => RequestField::QueryMap(field), - } - } - - /// Whether or not this request field is a body kind. - pub(super) fn is_body(&self) -> bool { - matches!(self, RequestField::Body(..)) - } - - /// Whether or not this request field is a header kind. - fn is_header(&self) -> bool { - matches!(self, RequestField::Header(..)) - } - - /// Whether or not this request field is a newtype body kind. - fn is_newtype_body(&self) -> bool { - matches!(self, RequestField::NewtypeBody(..)) - } - - /// Whether or not this request field is a path kind. - fn is_path(&self) -> bool { - matches!(self, RequestField::Path(..)) - } - - /// Whether or not this request field is a query string kind. - pub(super) fn is_query(&self) -> bool { - matches!(self, RequestField::Query(..)) - } - - /// Return the contained field if this request field is a body kind. - fn as_body_field(&self) -> Option<&Field> { - self.field_of_kind(RequestFieldKind::Body) - } - - /// Return the contained field if this request field is a body kind. - fn as_newtype_body_field(&self) -> Option<&Field> { - self.field_of_kind(RequestFieldKind::NewtypeBody) - } - - /// Return the contained field if this request field is a raw body kind. - fn as_newtype_raw_body_field(&self) -> Option<&Field> { - self.field_of_kind(RequestFieldKind::NewtypeRawBody) - } - - /// Return the contained field if this request field is a query kind. - fn as_query_field(&self) -> Option<&Field> { - self.field_of_kind(RequestFieldKind::Query) - } - - /// Return the contained field if this request field is a query map kind. - fn as_query_map_field(&self) -> Option<&Field> { - self.field_of_kind(RequestFieldKind::QueryMap) - } - - /// Gets the inner `Field` value. - fn field(&self) -> &Field { - match self { - RequestField::Body(field) - | RequestField::Header(field, _) - | RequestField::NewtypeBody(field) - | RequestField::NewtypeRawBody(field) - | RequestField::Path(field) - | RequestField::Query(field) - | RequestField::QueryMap(field) => field, - } - } - - /// Gets the inner `Field` value if it's of the provided kind. - fn field_of_kind(&self, kind: RequestFieldKind) -> Option<&Field> { - match (self, kind) { - (RequestField::Body(field), RequestFieldKind::Body) - | (RequestField::Header(field, _), RequestFieldKind::Header) - | (RequestField::NewtypeBody(field), RequestFieldKind::NewtypeBody) - | (RequestField::NewtypeRawBody(field), RequestFieldKind::NewtypeRawBody) - | (RequestField::Path(field), RequestFieldKind::Path) - | (RequestField::Query(field), RequestFieldKind::Query) - | (RequestField::QueryMap(field), RequestFieldKind::QueryMap) => Some(field), - _ => None, - } - } -} - -/// The types of fields that a request can have, without their values. -#[derive(Clone, Copy, PartialEq, Eq)] -pub(crate) enum RequestFieldKind { - Body, - Header, - NewtypeBody, - NewtypeRawBody, - Path, - Query, - QueryMap, -} diff --git a/crates/ruma-api-macros/src/api/response.rs b/crates/ruma-api-macros/src/api/response.rs index 5aa412f1..b3b8ebe2 100644 --- a/crates/ruma-api-macros/src/api/response.rs +++ b/crates/ruma-api-macros/src/api/response.rs @@ -2,76 +2,40 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{Attribute, Field, Ident}; +use syn::{punctuated::Punctuated, spanned::Spanned, Attribute, Field, Ident, Token}; -use super::metadata::Metadata; - -mod incoming; -mod outgoing; +use super::{kw, metadata::Metadata}; /// The result of processing the `response` section of the macro. pub(crate) struct Response { + /// The `response` keyword + pub(super) response_kw: kw::response, + /// The attributes that will be applied to the struct definition. pub attributes: Vec, /// The fields of the response. - pub fields: Vec, + pub fields: Punctuated, } impl Response { - /// Whether or not this response has any data in the HTTP body. - fn has_body_fields(&self) -> bool { - self.fields.iter().any(|field| field.is_body()) - } - - /// Whether or not this response has any data in HTTP headers. - fn has_header_fields(&self) -> bool { - self.fields.iter().any(|field| field.is_header()) - } - - /// Gets the newtype body field, if this response has one. - fn newtype_body_field(&self) -> Option<&Field> { - self.fields.iter().find_map(ResponseField::as_newtype_body_field) - } - - /// Gets the newtype raw body field, if this response has one. - fn newtype_raw_body_field(&self) -> Option<&Field> { - self.fields.iter().find_map(ResponseField::as_newtype_raw_body_field) - } - pub(super) fn expand( &self, metadata: &Metadata, error_ty: &TokenStream, ruma_api: &TokenStream, ) -> TokenStream { + let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - let serde = quote! { #ruma_api::exports::serde }; let docs = format!("Data in the response from the `{}` API endpoint.", metadata.name.value()); let struct_attributes = &self.attributes; - let def = if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { - let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; - quote! { (#field); } - } else if self.has_body_fields() { - let fields = self.fields.iter().filter(|f| f.is_body()).map(ResponseField::field); - quote! { { #(#fields),* } } - } else { - quote! { {} } - }; - - let response_body_struct = quote! { - /// Data in the response body. - #[derive(Debug, #ruma_serde::Outgoing, #serde::Deserialize, #serde::Serialize)] - struct ResponseBody #def - }; - let has_test_exhaustive_field = self .fields .iter() - .filter_map(|f| f.field().ident.as_ref()) + .filter_map(|f| f.ident.as_ref()) .any(|ident| ident == "__test_exhaustive"); let non_exhaustive_attr = if has_test_exhaustive_field { @@ -80,99 +44,24 @@ impl Response { quote! { #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] } }; - let fields = self.fields.iter().map(|response_field| response_field.field()); - - let outgoing_response_impl = self.expand_outgoing(ruma_api); - let incoming_response_impl = self.expand_incoming(error_ty, ruma_api); - + let response_ident = Ident::new("Response", self.response_kw.span()); + let fields = &self.fields; quote! { #[doc = #docs] - #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] + #[derive( + Clone, + Debug, + #ruma_api_macros::Response, + #ruma_serde::Outgoing, + #ruma_serde::_FakeDeriveSerde, + )] #non_exhaustive_attr - #[incoming_derive(!Deserialize)] + #[incoming_derive(!Deserialize, #ruma_api_macros::_FakeDeriveRumaApi)] + #[ruma_api(error_ty = #error_ty)] #( #struct_attributes )* - pub struct Response { - #(#fields),* + pub struct #response_ident { + #fields } - - #response_body_struct - - #outgoing_response_impl - #incoming_response_impl } } } - -/// The types of fields that a response can have. -pub(crate) enum ResponseField { - /// JSON data in the body of the response. - Body(Field), - - /// Data in an HTTP header. - Header(Field, Ident), - - /// A specific data type in the body of the response. - NewtypeBody(Field), - - /// Arbitrary bytes in the body of the response. - NewtypeRawBody(Field), -} - -impl ResponseField { - /// Gets the inner `Field` value. - fn field(&self) -> &Field { - match self { - ResponseField::Body(field) - | ResponseField::Header(field, _) - | ResponseField::NewtypeBody(field) - | ResponseField::NewtypeRawBody(field) => field, - } - } - - /// Whether or not this response field is a body kind. - pub(super) fn is_body(&self) -> bool { - self.as_body_field().is_some() - } - - /// Whether or not this response field is a header kind. - fn is_header(&self) -> bool { - matches!(self, ResponseField::Header(..)) - } - - /// Whether or not this response field is a newtype body kind. - fn is_newtype_body(&self) -> bool { - self.as_newtype_body_field().is_some() - } - - /// Return the contained field if this response field is a body kind. - fn as_body_field(&self) -> Option<&Field> { - match self { - ResponseField::Body(field) => Some(field), - _ => None, - } - } - - /// Return the contained field if this response field is a newtype body kind. - fn as_newtype_body_field(&self) -> Option<&Field> { - match self { - ResponseField::NewtypeBody(field) => Some(field), - _ => None, - } - } - - /// Return the contained field if this response field is a newtype raw body kind. - fn as_newtype_raw_body_field(&self) -> Option<&Field> { - match self { - ResponseField::NewtypeRawBody(field) => Some(field), - _ => None, - } - } -} - -/// The types of fields that a response can have, without their values. -pub(crate) enum ResponseFieldKind { - Body, - Header, - NewtypeBody, - NewtypeRawBody, -} diff --git a/crates/ruma-api-macros/src/api/attribute.rs b/crates/ruma-api-macros/src/attribute.rs similarity index 63% rename from crates/ruma-api-macros/src/api/attribute.rs rename to crates/ruma-api-macros/src/attribute.rs index d563045d..a8c76869 100644 --- a/crates/ruma-api-macros/src/api/attribute.rs +++ b/crates/ruma-api-macros/src/attribute.rs @@ -2,17 +2,42 @@ use syn::{ parse::{Parse, ParseStream}, - Ident, Token, + Ident, Lit, Token, Type, }; +/// Value type used for request and response struct attributes +#[allow(clippy::large_enum_variant)] +pub enum MetaValue { + Lit(Lit), + Type(Type), +} + +impl Parse for MetaValue { + fn parse(input: ParseStream<'_>) -> syn::Result { + if input.peek(Lit) { + input.parse().map(Self::Lit) + } else { + input.parse().map(Self::Type) + } + } +} + /// Like syn::MetaNameValue, but expects an identifier as the value. Also, we don't care about the /// the span of the equals sign, so we don't have the `eq_token` field from syn::MetaNameValue. -pub struct MetaNameValue { +pub struct MetaNameValue { /// The part left of the equals sign pub name: Ident, /// The part right of the equals sign - pub value: Ident, + pub value: V, +} + +impl Parse for MetaNameValue { + fn parse(input: ParseStream<'_>) -> syn::Result { + let ident = input.parse()?; + let _: Token![=] = input.parse()?; + Ok(MetaNameValue { name: ident, value: input.parse()? }) + } } /// Like syn::Meta, but only parses ruma_api attributes @@ -21,7 +46,7 @@ pub enum Meta { Word(Ident), /// A name-value pair, like `header = CONTENT_TYPE` in `#[ruma_api(header = CONTENT_TYPE)]` - NameValue(MetaNameValue), + NameValue(MetaNameValue), } impl Meta { diff --git a/crates/ruma-api-macros/src/auth_scheme.rs b/crates/ruma-api-macros/src/auth_scheme.rs new file mode 100644 index 00000000..90338aee --- /dev/null +++ b/crates/ruma-api-macros/src/auth_scheme.rs @@ -0,0 +1,46 @@ +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::parse::{Parse, ParseStream}; + +mod kw { + syn::custom_keyword!(None); + syn::custom_keyword!(AccessToken); + syn::custom_keyword!(ServerSignatures); + syn::custom_keyword!(QueryOnlyAccessToken); +} + +pub enum AuthScheme { + None(kw::None), + AccessToken(kw::AccessToken), + ServerSignatures(kw::ServerSignatures), + QueryOnlyAccessToken(kw::QueryOnlyAccessToken), +} + +impl Parse for AuthScheme { + fn parse(input: ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + + if lookahead.peek(kw::None) { + input.parse().map(Self::None) + } else if lookahead.peek(kw::AccessToken) { + input.parse().map(Self::AccessToken) + } else if lookahead.peek(kw::ServerSignatures) { + input.parse().map(Self::ServerSignatures) + } else if lookahead.peek(kw::QueryOnlyAccessToken) { + input.parse().map(Self::QueryOnlyAccessToken) + } else { + Err(lookahead.error()) + } + } +} + +impl ToTokens for AuthScheme { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + AuthScheme::None(kw) => kw.to_tokens(tokens), + AuthScheme::AccessToken(kw) => kw.to_tokens(tokens), + AuthScheme::ServerSignatures(kw) => kw.to_tokens(tokens), + AuthScheme::QueryOnlyAccessToken(kw) => kw.to_tokens(tokens), + } + } +} diff --git a/crates/ruma-api-macros/src/lib.rs b/crates/ruma-api-macros/src/lib.rs index 6b6f11de..e8e80c0f 100644 --- a/crates/ruma-api-macros/src/lib.rs +++ b/crates/ruma-api-macros/src/lib.rs @@ -11,15 +11,44 @@ #![recursion_limit = "256"] use proc_macro::TokenStream; -use syn::parse_macro_input; - -use self::api::Api; +use syn::{parse_macro_input, DeriveInput}; mod api; +mod attribute; +mod auth_scheme; +mod request; +mod response; mod util; +use api::Api; +use request::expand_derive_request; +use response::expand_derive_response; + #[proc_macro] pub fn ruma_api(input: TokenStream) -> TokenStream { let api = parse_macro_input!(input as Api); - api::expand_all(api).unwrap_or_else(syn::Error::into_compile_error).into() + api.expand_all().into() +} + +/// Internal helper taking care of the request-specific parts of `ruma_api!`. +#[proc_macro_derive(Request, attributes(ruma_api))] +pub fn derive_request(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_derive_request(input).unwrap_or_else(syn::Error::into_compile_error).into() +} + +/// Internal helper taking care of the response-specific parts of `ruma_api!`. +#[proc_macro_derive(Response, attributes(ruma_api))] +pub fn derive_response(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_derive_response(input).unwrap_or_else(syn::Error::into_compile_error).into() +} + +/// A derive macro that generates no code, but registers the ruma_api attribute so both +/// `#[ruma_api(...)]` and `#[cfg_attr(..., ruma_api(...))]` are accepted on the type, its fields +/// and (in case the input is an enum) variants fields. +#[doc(hidden)] +#[proc_macro_derive(_FakeDeriveRumaApi, attributes(ruma_api))] +pub fn fake_derive_ruma_api(_input: TokenStream) -> TokenStream { + TokenStream::new() } diff --git a/crates/ruma-api-macros/src/request.rs b/crates/ruma-api-macros/src/request.rs new file mode 100644 index 00000000..99425cc5 --- /dev/null +++ b/crates/ruma-api-macros/src/request.rs @@ -0,0 +1,491 @@ +use std::{ + collections::BTreeSet, + convert::{TryFrom, TryInto}, + mem, +}; + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + parse_quote, + punctuated::Punctuated, + DeriveInput, Field, Generics, Ident, Lifetime, Lit, LitStr, Token, Type, +}; + +use crate::{ + attribute::{Meta, MetaNameValue, MetaValue}, + auth_scheme::AuthScheme, + util::{collect_lifetime_idents, import_ruma_api}, +}; + +mod incoming; +mod outgoing; + +pub fn expand_derive_request(input: DeriveInput) -> syn::Result { + let fields = match input.data { + syn::Data::Struct(s) => s.fields, + _ => panic!("This derive macro only works on structs"), + }; + + let mut lifetimes = RequestLifetimes::default(); + let fields = fields + .into_iter() + .map(|f| { + let f = RequestField::try_from(f)?; + let ty = &f.field().ty; + + match &f { + RequestField::Header(..) => collect_lifetime_idents(&mut lifetimes.header, ty), + RequestField::Body(_) => collect_lifetime_idents(&mut lifetimes.body, ty), + RequestField::NewtypeBody(_) => collect_lifetime_idents(&mut lifetimes.body, ty), + RequestField::NewtypeRawBody(_) => collect_lifetime_idents(&mut lifetimes.body, ty), + RequestField::Path(_) => collect_lifetime_idents(&mut lifetimes.path, ty), + RequestField::Query(_) => collect_lifetime_idents(&mut lifetimes.query, ty), + RequestField::QueryMap(_) => collect_lifetime_idents(&mut lifetimes.query, ty), + } + + Ok(f) + }) + .collect::>()?; + + let mut authentication = None; + let mut error_ty = None; + let mut method = None; + let mut path = None; + + for attr in input.attrs { + if !attr.path.is_ident("ruma_api") { + continue; + } + + let meta = attr.parse_args_with(Punctuated::<_, Token![,]>::parse_terminated)?; + for MetaNameValue { name, value } in meta { + match value { + MetaValue::Type(t) if name == "authentication" => { + authentication = Some(parse_quote!(#t)); + } + MetaValue::Type(t) if name == "method" => { + method = Some(parse_quote!(#t)); + } + MetaValue::Type(t) if name == "error_ty" => { + error_ty = Some(t); + } + MetaValue::Lit(Lit::Str(s)) if name == "path" => { + path = Some(s); + } + _ => unreachable!("invalid ruma_api({}) attribute", name), + } + } + } + + let request = Request { + ident: input.ident, + generics: input.generics, + fields, + lifetimes, + authentication: authentication.expect("missing authentication attribute"), + method: method.expect("missing method attribute"), + path: path.expect("missing path attribute"), + error_ty: error_ty.expect("missing error_ty attribute"), + }; + + request.check()?; + Ok(request.expand_all()) +} + +#[derive(Default)] +struct RequestLifetimes { + pub body: BTreeSet, + pub path: BTreeSet, + pub query: BTreeSet, + pub header: BTreeSet, +} + +struct Request { + ident: Ident, + generics: Generics, + lifetimes: RequestLifetimes, + fields: Vec, + + authentication: AuthScheme, + method: Ident, + path: LitStr, + error_ty: Type, +} + +impl Request { + fn body_fields(&self) -> impl Iterator { + self.fields.iter().filter_map(RequestField::as_body_field) + } + + fn query_fields(&self) -> impl Iterator { + self.fields.iter().filter_map(RequestField::as_query_field) + } + + fn has_body_fields(&self) -> bool { + self.fields.iter().any(|f| matches!(f, RequestField::Body(..))) + } + + fn has_header_fields(&self) -> bool { + self.fields.iter().any(|f| matches!(f, RequestField::Header(..))) + } + + fn has_path_fields(&self) -> bool { + self.fields.iter().any(|f| matches!(f, RequestField::Path(..))) + } + + fn has_query_fields(&self) -> bool { + self.fields.iter().any(|f| matches!(f, RequestField::Query(..))) + } + + fn has_lifetimes(&self) -> bool { + !(self.lifetimes.body.is_empty() + && self.lifetimes.path.is_empty() + && self.lifetimes.query.is_empty() + && self.lifetimes.header.is_empty()) + } + + fn header_fields(&self) -> impl Iterator { + self.fields.iter().filter(|f| matches!(f, RequestField::Header(..))) + } + + fn path_field_count(&self) -> usize { + self.fields.iter().filter(|f| matches!(f, RequestField::Path(..))).count() + } + + fn newtype_body_field(&self) -> Option<&Field> { + self.fields.iter().find_map(RequestField::as_newtype_body_field) + } + + fn newtype_raw_body_field(&self) -> Option<&Field> { + self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) + } + + fn query_map_field(&self) -> Option<&Field> { + self.fields.iter().find_map(RequestField::as_query_map_field) + } + + fn expand_all(&self) -> TokenStream { + let ruma_api = import_ruma_api(); + let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; + let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; + let serde = quote! { #ruma_api::exports::serde }; + + let request_body_def = if let Some(body_field) = self.newtype_body_field() { + let field = Field { ident: None, colon_token: None, ..body_field.clone() }; + Some(quote! { (#field); }) + } else if self.has_body_fields() { + let fields = self.fields.iter().filter_map(RequestField::as_body_field); + Some(quote! { { #(#fields),* } }) + } else { + None + }; + + let request_body_struct = request_body_def.map(|def| { + // Though we don't track the difference between newtype body and body + // for lifetimes, the outer check and the macro failing if it encounters + // an illegal combination of field attributes, is enough to guarantee + // `body_lifetimes` correctness. + let (derive_deserialize, generics) = if self.lifetimes.body.is_empty() { + (quote! { #serde::Deserialize }, TokenStream::new()) + } else { + let lifetimes = &self.lifetimes.body; + (TokenStream::new(), quote! { < #(#lifetimes),* > }) + }; + + quote! { + /// Data in the request body. + #[derive( + Debug, + #ruma_api_macros::_FakeDeriveRumaApi, + #ruma_serde::Outgoing, + #serde::Serialize, + #derive_deserialize + )] + struct RequestBody #generics #def + } + }); + + let request_query_def = if let Some(f) = self.query_map_field() { + let field = Field { ident: None, colon_token: None, ..f.clone() }; + Some(quote! { (#field); }) + } else if self.has_query_fields() { + let fields = self.fields.iter().filter_map(RequestField::as_query_field); + Some(quote! { { #(#fields),* } }) + } else { + None + }; + + let request_query_struct = request_query_def.map(|def| { + let (derive_deserialize, generics) = if self.lifetimes.query.is_empty() { + (quote! { #serde::Deserialize }, TokenStream::new()) + } else { + let lifetimes = &self.lifetimes.query; + (TokenStream::new(), quote! { < #(#lifetimes),* > }) + }; + + quote! { + /// Data in the request's query string. + #[derive( + Debug, + #ruma_api_macros::_FakeDeriveRumaApi, + #ruma_serde::Outgoing, + #serde::Serialize, + #derive_deserialize + )] + struct RequestQuery #generics #def + } + }); + + let outgoing_request_impl = self.expand_outgoing(&ruma_api); + let incoming_request_impl = self.expand_incoming(&ruma_api); + + quote! { + #request_body_struct + #request_query_struct + + #outgoing_request_impl + #incoming_request_impl + } + } + + pub(super) fn check(&self) -> syn::Result<()> { + // TODO: highlight problematic fields + + let newtype_body_fields = self.fields.iter().filter(|field| { + matches!(field, RequestField::NewtypeBody(_) | RequestField::NewtypeRawBody(_)) + }); + + let has_newtype_body_field = match newtype_body_fields.count() { + 0 => false, + 1 => true, + _ => { + return Err(syn::Error::new_spanned( + &self.ident, + "Can't have more than one newtype body field", + )) + } + }; + + let query_map_fields = + self.fields.iter().filter(|f| matches!(f, RequestField::QueryMap(_))); + let has_query_map_field = match query_map_fields.count() { + 0 => false, + 1 => true, + _ => { + return Err(syn::Error::new_spanned( + &self.ident, + "Can't have more than one query_map field", + )) + } + }; + + let has_body_fields = self.body_fields().count() > 0; + let has_query_fields = self.query_fields().count() > 0; + + if has_newtype_body_field && has_body_fields { + return Err(syn::Error::new_spanned( + &self.ident, + "Can't have both a newtype body field and regular body fields", + )); + } + + if has_query_map_field && has_query_fields { + return Err(syn::Error::new_spanned( + &self.ident, + "Can't have both a query map field and regular query fields", + )); + } + + // TODO when/if `&[(&str, &str)]` is supported remove this + if has_query_map_field && !self.lifetimes.query.is_empty() { + return Err(syn::Error::new_spanned( + &self.ident, + "Lifetimes are not allowed for query_map fields", + )); + } + + if self.method == "GET" && (has_body_fields || has_newtype_body_field) { + return Err(syn::Error::new_spanned( + &self.ident, + "GET endpoints can't have body fields", + )); + } + + Ok(()) + } +} + +/// The types of fields that a request can have. +enum RequestField { + /// JSON data in the body of the request. + Body(Field), + + /// Data in an HTTP header. + Header(Field, Ident), + + /// A specific data type in the body of the request. + NewtypeBody(Field), + + /// Arbitrary bytes in the body of the request. + NewtypeRawBody(Field), + + /// Data that appears in the URL path. + Path(Field), + + /// Data that appears in the query string. + Query(Field), + + /// Data that appears in the query string as dynamic key-value pairs. + QueryMap(Field), +} + +impl RequestField { + /// Creates a new `RequestField`. + fn new(kind: RequestFieldKind, field: Field, header: Option) -> Self { + match kind { + RequestFieldKind::Body => RequestField::Body(field), + RequestFieldKind::Header => { + RequestField::Header(field, header.expect("missing header name")) + } + RequestFieldKind::NewtypeBody => RequestField::NewtypeBody(field), + RequestFieldKind::NewtypeRawBody => RequestField::NewtypeRawBody(field), + RequestFieldKind::Path => RequestField::Path(field), + RequestFieldKind::Query => RequestField::Query(field), + RequestFieldKind::QueryMap => RequestField::QueryMap(field), + } + } + + /// Return the contained field if this request field is a body kind. + pub fn as_body_field(&self) -> Option<&Field> { + self.field_of_kind(RequestFieldKind::Body) + } + + /// Return the contained field if this request field is a body kind. + pub fn as_newtype_body_field(&self) -> Option<&Field> { + self.field_of_kind(RequestFieldKind::NewtypeBody) + } + + /// Return the contained field if this request field is a raw body kind. + pub fn as_newtype_raw_body_field(&self) -> Option<&Field> { + self.field_of_kind(RequestFieldKind::NewtypeRawBody) + } + + /// Return the contained field if this request field is a query kind. + pub fn as_query_field(&self) -> Option<&Field> { + self.field_of_kind(RequestFieldKind::Query) + } + + /// Return the contained field if this request field is a query map kind. + pub fn as_query_map_field(&self) -> Option<&Field> { + self.field_of_kind(RequestFieldKind::QueryMap) + } + + /// Gets the inner `Field` value. + pub fn field(&self) -> &Field { + match self { + RequestField::Body(field) + | RequestField::Header(field, _) + | RequestField::NewtypeBody(field) + | RequestField::NewtypeRawBody(field) + | RequestField::Path(field) + | RequestField::Query(field) + | RequestField::QueryMap(field) => field, + } + } + + /// Gets the inner `Field` value if it's of the provided kind. + fn field_of_kind(&self, kind: RequestFieldKind) -> Option<&Field> { + match (self, kind) { + (RequestField::Body(field), RequestFieldKind::Body) + | (RequestField::Header(field, _), RequestFieldKind::Header) + | (RequestField::NewtypeBody(field), RequestFieldKind::NewtypeBody) + | (RequestField::NewtypeRawBody(field), RequestFieldKind::NewtypeRawBody) + | (RequestField::Path(field), RequestFieldKind::Path) + | (RequestField::Query(field), RequestFieldKind::Query) + | (RequestField::QueryMap(field), RequestFieldKind::QueryMap) => Some(field), + _ => None, + } + } +} + +impl TryFrom for RequestField { + type Error = syn::Error; + + fn try_from(mut field: Field) -> syn::Result { + let mut field_kind = None; + let mut header = None; + + for attr in mem::take(&mut field.attrs) { + let meta = match Meta::from_attribute(&attr)? { + Some(m) => m, + None => { + field.attrs.push(attr); + continue; + } + }; + + if field_kind.is_some() { + return Err(syn::Error::new_spanned( + attr, + "There can only be one field kind attribute", + )); + } + + field_kind = Some(match meta { + Meta::Word(ident) => match &ident.to_string()[..] { + "body" => RequestFieldKind::Body, + "raw_body" => RequestFieldKind::NewtypeRawBody, + "path" => RequestFieldKind::Path, + "query" => RequestFieldKind::Query, + "query_map" => RequestFieldKind::QueryMap, + _ => { + return Err(syn::Error::new_spanned( + ident, + "Invalid #[ruma_api] argument, expected one of \ + `body`, `raw_body`, `path`, `query`, `query_map`", + )); + } + }, + Meta::NameValue(MetaNameValue { name, value }) => { + if name != "header" { + return Err(syn::Error::new_spanned( + name, + "Invalid #[ruma_api] argument with value, expected `header`", + )); + } + + header = Some(value); + RequestFieldKind::Header + } + }); + } + + Ok(RequestField::new(field_kind.unwrap_or(RequestFieldKind::Body), field, header)) + } +} + +impl Parse for RequestField { + fn parse(input: ParseStream<'_>) -> syn::Result { + input.call(Field::parse_named)?.try_into() + } +} + +impl ToTokens for RequestField { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.field().to_tokens(tokens) + } +} + +/// The types of fields that a request can have, without their values. +#[derive(Clone, Copy, PartialEq, Eq)] +enum RequestFieldKind { + Body, + Header, + NewtypeBody, + NewtypeRawBody, + Path, + Query, + QueryMap, +} diff --git a/crates/ruma-api-macros/src/api/request/incoming.rs b/crates/ruma-api-macros/src/request/incoming.rs similarity index 91% rename from crates/ruma-api-macros/src/api/request/incoming.rs rename to crates/ruma-api-macros/src/request/incoming.rs index 6fd0ce64..272dcd4f 100644 --- a/crates/ruma-api-macros/src/api/request/incoming.rs +++ b/crates/ruma-api-macros/src/request/incoming.rs @@ -2,23 +2,19 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use super::{Request, RequestField, RequestFieldKind}; -use crate::api::metadata::{AuthScheme, Metadata}; +use crate::auth_scheme::AuthScheme; impl Request { - pub fn expand_incoming( - &self, - metadata: &Metadata, - error_ty: &TokenStream, - ruma_api: &TokenStream, - ) -> TokenStream { + pub fn expand_incoming(&self, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde_json = quote! { #ruma_api::exports::serde_json }; - let method = &metadata.method; + let method = &self.method; + let error_ty = &self.error_ty; - let incoming_request_type = if self.contains_lifetimes() { + let incoming_request_type = if self.has_lifetimes() { quote! { IncomingRequest } } else { quote! { Request } @@ -28,7 +24,7 @@ impl Request { // except this one. If we get errors about missing fields in IncomingRequest for // a path field look here. let (parse_request_path, path_vars) = if self.has_path_fields() { - let path_string = metadata.path.value(); + let path_string = self.path.value(); assert!(path_string.starts_with('/'), "path needs to start with '/'"); assert!( @@ -172,7 +168,7 @@ impl Request { let extract_body = (self.has_body_fields() || self.newtype_body_field().is_some()).then(|| { - let body_lifetimes = self.has_body_lifetimes().then(|| { + let body_lifetimes = (!self.lifetimes.body.is_empty()).then(|| { // duplicate the anonymous lifetime as many times as needed let lifetimes = std::iter::repeat(quote! { '_ }).take(self.lifetimes.body.len()); @@ -218,16 +214,12 @@ impl Request { self.vars(RequestFieldKind::Body, quote! { request_body }) }; - let non_auth_impls = metadata.authentication.iter().filter_map(|auth| { - matches!(auth.value, AuthScheme::None(_)).then(|| { - let attrs = &auth.attrs; - quote! { - #( #attrs )* - #[automatically_derived] - #[cfg(feature = "server")] - impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} - } - }) + let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { + quote! { + #[automatically_derived] + #[cfg(feature = "server")] + impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {} + } }); quote! { @@ -265,7 +257,7 @@ impl Request { } } - #(#non_auth_impls)* + #non_auth_impl } } diff --git a/crates/ruma-api-macros/src/api/request/outgoing.rs b/crates/ruma-api-macros/src/request/outgoing.rs similarity index 79% rename from crates/ruma-api-macros/src/api/request/outgoing.rs rename to crates/ruma-api-macros/src/request/outgoing.rs index 30c87d76..c4068e25 100644 --- a/crates/ruma-api-macros/src/api/request/outgoing.rs +++ b/crates/ruma-api-macros/src/request/outgoing.rs @@ -1,26 +1,21 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use crate::api::metadata::{AuthScheme, Metadata}; +use crate::auth_scheme::AuthScheme; use super::{Request, RequestField, RequestFieldKind}; impl Request { - pub fn expand_outgoing( - &self, - metadata: &Metadata, - error_ty: &TokenStream, - lifetimes: &TokenStream, - ruma_api: &TokenStream, - ) -> TokenStream { + pub fn expand_outgoing(&self, ruma_api: &TokenStream) -> TokenStream { let bytes = quote! { #ruma_api::exports::bytes }; let http = quote! { #ruma_api::exports::http }; let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - let method = &metadata.method; + let method = &self.method; + let error_ty = &self.error_ty; let request_path_string = if self.has_path_fields() { - let mut format_string = metadata.path.value(); + let mut format_string = self.path.value(); let mut format_args = Vec::new(); while let Some(start_of_segment) = format_string.find(':') { @@ -132,38 +127,31 @@ impl Request { }) .collect(); - for auth in &metadata.authentication { - let attrs = &auth.attrs; - - let hdr_kv = match auth.value { - AuthScheme::AccessToken(_) => quote! { - #( #attrs )* + let hdr_kv = match self.authentication { + AuthScheme::AccessToken(_) => quote! { + req_headers.insert( + #http::header::AUTHORIZATION, + ::std::convert::TryFrom::<_>::try_from(::std::format!( + "Bearer {}", + access_token + .get_required_for_endpoint() + .ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?, + ))?, + ); + }, + AuthScheme::None(_) => quote! { + if let Some(access_token) = access_token.get_not_required_for_endpoint() { req_headers.insert( #http::header::AUTHORIZATION, - ::std::convert::TryFrom::<_>::try_from(::std::format!( - "Bearer {}", - access_token - .get_required_for_endpoint() - .ok_or(#ruma_api::error::IntoHttpError::NeedsAuthentication)?, - ))?, + ::std::convert::TryFrom::<_>::try_from( + ::std::format!("Bearer {}", access_token), + )? ); - }, - AuthScheme::None(_) => quote! { - if let Some(access_token) = access_token.get_not_required_for_endpoint() { - #( #attrs )* - req_headers.insert( - #http::header::AUTHORIZATION, - ::std::convert::TryFrom::<_>::try_from( - ::std::format!("Bearer {}", access_token), - )? - ); - } - }, - AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {}, - }; - - header_kvs.extend(hdr_kv); - } + } + }, + AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {}, + }; + header_kvs.extend(hdr_kv); let request_body = if let Some(field) = self.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -185,22 +173,21 @@ impl Request { quote! { ::default() } }; - let non_auth_impls = metadata.authentication.iter().filter_map(|auth| { - matches!(auth.value, AuthScheme::None(_)).then(|| { - let attrs = &auth.attrs; - quote! { - #( #attrs )* - #[automatically_derived] - #[cfg(feature = "client")] - impl #lifetimes #ruma_api::OutgoingNonAuthRequest for Request #lifetimes {} - } - }) + let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + + let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { + quote! { + #[automatically_derived] + #[cfg(feature = "client")] + impl #impl_generics #ruma_api::OutgoingNonAuthRequest + for Request #ty_generics #where_clause {} + } }); quote! { #[automatically_derived] #[cfg(feature = "client")] - impl #lifetimes #ruma_api::OutgoingRequest for Request #lifetimes { + impl #impl_generics #ruma_api::OutgoingRequest for Request #ty_generics #where_clause { type EndpointError = #error_ty; type IncomingResponse = ::Incoming; @@ -236,7 +223,7 @@ impl Request { } } - #(#non_auth_impls)* + #non_auth_impl } } diff --git a/crates/ruma-api-macros/src/response.rs b/crates/ruma-api-macros/src/response.rs new file mode 100644 index 00000000..9542bcb5 --- /dev/null +++ b/crates/ruma-api-macros/src/response.rs @@ -0,0 +1,314 @@ +use std::{ + convert::{TryFrom, TryInto}, + mem, +}; + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + visit::Visit, + DeriveInput, Field, Generics, Ident, Lifetime, Token, Type, +}; + +use crate::{ + attribute::{Meta, MetaNameValue, MetaValue}, + util, +}; + +mod incoming; +mod outgoing; + +pub fn expand_derive_response(input: DeriveInput) -> syn::Result { + let fields = match input.data { + syn::Data::Struct(s) => s.fields, + _ => panic!("This derive macro only works on structs"), + }; + + let fields = fields.into_iter().map(ResponseField::try_from).collect::>()?; + let mut error_ty = None; + for attr in input.attrs { + if !attr.path.is_ident("ruma_api") { + continue; + } + + let meta = attr.parse_args_with(Punctuated::<_, Token![,]>::parse_terminated)?; + for MetaNameValue { name, value } in meta { + match value { + MetaValue::Type(t) if name == "error_ty" => { + error_ty = Some(t); + } + _ => unreachable!("invalid ruma_api({}) attribute", name), + } + } + } + + let response = Response { + ident: input.ident, + generics: input.generics, + fields, + error_ty: error_ty.unwrap(), + }; + + response.check()?; + Ok(response.expand_all()) +} + +struct Response { + ident: Ident, + generics: Generics, + fields: Vec, + error_ty: Type, +} + +impl Response { + /// Whether or not this request has any data in the HTTP body. + fn has_body_fields(&self) -> bool { + self.fields.iter().any(|f| matches!(f, ResponseField::Body(_))) + } + + /// Returns the body field. + fn newtype_body_field(&self) -> Option<&Field> { + self.fields.iter().find_map(ResponseField::as_newtype_body_field) + } + + /// Returns the body field. + fn newtype_raw_body_field(&self) -> Option<&Field> { + self.fields.iter().find_map(ResponseField::as_newtype_raw_body_field) + } + + /// Whether or not this request has any data in the URL path. + fn has_header_fields(&self) -> bool { + self.fields.iter().any(|f| matches!(f, &ResponseField::Header(..))) + } + + fn expand_all(&self) -> TokenStream { + let ruma_api = util::import_ruma_api(); + let ruma_api_macros = quote! { #ruma_api::exports::ruma_api_macros }; + let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; + let serde = quote! { #ruma_api::exports::serde }; + + let response_body_struct = + self.fields.iter().all(|f| !matches!(f, ResponseField::NewtypeRawBody(_))).then(|| { + let newtype_body_field = + self.fields.iter().find(|f| matches!(f, ResponseField::NewtypeBody(_))); + let def = if let Some(body_field) = newtype_body_field { + let field = + Field { ident: None, colon_token: None, ..body_field.field().clone() }; + quote! { (#field); } + } else { + let fields = self.fields.iter().filter_map(|f| f.as_body_field()); + quote! { { #(#fields),* } } + }; + + quote! { + /// Data in the response body. + #[derive( + Debug, + #ruma_api_macros::_FakeDeriveRumaApi, + #ruma_serde::Outgoing, + #serde::Deserialize, + #serde::Serialize, + )] + struct ResponseBody #def + } + }); + + let outgoing_response_impl = self.expand_outgoing(&ruma_api); + let incoming_response_impl = self.expand_incoming(&self.error_ty, &ruma_api); + + quote! { + #response_body_struct + + #outgoing_response_impl + #incoming_response_impl + } + } + + pub fn check(&self) -> syn::Result<()> { + // TODO: highlight problematic fields + + if !self.generics.params.is_empty() || self.generics.where_clause.is_some() { + panic!("This macro doesn't support generic types"); + } + + let newtype_body_fields = self.fields.iter().filter(|f| { + matches!(f, ResponseField::NewtypeBody(_) | ResponseField::NewtypeRawBody(_)) + }); + + let has_newtype_body_field = match newtype_body_fields.count() { + 0 => false, + 1 => true, + _ => { + return Err(syn::Error::new_spanned( + &self.ident, + "Can't have more than one newtype body field", + )) + } + }; + + let has_body_fields = self.fields.iter().any(|f| matches!(f, ResponseField::Body(_))); + if has_newtype_body_field && has_body_fields { + return Err(syn::Error::new_spanned( + &self.ident, + "Can't have both a newtype body field and regular body fields", + )); + } + + Ok(()) + } +} + +/// The types of fields that a response can have. +enum ResponseField { + /// JSON data in the body of the response. + Body(Field), + + /// Data in an HTTP header. + Header(Field, Ident), + + /// A specific data type in the body of the response. + NewtypeBody(Field), + + /// Arbitrary bytes in the body of the response. + NewtypeRawBody(Field), +} + +impl ResponseField { + /// Gets the inner `Field` value. + fn field(&self) -> &Field { + match self { + ResponseField::Body(field) + | ResponseField::Header(field, _) + | ResponseField::NewtypeBody(field) + | ResponseField::NewtypeRawBody(field) => field, + } + } + + /// Return the contained field if this response field is a body kind. + fn as_body_field(&self) -> Option<&Field> { + match self { + ResponseField::Body(field) => Some(field), + _ => None, + } + } + + /// Return the contained field if this response field is a newtype body kind. + fn as_newtype_body_field(&self) -> Option<&Field> { + match self { + ResponseField::NewtypeBody(field) => Some(field), + _ => None, + } + } + + /// Return the contained field if this response field is a newtype raw body kind. + fn as_newtype_raw_body_field(&self) -> Option<&Field> { + match self { + ResponseField::NewtypeRawBody(field) => Some(field), + _ => None, + } + } +} + +impl TryFrom for ResponseField { + type Error = syn::Error; + + fn try_from(mut field: Field) -> syn::Result { + if has_lifetime(&field.ty) { + return Err(syn::Error::new_spanned( + field.ident, + "Lifetimes on Response fields cannot be supported until GAT are stable", + )); + } + + let mut field_kind = None; + let mut header = None; + + for attr in mem::take(&mut field.attrs) { + let meta = match Meta::from_attribute(&attr)? { + Some(m) => m, + None => { + field.attrs.push(attr); + continue; + } + }; + + if field_kind.is_some() { + return Err(syn::Error::new_spanned( + attr, + "There can only be one field kind attribute", + )); + } + + field_kind = Some(match meta { + Meta::Word(ident) => match &ident.to_string()[..] { + "body" => ResponseFieldKind::NewtypeBody, + "raw_body" => ResponseFieldKind::NewtypeRawBody, + _ => { + return Err(syn::Error::new_spanned( + ident, + "Invalid #[ruma_api] argument with value, expected `body`", + )); + } + }, + Meta::NameValue(MetaNameValue { name, value }) => { + if name != "header" { + return Err(syn::Error::new_spanned( + name, + "Invalid #[ruma_api] argument with value, expected `header`", + )); + } + + header = Some(value); + ResponseFieldKind::Header + } + }); + } + + Ok(match field_kind.unwrap_or(ResponseFieldKind::Body) { + ResponseFieldKind::Body => ResponseField::Body(field), + ResponseFieldKind::Header => { + ResponseField::Header(field, header.expect("missing header name")) + } + ResponseFieldKind::NewtypeBody => ResponseField::NewtypeBody(field), + ResponseFieldKind::NewtypeRawBody => ResponseField::NewtypeRawBody(field), + }) + } +} + +impl Parse for ResponseField { + fn parse(input: ParseStream<'_>) -> syn::Result { + input.call(Field::parse_named)?.try_into() + } +} + +impl ToTokens for ResponseField { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.field().to_tokens(tokens) + } +} + +/// The types of fields that a response can have, without their values. +enum ResponseFieldKind { + Body, + Header, + NewtypeBody, + NewtypeRawBody, +} + +fn has_lifetime(ty: &Type) -> bool { + struct Visitor { + found_lifetime: bool, + } + + impl<'ast> Visit<'ast> for Visitor { + fn visit_lifetime(&mut self, _lt: &'ast Lifetime) { + self.found_lifetime = true; + } + } + + let mut vis = Visitor { found_lifetime: false }; + vis.visit_type(ty); + vis.found_lifetime +} diff --git a/crates/ruma-api-macros/src/api/response/incoming.rs b/crates/ruma-api-macros/src/response/incoming.rs similarity index 98% rename from crates/ruma-api-macros/src/api/response/incoming.rs rename to crates/ruma-api-macros/src/response/incoming.rs index b128b66c..37c0e58f 100644 --- a/crates/ruma-api-macros/src/api/response/incoming.rs +++ b/crates/ruma-api-macros/src/response/incoming.rs @@ -1,10 +1,11 @@ use proc_macro2::TokenStream; use quote::quote; +use syn::Type; use super::{Response, ResponseField}; impl Response { - pub fn expand_incoming(&self, error_ty: &TokenStream, ruma_api: &TokenStream) -> TokenStream { + pub fn expand_incoming(&self, error_ty: &Type, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let serde_json = quote! { #ruma_api::exports::serde_json }; diff --git a/crates/ruma-api-macros/src/api/response/outgoing.rs b/crates/ruma-api-macros/src/response/outgoing.rs similarity index 100% rename from crates/ruma-api-macros/src/api/response/outgoing.rs rename to crates/ruma-api-macros/src/response/outgoing.rs diff --git a/crates/ruma-api-macros/src/util.rs b/crates/ruma-api-macros/src/util.rs index a5cf4735..0d21ab00 100644 --- a/crates/ruma-api-macros/src/util.rs +++ b/crates/ruma-api-macros/src/util.rs @@ -5,25 +5,9 @@ use std::collections::BTreeSet; use proc_macro2::TokenStream; use proc_macro_crate::{crate_name, FoundCrate}; use quote::{format_ident, quote}; -use syn::{AttrStyle, Attribute, Lifetime}; +use syn::{parse_quote, visit::Visit, AttrStyle, Attribute, Lifetime, NestedMeta, Type}; -/// Generates a `TokenStream` of lifetime identifiers `<'lifetime>`. -pub(crate) fn unique_lifetimes_to_tokens<'a, I: IntoIterator>( - lifetimes: I, -) -> TokenStream { - let lifetimes = lifetimes.into_iter().collect::>(); - if lifetimes.is_empty() { - TokenStream::new() - } else { - quote! { < #( #lifetimes ),* > } - } -} - -pub(crate) fn is_valid_endpoint_path(string: &str) -> bool { - string.as_bytes().iter().all(|b| (0x21..=0x7E).contains(b)) -} - -pub(crate) fn import_ruma_api() -> TokenStream { +pub fn import_ruma_api() -> TokenStream { if let Ok(FoundCrate::Name(name)) = crate_name("ruma-api") { let import = format_ident!("{}", name); quote! { ::#import } @@ -41,6 +25,48 @@ pub(crate) fn import_ruma_api() -> TokenStream { } } -pub(crate) fn is_cfg_attribute(attr: &Attribute) -> bool { +pub fn is_valid_endpoint_path(string: &str) -> bool { + string.as_bytes().iter().all(|b| (0x21..=0x7E).contains(b)) +} + +pub fn collect_lifetime_idents(lifetimes: &mut BTreeSet, ty: &Type) { + struct Visitor<'lt>(&'lt mut BTreeSet); + impl<'ast> Visit<'ast> for Visitor<'_> { + fn visit_lifetime(&mut self, lt: &'ast Lifetime) { + self.0.insert(lt.clone()); + } + } + + Visitor(lifetimes).visit_type(ty) +} + +pub fn is_cfg_attribute(attr: &Attribute) -> bool { matches!(attr.style, AttrStyle::Outer) && attr.path.is_ident("cfg") } + +pub fn all_cfgs_expr(cfgs: &[Attribute]) -> Option { + let sub_cfgs: Vec<_> = cfgs.iter().filter_map(extract_cfg).collect(); + (!sub_cfgs.is_empty()).then(|| quote! { all( #(#sub_cfgs),* ) }) +} + +pub fn all_cfgs(cfgs: &[Attribute]) -> Option { + let cfg_expr = all_cfgs_expr(cfgs)?; + Some(parse_quote! { #[cfg( #cfg_expr )] }) +} + +pub fn extract_cfg(attr: &Attribute) -> Option { + if !attr.path.is_ident("cfg") { + return None; + } + + let meta = attr.parse_meta().expect("cfg attribute can be parsed to syn::Meta"); + let mut list = match meta { + syn::Meta::List(l) => l, + _ => panic!("unexpected cfg syntax"), + }; + + assert!(list.path.is_ident("cfg"), "expected cfg attributes only"); + assert_eq!(list.nested.len(), 1, "expected one item inside cfg()"); + + Some(list.nested.pop().unwrap().into_value()) +} diff --git a/crates/ruma-api/src/lib.rs b/crates/ruma-api/src/lib.rs index 8c1fe9fa..155d7fc6 100644 --- a/crates/ruma-api/src/lib.rs +++ b/crates/ruma-api/src/lib.rs @@ -205,6 +205,7 @@ pub mod exports { pub use bytes; pub use http; pub use percent_encoding; + pub use ruma_api_macros; pub use ruma_serde; pub use serde; pub use serde_json; diff --git a/crates/ruma-api/tests/ruma_api_lifetime.rs b/crates/ruma-api/tests/ruma_api_lifetime.rs index d8347b66..f7f51e8a 100644 --- a/crates/ruma-api/tests/ruma_api_lifetime.rs +++ b/crates/ruma-api/tests/ruma_api_lifetime.rs @@ -1,6 +1,6 @@ #![allow(clippy::exhaustive_structs)] -#[derive(Copy, Clone, Debug, ruma_serde::Outgoing, serde::Serialize)] +/*#[derive(Copy, Clone, Debug, ruma_serde::Outgoing, serde::Serialize)] pub struct OtherThing<'t> { pub some: &'t str, pub t: &'t [u8], @@ -31,7 +31,7 @@ mod empty_response { response: {} } -} +}*/ mod nested_types { use ruma_api::ruma_api; @@ -59,7 +59,7 @@ mod nested_types { } } -mod full_request_response { +/*mod full_request_response { use ruma_api::ruma_api; use super::{IncomingOtherThing, OtherThing}; @@ -159,4 +159,4 @@ mod query_fields { response: {} } -} +}*/