api-macros: Refactor request code generation

This commit is contained in:
Jonas Platte 2021-04-10 16:47:34 +02:00
parent a20f03894e
commit 23ba0bc164
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
3 changed files with 188 additions and 163 deletions

View File

@ -32,47 +32,47 @@ pub(crate) struct Request {
impl Request { impl Request {
/// Whether or not this request has any data in the HTTP body. /// Whether or not this request has any data in the HTTP body.
pub fn has_body_fields(&self) -> bool { pub(super) fn has_body_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_body()) self.fields.iter().any(|field| field.is_body())
} }
/// Whether or not this request has any data in HTTP headers. /// Whether or not this request has any data in HTTP headers.
pub fn has_header_fields(&self) -> bool { fn has_header_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_header()) self.fields.iter().any(|field| field.is_header())
} }
/// Whether or not this request has any data in the URL path. /// Whether or not this request has any data in the URL path.
pub fn has_path_fields(&self) -> bool { fn has_path_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_path()) self.fields.iter().any(|field| field.is_path())
} }
/// Whether or not this request has any data in the query string. /// Whether or not this request has any data in the query string.
pub fn has_query_fields(&self) -> bool { fn has_query_fields(&self) -> bool {
self.fields.iter().any(|field| field.is_query()) self.fields.iter().any(|field| field.is_query())
} }
/// Produces an iterator over all the body fields. /// Produces an iterator over all the body fields.
pub fn body_fields(&self) -> impl Iterator<Item = &Field> { pub(super) fn body_fields(&self) -> impl Iterator<Item = &Field> {
self.fields.iter().filter_map(|field| field.as_body_field()) self.fields.iter().filter_map(|field| field.as_body_field())
} }
/// The number of unique lifetime annotations for `body` fields. /// The number of unique lifetime annotations for `body` fields.
pub fn body_lifetime_count(&self) -> usize { fn body_lifetime_count(&self) -> usize {
self.lifetimes.body.len() self.lifetimes.body.len()
} }
/// Whether any `body` field has a lifetime annotation. /// Whether any `body` field has a lifetime annotation.
pub fn has_body_lifetimes(&self) -> bool { fn has_body_lifetimes(&self) -> bool {
!self.lifetimes.body.is_empty() !self.lifetimes.body.is_empty()
} }
/// Whether any `query` field has a lifetime annotation. /// Whether any `query` field has a lifetime annotation.
pub fn has_query_lifetimes(&self) -> bool { fn has_query_lifetimes(&self) -> bool {
!self.lifetimes.query.is_empty() !self.lifetimes.query.is_empty()
} }
/// Whether any field has a lifetime. /// Whether any field has a lifetime.
pub fn contains_lifetimes(&self) -> bool { fn contains_lifetimes(&self) -> bool {
!(self.lifetimes.body.is_empty() !(self.lifetimes.body.is_empty()
&& self.lifetimes.path.is_empty() && self.lifetimes.path.is_empty()
&& self.lifetimes.query.is_empty() && self.lifetimes.query.is_empty()
@ -80,7 +80,7 @@ impl Request {
} }
/// The combination of every fields unique lifetime annotation. /// The combination of every fields unique lifetime annotation.
pub fn combine_lifetimes(&self) -> TokenStream { fn combine_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens( util::unique_lifetimes_to_tokens(
[ [
&self.lifetimes.body, &self.lifetimes.body,
@ -94,22 +94,22 @@ impl Request {
} }
/// The lifetimes on fields with the `query` attribute. /// The lifetimes on fields with the `query` attribute.
pub fn query_lifetimes(&self) -> TokenStream { fn query_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens(&self.lifetimes.query) util::unique_lifetimes_to_tokens(&self.lifetimes.query)
} }
/// The lifetimes on fields with the `body` attribute. /// The lifetimes on fields with the `body` attribute.
pub fn body_lifetimes(&self) -> TokenStream { fn body_lifetimes(&self) -> TokenStream {
util::unique_lifetimes_to_tokens(&self.lifetimes.body) util::unique_lifetimes_to_tokens(&self.lifetimes.body)
} }
/// Produces an iterator over all the header fields. /// Produces an iterator over all the header fields.
pub fn header_fields(&self) -> impl Iterator<Item = &RequestField> { fn header_fields(&self) -> impl Iterator<Item = &RequestField> {
self.fields.iter().filter(|field| field.is_header()) self.fields.iter().filter(|field| field.is_header())
} }
/// Gets the number of path fields. /// Gets the number of path fields.
pub fn path_field_count(&self) -> usize { fn path_field_count(&self) -> usize {
self.fields.iter().filter(|field| field.is_path()).count() self.fields.iter().filter(|field| field.is_path()).count()
} }
@ -119,12 +119,12 @@ impl Request {
} }
/// Returns the body field. /// Returns the body field.
pub fn newtype_raw_body_field(&self) -> Option<&Field> { fn newtype_raw_body_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_newtype_raw_body_field) self.fields.iter().find_map(RequestField::as_newtype_raw_body_field)
} }
/// Returns the query map field. /// Returns the query map field.
pub fn query_map_field(&self) -> Option<&Field> { fn query_map_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_query_map_field) self.fields.iter().find_map(RequestField::as_query_map_field)
} }
@ -135,8 +135,8 @@ impl Request {
request_field_kind: RequestFieldKind, request_field_kind: RequestFieldKind,
src: TokenStream, src: TokenStream,
) -> TokenStream { ) -> TokenStream {
let process_field = |f: &RequestField| { let fields =
f.field_of_kind(request_field_kind).map(|field| { self.fields.iter().filter_map(|f| f.field_of_kind(request_field_kind)).map(|field| {
let field_name = let field_name =
field.ident.as_ref().expect("expected field to have an identifier"); field.ident.as_ref().expect("expected field to have an identifier");
let span = field.span(); let span = field.span();
@ -147,25 +147,43 @@ impl Request {
#( #cfg_attrs )* #( #cfg_attrs )*
#field_name: #src.#field_name #field_name: #src.#field_name
} }
}) });
};
let mut fields = vec![];
let mut new_type_body = None;
for field in &self.fields {
if let RequestField::NewtypeRawBody(_) = field {
new_type_body = process_field(field);
} else {
fields.extend(process_field(field));
}
}
// Move field that consumes `request` to the end of the init list.
fields.extend(new_type_body);
quote! { #(#fields,)* } quote! { #(#fields,)* }
} }
/// Produces code for a struct initializer for the given field kind to be accessed through the
/// given variable name.
fn vars(
&self,
request_field_kind: RequestFieldKind,
src: TokenStream,
) -> (TokenStream, TokenStream) {
let (decls, names): (TokenStream, Vec<_>) = self
.fields
.iter()
.filter_map(|f| f.field_of_kind(request_field_kind))
.map(|field| {
let field_name =
field.ident.as_ref().expect("expected field to have an identifier");
let span = field.span();
let cfg_attrs =
field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::<Vec<_>>();
let decl = quote_spanned! {span=>
#( #cfg_attrs )*
let #field_name = #src.#field_name;
};
(decl, field_name)
})
.unzip();
let names = quote! { #(#names,)* };
(decls, names)
}
pub(super) fn expand( pub(super) fn expand(
&self, &self,
metadata: &Metadata, metadata: &Metadata,
@ -197,16 +215,7 @@ impl Request {
let incoming_request_type = let incoming_request_type =
if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) }; if self.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) };
let extract_request_path = if self.has_path_fields() { let (request_path_string, parse_request_path, path_vars) = if self.has_path_fields() {
quote! {
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
request.uri().path()[1..].split('/').collect();
}
} else {
TokenStream::new()
};
let (request_path_string, parse_request_path) = if self.has_path_fields() {
let path_string = metadata.path.value(); let path_string = metadata.path.value();
assert!(path_string.starts_with('/'), "path needs to start with '/'"); assert!(path_string.starts_with('/'), "path needs to start with '/'");
@ -246,26 +255,38 @@ impl Request {
} }
}; };
let path_fields = let path_var_decls = path_string[1..]
path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( .split('/')
|(i, segment)| { .enumerate()
let path_var = &segment[1..]; .filter(|(_, seg)| seg.starts_with(':'))
let path_var_ident = Ident::new(path_var, Span::call_site()); .map(|(i, seg)| {
let path_var = Ident::new(&seg[1..], Span::call_site());
quote! { quote! {
#path_var_ident: { let #path_var = {
let segment = path_segments[#i].as_bytes(); let segment = path_segments[#i].as_bytes();
let decoded = let decoded =
#percent_encoding::percent_decode(segment).decode_utf8()?; #percent_encoding::percent_decode(segment).decode_utf8()?;
::std::convert::TryFrom::try_from(&*decoded)? ::std::convert::TryFrom::try_from(&*decoded)?
};
} }
} });
},
);
(format_call, quote! { #(#path_fields,)* }) let parse_request_path = quote! {
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
request.uri().path()[1..].split('/').collect();
#(#path_var_decls)*
};
let path_vars = path_string[1..]
.split('/')
.filter(|seg| seg.starts_with(':'))
.map(|seg| Ident::new(&seg[1..], Span::call_site()));
(format_call, parse_request_path, quote! { #(#path_vars,)* })
} else { } else {
(quote! { metadata.path.to_owned() }, TokenStream::new()) (quote! { metadata.path.to_owned() }, TokenStream::new(), TokenStream::new())
}; };
let request_query_string = if let Some(field) = self.query_map_field() { let request_query_string = if let Some(field) = self.query_map_field() {
@ -315,31 +336,30 @@ impl Request {
quote! { "" } quote! { "" }
}; };
let extract_request_query = if self.query_map_field().is_some() { let (parse_query, query_vars) = if let Some(field) = self.query_map_field() {
quote! { let field_name = field.ident.as_ref().expect("expected field to have an identifier");
let request_query = #ruma_serde::urlencoded::from_str( let parse = quote! {
let #field_name = #ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or(""), &request.uri().query().unwrap_or(""),
)?; )?;
} };
(parse, quote! { #field_name, })
} else if self.has_query_fields() { } else if self.has_query_fields() {
quote! { let (decls, names) = self.vars(RequestFieldKind::Query, quote!(request_query));
let parse = quote! {
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming = let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
#ruma_serde::urlencoded::from_str( #ruma_serde::urlencoded::from_str(
&request.uri().query().unwrap_or("") &request.uri().query().unwrap_or("")
)?; )?;
}
} else { #decls
TokenStream::new()
}; };
let parse_request_query = if let Some(field) = self.query_map_field() { (parse, names)
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_query,
}
} else { } else {
self.struct_init_fields(RequestFieldKind::Query, quote!(request_query)) (TokenStream::new(), TokenStream::new())
}; };
let mut header_kvs: TokenStream = self let mut header_kvs: TokenStream = self
@ -395,16 +415,62 @@ impl Request {
} }
} }
let extract_request_headers = if self.has_header_fields() { let (parse_headers, header_vars) = if self.has_header_fields() {
quote! { let (decls, names): (TokenStream, Vec<_>) = self
let headers = request.headers(); .header_fields()
} .map(|request_field| {
} else { let (field, header_name) = match request_field {
TokenStream::new() RequestField::Header(field, header_name) => (field, header_name),
_ => panic!("expected request field to be header variant"),
}; };
let extract_request_body = if self.has_body_fields() || self.newtype_body_field().is_some() let field_name = &field.ident;
{ let header_name_string = header_name.to_string();
let (some_case, none_case) = match &field.ty {
syn::Type::Path(syn::TypePath {
path: syn::Path { segments, .. }, ..
}) if segments.last().unwrap().ident == "Option" => {
(quote! { Some(str_value.to_owned()) }, quote! { None })
}
_ => (
quote! { str_value.to_owned() },
quote! {
return Err(
#ruma_api::error::HeaderDeserializationError::MissingHeader(
#header_name_string.into()
).into(),
)
},
),
};
let decl = quote! {
let #field_name = match headers.get(#http::header::#header_name) {
Some(header_value) => {
let str_value = header_value.to_str()?;
#some_case
}
None => #none_case,
};
};
(decl, field_name)
})
.unzip();
let parse = quote! {
let headers = request.headers();
#decls
};
(parse, quote! { #(#names,)* })
} else {
(TokenStream::new(), TokenStream::new())
};
let extract_body = if self.has_body_fields() || self.newtype_body_field().is_some() {
let body_lifetimes = if self.has_body_lifetimes() { let body_lifetimes = if self.has_body_lifetimes() {
// duplicate the anonymous lifetime as many times as needed // duplicate the anonymous lifetime as many times as needed
let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count()); let lifetimes = std::iter::repeat(quote! { '_ }).take(self.body_lifetime_count());
@ -412,6 +478,7 @@ impl Request {
} else { } else {
TokenStream::new() TokenStream::new()
}; };
quote! { quote! {
let request_body: < let request_body: <
RequestBody #body_lifetimes RequestBody #body_lifetimes
@ -432,52 +499,6 @@ impl Request {
TokenStream::new() TokenStream::new()
}; };
let parse_request_headers = if self.has_header_fields() {
let fields = self.header_fields().map(|request_field| {
let (field, header_name) = match request_field {
RequestField::Header(field, header_name) => (field, header_name),
_ => panic!("expected request field to be header variant"),
};
let field_name = &field.ident;
let header_name_string = header_name.to_string();
let (some_case, none_case) = match &field.ty {
syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
if segments.last().unwrap().ident == "Option" =>
{
(quote! { Some(str_value.to_owned()) }, quote! { None })
}
_ => (
quote! { str_value.to_owned() },
quote! {
return Err(
#ruma_api::error::HeaderDeserializationError::MissingHeader(
#header_name_string.into()
).into(),
)
},
),
};
quote! {
#field_name: match headers.get(#http::header::#header_name) {
Some(header_value) => {
let str_value = header_value.to_str()?;
#some_case
}
None => #none_case,
}
}
});
quote! {
#(#fields,)*
}
} else {
TokenStream::new()
};
let request_body = if let Some(field) = self.newtype_raw_body_field() { let request_body = if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { self.#field_name } quote! { self.#field_name }
@ -501,18 +522,22 @@ impl Request {
quote! { Vec::new() } quote! { Vec::new() }
}; };
let parse_request_body = if let Some(field) = self.newtype_body_field() { let (parse_body, body_vars) = if let Some(field) = self.newtype_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { let parse = quote! {
#field_name: request_body.0, let #field_name = request_body.0;
} };
(parse, quote! { #field_name, })
} else if let Some(field) = self.newtype_raw_body_field() { } else if let Some(field) = self.newtype_raw_body_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! { let parse = quote! {
#field_name: request.into_body(), let #field_name = request.into_body();
} };
(parse, quote! { #field_name, })
} else { } else {
self.struct_init_fields(RequestFieldKind::Body, quote!(request_body)) self.vars(RequestFieldKind::Body, quote!(request_body))
}; };
let request_generics = self.combine_lifetimes(); let request_generics = self.combine_lifetimes();
@ -699,16 +724,18 @@ impl Request {
}); });
} }
#extract_request_path #parse_request_path
#extract_request_query #parse_query
#extract_request_headers #parse_headers
#extract_request_body
#extract_body
#parse_body
Ok(Self { Ok(Self {
#parse_request_path #path_vars
#parse_request_query #query_vars
#parse_request_headers #header_vars
#parse_request_body #body_vars
}) })
} }
} }

