Refactor large blocks of Api::to_tokens into separate functions

This commit is contained in:
Ragotzy.devin 2020-07-01 17:35:18 -04:00 committed by GitHub
parent ff2cbc282b
commit b08b1d1819
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 312 additions and 265 deletions

View File

@ -2,20 +2,21 @@
use std::convert::{TryFrom, TryInto as _}; use std::convert::{TryFrom, TryInto as _};
use proc_macro2::{Span, TokenStream}; use proc_macro2::TokenStream;
use quote::{quote, ToTokens}; use quote::{quote, ToTokens};
use syn::{ use syn::{
braced, braced,
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
Field, FieldValue, Ident, Token, Type, Field, FieldValue, Token, Type,
}; };
mod attribute; pub(crate) mod attribute;
mod metadata; pub(crate) mod metadata;
mod request; pub(crate) mod request;
mod response; pub(crate) mod response;
use self::{metadata::Metadata, request::Request, response::Response}; use self::{metadata::Metadata, request::Request, response::Response};
use crate::util;
/// Removes `serde` attributes from struct fields. /// Removes `serde` attributes from struct fields.
pub fn strip_serde_attrs(field: &Field) -> Field { pub fn strip_serde_attrs(field: &Field) -> Field {
@ -121,159 +122,12 @@ impl ToTokens for Api {
TokenStream::new() TokenStream::new()
}; };
let (request_path_string, parse_request_path) = if self.request.has_path_fields() { let (request_path_string, parse_request_path) =
let path_string = path.value(); util::request_path_string_and_parse(&self.request, &self.metadata);
assert!(path_string.starts_with('/'), "path needs to start with '/'"); let request_query_string = util::build_query_string(&self.request);
assert!(
path_string.chars().filter(|c| *c == ':').count()
== self.request.path_field_count(),
"number of declared path parameters needs to match amount of placeholders in path"
);
let format_call = { let extract_request_query = util::extract_request_query(&self.request);
let mut format_string = path_string.clone();
let mut format_args = Vec::new();
while let Some(start_of_segment) = format_string.find(':') {
// ':' should only ever appear at the start of a segment
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
let end_of_segment = match format_string[start_of_segment..].find('/') {
Some(rel_pos) => start_of_segment + rel_pos,
None => format_string.len(),
};
let path_var = Ident::new(
&format_string[start_of_segment + 1..end_of_segment],
Span::call_site(),
);
format_args.push(quote! {
ruma_api::exports::percent_encoding::utf8_percent_encode(
&request.#path_var.to_string(),
ruma_api::exports::percent_encoding::NON_ALPHANUMERIC,
)
});
format_string.replace_range(start_of_segment..end_of_segment, "{}");
}
quote! {
format!(#format_string, #(#format_args),*)
}
};
let path_fields = path_string[1..]
.split('/')
.enumerate()
.filter(|(_, s)| s.starts_with(':'))
.map(|(i, segment)| {
let path_var = &segment[1..];
let path_var_ident = Ident::new(path_var, Span::call_site());
quote! {
#path_var_ident: {
use std::ops::Deref as _;
use ruma_api::error::RequestDeserializationError;
let segment = path_segments.get(#i).unwrap().as_bytes();
let decoded = match ruma_api::exports::percent_encoding::percent_decode(
segment
).decode_utf8() {
Ok(x) => x,
Err(err) => {
return Err(
RequestDeserializationError::new(err, request).into()
);
}
};
match std::convert::TryFrom::try_from(decoded.deref()) {
Ok(val) => val,
Err(err) => {
return Err(
RequestDeserializationError::new(err, request).into()
);
}
}
}
}
});
(format_call, quote! { #(#path_fields,)* })
} else {
(quote! { metadata.path.to_owned() }, TokenStream::new())
};
let request_query_string = if let Some(field) = self.request.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have identifier");
let field_type = &field.ty;
quote!({
// This function exists so that the compiler will throw an
// error when the type of the field with the query_map
// attribute doesn't implement IntoIterator<Item = (String, String)>
//
// This is necessary because the ruma_serde::urlencoded::to_string
// call will result in a runtime error when the type cannot be
// encoded as a list key-value pairs (?key1=value1&key2=value2)
//
// By asserting that it implements the iterator trait, we can
// ensure that it won't fail.
fn assert_trait_impl<T>()
where
T: std::iter::IntoIterator<Item = (std::string::String, std::string::String)>,
{}
assert_trait_impl::<#field_type>();
let request_query = RequestQuery(request.#field_name);
format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?)
})
} else if self.request.has_query_fields() {
let request_query_init_fields = self.request.request_query_init_fields();
quote!({
let request_query = RequestQuery {
#request_query_init_fields
};
format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?)
})
} else {
quote! {
String::new()
}
};
let extract_request_query = if self.request.query_map_field().is_some() {
quote! {
let request_query = match ruma_api::exports::ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or("")
) {
Ok(query) => query,
Err(err) => {
return Err(
ruma_api::error::RequestDeserializationError::new(err, request).into()
);
}
};
}
} else if self.request.has_query_fields() {
quote! {
let request_query: RequestQuery =
match ruma_api::exports::ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or("")
) {
Ok(query) => query,
Err(err) => {
return Err(
ruma_api::error::RequestDeserializationError::new(err, request)
.into()
);
}
};
}
} else {
TokenStream::new()
};
let parse_request_query = if let Some(field) = self.request.query_map_field() { let parse_request_query = if let Some(field) = self.request.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
@ -327,42 +181,9 @@ impl ToTokens for Api {
TokenStream::new() TokenStream::new()
}; };
let request_body = if let Some(field) = self.request.newtype_raw_body_field() { let request_body = util::build_request_body(&self.request);
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote!(request.#field_name)
} else if self.request.has_body_fields() || self.request.newtype_body_field().is_some() {
let request_body_initializers = if let Some(field) = self.request.newtype_body_field() {
let field_name =
field.ident.as_ref().expect("expected field to have an identifier");
quote! { (request.#field_name) }
} else {
let initializers = self.request.request_body_init_fields();
quote! { { #initializers } }
};
quote! { let parse_request_body = util::parse_request_body(&self.request);
{
let request_body = RequestBody #request_body_initializers;
ruma_api::exports::serde_json::to_vec(&request_body)?
}
}
} else {
quote!(Vec::new())
};
let parse_request_body = if let Some(field) = self.request.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_body.0,
}
} else if let Some(field) = self.request.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request.into_body(),
}
} else {
self.request.request_init_body_fields()
};
let extract_response_headers = if self.response.has_header_fields() { let extract_response_headers = if self.response.has_header_fields() {
quote! { quote! {

View File

@ -24,10 +24,6 @@ pub enum Meta {
impl Meta { impl Meta {
/// Check if the given attribute is a ruma_api attribute. If it is, parse it. /// Check if the given attribute is a ruma_api attribute. If it is, parse it.
///
/// # Panics
///
/// Panics if the given attribute is a ruma_api attribute, but fails to parse.
pub fn from_attribute(attr: &syn::Attribute) -> syn::Result<Option<Self>> { pub fn from_attribute(attr: &syn::Attribute) -> syn::Result<Option<Self>> {
if attr.path.is_ident("ruma_api") { if attr.path.is_ident("ruma_api") {
attr.parse_args().map(Some) attr.parse_args().map(Some)

View File

@ -6,9 +6,12 @@ use proc_macro2::TokenStream;
use quote::{quote, quote_spanned, ToTokens}; use quote::{quote, quote_spanned, ToTokens};
use syn::{spanned::Spanned, Field, Ident}; use syn::{spanned::Spanned, Field, Ident};
use crate::api::{ use crate::{
api::{
attribute::{Meta, MetaNameValue}, attribute::{Meta, MetaNameValue},
strip_serde_attrs, RawRequest, strip_serde_attrs, RawRequest,
},
util,
}; };
/// The result of processing the `request` section of the macro. /// The result of processing the `request` section of the macro.
@ -209,26 +212,13 @@ impl TryFrom<RawRequest> for Request {
field_kind = Some(match meta { field_kind = Some(match meta {
Meta::Word(ident) => { Meta::Word(ident) => {
match &ident.to_string()[..] { match &ident.to_string()[..] {
s @ "body" | s @ "raw_body" => { attr @ "body" | attr @ "raw_body" => util::req_res_meta_word(
if let Some(f) = &newtype_body_field { attr,
let mut error = syn::Error::new_spanned( &field,
field, &mut newtype_body_field,
"There can only be one newtype body field", RequestFieldKind::NewtypeBody,
); RequestFieldKind::NewtypeRawBody,
error.combine(syn::Error::new_spanned( )?,
f,
"Previous newtype body field",
));
return Err(error);
}
newtype_body_field = Some(field.clone());
match s {
"body" => RequestFieldKind::NewtypeBody,
"raw_body" => RequestFieldKind::NewtypeRawBody,
_ => unreachable!(),
}
}
"path" => RequestFieldKind::Path, "path" => RequestFieldKind::Path,
"query" => RequestFieldKind::Query, "query" => RequestFieldKind::Query,
"query_map" => { "query_map" => {
@ -255,17 +245,12 @@ impl TryFrom<RawRequest> for Request {
} }
} }
} }
Meta::NameValue(MetaNameValue { name, value }) => { Meta::NameValue(MetaNameValue { name, value }) => util::req_res_name_value(
if name != "header" {
return Err(syn::Error::new_spanned(
name, name,
"Invalid #[ruma_api] argument with value, expected `header`" value,
)); &mut header,
} RequestFieldKind::Header,
)?,
header = Some(value);
RequestFieldKind::Header
}
}); });
} }

View File

@ -6,9 +6,12 @@ use proc_macro2::TokenStream;
use quote::{quote, quote_spanned, ToTokens}; use quote::{quote, quote_spanned, ToTokens};
use syn::{spanned::Spanned, Field, Ident}; use syn::{spanned::Spanned, Field, Ident};
use crate::api::{ use crate::{
api::{
attribute::{Meta, MetaNameValue}, attribute::{Meta, MetaNameValue},
strip_serde_attrs, RawResponse, strip_serde_attrs, RawResponse,
},
util,
}; };
/// The result of processing the `response` section of the macro. /// The result of processing the `response` section of the macro.
@ -169,26 +172,13 @@ impl TryFrom<RawResponse> for Response {
field_kind = Some(match meta { field_kind = Some(match meta {
Meta::Word(ident) => match &ident.to_string()[..] { Meta::Word(ident) => match &ident.to_string()[..] {
s @ "body" | s @ "raw_body" => { s @ "body" | s @ "raw_body" => util::req_res_meta_word(
if let Some(f) = &newtype_body_field { s,
let mut error = syn::Error::new_spanned( &field,
field, &mut newtype_body_field,
"There can only be one newtype body field", ResponseFieldKind::NewtypeBody,
); ResponseFieldKind::NewtypeRawBody,
error.combine(syn::Error::new_spanned( )?,
f,
"Previous newtype body field",
));
return Err(error);
}
newtype_body_field = Some(field.clone());
match s {
"body" => ResponseFieldKind::NewtypeBody,
"raw_body" => ResponseFieldKind::NewtypeRawBody,
_ => unreachable!(),
}
}
_ => { _ => {
return Err(syn::Error::new_spanned( return Err(syn::Error::new_spanned(
ident, ident,
@ -196,17 +186,12 @@ impl TryFrom<RawResponse> for Response {
)); ));
} }
}, },
Meta::NameValue(MetaNameValue { name, value }) => { Meta::NameValue(MetaNameValue { name, value }) => util::req_res_name_value(
if name != "header" {
return Err(syn::Error::new_spanned(
name, name,
"Invalid #[ruma_api] argument with value, expected `header`", value,
)); &mut header,
} ResponseFieldKind::Header,
)?,
header = Some(value);
ResponseFieldKind::Header
}
}); });
} }
@ -340,7 +325,7 @@ impl ResponseField {
} }
} }
/// Whether or not the reponse field has a #[wrap_incoming] attribute. /// Whether or not the response field has a #[wrap_incoming] attribute.
fn has_wrap_incoming_attr(&self) -> bool { fn has_wrap_incoming_attr(&self) -> bool {
self.field().attrs.iter().any(|attr| { self.field().attrs.iter().any(|attr| {
attr.path.segments.len() == 1 && attr.path.segments[0].ident == "wrap_incoming" attr.path.segments.len() == 1 && attr.path.segments[0].ident == "wrap_incoming"

View File

@ -22,6 +22,7 @@ use syn::parse_macro_input;
use self::api::{Api, RawApi}; use self::api::{Api, RawApi};
mod api; mod api;
mod util;
#[proc_macro] #[proc_macro]
pub fn ruma_api(input: TokenStream) -> TokenStream { pub fn ruma_api(input: TokenStream) -> TokenStream {

259
ruma-api-macros/src/util.rs Normal file
View File

@ -0,0 +1,259 @@
//! Functions to aid the `Api::to_tokens` method.
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::Ident;
use crate::api::{metadata::Metadata, request::Request};
/// The first item in the tuple generates code for the request path from
/// the `Metadata` and `Request` structs. The second item in the returned tuple
/// is the code to generate a Request struct field created from any segments
/// of the path that start with ":".
///
/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is
/// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to ruma types.
pub(crate) fn request_path_string_and_parse(
request: &Request,
metadata: &Metadata,
) -> (TokenStream, TokenStream) {
if request.has_path_fields() {
let path_string = metadata.path.value();
assert!(path_string.starts_with('/'), "path needs to start with '/'");
assert!(
path_string.chars().filter(|c| *c == ':').count() == request.path_field_count(),
"number of declared path parameters needs to match amount of placeholders in path"
);
let format_call = {
let mut format_string = path_string.clone();
let mut format_args = Vec::new();
while let Some(start_of_segment) = format_string.find(':') {
// ':' should only ever appear at the start of a segment
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
let end_of_segment = match format_string[start_of_segment..].find('/') {
Some(rel_pos) => start_of_segment + rel_pos,
None => format_string.len(),
};
let path_var = Ident::new(
&format_string[start_of_segment + 1..end_of_segment],
Span::call_site(),
);
format_args.push(quote! {
ruma_api::exports::percent_encoding::utf8_percent_encode(
&request.#path_var.to_string(),
ruma_api::exports::percent_encoding::NON_ALPHANUMERIC,
)
});
format_string.replace_range(start_of_segment..end_of_segment, "{}");
}
quote! {
format!(#format_string, #(#format_args),*)
}
};
let path_fields =
path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map(
|(i, segment)| {
let path_var = &segment[1..];
let path_var_ident = Ident::new(path_var, Span::call_site());
quote! {
#path_var_ident: {
use std::ops::Deref as _;
use ruma_api::error::RequestDeserializationError;
let segment = path_segments.get(#i).unwrap().as_bytes();
let decoded = match ruma_api::exports::percent_encoding::percent_decode(
segment
).decode_utf8() {
Ok(x) => x,
Err(err) => {
return Err(
RequestDeserializationError::new(err, request).into()
);
}
};
match std::convert::TryFrom::try_from(decoded.deref()) {
Ok(val) => val,
Err(err) => {
return Err(
RequestDeserializationError::new(err, request).into()
);
}
}
}
}
},
);
(format_call, quote! { #(#path_fields,)* })
} else {
(quote! { metadata.path.to_owned() }, TokenStream::new())
}
}
/// The function determines the type of query string that needs to be built
/// and then builds it using `ruma_serde::urlencoded::to_string`.
pub(crate) fn build_query_string(request: &Request) -> TokenStream {
if let Some(field) = request.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have identifier");
let field_type = &field.ty;
quote!({
// This function exists so that the compiler will throw an
// error when the type of the field with the query_map
// attribute doesn't implement IntoIterator<Item = (String, String)>
//
// This is necessary because the ruma_serde::urlencoded::to_string
// call will result in a runtime error when the type cannot be
// encoded as a list key-value pairs (?key1=value1&key2=value2)
//
// By asserting that it implements the iterator trait, we can
// ensure that it won't fail.
fn assert_trait_impl<T>()
where
T: std::iter::IntoIterator<Item = (std::string::String, std::string::String)>,
{}
assert_trait_impl::<#field_type>();
let request_query = RequestQuery(request.#field_name);
format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?)
})
} else if request.has_query_fields() {
let request_query_init_fields = request.request_query_init_fields();
quote!({
let request_query = RequestQuery {
#request_query_init_fields
};
format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?)
})
} else {
quote! {
String::new()
}
}
}
/// Deserialize the query string.
pub(crate) fn extract_request_query(request: &Request) -> TokenStream {
if request.query_map_field().is_some() {
quote! {
let request_query = match ruma_api::exports::ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or("")
) {
Ok(query) => query,
Err(err) => {
return Err(
ruma_api::error::RequestDeserializationError::new(err, request).into()
);
}
};
}
} else if request.has_query_fields() {
quote! {
let request_query: RequestQuery =
match ruma_api::exports::ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or("")
) {
Ok(query) => query,
Err(err) => {
return Err(
ruma_api::error::RequestDeserializationError::new(err, request)
.into()
);
}
};
}
} else {
TokenStream::new()
}
}
/// Generates the code to initialize a `Request`.
///
/// Used to construct an `http::Request`s body.
pub(crate) fn build_request_body(request: &Request) -> TokenStream {
if let Some(field) = request.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote!(request.#field_name)
} else if request.has_body_fields() || request.newtype_body_field().is_some() {
let request_body_initializers = if let Some(field) = request.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { (request.#field_name) }
} else {
let initializers = request.request_body_init_fields();
quote! { { #initializers } }
};
quote! {
{
let request_body = RequestBody #request_body_initializers;
ruma_api::exports::serde_json::to_vec(&request_body)?
}
}
} else {
quote!(Vec::new())
}
}
pub(crate) fn parse_request_body(request: &Request) -> TokenStream {
if let Some(field) = request.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_body.0,
}
} else if let Some(field) = request.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request.into_body(),
}
} else {
request.request_init_body_fields()
}
}
pub(crate) fn req_res_meta_word<T>(
attr_kind: &str,
field: &syn::Field,
newtype_body_field: &mut Option<syn::Field>,
body_field_kind: T,
raw_field_kind: T,
) -> syn::Result<T> {
if let Some(f) = &newtype_body_field {
let mut error = syn::Error::new_spanned(field, "There can only be one newtype body field");
error.combine(syn::Error::new_spanned(f, "Previous newtype body field"));
return Err(error);
}
*newtype_body_field = Some(field.clone());
Ok(match attr_kind {
"body" => body_field_kind,
"raw_body" => raw_field_kind,
_ => unreachable!(),
})
}
pub(crate) fn req_res_name_value<T>(
name: Ident,
value: Ident,
header: &mut Option<Ident>,
field_kind: T,
) -> syn::Result<T> {
if name != "header" {
return Err(syn::Error::new_spanned(
name,
"Invalid #[ruma_api] argument with value, expected `header`",
));
}
*header = Some(value);
Ok(field_kind)
}