api-macros: Refactor request code generation

This commit is contained in:
Jonas Platte 2021-04-10 16:47:34 +02:00
parent a20f03894e
commit 23ba0bc164
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
3 changed files with 188 additions and 163 deletions

View File

@ -32,47 +32,47 @@ pub(crate) struct Request {
impl Request {
/// Whether or not this request has any data in the HTTP body.
pub fn has_body_fields(&self) -> bool {
pub(super) fn has_body_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_body())
}
/// Whether or not this request has any data in HTTP headers.
pub fn has_header_fields(&self) -> bool {
fn has_header_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_header())
}
/// Whether or not this request has any data in the URL path.
pub fn has_path_fields(&self) -> bool {
fn has_path_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_path())
}
/// Whether or not this request has any data in the query string.
pub fn has_query_fields(&self) -> bool {
fn has_query_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_query())
}
/// Produces an iterator over all the body fields.
pub fn body_fields(&self) -> impl Iterator<Item = &Field> {
pub(super) fn body_fields(&self) -> impl Iterator<Item = &Field> {
self.fields.iter().filter_map(|field| field.as_body_field())
}
/// The number of unique lifetime annotations for `body` fields.
pub fn body_lifetime_count(&self) -> usize {
fn body_lifetime_count(&self) -> usize {
self.lifetimes.body.len()
}
/// Whether any `body` field has a lifetime annotation.
pub fn has_body_lifetimes(&self) -> bool {
fn has_body_lifetimes(&self) -> bool {
!self.lifetimes.body.is_empty()
}
/// Whether any `query` field has a lifetime annotation.
pub fn has_query_lifetimes(&self) -> bool {
fn has_query_lifetimes(&self) -> bool {
!self.lifetimes.query.is_empty()
}
/// Whether any field has a lifetime.
pub fn contains_lifetimes(&self) -> bool {
fn contains_lifetimes(&self) -> bool {
!(self.lifetimes.body.is_empty()
&& self.lifetimes.path.is_empty()
&& self.lifetimes.query.is_empty()
@ -80,7 +80,7 @@ impl Request {
}
/// The combination of every fields unique lifetime annotation.
pub fn combine_lifetimes(&self) -> TokenStream {
fn combine_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens(
[
&self.lifetimes.body,
@ -94,22 +94,22 @@ impl Request {
}
/// The lifetimes on fields with the `query` attribute.
pub fn query_lifetimes(&self) -> TokenStream {
fn query_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens(&self.lifetimes.query)
}
/// The lifetimes on fields with the `body` attribute.
pub fn body_lifetimes(&self) -> TokenStream {
fn body_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens(&self.lifetimes.body)
}
/// Produces an iterator over all the header fields.
pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> {
fn header_fields(&self) -> impl Iterator<Item = &RequestField> {
self.fields.iter().filter(|field| field.is_header())
}
/// Gets the number of path fields.
pub fn path_field_count(&self) -> usize {
fn path_field_count(&self) -> usize {
self.fields.iter().filter(|field| field.is_path()).count()
}
@ -119,12 +119,12 @@ impl Request {
}
/// Returns the body field.
pub fn newtype_raw_body_field(&self) -> Option<&Field> {
fn newtype_raw_body_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_newtype_raw_body_field)
}
/// Returns the query map field.
pub fn query_map_field(&self) -> Option<&Field> {
fn query_map_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_query_map_field)
}
@ -135,8 +135,8 @@ impl Request {
request_field_kind: RequestFieldKind,
src: TokenStream,
) -> TokenStream {
let process_field = |f: &RequestField| {
f.field_of_kind(request_field_kind).map(|field| {
let fields =
self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)).map(|field| {
let field_name =
field.ident.as_ref().expect("expected field to have an identifier");
let span = field.span();
@ -147,25 +147,43 @@ impl Request {
#( #cfg_attrs )*
#field_name: #src.#field_name
}
})
};
let mut fields = vec![];
let mut new_type_body = None;
for field in &self.fields {
if let RequestField::NewtypeRawBody(_) = field {
new_type_body = process_field(field);
} else {
fields.extend(process_field(field));
}
}
// Move field that consumes `request` to the end of the init list.
fields.extend(new_type_body);
});
quote! { #(#fields,)* }
}
/// Produces code for a struct initializer for the given field kind to be accessed through the
/// given variable name.
fn vars(
&self,
request_field_kind: RequestFieldKind,
src: TokenStream,
) -> (TokenStream, TokenStream) {
let (decls, names): (TokenStream, Vec<_>) = self
.fields
.iter()
.filter_map(|f| f.field_of_kind(request_field_kind))
.map(|field| {
let field_name =
field.ident.as_ref().expect("expected field to have an identifier");
let span = field.span();
let cfg_attrs =
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
let decl = quote_spanned! {span=>
#( #cfg_attrs )*
let #field_name = #src.#field_name;
};
(decl, field_name)
})
.unzip();
let names = quote! { #(#names,)* };
(decls, names)
}
pub(super) fn expand(
&self,
metadata: &Metadata,
@ -197,16 +215,7 @@ impl Request {
let incoming_request_type =
if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) };
let extract_request_path = if self.has_path_fields() {
quote! {
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
request.uri().path()[1..].split('/').collect();
}
} else {
TokenStream::new()
};
let (request_path_string, parse_request_path) = if self.has_path_fields() {
let (request_path_string, parse_request_path, path_vars) = if self.has_path_fields() {
let path_string = metadata.path.value();
assert!(path_string.starts_with('/'), "path needs to start with '/'");
@ -246,26 +255,38 @@ impl Request {
}
};
let path_fields =
path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map(
|(i, segment)| {
let path_var = &segment[1..];
let path_var_ident = Ident::new(path_var, Span::call_site());
quote! {
#path_var_ident: {
let segment = path_segments[#i].as_bytes();
let decoded =
#percent_encoding::percent_decode(segment).decode_utf8()?;
let path_var_decls = path_string[1..]
.split('/')
.enumerate()
.filter(|(_, seg)| seg.starts_with(':'))
.map(|(i, seg)| {
let path_var = Ident::new(&seg[1..], Span::call_site());
quote! {
let #path_var = {
let segment = path_segments[#i].as_bytes();
let decoded =
#percent_encoding::percent_decode(segment).decode_utf8()?;
::std::convert::TryFrom::try_from(&*decoded)?
}
}
},
);
::std::convert::TryFrom::try_from(&*decoded)?
};
}
});
(format_call, quote! { #(#path_fields,)* })
let parse_request_path = quote! {
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
request.uri().path()[1..].split('/').collect();
#(#path_var_decls)*
};
let path_vars = path_string[1..]
.split('/')
.filter(|seg| seg.starts_with(':'))
.map(|seg| Ident::new(&seg[1..], Span::call_site()));
(format_call, parse_request_path, quote! { #(#path_vars,)* })
} else {
(quote! { metadata.path.to_owned() }, TokenStream::new())
(quote! { metadata.path.to_owned() }, TokenStream::new(), TokenStream::new())
};
let request_query_string = if let Some(field) = self.query_map_field() {
@ -315,31 +336,30 @@ impl Request {
quote! { "" }
};
let extract_request_query = if self.query_map_field().is_some() {
quote! {
let request_query = #ruma_serde::urlencoded::from_str(
let (parse_query, query_vars) = if let Some(field) = self.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let parse = quote! {
let #field_name = #ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or(""),
)?;
}
};
(parse, quote! { #field_name, })
} else if self.has_query_fields() {
quote! {
let (decls, names) = self.vars(RequestFieldKind::Query, quote!(request_query));
let parse = quote! {
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
#ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or("")
)?;
}
} else {
TokenStream::new()
};
let parse_request_query = if let Some(field) = self.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
#decls
};
quote! {
#field_name: request_query,
}
(parse, names)
} else {
self.struct_init_fields(RequestFieldKind::Query, quote!(request_query))
(TokenStream::new(), TokenStream::new())
};
let mut header_kvs: TokenStream = self
@ -395,16 +415,62 @@ impl Request {
}
}
let extract_request_headers = if self.has_header_fields() {
quote! {
let (parse_headers, header_vars) = if self.has_header_fields() {
let (decls, names): (TokenStream, Vec<_>) = self
.header_fields()
.map(|request_field| {
let (field, header_name) = match request_field {
RequestField::Header(field, header_name) => (field, header_name),
_ => panic!("expected request field to be header variant"),
};
let field_name = &field.ident;
let header_name_string = header_name.to_string();
let (some_case, none_case) = match &field.ty {
syn::Type::Path(syn::TypePath {
path: syn::Path { segments, .. }, ..
}) if segments.last().unwrap().ident == "Option" => {
(quote! { Some(str_value.to_owned()) }, quote! { None })
}
_ => (
quote! { str_value.to_owned() },
quote! {
return Err(
#ruma_api::error::HeaderDeserializationError::MissingHeader(
#header_name_string.into()
).into(),
)
},
),
};
let decl = quote! {
let #field_name = match headers.get(#http::header::#header_name) {
Some(header_value) => {
let str_value = header_value.to_str()?;
#some_case
}
None => #none_case,
};
};
(decl, field_name)
})
.unzip();
let parse = quote! {
let headers = request.headers();
}
#decls
};
(parse, quote! { #(#names,)* })
} else {
TokenStream::new()
(TokenStream::new(), TokenStream::new())
};
let extract_request_body = if self.has_body_fields() || self.newtype_body_field().is_some()
{
let extract_body = if self.has_body_fields() || self.newtype_body_field().is_some() {
let body_lifetimes = if self.has_body_lifetimes() {
// duplicate the anonymous lifetime as many times as needed
let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count());
@ -412,6 +478,7 @@ impl Request {
} else {
TokenStream::new()
};
quote! {
let request_body: <
RequestBody #body_lifetimes
@ -432,52 +499,6 @@ impl Request {
TokenStream::new()
};
let parse_request_headers = if self.has_header_fields() {
let fields = self.header_fields().map(|request_field| {
let (field, header_name) = match request_field {
RequestField::Header(field, header_name) => (field, header_name),
_ => panic!("expected request field to be header variant"),
};
let field_name = &field.ident;
let header_name_string = header_name.to_string();
let (some_case, none_case) = match &field.ty {
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
if segments.last().unwrap().ident == "Option" =>
{
(quote! { Some(str_value.to_owned()) }, quote! { None })
}
_ => (
quote! { str_value.to_owned() },
quote! {
return Err(
#ruma_api::error::HeaderDeserializationError::MissingHeader(
#header_name_string.into()
).into(),
)
},
),
};
quote! {
#field_name: match headers.get(#http::header::#header_name) {
Some(header_value) => {
let str_value = header_value.to_str()?;
#some_case
}
None => #none_case,
}
}
});
quote! {
#(#fields,)*
}
} else {
TokenStream::new()
};
let request_body = if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { self.#field_name }
@ -501,18 +522,22 @@ impl Request {
quote! { Vec::new() }
};
let parse_request_body = if let Some(field) = self.newtype_body_field() {
let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_body.0,
}
let parse = quote! {
let #field_name = request_body.0;
};
(parse, quote! { #field_name, })
} else if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request.into_body(),
}
let parse = quote! {
let #field_name = request.into_body();
};
(parse, quote! { #field_name, })
} else {
self.struct_init_fields(RequestFieldKind::Body, quote!(request_body))
self.vars(RequestFieldKind::Body, quote!(request_body))
};
let request_generics = self.combine_lifetimes();
@ -699,16 +724,18 @@ impl Request {
});
}
#extract_request_path
#extract_request_query
#extract_request_headers
#extract_request_body
#parse_request_path
#parse_query
#parse_headers
#extract_body
#parse_body
Ok(Self {
#parse_request_path
#parse_request_query
#parse_request_headers
#parse_request_body
#path_vars
#query_vars
#header_vars
#body_vars
})
}
}

View File

@ -8,7 +8,7 @@ ruma_api! {
description: "Does something.",
method: POST,
name: "my_endpoint",
path: "/_matrix/foo/:bar/:baz",
path: "/_matrix/foo/:bar/:user",
rate_limited: false,
authentication: None,
}
@ -24,7 +24,7 @@ ruma_api! {
#[ruma_api(path)]
pub bar: String,
#[ruma_api(path)]
pub baz: UserId,
pub user: UserId,
}
response: {
@ -44,7 +44,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
q1: "query_param_special_chars %/&@!".to_owned(),
q2: 55,
bar: "barVal".to_owned(),
baz: user_id!("@bazme:ruma.io"),
user: user_id!("@bazme:ruma.io"),
};
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?;
@ -55,7 +55,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
assert_eq!(req.q1, req2.q1);
assert_eq!(req.q2, req2.q2);
assert_eq!(req.bar, req2.bar);
assert_eq!(req.baz, req2.baz);
assert_eq!(req.user, req2.user);
Ok(())
}
@ -68,12 +68,12 @@ fn request_with_user_id_serde() -> Result<(), Box<dyn std::error::Error + 'stati
q1: "query_param_special_chars %/&@!".to_owned(),
q2: 55,
bar: "barVal".to_owned(),
baz: user_id!("@bazme:ruma.io"),
user: user_id!("@bazme:ruma.io"),
};
let user_id = user_id!("@_virtual_:ruma.io");
let http_req =
req.clone().try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?;
req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?;
let query = http_req.uri().query().unwrap();
@ -93,7 +93,7 @@ mod without_query {
description: "Does something without query.",
method: POST,
name: "my_endpoint",
path: "/_matrix/foo/:bar/:baz",
path: "/_matrix/foo/:bar/:user",
rate_limited: false,
authentication: None,
}
@ -105,7 +105,7 @@ mod without_query {
#[ruma_api(path)]
pub bar: String,
#[ruma_api(path)]
pub baz: UserId,
pub user: UserId,
}
response: {
@ -124,15 +124,12 @@ mod without_query {
hello: "hi".to_owned(),
world: "test".to_owned(),
bar: "barVal".to_owned(),
baz: user_id!("@bazme:ruma.io"),
user: user_id!("@bazme:ruma.io"),
};
let user_id = user_id!("@_virtual_:ruma.io");
let http_req = req.clone().try_into_http_request_with_user_id(
"https://homeserver.tld",
None,
user_id,
)?;
let http_req =
req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?;
let query = http_req.uri().query().unwrap();

View File

@ -1,6 +1,7 @@
pub mod some_endpoint {
use ruma_api::ruma_api;
use ruma_events::{tag::TagEvent, AnyRoomEvent};
use ruma_identifiers::UserId;
use ruma_serde::Raw;
ruma_api! {
@ -8,7 +9,7 @@ pub mod some_endpoint {
description: "Does something.",
method: POST, // An `http::Method` constant. No imports required.
name: "some_endpoint",
path: "/_matrix/some/endpoint/:baz",
path: "/_matrix/some/endpoint/:user",
#[cfg(all())]
rate_limited: true,
@ -23,7 +24,7 @@ pub mod some_endpoint {
request: {
// With no attribute on the field, it will be put into the body of the request.
pub foo: String,
pub a_field: String,
// This value will be put into the "Content-Type" HTTP header.
#[ruma_api(header = CONTENT_TYPE)]
@ -34,9 +35,9 @@ pub mod some_endpoint {
pub bar: String,
// This value will be inserted into the request's URL in place of the
// ":baz" path component.
// ":user" path component.
#[ruma_api(path)]
pub baz: String,
pub user: UserId,
}
response: {
@ -65,7 +66,7 @@ pub mod newtype_body_endpoint {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct MyCustomType {
pub foo: String,
pub a_field: String,
}
ruma_api! {
@ -95,7 +96,7 @@ pub mod newtype_raw_body_endpoint {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct MyCustomType {
pub foo: String,
pub a_field: String,
}
ruma_api! {