api-macros: Inline request codegen methods
This commit is contained in:
parent
2e2609b752
commit
a20f03894e
@ -173,6 +173,7 @@ impl Request {
|
||||
ruma_api: &TokenStream,
|
||||
) -> TokenStream {
|
||||
let http = quote! { #ruma_api::exports::http };
|
||||
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
||||
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||
let serde = quote! { #ruma_api::exports::serde };
|
||||
let serde_json = quote! { #ruma_api::exports::serde_json };
|
||||
@ -205,11 +206,131 @@ impl Request {
|
||||
TokenStream::new()
|
||||
};
|
||||
|
||||
let (request_path_string, parse_request_path) =
|
||||
self.path_string_and_parse(metadata, &ruma_api);
|
||||
let (request_path_string, parse_request_path) = if self.has_path_fields() {
|
||||
let path_string = metadata.path.value();
|
||||
|
||||
let request_query_string = self.build_query_string(&ruma_api);
|
||||
let extract_request_query = self.extract_request_query(&ruma_api);
|
||||
assert!(path_string.starts_with('/'), "path needs to start with '/'");
|
||||
assert!(
|
||||
path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(),
|
||||
"number of declared path parameters needs to match amount of placeholders in path"
|
||||
);
|
||||
|
||||
let format_call = {
|
||||
let mut format_string = path_string.clone();
|
||||
let mut format_args = Vec::new();
|
||||
|
||||
while let Some(start_of_segment) = format_string.find(':') {
|
||||
// ':' should only ever appear at the start of a segment
|
||||
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
|
||||
|
||||
let end_of_segment = match format_string[start_of_segment..].find('/') {
|
||||
Some(rel_pos) => start_of_segment + rel_pos,
|
||||
None => format_string.len(),
|
||||
};
|
||||
|
||||
let path_var = Ident::new(
|
||||
&format_string[start_of_segment + 1..end_of_segment],
|
||||
Span::call_site(),
|
||||
);
|
||||
format_args.push(quote! {
|
||||
#percent_encoding::utf8_percent_encode(
|
||||
&self.#path_var.to_string(),
|
||||
#percent_encoding::NON_ALPHANUMERIC,
|
||||
)
|
||||
});
|
||||
format_string.replace_range(start_of_segment..end_of_segment, "{}");
|
||||
}
|
||||
|
||||
quote! {
|
||||
format_args!(#format_string, #(#format_args),*)
|
||||
}
|
||||
};
|
||||
|
||||
let path_fields =
|
||||
path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map(
|
||||
|(i, segment)| {
|
||||
let path_var = &segment[1..];
|
||||
let path_var_ident = Ident::new(path_var, Span::call_site());
|
||||
quote! {
|
||||
#path_var_ident: {
|
||||
let segment = path_segments[#i].as_bytes();
|
||||
let decoded =
|
||||
#percent_encoding::percent_decode(segment).decode_utf8()?;
|
||||
|
||||
::std::convert::TryFrom::try_from(&*decoded)?
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
(format_call, quote! { #(#path_fields,)* })
|
||||
} else {
|
||||
(quote! { metadata.path.to_owned() }, TokenStream::new())
|
||||
};
|
||||
|
||||
let request_query_string = if let Some(field) = self.query_map_field() {
|
||||
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
||||
|
||||
quote!({
|
||||
// This function exists so that the compiler will throw an error when the type of
|
||||
// the field with the query_map attribute doesn't implement
|
||||
// `IntoIterator<Item = (String, String)>`.
|
||||
//
|
||||
// This is necessary because the `ruma_serde::urlencoded::to_string` call will
|
||||
// result in a runtime error when the type cannot be encoded as a list key-value
|
||||
// pairs (?key1=value1&key2=value2).
|
||||
//
|
||||
// By asserting that it implements the iterator trait, we can ensure that it won't
|
||||
// fail.
|
||||
fn assert_trait_impl<T>(_: &T)
|
||||
where
|
||||
T: ::std::iter::IntoIterator<
|
||||
Item = (::std::string::String, ::std::string::String),
|
||||
>,
|
||||
{}
|
||||
|
||||
let request_query = RequestQuery(self.#field_name);
|
||||
assert_trait_impl(&request_query.0);
|
||||
|
||||
format_args!(
|
||||
"?{}",
|
||||
#ruma_serde::urlencoded::to_string(request_query)?
|
||||
)
|
||||
})
|
||||
} else if self.has_query_fields() {
|
||||
let request_query_init_fields =
|
||||
self.struct_init_fields(RequestFieldKind::Query, quote!(self));
|
||||
|
||||
quote!({
|
||||
let request_query = RequestQuery {
|
||||
#request_query_init_fields
|
||||
};
|
||||
|
||||
format_args!(
|
||||
"?{}",
|
||||
#ruma_serde::urlencoded::to_string(request_query)?
|
||||
)
|
||||
})
|
||||
} else {
|
||||
quote! { "" }
|
||||
};
|
||||
|
||||
let extract_request_query = if self.query_map_field().is_some() {
|
||||
quote! {
|
||||
let request_query = #ruma_serde::urlencoded::from_str(
|
||||
&request.uri().query().unwrap_or(""),
|
||||
)?;
|
||||
}
|
||||
} else if self.has_query_fields() {
|
||||
quote! {
|
||||
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
|
||||
#ruma_serde::urlencoded::from_str(
|
||||
&request.uri().query().unwrap_or("")
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
TokenStream::new()
|
||||
};
|
||||
|
||||
let parse_request_query = if let Some(field) = self.query_map_field() {
|
||||
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
||||
@ -593,158 +714,6 @@ impl Request {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize the query string.
|
||||
fn extract_request_query(&self, ruma_api: &TokenStream) -> TokenStream {
|
||||
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||
|
||||
if self.query_map_field().is_some() {
|
||||
quote! {
|
||||
let request_query = #ruma_serde::urlencoded::from_str(
|
||||
&request.uri().query().unwrap_or(""),
|
||||
)?;
|
||||
}
|
||||
} else if self.has_query_fields() {
|
||||
quote! {
|
||||
let request_query: <RequestQuery as #ruma_serde::Outgoing>::Incoming =
|
||||
#ruma_serde::urlencoded::from_str(
|
||||
&request.uri().query().unwrap_or("")
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
TokenStream::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The function determines the type of query string that needs to be built
|
||||
/// and then builds it using `ruma_serde::urlencoded::to_string`.
|
||||
fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream {
|
||||
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
||||
|
||||
if let Some(field) = self.query_map_field() {
|
||||
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
||||
|
||||
quote!({
|
||||
// This function exists so that the compiler will throw an error when the type of
|
||||
// the field with the query_map attribute doesn't implement
|
||||
// `IntoIterator<Item = (String, String)>`.
|
||||
//
|
||||
// This is necessary because the `ruma_serde::urlencoded::to_string` call will
|
||||
// result in a runtime error when the type cannot be encoded as a list key-value
|
||||
// pairs (?key1=value1&key2=value2).
|
||||
//
|
||||
// By asserting that it implements the iterator trait, we can ensure that it won't
|
||||
// fail.
|
||||
fn assert_trait_impl<T>(_: &T)
|
||||
where
|
||||
T: ::std::iter::IntoIterator<
|
||||
Item = (::std::string::String, ::std::string::String),
|
||||
>,
|
||||
{}
|
||||
|
||||
let request_query = RequestQuery(self.#field_name);
|
||||
assert_trait_impl(&request_query.0);
|
||||
|
||||
format_args!(
|
||||
"?{}",
|
||||
#ruma_serde::urlencoded::to_string(request_query)?
|
||||
)
|
||||
})
|
||||
} else if self.has_query_fields() {
|
||||
let request_query_init_fields =
|
||||
self.struct_init_fields(RequestFieldKind::Query, quote!(self));
|
||||
|
||||
quote!({
|
||||
let request_query = RequestQuery {
|
||||
#request_query_init_fields
|
||||
};
|
||||
|
||||
format_args!(
|
||||
"?{}",
|
||||
#ruma_serde::urlencoded::to_string(request_query)?
|
||||
)
|
||||
})
|
||||
} else {
|
||||
quote! { "" }
|
||||
}
|
||||
}
|
||||
|
||||
/// The first item in the tuple generates code for the request path from the `Metadata` and
|
||||
/// `Request` structs. The second item in the returned tuple is the code to generate a Request
|
||||
/// struct field created from any segments of the path that start with ":".
|
||||
///
|
||||
/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is
|
||||
/// used for implementing `TryFrom<http::Request<Vec<u8>>>`, from path strings deserialized to
|
||||
/// Ruma types.
|
||||
fn path_string_and_parse(
|
||||
&self,
|
||||
metadata: &Metadata,
|
||||
ruma_api: &TokenStream,
|
||||
) -> (TokenStream, TokenStream) {
|
||||
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
|
||||
|
||||
if self.has_path_fields() {
|
||||
let path_string = metadata.path.value();
|
||||
|
||||
assert!(path_string.starts_with('/'), "path needs to start with '/'");
|
||||
assert!(
|
||||
path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(),
|
||||
"number of declared path parameters needs to match amount of placeholders in path"
|
||||
);
|
||||
|
||||
let format_call = {
|
||||
let mut format_string = path_string.clone();
|
||||
let mut format_args = Vec::new();
|
||||
|
||||
while let Some(start_of_segment) = format_string.find(':') {
|
||||
// ':' should only ever appear at the start of a segment
|
||||
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
|
||||
|
||||
let end_of_segment = match format_string[start_of_segment..].find('/') {
|
||||
Some(rel_pos) => start_of_segment + rel_pos,
|
||||
None => format_string.len(),
|
||||
};
|
||||
|
||||
let path_var = Ident::new(
|
||||
&format_string[start_of_segment + 1..end_of_segment],
|
||||
Span::call_site(),
|
||||
);
|
||||
format_args.push(quote! {
|
||||
#percent_encoding::utf8_percent_encode(
|
||||
&self.#path_var.to_string(),
|
||||
#percent_encoding::NON_ALPHANUMERIC,
|
||||
)
|
||||
});
|
||||
format_string.replace_range(start_of_segment..end_of_segment, "{}");
|
||||
}
|
||||
|
||||
quote! {
|
||||
format_args!(#format_string, #(#format_args),*)
|
||||
}
|
||||
};
|
||||
|
||||
let path_fields =
|
||||
path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map(
|
||||
|(i, segment)| {
|
||||
let path_var = &segment[1..];
|
||||
let path_var_ident = Ident::new(path_var, Span::call_site());
|
||||
quote! {
|
||||
#path_var_ident: {
|
||||
let segment = path_segments[#i].as_bytes();
|
||||
let decoded =
|
||||
#percent_encoding::percent_decode(segment).decode_utf8()?;
|
||||
|
||||
::std::convert::TryFrom::try_from(&*decoded)?
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
(format_call, quote! { #(#path_fields,)* })
|
||||
} else {
|
||||
(quote! { metadata.path.to_owned() }, TokenStream::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The types of fields that a request can have.
|
||||
|
Loading…
x
Reference in New Issue
Block a user