View File

@ -8,7 +8,7 @@ ruma_api! {
description: "Does something.", description: "Does something.",
method: POST, method: POST,
name: "my_endpoint", name: "my_endpoint",
path: "/_matrix/foo/:bar/:baz", path: "/_matrix/foo/:bar/:user",
rate_limited: false, rate_limited: false,
authentication: None, authentication: None,
} }
@ -24,7 +24,7 @@ ruma_api! {
#[ruma_api(path)] #[ruma_api(path)]
pub bar: String, pub bar: String,
#[ruma_api(path)] #[ruma_api(path)]
pub baz: UserId, pub user: UserId,
} }
response: { response: {
@ -44,7 +44,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
q1: "query_param_special_chars %/&@!".to_owned(), q1: "query_param_special_chars %/&@!".to_owned(),
q2: 55, q2: 55,
bar: "barVal".to_owned(), bar: "barVal".to_owned(),
baz: user_id!("@bazme:ruma.io"), user: user_id!("@bazme:ruma.io"),
}; };
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?; let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?;
@ -55,7 +55,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
assert_eq!(req.q1, req2.q1); assert_eq!(req.q1, req2.q1);
assert_eq!(req.q2, req2.q2); assert_eq!(req.q2, req2.q2);
assert_eq!(req.bar, req2.bar); assert_eq!(req.bar, req2.bar);
assert_eq!(req.baz, req2.baz); assert_eq!(req.user, req2.user);
Ok(()) Ok(())
} }
@ -68,12 +68,12 @@ fn request_with_user_id_serde() -> Result<(), Box<dyn std::error::Error + 'stati
q1: "query_param_special_chars %/&@!".to_owned(), q1: "query_param_special_chars %/&@!".to_owned(),
q2: 55, q2: 55,
bar: "barVal".to_owned(), bar: "barVal".to_owned(),
baz: user_id!("@bazme:ruma.io"), user: user_id!("@bazme:ruma.io"),
}; };
let user_id = user_id!("@_virtual_:ruma.io"); let user_id = user_id!("@_virtual_:ruma.io");
let http_req = let http_req =
req.clone().try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?; req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?;
let query = http_req.uri().query().unwrap(); let query = http_req.uri().query().unwrap();
@ -93,7 +93,7 @@ mod without_query {
description: "Does something without query.", description: "Does something without query.",
method: POST, method: POST,
name: "my_endpoint", name: "my_endpoint",
path: "/_matrix/foo/:bar/:baz", path: "/_matrix/foo/:bar/:user",
rate_limited: false, rate_limited: false,
authentication: None, authentication: None,
} }
@ -105,7 +105,7 @@ mod without_query {
#[ruma_api(path)] #[ruma_api(path)]
pub bar: String, pub bar: String,
#[ruma_api(path)] #[ruma_api(path)]
pub baz: UserId, pub user: UserId,
} }
response: { response: {
@ -124,15 +124,12 @@ mod without_query {
hello: "hi".to_owned(), hello: "hi".to_owned(),
world: "test".to_owned(), world: "test".to_owned(),
bar: "barVal".to_owned(), bar: "barVal".to_owned(),
baz: user_id!("@bazme:ruma.io"), user: user_id!("@bazme:ruma.io"),
}; };
let user_id = user_id!("@_virtual_:ruma.io"); let user_id = user_id!("@_virtual_:ruma.io");
let http_req = req.clone().try_into_http_request_with_user_id( let http_req =
"https://homeserver.tld", req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?;
None,
user_id,
)?;
let query = http_req.uri().query().unwrap(); let query = http_req.uri().query().unwrap();

View File

@ -1,6 +1,7 @@
pub mod some_endpoint { pub mod some_endpoint {
use ruma_api::ruma_api; use ruma_api::ruma_api;
use ruma_events::{tag::TagEvent, AnyRoomEvent}; use ruma_events::{tag::TagEvent, AnyRoomEvent};
use ruma_identifiers::UserId;
use ruma_serde::Raw; use ruma_serde::Raw;
ruma_api! { ruma_api! {
@ -8,7 +9,7 @@ pub mod some_endpoint {
description: "Does something.", description: "Does something.",
method: POST, // An `http::Method` constant. No imports required. method: POST, // An `http::Method` constant. No imports required.
name: "some_endpoint", name: "some_endpoint",
path: "/_matrix/some/endpoint/:baz", path: "/_matrix/some/endpoint/:user",
#[cfg(all())] #[cfg(all())]
rate_limited: true, rate_limited: true,
@ -23,7 +24,7 @@ pub mod some_endpoint {
request: { request: {
// With no attribute on the field, it will be put into the body of the request. // With no attribute on the field, it will be put into the body of the request.
pub foo: String, pub a_field: String,
// This value will be put into the "Content-Type" HTTP header. // This value will be put into the "Content-Type" HTTP header.
#[ruma_api(header = CONTENT_TYPE)] #[ruma_api(header = CONTENT_TYPE)]
@ -34,9 +35,9 @@ pub mod some_endpoint {
pub bar: String, pub bar: String,
// This value will be inserted into the request's URL in place of the // This value will be inserted into the request's URL in place of the
// ":baz" path component. // ":user" path component.
#[ruma_api(path)] #[ruma_api(path)]
pub baz: String, pub user: UserId,
} }
response: { response: {
@ -65,7 +66,7 @@ pub mod newtype_body_endpoint {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct MyCustomType { pub struct MyCustomType {
pub foo: String, pub a_field: String,
} }
ruma_api! { ruma_api! {
@ -95,7 +96,7 @@ pub mod newtype_raw_body_endpoint {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct MyCustomType { pub struct MyCustomType {
pub foo: String, pub a_field: String,
} }
ruma_api! { ruma_api! {