api-macros: Turn request codegen helper functions into methods
This commit is contained in:
parent
06a2a27a99
commit
455eb31c74
@ -17,7 +17,7 @@ edition = "2018"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
proc-macro2 = "1.0.24"
|
proc-macro2 = "1.0.24"
|
||||||
quote = "1.0.8"
|
quote = "1.0.8"
|
||||||
syn = { version = "1.0.57", features = ["full", "extra-traits"] }
|
syn = { version = "1.0.57", features = ["full", "extra-traits", "visit"] }
|
||||||
proc-macro-crate = "1.0.0"
|
proc-macro-crate = "1.0.0"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
|
@ -324,10 +324,10 @@ impl Request {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let (request_path_string, parse_request_path) =
|
let (request_path_string, parse_request_path) =
|
||||||
path_string_and_parse(self, metadata, &ruma_api);
|
self.path_string_and_parse(metadata, &ruma_api);
|
||||||
|
|
||||||
let request_query_string = build_query_string(self, &ruma_api);
|
let request_query_string = self.build_query_string(&ruma_api);
|
||||||
let extract_request_query = extract_request_query(self, &ruma_api);
|
let extract_request_query = self.extract_request_query(&ruma_api);
|
||||||
|
|
||||||
let parse_request_query = if let Some(field) = self.query_map_field() {
|
let parse_request_query = if let Some(field) = self.query_map_field() {
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
@ -403,8 +403,8 @@ impl Request {
|
|||||||
TokenStream::new()
|
TokenStream::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
let request_body = build_request_body(self, &ruma_api);
|
let request_body = self.build_request_body(&ruma_api);
|
||||||
let parse_request_body = parse_request_body(self);
|
let parse_request_body = self.parse_request_body();
|
||||||
|
|
||||||
let request_generics = self.combine_lifetimes();
|
let request_generics = self.combine_lifetimes();
|
||||||
|
|
||||||
@ -552,7 +552,10 @@ impl Request {
|
|||||||
#request_path_string,
|
#request_path_string,
|
||||||
#request_query_string,
|
#request_query_string,
|
||||||
))
|
))
|
||||||
.header(#ruma_api::exports::http::header::CONTENT_TYPE, "application/json");
|
.header(
|
||||||
|
#ruma_api::exports::http::header::CONTENT_TYPE,
|
||||||
|
"application/json",
|
||||||
|
);
|
||||||
|
|
||||||
let mut req_headers = req_builder
|
let mut req_headers = req_builder
|
||||||
.headers_mut()
|
.headers_mut()
|
||||||
@ -583,6 +586,7 @@ impl Request {
|
|||||||
received: request.method().clone(),
|
received: request.method().clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#extract_request_path
|
#extract_request_path
|
||||||
#extract_request_query
|
#extract_request_query
|
||||||
#extract_request_headers
|
#extract_request_headers
|
||||||
@ -598,6 +602,216 @@ impl Request {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Deserialize the query string.
|
||||||
|
fn extract_request_query(&self, ruma_api: &TokenStream) -> TokenStream {
|
||||||
|
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||||
|
|
||||||
|
if self.query_map_field().is_some() {
|
||||||
|
quote! {
|
||||||
|
let request_query = #ruma_api::try_deserialize!(
|
||||||
|
request,
|
||||||
|
#ruma_serde::urlencoded::from_str(
|
||||||
|
&request.uri().query().unwrap_or("")
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else if self.has_query_fields() {
|
||||||
|
quote! {
|
||||||
|
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
|
||||||
|
#ruma_api::try_deserialize!(
|
||||||
|
request,
|
||||||
|
#ruma_serde::urlencoded::from_str(
|
||||||
|
&request.uri().query().unwrap_or("")
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TokenStream::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates the code to initialize a `Request`.
|
||||||
|
///
|
||||||
|
/// Used to construct an `http::Request`s body.
|
||||||
|
fn build_request_body(&self, ruma_api: &TokenStream) -> TokenStream {
|
||||||
|
let serde_json = quote! { #ruma_api::exports::serde_json };
|
||||||
|
|
||||||
|
if let Some(field) = self.newtype_raw_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote!(self.#field_name)
|
||||||
|
} else if self.has_body_fields() || self.newtype_body_field().is_some() {
|
||||||
|
let request_body_initializers = if let Some(field) = self.newtype_body_field() {
|
||||||
|
let field_name =
|
||||||
|
field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! { (self.#field_name) }
|
||||||
|
} else {
|
||||||
|
let initializers = self.request_body_init_fields();
|
||||||
|
quote! { { #initializers } }
|
||||||
|
};
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
{
|
||||||
|
let request_body = RequestBody #request_body_initializers;
|
||||||
|
#serde_json::to_vec(&request_body)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote!(Vec::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_request_body(&self) -> TokenStream {
|
||||||
|
if let Some(field) = self.newtype_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! {
|
||||||
|
#field_name: request_body.0,
|
||||||
|
}
|
||||||
|
} else if let Some(field) = self.newtype_raw_body_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||||
|
quote! {
|
||||||
|
#field_name: request.into_body(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.request_init_body_fields()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The function determines the type of query string that needs to be built
|
||||||
|
/// and then builds it using `ruma_serde::urlencoded::to_string`.
|
||||||
|
fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream {
|
||||||
|
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||||
|
|
||||||
|
if let Some(field) = self.query_map_field() {
|
||||||
|
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
||||||
|
|
||||||
|
quote!({
|
||||||
|
// This function exists so that the compiler will throw an
|
||||||
|
// error when the type of the field with the query_map
|
||||||
|
// attribute doesn't implement IntoIterator<Item = (String, String)>
|
||||||
|
//
|
||||||
|
// This is necessary because the ruma_serde::urlencoded::to_string
|
||||||
|
// call will result in a runtime error when the type cannot be
|
||||||
|
// encoded as a list key-value pairs (?key1=value1&key2=value2)
|
||||||
|
//
|
||||||
|
// By asserting that it implements the iterator trait, we can
|
||||||
|
// ensure that it won't fail.
|
||||||
|
fn assert_trait_impl<T>(_: &T)
|
||||||
|
where
|
||||||
|
T: ::std::iter::IntoIterator<
|
||||||
|
Item = (::std::string::String, ::std::string::String)
|
||||||
|
>,
|
||||||
|
{}
|
||||||
|
|
||||||
|
let request_query = RequestQuery(self.#field_name);
|
||||||
|
assert_trait_impl(&request_query.0);
|
||||||
|
|
||||||
|
format_args!(
|
||||||
|
"?{}",
|
||||||
|
#ruma_serde::urlencoded::to_string(request_query)?
|
||||||
|
)
|
||||||
|
})
|
||||||
|
} else if self.has_query_fields() {
|
||||||
|
let request_query_init_fields = self.request_query_init_fields();
|
||||||
|
|
||||||
|
quote!({
|
||||||
|
let request_query = RequestQuery {
|
||||||
|
#request_query_init_fields
|
||||||
|
};
|
||||||
|
|
||||||
|
format_args!(
|
||||||
|
"?{}",
|
||||||
|
#ruma_serde::urlencoded::to_string(request_query)?
|
||||||
|
)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
quote! { "" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The first item in the tuple generates code for the request path from
|
||||||
|
/// the `Metadata` and `Request` structs. The second item in the returned tuple
|
||||||
|
/// is the code to generate a Request struct field created from any segments
|
||||||
|
/// of the path that start with ":".
|
||||||
|
///
|
||||||
|
/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is
|
||||||
|
/// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to Ruma
|
||||||
|
/// types.
|
||||||
|
pub(crate) fn path_string_and_parse(
|
||||||
|
&self,
|
||||||
|
metadata: &Metadata,
|
||||||
|
ruma_api: &TokenStream,
|
||||||
|
) -> (TokenStream, TokenStream) {
|
||||||
|
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
||||||
|
|
||||||
|
if self.has_path_fields() {
|
||||||
|
let path_string = metadata.path.value();
|
||||||
|
|
||||||
|
assert!(path_string.starts_with('/'), "path needs to start with '/'");
|
||||||
|
assert!(
|
||||||
|
path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(),
|
||||||
|
"number of declared path parameters needs to match amount of placeholders in path"
|
||||||
|
);
|
||||||
|
|
||||||
|
let format_call = {
|
||||||
|
let mut format_string = path_string.clone();
|
||||||
|
let mut format_args = Vec::new();
|
||||||
|
|
||||||
|
while let Some(start_of_segment) = format_string.find(':') {
|
||||||
|
// ':' should only ever appear at the start of a segment
|
||||||
|
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
|
||||||
|
|
||||||
|
let end_of_segment = match format_string[start_of_segment..].find('/') {
|
||||||
|
Some(rel_pos) => start_of_segment + rel_pos,
|
||||||
|
None => format_string.len(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let path_var = Ident::new(
|
||||||
|
&format_string[start_of_segment + 1..end_of_segment],
|
||||||
|
Span::call_site(),
|
||||||
|
);
|
||||||
|
format_args.push(quote! {
|
||||||
|
#percent_encoding::utf8_percent_encode(
|
||||||
|
&self.#path_var.to_string(),
|
||||||
|
#percent_encoding::NON_ALPHANUMERIC,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
format_string.replace_range(start_of_segment..end_of_segment, "{}");
|
||||||
|
}
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
format_args!(#format_string, #(#format_args),*)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let path_fields =
|
||||||
|
path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map(
|
||||||
|
|(i, segment)| {
|
||||||
|
let path_var = &segment[1..];
|
||||||
|
let path_var_ident = Ident::new(path_var, Span::call_site());
|
||||||
|
quote! {
|
||||||
|
#path_var_ident: {
|
||||||
|
let segment = path_segments[#i].as_bytes();
|
||||||
|
let decoded = #ruma_api::try_deserialize!(
|
||||||
|
request,
|
||||||
|
#percent_encoding::percent_decode(segment)
|
||||||
|
.decode_utf8(),
|
||||||
|
);
|
||||||
|
|
||||||
|
#ruma_api::try_deserialize!(
|
||||||
|
request,
|
||||||
|
::std::convert::TryFrom::try_from(&*decoded),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
(format_call, quote! { #(#path_fields,)* })
|
||||||
|
} else {
|
||||||
|
(quote! { metadata.path.to_owned() }, TokenStream::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The types of fields that a request can have.
|
/// The types of fields that a request can have.
|
||||||
@ -750,210 +964,3 @@ pub(crate) enum RequestFieldKind {
|
|||||||
/// See the similarly named variant of `RequestField`.
|
/// See the similarly named variant of `RequestField`.
|
||||||
QueryMap,
|
QueryMap,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The first item in the tuple generates code for the request path from
|
|
||||||
/// the `Metadata` and `Request` structs. The second item in the returned tuple
|
|
||||||
/// is the code to generate a Request struct field created from any segments
|
|
||||||
/// of the path that start with ":".
|
|
||||||
///
|
|
||||||
/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is
|
|
||||||
/// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to Ruma
|
|
||||||
/// types.
|
|
||||||
pub(crate) fn path_string_and_parse(
|
|
||||||
request: &Request,
|
|
||||||
metadata: &Metadata,
|
|
||||||
ruma_api: &TokenStream,
|
|
||||||
) -> (TokenStream, TokenStream) {
|
|
||||||
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
|
||||||
|
|
||||||
if request.has_path_fields() {
|
|
||||||
let path_string = metadata.path.value();
|
|
||||||
|
|
||||||
assert!(path_string.starts_with('/'), "path needs to start with '/'");
|
|
||||||
assert!(
|
|
||||||
path_string.chars().filter(|c| *c == ':').count() == request.path_field_count(),
|
|
||||||
"number of declared path parameters needs to match amount of placeholders in path"
|
|
||||||
);
|
|
||||||
|
|
||||||
let format_call = {
|
|
||||||
let mut format_string = path_string.clone();
|
|
||||||
let mut format_args = Vec::new();
|
|
||||||
|
|
||||||
while let Some(start_of_segment) = format_string.find(':') {
|
|
||||||
// ':' should only ever appear at the start of a segment
|
|
||||||
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
|
|
||||||
|
|
||||||
let end_of_segment = match format_string[start_of_segment..].find('/') {
|
|
||||||
Some(rel_pos) => start_of_segment + rel_pos,
|
|
||||||
None => format_string.len(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let path_var = Ident::new(
|
|
||||||
&format_string[start_of_segment + 1..end_of_segment],
|
|
||||||
Span::call_site(),
|
|
||||||
);
|
|
||||||
format_args.push(quote! {
|
|
||||||
#percent_encoding::utf8_percent_encode(
|
|
||||||
&self.#path_var.to_string(),
|
|
||||||
#percent_encoding::NON_ALPHANUMERIC,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
format_string.replace_range(start_of_segment..end_of_segment, "{}");
|
|
||||||
}
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
format_args!(#format_string, #(#format_args),*)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let path_fields =
|
|
||||||
path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map(
|
|
||||||
|(i, segment)| {
|
|
||||||
let path_var = &segment[1..];
|
|
||||||
let path_var_ident = Ident::new(path_var, Span::call_site());
|
|
||||||
quote! {
|
|
||||||
#path_var_ident: {
|
|
||||||
let segment = path_segments[#i].as_bytes();
|
|
||||||
let decoded = #ruma_api::try_deserialize!(
|
|
||||||
request,
|
|
||||||
#percent_encoding::percent_decode(segment)
|
|
||||||
.decode_utf8(),
|
|
||||||
);
|
|
||||||
|
|
||||||
#ruma_api::try_deserialize!(
|
|
||||||
request,
|
|
||||||
::std::convert::TryFrom::try_from(&*decoded),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
(format_call, quote! { #(#path_fields,)* })
|
|
||||||
} else {
|
|
||||||
(quote! { metadata.path.to_owned() }, TokenStream::new())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The function determines the type of query string that needs to be built
|
|
||||||
/// and then builds it using `ruma_serde::urlencoded::to_string`.
|
|
||||||
fn build_query_string(request: &Request, ruma_api: &TokenStream) -> TokenStream {
|
|
||||||
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
|
||||||
|
|
||||||
if let Some(field) = request.query_map_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
|
||||||
|
|
||||||
quote!({
|
|
||||||
// This function exists so that the compiler will throw an
|
|
||||||
// error when the type of the field with the query_map
|
|
||||||
// attribute doesn't implement IntoIterator<Item = (String, String)>
|
|
||||||
//
|
|
||||||
// This is necessary because the ruma_serde::urlencoded::to_string
|
|
||||||
// call will result in a runtime error when the type cannot be
|
|
||||||
// encoded as a list key-value pairs (?key1=value1&key2=value2)
|
|
||||||
//
|
|
||||||
// By asserting that it implements the iterator trait, we can
|
|
||||||
// ensure that it won't fail.
|
|
||||||
fn assert_trait_impl<T>(_: &T)
|
|
||||||
where
|
|
||||||
T: ::std::iter::IntoIterator<Item = (::std::string::String, ::std::string::String)>,
|
|
||||||
{}
|
|
||||||
|
|
||||||
let request_query = RequestQuery(self.#field_name);
|
|
||||||
assert_trait_impl(&request_query.0);
|
|
||||||
|
|
||||||
format_args!(
|
|
||||||
"?{}",
|
|
||||||
#ruma_serde::urlencoded::to_string(request_query)?
|
|
||||||
)
|
|
||||||
})
|
|
||||||
} else if request.has_query_fields() {
|
|
||||||
let request_query_init_fields = request.request_query_init_fields();
|
|
||||||
|
|
||||||
quote!({
|
|
||||||
let request_query = RequestQuery {
|
|
||||||
#request_query_init_fields
|
|
||||||
};
|
|
||||||
|
|
||||||
format_args!(
|
|
||||||
"?{}",
|
|
||||||
#ruma_serde::urlencoded::to_string(request_query)?
|
|
||||||
)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
quote! { "" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Deserialize the query string.
|
|
||||||
fn extract_request_query(request: &Request, ruma_api: &TokenStream) -> TokenStream {
|
|
||||||
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
|
||||||
|
|
||||||
if request.query_map_field().is_some() {
|
|
||||||
quote! {
|
|
||||||
let request_query = #ruma_api::try_deserialize!(
|
|
||||||
request,
|
|
||||||
#ruma_serde::urlencoded::from_str(
|
|
||||||
&request.uri().query().unwrap_or("")
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else if request.has_query_fields() {
|
|
||||||
quote! {
|
|
||||||
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
|
|
||||||
#ruma_api::try_deserialize!(
|
|
||||||
request,
|
|
||||||
#ruma_serde::urlencoded::from_str(
|
|
||||||
&request.uri().query().unwrap_or("")
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TokenStream::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates the code to initialize a `Request`.
|
|
||||||
///
|
|
||||||
/// Used to construct an `http::Request`s body.
|
|
||||||
fn build_request_body(request: &Request, ruma_api: &TokenStream) -> TokenStream {
|
|
||||||
let serde_json = quote! { #ruma_api::exports::serde_json };
|
|
||||||
|
|
||||||
if let Some(field) = request.newtype_raw_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote!(self.#field_name)
|
|
||||||
} else if request.has_body_fields() || request.newtype_body_field().is_some() {
|
|
||||||
let request_body_initializers = if let Some(field) = request.newtype_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! { (self.#field_name) }
|
|
||||||
} else {
|
|
||||||
let initializers = request.request_body_init_fields();
|
|
||||||
quote! { { #initializers } }
|
|
||||||
};
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
{
|
|
||||||
let request_body = RequestBody #request_body_initializers;
|
|
||||||
#serde_json::to_vec(&request_body)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
quote!(Vec::new())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_request_body(request: &Request) -> TokenStream {
|
|
||||||
if let Some(field) = request.newtype_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! {
|
|
||||||
#field_name: request_body.0,
|
|
||||||
}
|
|
||||||
} else if let Some(field) = request.newtype_raw_body_field() {
|
|
||||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
||||||
quote! {
|
|
||||||
#field_name: request.into_body(),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
request.request_init_body_fields()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user