diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index c776b126..0a0f6eef 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -91,14 +91,6 @@ impl ToTokens for Api { let rate_limited = &self.metadata.rate_limited; let requires_authentication = &self.metadata.requires_authentication; - let non_auth_endpoint_impl = if requires_authentication.value { - quote! { - impl ::ruma_api::NonAuthEndpoint for Request {} - } - } else { - TokenStream::new() - }; - let request_type = &self.request; let response_type = &self.response; @@ -163,18 +155,27 @@ impl ToTokens for Api { TokenStream::new() }; - let extract_request_body = - if self.request.has_body_fields() || self.request.newtype_body_field().is_some() { - quote! { - let request_body: ::Incoming = - ::ruma_api::try_deserialize!( - request, - ::ruma_api::exports::serde_json::from_slice(request.body().as_slice()) - ); - } + let extract_request_body = if self.request.has_body_fields() + || self.request.newtype_body_field().is_some() + { + let body_lifetimes = if self.request.has_body_lifetimes() { + // duplicate the anonymous lifetime as many times as needed + let anon = quote! { '_, }; + let lifetimes = vec![&anon].repeat(self.request.body_lifetime_count()); + quote! { < #( #lifetimes )* >} } else { TokenStream::new() }; + quote! { + let request_body: ::Incoming = + ::ruma_api::try_deserialize!( + request, + ::ruma_api::exports::serde_json::from_slice(request.body().as_slice()) + ); + } + } else { + TokenStream::new() + }; let parse_request_headers = if self.request.has_header_fields() { self.request.parse_headers_from_request() @@ -194,19 +195,29 @@ impl ToTokens for Api { TokenStream::new() }; - let typed_response_body_decl = - if self.response.has_body_fields() || self.response.newtype_body_field().is_some() { - quote! { - let response_body: ::Incoming = - ::ruma_api::try_deserialize!( - response, - ::ruma_api::exports::serde_json::from_slice(response.body().as_slice()), - ); - } + let typed_response_body_decl = if self.response.has_body_fields() + || self.response.newtype_body_field().is_some() + { + let body_lifetimes = if self.response.has_body_lifetimes() { + // duplicate the anonymous lifetime as many times as needed + let anon = quote! { '_, }; + let lifetimes = vec![&anon].repeat(self.response.body_lifetime_count()); + quote! { < #( #lifetimes )* >} } else { TokenStream::new() }; + quote! { + let response_body: ::Incoming = + ::ruma_api::try_deserialize!( + response, + ::ruma_api::exports::serde_json::from_slice(response.body().as_slice()), + ); + } + } else { + TokenStream::new() + }; + let response_init_fields = self.response.init_fields(); let serialize_response_headers = self.response.apply_header_fields(); @@ -222,28 +233,24 @@ impl ToTokens for Api { let error = &self.error; - let res_life = self.response.lifetimes().collect::>(); - let req_life = self.request.lifetimes().collect::>(); + let response_lifetimes = self.response.combine_lifetimes(); + let request_lifetimes = self.request.combine_lifetimes(); - let response_lifetimes = util::generics_to_tokens(res_life.iter().cloned()); - let request_lifetimes = util::generics_to_tokens(req_life.iter().cloned()); - - let endpoint_impl_lifetimes = if res_life != req_life { - let diff = - res_life.into_iter().filter(|resp| !req_life.contains(resp)).collect::>(); - - util::generics_to_tokens(req_life.iter().cloned().chain(diff)) + let non_auth_endpoint_impl = if requires_authentication.value { + quote! { + impl #request_lifetimes ::ruma_api::NonAuthEndpoint for Request #request_lifetimes {} + } } else { - request_lifetimes.clone() + TokenStream::new() }; let api = quote! { // FIXME: These can't conflict with other imports, but it would still be nice not to // bring anything into scope that code outside the macro could then rely on. - use ::std::convert::TryInto as _; + // use ::std::convert::TryInto as _; use ::ruma_api::exports::serde::de::Error as _; - use ::ruma_api::exports::serde::Deserialize as _; + // use ::ruma_api::exports::serde::Deserialize as _; use ::ruma_api::Endpoint as _; #[doc = #request_doc] @@ -319,7 +326,7 @@ impl ToTokens for Api { } } - impl #endpoint_impl_lifetimes ::ruma_api::Endpoint for Request #request_lifetimes { + impl #request_lifetimes ::ruma_api::Endpoint for Request #request_lifetimes { type Response = Response #response_lifetimes; type ResponseError = #error; diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 4b7d7a77..10e62f52 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -14,13 +14,21 @@ use crate::{ util, }; +#[derive(Debug, Default)] +pub struct RequetLifetimes { + body: BTreeSet, + path: BTreeSet, + query: BTreeSet, + header: BTreeSet, +} + /// The result of processing the `request` section of the macro. pub struct Request { /// The fields of the request. fields: Vec, /// The collected lifetime identifiers from the declared fields. - lifetimes: Vec, + lifetimes: RequetLifetimes, } impl Request { @@ -101,15 +109,58 @@ impl Request { self.fields.iter().filter_map(|field| field.as_body_field()) } - /// Whether any field has a lifetime. - pub fn contains_lifetimes(&self) -> bool { - self.fields.iter().any(|f| util::has_lifetime(&f.field().ty)) + /// The number of unique lifetime annotations for `body` fields. + pub fn body_lifetime_count(&self) -> usize { + self.lifetimes.body.len() } - pub fn lifetimes(&self) -> impl Iterator { - self.lifetimes.iter() + /// Whether any `body` field has a lifetime annotation. + pub fn has_body_lifetimes(&self) -> bool { + !self.lifetimes.body.is_empty() } + /// Whether any `query` field has a lifetime annotation. + pub fn has_query_lifetimes(&self) -> bool { + !self.lifetimes.query.is_empty() + } + + /// Whether any field has a lifetime. + pub fn contains_lifetimes(&self) -> bool { + !(self.lifetimes.body.is_empty() + && self.lifetimes.path.is_empty() + && self.lifetimes.query.is_empty() + && self.lifetimes.header.is_empty()) + } + + /// The combination of every fields unique lifetime annotation. + pub fn combine_lifetimes(&self) -> TokenStream { + util::generics_to_tokens( + self.lifetimes + .body + .iter() + .chain(self.lifetimes.path.iter()) + .chain(self.lifetimes.query.iter()) + .chain(self.lifetimes.header.iter()) + .collect::>() + .into_iter(), + ) + } + + /// The lifetimes on fields with the `query` attribute. + pub fn query_lifetimes(&self) -> TokenStream { + util::generics_to_tokens(self.lifetimes.query.iter()) + } + + /// The lifetimes on fields with the `body` attribute. + pub fn body_lifetimes(&self) -> TokenStream { + util::generics_to_tokens(self.lifetimes.body.iter()) + } + + // /// The lifetimes on fields with the `header` attribute. + // pub fn header_lifetimes(&self) -> TokenStream { + // util::generics_to_tokens(self.lifetimes.header.iter()) + // } + /// Produces an iterator over all the header fields. pub fn header_fields(&self) -> impl Iterator { self.fields.iter().filter(|field| field.is_header()) @@ -201,7 +252,7 @@ impl TryFrom for Request { fn try_from(raw: RawRequest) -> syn::Result { let mut newtype_body_field = None; let mut query_map_field = None; - let mut lifetimes = BTreeSet::new(); + let mut lifetimes = RequetLifetimes::default(); let fields = raw .fields @@ -210,8 +261,6 @@ impl TryFrom for Request { let mut field_kind = None; let mut header = None; - util::copy_lifetime_ident(&mut lifetimes, &field.ty); - for attr in mem::replace(&mut field.attrs, Vec::new()) { let meta = match Meta::from_attribute(&attr)? { Some(m) => m, @@ -273,6 +322,16 @@ impl TryFrom for Request { }); } + match field_kind.unwrap_or(RequestFieldKind::Body) { + RequestFieldKind::Header => util::copy_lifetime_ident(&mut lifetimes.header, &field.ty), + RequestFieldKind::Body => util::copy_lifetime_ident(&mut lifetimes.body, &field.ty), + RequestFieldKind::NewtypeBody => util::copy_lifetime_ident(&mut lifetimes.body, &field.ty), + RequestFieldKind::NewtypeRawBody => util::copy_lifetime_ident(&mut lifetimes.body, &field.ty), + RequestFieldKind::Path => util::copy_lifetime_ident(&mut lifetimes.path, &field.ty), + RequestFieldKind::Query => util::copy_lifetime_ident(&mut lifetimes.query, &field.ty), + RequestFieldKind::QueryMap => util::copy_lifetime_ident(&mut lifetimes.query, &field.ty), + } + Ok(RequestField::new( field_kind.unwrap_or(RequestFieldKind::Body), field, @@ -297,7 +356,7 @@ impl TryFrom for Request { )); } - Ok(Self { fields, lifetimes: lifetimes.into_iter().collect() }) + Ok(Self { fields, lifetimes }) } } @@ -311,29 +370,29 @@ impl ToTokens for Request { quote! { { #(#fields),* } } }; - let request_generics = util::generics_to_tokens(self.lifetimes.iter()); + let request_generics = self.combine_lifetimes(); let request_body_struct = if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; - let derive_deserialize = if util::has_lifetime(&body_field.field().ty) { - TokenStream::new() + // Though we don't track the difference between new type body and body + // for lifetimes, the outer check and the macro failing if it encounters + // an illegal combination of field attributes, is enough to guarantee `body_lifetimes` + // correctness. + let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { + (TokenStream::new(), self.body_lifetimes()) } else { - quote!(::ruma_api::exports::serde::Deserialize) + (quote!(::ruma_api::exports::serde::Deserialize), TokenStream::new()) }; - Some((derive_deserialize, quote! { (#field); })) + Some((derive_deserialize, quote! { #lifetimes (#field); })) } else if self.has_body_fields() { let fields = self.fields.iter().filter(|f| f.is_body()); - let (derive_deserialize, lifetimes) = - if fields.clone().any(|f| util::has_lifetime(&f.field().ty)) { - ( - TokenStream::new(), - util::collect_generic_idents(fields.clone().map(|f| &f.field().ty)), - ) - } else { - (quote!(::ruma_api::exports::serde::Deserialize), TokenStream::new()) - }; + let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { + (TokenStream::new(), self.body_lifetimes()) + } else { + (quote!(::ruma_api::exports::serde::Deserialize), TokenStream::new()) + }; let fields = fields.map(RequestField::field); Some((derive_deserialize, quote! { #lifetimes { #(#fields),* } })) @@ -355,12 +414,13 @@ impl ToTokens for Request { let request_query_struct = if let Some(f) = self.query_map_field() { let field = Field { ident: None, colon_token: None, ..f.clone() }; - let lifetime = util::collect_generic_idents(Some(&field.ty).into_iter()); + let lifetime = self.query_lifetimes(); quote! { /// Data in the request's query string. #[derive( Debug, + ::ruma_api::Outgoing, ::ruma_api::exports::serde::Deserialize, ::ruma_api::exports::serde::Serialize, )] @@ -368,12 +428,13 @@ impl ToTokens for Request { } } else if self.has_query_fields() { let fields = self.fields.iter().filter_map(RequestField::as_query_field); - let lifetime = util::collect_generic_idents(fields.clone().map(|f| &f.ty)); + let lifetime = self.query_lifetimes(); quote! { /// Data in the request's query string. #[derive( Debug, + ::ruma_api::Outgoing, ::ruma_api::exports::serde::Deserialize, ::ruma_api::exports::serde::Serialize, )] diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index 51a34434..2b19347e 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -14,13 +14,19 @@ use crate::{ util, }; +#[derive(Debug, Default)] +pub struct ResponseLifetimes { + body: BTreeSet, + header: BTreeSet, +} + /// The result of processing the `response` section of the macro. pub struct Response { /// The fields of the response. fields: Vec, /// The collected lifetime identifiers from the declared fields. - lifetimes: Vec, + lifetimes: ResponseLifetimes, } impl Response { @@ -34,13 +40,34 @@ impl Response { self.fields.iter().any(|field| field.is_header()) } - /// Whether any field has a lifetime. - pub fn contains_lifetimes(&self) -> bool { - self.fields.iter().any(|f| util::has_lifetime(&f.field().ty)) + /// Whether any `body` field has a lifetime annotation. + pub fn has_body_lifetimes(&self) -> bool { + !self.lifetimes.body.is_empty() } - pub fn lifetimes(&self) -> impl Iterator { - self.lifetimes.iter() + /// The number of unique lifetime annotations for `body` fields. + pub fn body_lifetime_count(&self) -> usize { + self.lifetimes.body.len() + } + + /// Whether any field has a lifetime annotation. + pub fn contains_lifetimes(&self) -> bool { + !(self.lifetimes.body.is_empty() && self.lifetimes.header.is_empty()) + } + + pub fn combine_lifetimes(&self) -> TokenStream { + util::generics_to_tokens( + self.lifetimes + .body + .iter() + .chain(self.lifetimes.header.iter()) + .collect::>() + .into_iter(), + ) + } + + pub fn body_lifetimes(&self) -> TokenStream { + util::generics_to_tokens(self.lifetimes.body.iter()) } /// Produces code for a response struct initializer. @@ -169,7 +196,7 @@ impl TryFrom for Response { fn try_from(raw: RawResponse) -> syn::Result { let mut newtype_body_field = None; - let mut lifetimes = BTreeSet::new(); + let mut lifetimes = ResponseLifetimes::default(); let fields = raw .fields @@ -178,8 +205,6 @@ impl TryFrom for Response { let mut field_kind = None; let mut header = None; - util::copy_lifetime_ident(&mut lifetimes, &field.ty); - for attr in mem::replace(&mut field.attrs, Vec::new()) { let meta = match Meta::from_attribute(&attr)? { Some(m) => m, @@ -222,12 +247,22 @@ impl TryFrom for Response { } Ok(match field_kind.unwrap_or(ResponseFieldKind::Body) { - ResponseFieldKind::Body => ResponseField::Body(field), + ResponseFieldKind::Body => { + util::copy_lifetime_ident(&mut lifetimes.body, &field.ty); + ResponseField::Body(field) + } ResponseFieldKind::Header => { + util::copy_lifetime_ident(&mut lifetimes.header, &field.ty); ResponseField::Header(field, header.expect("missing header name")) } - ResponseFieldKind::NewtypeBody => ResponseField::NewtypeBody(field), - ResponseFieldKind::NewtypeRawBody => ResponseField::NewtypeRawBody(field), + ResponseFieldKind::NewtypeBody => { + util::copy_lifetime_ident(&mut lifetimes.body, &field.ty); + ResponseField::NewtypeBody(field) + } + ResponseFieldKind::NewtypeRawBody => { + util::copy_lifetime_ident(&mut lifetimes.body, &field.ty); + ResponseField::NewtypeRawBody(field) + } }) }) .collect::>>()?; @@ -240,7 +275,7 @@ impl TryFrom for Response { )); } - Ok(Self { fields, lifetimes: lifetimes.into_iter().collect() }) + Ok(Self { fields, lifetimes }) } } @@ -255,39 +290,35 @@ impl ToTokens for Response { quote! { { #(#fields),* } } }; - let (derive_deserialize, def) = if let Some(body_field) = - self.fields.iter().find(|f| f.is_newtype_body()) - { - let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; + let (derive_deserialize, def) = + if let Some(body_field) = self.fields.iter().find(|f| f.is_newtype_body()) { + let field = Field { ident: None, colon_token: None, ..body_field.field().clone() }; - let (derive_deserialize, lifetimes) = if util::has_lifetime(&body_field.field().ty) { - ( - TokenStream::new(), - util::collect_generic_idents(Some(&body_field.field().ty).into_iter()), - ) - } else { - (quote!(::ruma_api::exports::serde::Deserialize), TokenStream::new()) - }; - - (derive_deserialize, quote! { #lifetimes (#field); }) - } else if self.has_body_fields() { - let fields = self.fields.iter().filter(|f| f.is_body()); - let (derive_deserialize, lifetimes) = - if fields.clone().any(|f| util::has_lifetime(&f.field().ty)) { - ( - TokenStream::new(), - util::collect_generic_idents(fields.clone().map(|f| &f.field().ty)), - ) + // Though we don't track the difference between new type body and body + // for lifetimes, the outer check and the macro failing if it encounters + // an illegal combination of field attributes, is enough to guarantee `body_lifetimes` + // correctness. + let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { + (TokenStream::new(), self.body_lifetimes()) } else { (quote!(::ruma_api::exports::serde::Deserialize), TokenStream::new()) }; - let fields = fields.map(ResponseField::field); + (derive_deserialize, quote! { #lifetimes (#field); }) + } else if self.has_body_fields() { + let fields = self.fields.iter().filter(|f| f.is_body()); + let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { + (TokenStream::new(), self.body_lifetimes()) + } else { + (quote!(::ruma_api::exports::serde::Deserialize), TokenStream::new()) + }; - (derive_deserialize, quote!( #lifetimes { #(#fields),* })) - } else { - (TokenStream::new(), quote!({})) - }; + let fields = fields.map(ResponseField::field); + + (derive_deserialize, quote!( #lifetimes { #(#fields),* })) + } else { + (TokenStream::new(), quote!({})) + }; let response_body_struct = quote! { /// Data in the response body. @@ -300,7 +331,7 @@ impl ToTokens for Response { struct ResponseBody #def }; - let response_generics = util::generics_to_tokens(self.lifetimes.iter()); + let response_generics = self.combine_lifetimes(); let response = quote! { #[derive(Debug, Clone, ::ruma_api::Outgoing)] #[incoming_no_deserialize] diff --git a/ruma-api-macros/src/util.rs b/ruma-api-macros/src/util.rs index 0022c904..3cc002ea 100644 --- a/ruma-api-macros/src/util.rs +++ b/ruma-api-macros/src/util.rs @@ -10,35 +10,6 @@ use syn::{ use crate::api::{metadata::Metadata, request::Request}; -/// Whether or not the request field has a lifetime. -pub fn has_lifetime(ty: &Type) -> bool { - let mut found_lifetime = false; - if let Type::Path(TypePath { path, .. }) = ty { - for seg in &path.segments { - #[allow(clippy::blocks_in_if_conditions)] // TODO - if match &seg.arguments { - PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => { - args.clone().iter().any(|gen| { - if let GenericArgument::Type(ty) = gen { - has_lifetime(ty) - } else { - matches!(gen, GenericArgument::Lifetime(_)) - } - }) - } - PathArguments::Parenthesized(ParenthesizedGenericArguments { inputs, .. }) => { - inputs.iter().any(|ty| has_lifetime(ty)) - } - _ => false, - } { - found_lifetime = true; - } - } - } - - matches!(ty, Type::Reference(_) | Type::Slice(_)) || found_lifetime -} - pub fn copy_lifetime_ident(lifetimes: &mut BTreeSet, ty: &Type) { match ty { Type::Path(TypePath { path, .. }) => { @@ -182,7 +153,6 @@ pub(crate) fn request_path_string_and_parse( pub(crate) fn build_query_string(request: &Request) -> TokenStream { if let Some(field) = request.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have identifier"); - let field_type = &field.ty; quote!({ // This function exists so that the compiler will throw an @@ -195,13 +165,14 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream { // // By asserting that it implements the iterator trait, we can // ensure that it won't fail. - fn assert_trait_impl() + fn assert_trait_impl(_: &T) where T: ::std::iter::IntoIterator, {} - assert_trait_impl::<#field_type>(); let request_query = RequestQuery(self.#field_name); + assert_trait_impl(&request_query.0); + format_args!( "?{}", ::ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)? @@ -237,8 +208,13 @@ pub(crate) fn extract_request_query(request: &Request) -> TokenStream { ); } } else if request.has_query_fields() { + let request_query_type = if request.has_query_lifetimes() { + quote! { IncomingRequestQuery } + } else { + quote! { RequestQuery } + }; quote! { - let request_query: RequestQuery = ::ruma_api::try_deserialize!( + let request_query: #request_query_type = ::ruma_api::try_deserialize!( request, ::ruma_api::exports::ruma_serde::urlencoded::from_str( &request.uri().query().unwrap_or("")