diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index cb1b2cff..fe12fedf 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -2,20 +2,21 @@ use std::convert::{TryFrom, TryInto as _}; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ braced, parse::{Parse, ParseStream}, - Field, FieldValue, Ident, Token, Type, + Field, FieldValue, Token, Type, }; -mod attribute; -mod metadata; -mod request; -mod response; +pub(crate) mod attribute; +pub(crate) mod metadata; +pub(crate) mod request; +pub(crate) mod response; use self::{metadata::Metadata, request::Request, response::Response}; +use crate::util; /// Removes `serde` attributes from struct fields. pub fn strip_serde_attrs(field: &Field) -> Field { @@ -121,159 +122,12 @@ impl ToTokens for Api { TokenStream::new() }; - let (request_path_string, parse_request_path) = if self.request.has_path_fields() { - let path_string = path.value(); + let (request_path_string, parse_request_path) = + util::request_path_string_and_parse(&self.request, &self.metadata); - assert!(path_string.starts_with('/'), "path needs to start with '/'"); - assert!( - path_string.chars().filter(|c| *c == ':').count() - == self.request.path_field_count(), - "number of declared path parameters needs to match amount of placeholders in path" - ); + let request_query_string = util::build_query_string(&self.request); - let format_call = { - let mut format_string = path_string.clone(); - let mut format_args = Vec::new(); - - while let Some(start_of_segment) = format_string.find(':') { - // ':' should only ever appear at the start of a segment - assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); - - let end_of_segment = match format_string[start_of_segment..].find('/') { - Some(rel_pos) => start_of_segment + rel_pos, - None => format_string.len(), - }; - - let path_var = Ident::new( - &format_string[start_of_segment + 1..end_of_segment], - Span::call_site(), - ); - format_args.push(quote! { - ruma_api::exports::percent_encoding::utf8_percent_encode( - &request.#path_var.to_string(), - ruma_api::exports::percent_encoding::NON_ALPHANUMERIC, - ) - }); - format_string.replace_range(start_of_segment..end_of_segment, "{}"); - } - - quote! { - format!(#format_string, #(#format_args),*) - } - }; - - let path_fields = path_string[1..] - .split('/') - .enumerate() - .filter(|(_, s)| s.starts_with(':')) - .map(|(i, segment)| { - let path_var = &segment[1..]; - let path_var_ident = Ident::new(path_var, Span::call_site()); - - quote! { - #path_var_ident: { - use std::ops::Deref as _; - use ruma_api::error::RequestDeserializationError; - - let segment = path_segments.get(#i).unwrap().as_bytes(); - let decoded = match ruma_api::exports::percent_encoding::percent_decode( - segment - ).decode_utf8() { - Ok(x) => x, - Err(err) => { - return Err( - RequestDeserializationError::new(err, request).into() - ); - } - }; - match std::convert::TryFrom::try_from(decoded.deref()) { - Ok(val) => val, - Err(err) => { - return Err( - RequestDeserializationError::new(err, request).into() - ); - } - } - } - } - }); - - (format_call, quote! { #(#path_fields,)* }) - } else { - (quote! { metadata.path.to_owned() }, TokenStream::new()) - }; - - let request_query_string = if let Some(field) = self.request.query_map_field() { - let field_name = field.ident.as_ref().expect("expected field to have identifier"); - let field_type = &field.ty; - - quote!({ - // This function exists so that the compiler will throw an - // error when the type of the field with the query_map - // attribute doesn't implement IntoIterator - // - // This is necessary because the ruma_serde::urlencoded::to_string - // call will result in a runtime error when the type cannot be - // encoded as a list key-value pairs (?key1=value1&key2=value2) - // - // By asserting that it implements the iterator trait, we can - // ensure that it won't fail. - fn assert_trait_impl() - where - T: std::iter::IntoIterator, - {} - assert_trait_impl::<#field_type>(); - - let request_query = RequestQuery(request.#field_name); - format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?) - }) - } else if self.request.has_query_fields() { - let request_query_init_fields = self.request.request_query_init_fields(); - - quote!({ - let request_query = RequestQuery { - #request_query_init_fields - }; - - format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?) - }) - } else { - quote! { - String::new() - } - }; - - let extract_request_query = if self.request.query_map_field().is_some() { - quote! { - let request_query = match ruma_api::exports::ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or("") - ) { - Ok(query) => query, - Err(err) => { - return Err( - ruma_api::error::RequestDeserializationError::new(err, request).into() - ); - } - }; - } - } else if self.request.has_query_fields() { - quote! { - let request_query: RequestQuery = - match ruma_api::exports::ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or("") - ) { - Ok(query) => query, - Err(err) => { - return Err( - ruma_api::error::RequestDeserializationError::new(err, request) - .into() - ); - } - }; - } - } else { - TokenStream::new() - }; + let extract_request_query = util::extract_request_query(&self.request); let parse_request_query = if let Some(field) = self.request.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -327,42 +181,9 @@ impl ToTokens for Api { TokenStream::new() }; - let request_body = if let Some(field) = self.request.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote!(request.#field_name) - } else if self.request.has_body_fields() || self.request.newtype_body_field().is_some() { - let request_body_initializers = if let Some(field) = self.request.newtype_body_field() { - let field_name = - field.ident.as_ref().expect("expected field to have an identifier"); - quote! { (request.#field_name) } - } else { - let initializers = self.request.request_body_init_fields(); - quote! { { #initializers } } - }; + let request_body = util::build_request_body(&self.request); - quote! { - { - let request_body = RequestBody #request_body_initializers; - ruma_api::exports::serde_json::to_vec(&request_body)? - } - } - } else { - quote!(Vec::new()) - }; - - let parse_request_body = if let Some(field) = self.request.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request_body.0, - } - } else if let Some(field) = self.request.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request.into_body(), - } - } else { - self.request.request_init_body_fields() - }; + let parse_request_body = util::parse_request_body(&self.request); let extract_response_headers = if self.response.has_header_fields() { quote! { diff --git a/ruma-api-macros/src/api/attribute.rs b/ruma-api-macros/src/api/attribute.rs index d958ac13..1d93eb97 100644 --- a/ruma-api-macros/src/api/attribute.rs +++ b/ruma-api-macros/src/api/attribute.rs @@ -24,10 +24,6 @@ pub enum Meta { impl Meta { /// Check if the given attribute is a ruma_api attribute. If it is, parse it. - /// - /// # Panics - /// - /// Panics if the given attribute is a ruma_api attribute, but fails to parse. pub fn from_attribute(attr: &syn::Attribute) -> syn::Result> { if attr.path.is_ident("ruma_api") { attr.parse_args().map(Some) diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 408060fc..3e6395b6 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -6,9 +6,12 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use syn::{spanned::Spanned, Field, Ident}; -use crate::api::{ - attribute::{Meta, MetaNameValue}, - strip_serde_attrs, RawRequest, +use crate::{ + api::{ + attribute::{Meta, MetaNameValue}, + strip_serde_attrs, RawRequest, + }, + util, }; /// The result of processing the `request` section of the macro. @@ -209,26 +212,13 @@ impl TryFrom for Request { field_kind = Some(match meta { Meta::Word(ident) => { match &ident.to_string()[..] { - s @ "body" | s @ "raw_body" => { - if let Some(f) = &newtype_body_field { - let mut error = syn::Error::new_spanned( - field, - "There can only be one newtype body field", - ); - error.combine(syn::Error::new_spanned( - f, - "Previous newtype body field", - )); - return Err(error); - } - - newtype_body_field = Some(field.clone()); - match s { - "body" => RequestFieldKind::NewtypeBody, - "raw_body" => RequestFieldKind::NewtypeRawBody, - _ => unreachable!(), - } - } + attr @ "body" | attr @ "raw_body" => util::req_res_meta_word( + attr, + &field, + &mut newtype_body_field, + RequestFieldKind::NewtypeBody, + RequestFieldKind::NewtypeRawBody, + )?, "path" => RequestFieldKind::Path, "query" => RequestFieldKind::Query, "query_map" => { @@ -255,17 +245,12 @@ impl TryFrom for Request { } } } - Meta::NameValue(MetaNameValue { name, value }) => { - if name != "header" { - return Err(syn::Error::new_spanned( - name, - "Invalid #[ruma_api] argument with value, expected `header`" - )); - } - - header = Some(value); - RequestFieldKind::Header - } + Meta::NameValue(MetaNameValue { name, value }) => util::req_res_name_value( + name, + value, + &mut header, + RequestFieldKind::Header, + )?, }); } diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index 2bdef543..35a99d0f 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -6,9 +6,12 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use syn::{spanned::Spanned, Field, Ident}; -use crate::api::{ - attribute::{Meta, MetaNameValue}, - strip_serde_attrs, RawResponse, +use crate::{ + api::{ + attribute::{Meta, MetaNameValue}, + strip_serde_attrs, RawResponse, + }, + util, }; /// The result of processing the `response` section of the macro. @@ -169,26 +172,13 @@ impl TryFrom for Response { field_kind = Some(match meta { Meta::Word(ident) => match &ident.to_string()[..] { - s @ "body" | s @ "raw_body" => { - if let Some(f) = &newtype_body_field { - let mut error = syn::Error::new_spanned( - field, - "There can only be one newtype body field", - ); - error.combine(syn::Error::new_spanned( - f, - "Previous newtype body field", - )); - return Err(error); - } - - newtype_body_field = Some(field.clone()); - match s { - "body" => ResponseFieldKind::NewtypeBody, - "raw_body" => ResponseFieldKind::NewtypeRawBody, - _ => unreachable!(), - } - } + s @ "body" | s @ "raw_body" => util::req_res_meta_word( + s, + &field, + &mut newtype_body_field, + ResponseFieldKind::NewtypeBody, + ResponseFieldKind::NewtypeRawBody, + )?, _ => { return Err(syn::Error::new_spanned( ident, @@ -196,17 +186,12 @@ impl TryFrom for Response { )); } }, - Meta::NameValue(MetaNameValue { name, value }) => { - if name != "header" { - return Err(syn::Error::new_spanned( - name, - "Invalid #[ruma_api] argument with value, expected `header`", - )); - } - - header = Some(value); - ResponseFieldKind::Header - } + Meta::NameValue(MetaNameValue { name, value }) => util::req_res_name_value( + name, + value, + &mut header, + ResponseFieldKind::Header, + )?, }); } @@ -340,7 +325,7 @@ impl ResponseField { } } - /// Whether or not the reponse field has a #[wrap_incoming] attribute. + /// Whether or not the response field has a #[wrap_incoming] attribute. fn has_wrap_incoming_attr(&self) -> bool { self.field().attrs.iter().any(|attr| { attr.path.segments.len() == 1 && attr.path.segments[0].ident == "wrap_incoming" diff --git a/ruma-api-macros/src/lib.rs b/ruma-api-macros/src/lib.rs index cad8e4a9..a65f1859 100644 --- a/ruma-api-macros/src/lib.rs +++ b/ruma-api-macros/src/lib.rs @@ -22,6 +22,7 @@ use syn::parse_macro_input; use self::api::{Api, RawApi}; mod api; +mod util; #[proc_macro] pub fn ruma_api(input: TokenStream) -> TokenStream { diff --git a/ruma-api-macros/src/util.rs b/ruma-api-macros/src/util.rs new file mode 100644 index 00000000..256bb93a --- /dev/null +++ b/ruma-api-macros/src/util.rs @@ -0,0 +1,259 @@ +//! Functions to aid the `Api::to_tokens` method. + +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::Ident; + +use crate::api::{metadata::Metadata, request::Request}; + +/// The first item in the tuple generates code for the request path from +/// the `Metadata` and `Request` structs. The second item in the returned tuple +/// is the code to generate a Request struct field created from any segments +/// of the path that start with ":". +/// +/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is +/// used for implementing `TryFrom>>`, from path strings deserialized to ruma types. +pub(crate) fn request_path_string_and_parse( + request: &Request, + metadata: &Metadata, +) -> (TokenStream, TokenStream) { + if request.has_path_fields() { + let path_string = metadata.path.value(); + + assert!(path_string.starts_with('/'), "path needs to start with '/'"); + assert!( + path_string.chars().filter(|c| *c == ':').count() == request.path_field_count(), + "number of declared path parameters needs to match amount of placeholders in path" + ); + + let format_call = { + let mut format_string = path_string.clone(); + let mut format_args = Vec::new(); + + while let Some(start_of_segment) = format_string.find(':') { + // ':' should only ever appear at the start of a segment + assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); + + let end_of_segment = match format_string[start_of_segment..].find('/') { + Some(rel_pos) => start_of_segment + rel_pos, + None => format_string.len(), + }; + + let path_var = Ident::new( + &format_string[start_of_segment + 1..end_of_segment], + Span::call_site(), + ); + format_args.push(quote! { + ruma_api::exports::percent_encoding::utf8_percent_encode( + &request.#path_var.to_string(), + ruma_api::exports::percent_encoding::NON_ALPHANUMERIC, + ) + }); + format_string.replace_range(start_of_segment..end_of_segment, "{}"); + } + + quote! { + format!(#format_string, #(#format_args),*) + } + }; + + let path_fields = + path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( + |(i, segment)| { + let path_var = &segment[1..]; + let path_var_ident = Ident::new(path_var, Span::call_site()); + + quote! { + #path_var_ident: { + use std::ops::Deref as _; + use ruma_api::error::RequestDeserializationError; + + let segment = path_segments.get(#i).unwrap().as_bytes(); + let decoded = match ruma_api::exports::percent_encoding::percent_decode( + segment + ).decode_utf8() { + Ok(x) => x, + Err(err) => { + return Err( + RequestDeserializationError::new(err, request).into() + ); + } + }; + match std::convert::TryFrom::try_from(decoded.deref()) { + Ok(val) => val, + Err(err) => { + return Err( + RequestDeserializationError::new(err, request).into() + ); + } + } + } + } + }, + ); + + (format_call, quote! { #(#path_fields,)* }) + } else { + (quote! { metadata.path.to_owned() }, TokenStream::new()) + } +} + +/// The function determines the type of query string that needs to be built +/// and then builds it using `ruma_serde::urlencoded::to_string`. +pub(crate) fn build_query_string(request: &Request) -> TokenStream { + if let Some(field) = request.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have identifier"); + let field_type = &field.ty; + + quote!({ + // This function exists so that the compiler will throw an + // error when the type of the field with the query_map + // attribute doesn't implement IntoIterator + // + // This is necessary because the ruma_serde::urlencoded::to_string + // call will result in a runtime error when the type cannot be + // encoded as a list key-value pairs (?key1=value1&key2=value2) + // + // By asserting that it implements the iterator trait, we can + // ensure that it won't fail. + fn assert_trait_impl() + where + T: std::iter::IntoIterator, + {} + assert_trait_impl::<#field_type>(); + + let request_query = RequestQuery(request.#field_name); + format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?) + }) + } else if request.has_query_fields() { + let request_query_init_fields = request.request_query_init_fields(); + + quote!({ + let request_query = RequestQuery { + #request_query_init_fields + }; + + format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?) + }) + } else { + quote! { + String::new() + } + } +} + +/// Deserialize the query string. +pub(crate) fn extract_request_query(request: &Request) -> TokenStream { + if request.query_map_field().is_some() { + quote! { + let request_query = match ruma_api::exports::ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or("") + ) { + Ok(query) => query, + Err(err) => { + return Err( + ruma_api::error::RequestDeserializationError::new(err, request).into() + ); + } + }; + } + } else if request.has_query_fields() { + quote! { + let request_query: RequestQuery = + match ruma_api::exports::ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or("") + ) { + Ok(query) => query, + Err(err) => { + return Err( + ruma_api::error::RequestDeserializationError::new(err, request) + .into() + ); + } + }; + } + } else { + TokenStream::new() + } +} + +/// Generates the code to initialize a `Request`. +/// +/// Used to construct an `http::Request`s body. +pub(crate) fn build_request_body(request: &Request) -> TokenStream { + if let Some(field) = request.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote!(request.#field_name) + } else if request.has_body_fields() || request.newtype_body_field().is_some() { + let request_body_initializers = if let Some(field) = request.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { (request.#field_name) } + } else { + let initializers = request.request_body_init_fields(); + quote! { { #initializers } } + }; + + quote! { + { + let request_body = RequestBody #request_body_initializers; + ruma_api::exports::serde_json::to_vec(&request_body)? + } + } + } else { + quote!(Vec::new()) + } +} + +pub(crate) fn parse_request_body(request: &Request) -> TokenStream { + if let Some(field) = request.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request_body.0, + } + } else if let Some(field) = request.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request.into_body(), + } + } else { + request.request_init_body_fields() + } +} + +pub(crate) fn req_res_meta_word( + attr_kind: &str, + field: &syn::Field, + newtype_body_field: &mut Option, + body_field_kind: T, + raw_field_kind: T, +) -> syn::Result { + if let Some(f) = &newtype_body_field { + let mut error = syn::Error::new_spanned(field, "There can only be one newtype body field"); + error.combine(syn::Error::new_spanned(f, "Previous newtype body field")); + return Err(error); + } + + *newtype_body_field = Some(field.clone()); + Ok(match attr_kind { + "body" => body_field_kind, + "raw_body" => raw_field_kind, + _ => unreachable!(), + }) +} + +pub(crate) fn req_res_name_value( + name: Ident, + value: Ident, + header: &mut Option, + field_kind: T, +) -> syn::Result { + if name != "header" { + return Err(syn::Error::new_spanned( + name, + "Invalid #[ruma_api] argument with value, expected `header`", + )); + } + + *header = Some(value); + Ok(field_kind) +}