api-macros: Remove RawApi, some refactoring

This commit is contained in:
Jonas Platte 2020-11-27 21:12:10 +01:00
parent 183f427143
commit 761aecbe4e
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
2 changed files with 257 additions and 308 deletions

View File

@ -1,7 +1,5 @@
//! Details of the `ruma_api` procedural macro. //! Details of the `ruma_api` procedural macro.
use std::convert::TryFrom;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{quote, ToTokens}; use quote::{quote, ToTokens};
use syn::{ use syn::{
@ -36,29 +34,23 @@ pub struct Api {
response: Response, response: Response,
/// The `error` section of the macro. /// The `error` section of the macro.
error: TokenStream, error_ty: TokenStream,
} }
impl TryFrom<RawApi> for Api { impl Parse for Api {
type Error = syn::Error; fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
fn try_from(raw_api: RawApi) -> syn::Result<Self> {
let import_path = util::import_ruma_api(); let import_path = util::import_ruma_api();
let res = Self { let metadata: Metadata = input.parse()?;
metadata: raw_api.metadata, let request: Request = input.parse()?;
request: raw_api.request, let response: Response = input.parse()?;
response: raw_api.response, let error_ty = match input.parse::<ErrorType>() {
error: match raw_api.error { Ok(err) => err.ty.to_token_stream(),
Some(raw_err) => raw_err.ty.to_token_stream(), Err(_) => quote! { #import_path::error::Void },
None => quote! { #import_path::error::Void },
},
}; };
let newtype_body_field = res.request.newtype_body_field(); let newtype_body_field = request.newtype_body_field();
if res.metadata.method == "GET" if metadata.method == "GET" && (request.has_body_fields() || newtype_body_field.is_some()) {
&& (res.request.has_body_fields() || newtype_body_field.is_some())
{
let mut combined_error: Option<syn::Error> = None; let mut combined_error: Option<syn::Error> = None;
let mut add_error = |field| { let mut add_error = |field| {
let error = syn::Error::new_spanned(field, "GET endpoints can't have body fields"); let error = syn::Error::new_spanned(field, "GET endpoints can't have body fields");
@ -69,7 +61,7 @@ impl TryFrom<RawApi> for Api {
} }
}; };
for field in res.request.body_fields() { for field in request.body_fields() {
add_error(field); add_error(field);
} }
@ -77,95 +69,90 @@ impl TryFrom<RawApi> for Api {
add_error(field); add_error(field);
} }
Err(combined_error.unwrap()) return Err(combined_error.unwrap());
} else {
Ok(res)
} }
Ok(Self { metadata, request, response, error_ty })
} }
} }
impl ToTokens for Api { pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
fn to_tokens(&self, tokens: &mut TokenStream) { // Guarantee `ruma_api` is available and named something we can refer to.
// Guarantee `ruma_api` is available and named something we can refer to. let ruma_api_import = util::import_ruma_api();
let ruma_api_import = util::import_ruma_api();
let description = &self.metadata.description; let description = &api.metadata.description;
let method = &self.metadata.method; let method = &api.metadata.method;
// We don't (currently) use this literal as a literal in the generated code. Instead we just // We don't (currently) use this literal as a literal in the generated code. Instead we just
// put it into doc comments, for which the span information is irrelevant. So we can work // put it into doc comments, for which the span information is irrelevant. So we can work
// with only the literal's value from here on. // with only the literal's value from here on.
let name = &self.metadata.name.value(); let name = &api.metadata.name.value();
let path = &self.metadata.path; let path = &api.metadata.path;
let rate_limited = &self.metadata.rate_limited; let rate_limited = &api.metadata.rate_limited;
let authentication = &self.metadata.authentication; let authentication = &api.metadata.authentication;
let request_type = &self.request; let request_type = &api.request;
let response_type = &self.response; let response_type = &api.response;
let incoming_request_type = if self.request.contains_lifetimes() { let incoming_request_type =
quote!(IncomingRequest) if api.request.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) };
} else {
quote!(Request)
};
let extract_request_path = if self.request.has_path_fields() { let extract_request_path = if api.request.has_path_fields() {
quote! { quote! {
let path_segments: ::std::vec::Vec<&::std::primitive::str> = let path_segments: ::std::vec::Vec<&::std::primitive::str> =
request.uri().path()[1..].split('/').collect(); request.uri().path()[1..].split('/').collect();
}
} else {
TokenStream::new()
};
let (request_path_string, parse_request_path) =
util::request_path_string_and_parse(&self.request, &self.metadata, &ruma_api_import);
let request_query_string = util::build_query_string(&self.request, &ruma_api_import);
let extract_request_query = util::extract_request_query(&self.request, &ruma_api_import);
let parse_request_query = if let Some(field) = self.request.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_query,
}
} else {
self.request.request_init_query_fields()
};
let mut header_kvs = self.request.append_header_kvs();
if authentication == "AccessToken" {
header_kvs.extend(quote! {
req_builder = req_builder.header(
#ruma_api_import::exports::http::header::AUTHORIZATION,
#ruma_api_import::exports::http::header::HeaderValue::from_str(
&::std::format!(
"Bearer {}",
access_token.ok_or(
#ruma_api_import::error::IntoHttpError::NeedsAuthentication
)?
)
)?
);
});
} }
} else {
TokenStream::new()
};
let extract_request_headers = if self.request.has_header_fields() { let (request_path_string, parse_request_path) =
quote! { util::request_path_string_and_parse(&api.request, &api.metadata, &ruma_api_import);
let headers = request.headers();
}
} else {
TokenStream::new()
};
let extract_request_body = if self.request.has_body_fields() let request_query_string = util::build_query_string(&api.request, &ruma_api_import);
|| self.request.newtype_body_field().is_some()
{ let extract_request_query = util::extract_request_query(&api.request, &ruma_api_import);
let body_lifetimes = if self.request.has_body_lifetimes() {
let parse_request_query = if let Some(field) = api.request.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
quote! {
#field_name: request_query,
}
} else {
api.request.request_init_query_fields()
};
let mut header_kvs = api.request.append_header_kvs();
if authentication == "AccessToken" {
header_kvs.extend(quote! {
req_builder = req_builder.header(
#ruma_api_import::exports::http::header::AUTHORIZATION,
#ruma_api_import::exports::http::header::HeaderValue::from_str(
&::std::format!(
"Bearer {}",
access_token.ok_or(
#ruma_api_import::error::IntoHttpError::NeedsAuthentication
)?
)
)?
);
});
}
let extract_request_headers = if api.request.has_header_fields() {
quote! {
let headers = request.headers();
}
} else {
TokenStream::new()
};
let extract_request_body =
if api.request.has_body_fields() || api.request.newtype_body_field().is_some() {
let body_lifetimes = if api.request.has_body_lifetimes() {
// duplicate the anonymous lifetime as many times as needed // duplicate the anonymous lifetime as many times as needed
let lifetimes = let lifetimes =
std::iter::repeat(quote! { '_ }).take(self.request.body_lifetime_count()); std::iter::repeat(quote! { '_ }).take(api.request.body_lifetime_count());
quote! { < #( #lifetimes, )* >} quote! { < #( #lifetimes, )* >}
} else { } else {
TokenStream::new() TokenStream::new()
@ -184,239 +171,207 @@ impl ToTokens for Api {
TokenStream::new() TokenStream::new()
}; };
let parse_request_headers = if self.request.has_header_fields() { let parse_request_headers = if api.request.has_header_fields() {
self.request.parse_headers_from_request() api.request.parse_headers_from_request()
} else { } else {
TokenStream::new() TokenStream::new()
}; };
let request_body = util::build_request_body(&self.request, &ruma_api_import); let request_body = util::build_request_body(&api.request, &ruma_api_import);
let parse_request_body = util::parse_request_body(&self.request); let parse_request_body = util::parse_request_body(&api.request);
let extract_response_headers = if self.response.has_header_fields() { let extract_response_headers = if api.response.has_header_fields() {
quote! { quote! {
let mut headers = response.headers().clone(); let mut headers = response.headers().clone();
} }
} else { } else {
TokenStream::new() TokenStream::new()
}; };
let typed_response_body_decl = if self.response.has_body_fields() let typed_response_body_decl = if api.response.has_body_fields()
|| self.response.newtype_body_field().is_some() || api.response.newtype_body_field().is_some()
{
quote! {
let response_body: <
ResponseBody
as #ruma_api_import::exports::ruma_common::Outgoing
>::Incoming =
#ruma_api_import::try_deserialize!(
response,
#ruma_api_import::exports::serde_json::from_slice(response.body().as_slice()),
);
}
} else {
TokenStream::new()
};
let response_init_fields = api.response.init_fields();
let serialize_response_headers = api.response.apply_header_fields();
let body = api.response.to_body();
let metadata_doc = format!("Metadata for the `{}` API endpoint.", name);
let request_doc =
format!("Data for a request to the `{}` API endpoint.\n\n{}", name, description.value());
let response_doc = format!("Data in the response from the `{}` API endpoint.", name);
let error = &api.error_ty;
let request_lifetimes = api.request.combine_lifetimes();
let non_auth_endpoint_impls = if authentication != "None" {
TokenStream::new()
} else {
quote! {
impl #request_lifetimes #ruma_api_import::OutgoingNonAuthRequest
for Request #request_lifetimes
{}
impl #ruma_api_import::IncomingNonAuthRequest for #incoming_request_type {}
}
};
Ok(quote! {
#[doc = #request_doc]
#request_type
impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Request<Vec<u8>>>
for #incoming_request_type
{ {
quote! { type Error = #ruma_api_import::error::FromHttpRequestError;
let response_body: <
ResponseBody #[allow(unused_variables)]
as #ruma_api_import::exports::ruma_common::Outgoing fn try_from(
>::Incoming = request: #ruma_api_import::exports::http::Request<Vec<u8>>
#ruma_api_import::try_deserialize!( ) -> ::std::result::Result<Self, Self::Error> {
response, #extract_request_path
#ruma_api_import::exports::serde_json::from_slice(response.body().as_slice()), #extract_request_query
); #extract_request_headers
#extract_request_body
Ok(Self {
#parse_request_path
#parse_request_query
#parse_request_headers
#parse_request_body
})
} }
} else { }
TokenStream::new()
};
let response_init_fields = self.response.init_fields(); #[doc = #response_doc]
#response_type
let serialize_response_headers = self.response.apply_header_fields(); impl ::std::convert::TryFrom<Response> for #ruma_api_import::exports::http::Response<Vec<u8>> {
type Error = #ruma_api_import::error::IntoHttpError;
let body = self.response.to_body(); #[allow(unused_variables)]
fn try_from(response: Response) -> ::std::result::Result<Self, Self::Error> {
let mut resp_builder = #ruma_api_import::exports::http::Response::builder()
.header(#ruma_api_import::exports::http::header::CONTENT_TYPE, "application/json");
let metadata_doc = format!("Metadata for the `{}` API endpoint.", name); let mut headers =
let request_doc = format!( resp_builder.headers_mut().expect("`http::ResponseBuilder` is in unusable state");
"Data for a request to the `{}` API endpoint.\n\n{}", #serialize_response_headers
name,
description.value()
);
let response_doc = format!("Data in the response from the `{}` API endpoint.", name);
let error = &self.error; // This cannot fail because we parse each header value
// checking for errors as each value is inserted and
let request_lifetimes = self.request.combine_lifetimes(); // we only allow keys from the `http::header` module.
let response = resp_builder.body(#body).unwrap();
let non_auth_endpoint_impls = if authentication != "None" { Ok(response)
TokenStream::new()
} else {
quote! {
impl #request_lifetimes #ruma_api_import::OutgoingNonAuthRequest
for Request #request_lifetimes
{}
impl #ruma_api_import::IncomingNonAuthRequest for #incoming_request_type {}
} }
}; }
let api = quote! { impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Response<Vec<u8>>> for Response {
#[doc = #request_doc] type Error = #ruma_api_import::error::FromHttpResponseError<#error>;
#request_type
impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Request<Vec<u8>>> #[allow(unused_variables)]
for #incoming_request_type fn try_from(
{ response: #ruma_api_import::exports::http::Response<Vec<u8>>,
type Error = #ruma_api_import::error::FromHttpRequestError; ) -> ::std::result::Result<Self, Self::Error> {
if response.status().as_u16() < 400 {
#extract_response_headers
#[allow(unused_variables)] #typed_response_body_decl
fn try_from(
request: #ruma_api_import::exports::http::Request<Vec<u8>>
) -> ::std::result::Result<Self, Self::Error> {
#extract_request_path
#extract_request_query
#extract_request_headers
#extract_request_body
Ok(Self { Ok(Self {
#parse_request_path #response_init_fields
#parse_request_query
#parse_request_headers
#parse_request_body
}) })
} } else {
} match <#error as #ruma_api_import::EndpointError>::try_from_response(response) {
Ok(err) => Err(#ruma_api_import::error::ServerError::Known(err).into()),
#[doc = #response_doc] Err(response_err) => {
#response_type Err(#ruma_api_import::error::ServerError::Unknown(response_err).into())
impl ::std::convert::TryFrom<Response> for #ruma_api_import::exports::http::Response<Vec<u8>> {
type Error = #ruma_api_import::error::IntoHttpError;
#[allow(unused_variables)]
fn try_from(response: Response) -> ::std::result::Result<Self, Self::Error> {
let mut resp_builder = #ruma_api_import::exports::http::Response::builder()
.header(#ruma_api_import::exports::http::header::CONTENT_TYPE, "application/json");
let mut headers =
resp_builder.headers_mut().expect("`http::ResponseBuilder` is in unusable state");
#serialize_response_headers
// This cannot fail because we parse each header value
// checking for errors as each value is inserted and
// we only allow keys from the `http::header` module.
let response = resp_builder.body(#body).unwrap();
Ok(response)
}
}
impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Response<Vec<u8>>> for Response {
type Error = #ruma_api_import::error::FromHttpResponseError<#error>;
#[allow(unused_variables)]
fn try_from(
response: #ruma_api_import::exports::http::Response<Vec<u8>>,
) -> ::std::result::Result<Self, Self::Error> {
if response.status().as_u16() < 400 {
#extract_response_headers
#typed_response_body_decl
Ok(Self {
#response_init_fields
})
} else {
match <#error as #ruma_api_import::EndpointError>::try_from_response(response) {
Ok(err) => Err(#ruma_api_import::error::ServerError::Known(err).into()),
Err(response_err) => {
Err(#ruma_api_import::error::ServerError::Unknown(response_err).into())
}
} }
} }
} }
} }
}
#[doc = #metadata_doc] #[doc = #metadata_doc]
pub const METADATA: #ruma_api_import::Metadata = #ruma_api_import::Metadata { pub const METADATA: #ruma_api_import::Metadata = #ruma_api_import::Metadata {
description: #description, description: #description,
method: #ruma_api_import::exports::http::Method::#method, method: #ruma_api_import::exports::http::Method::#method,
name: #name, name: #name,
path: #path, path: #path,
rate_limited: #rate_limited, rate_limited: #rate_limited,
authentication: #ruma_api_import::AuthScheme::#authentication, authentication: #ruma_api_import::AuthScheme::#authentication,
};
impl #request_lifetimes #ruma_api_import::OutgoingRequest
for Request #request_lifetimes
{
type EndpointError = #error;
type IncomingResponse =
<Response as #ruma_api_import::exports::ruma_common::Outgoing>::Incoming;
#[doc = #metadata_doc]
const METADATA: #ruma_api_import::Metadata = self::METADATA;
#[allow(unused_mut, unused_variables)]
fn try_into_http_request(
self,
base_url: &::std::primitive::str,
access_token: ::std::option::Option<&str>,
) -> ::std::result::Result<
#ruma_api_import::exports::http::Request<Vec<u8>>,
#ruma_api_import::error::IntoHttpError,
> {
let metadata = self::METADATA;
let mut req_builder = #ruma_api_import::exports::http::Request::builder()
.method(#ruma_api_import::exports::http::Method::#method)
.uri(::std::format!(
"{}{}{}",
// FIXME: Once MSRV is >= 1.45.0, switch to
// base_url.strip_suffix('/').unwrap_or(base_url),
match base_url.as_bytes().last() {
Some(b'/') => &base_url[..base_url.len() - 1],
_ => base_url,
},
#request_path_string,
#request_query_string,
));
#header_kvs
let http_request = req_builder.body(#request_body)?;
Ok(http_request)
}
}
impl #ruma_api_import::IncomingRequest for #incoming_request_type {
type EndpointError = #error;
type OutgoingResponse = Response;
#[doc = #metadata_doc]
const METADATA: #ruma_api_import::Metadata = self::METADATA;
}
#non_auth_endpoint_impls
}; };
api.to_tokens(tokens); impl #request_lifetimes #ruma_api_import::OutgoingRequest
} for Request #request_lifetimes
} {
type EndpointError = #error;
type IncomingResponse =
<Response as #ruma_api_import::exports::ruma_common::Outgoing>::Incoming;
/// The entire `ruma_api!` macro structure directly as it appears in the source code.. #[doc = #metadata_doc]
pub struct RawApi { const METADATA: #ruma_api_import::Metadata = self::METADATA;
/// The `metadata` section of the macro.
pub metadata: Metadata,
/// The `request` section of the macro. #[allow(unused_mut, unused_variables)]
pub request: Request, fn try_into_http_request(
self,
base_url: &::std::primitive::str,
access_token: ::std::option::Option<&str>,
) -> ::std::result::Result<
#ruma_api_import::exports::http::Request<Vec<u8>>,
#ruma_api_import::error::IntoHttpError,
> {
let metadata = self::METADATA;
/// The `response` section of the macro. let mut req_builder = #ruma_api_import::exports::http::Request::builder()
pub response: Response, .method(#ruma_api_import::exports::http::Method::#method)
.uri(::std::format!(
"{}{}{}",
// FIXME: Once MSRV is >= 1.45.0, switch to
// base_url.strip_suffix('/').unwrap_or(base_url),
match base_url.as_bytes().last() {
Some(b'/') => &base_url[..base_url.len() - 1],
_ => base_url,
},
#request_path_string,
#request_query_string,
));
/// The `error` section of the macro. #header_kvs
pub error: Option<ErrorType>,
}
impl Parse for RawApi { let http_request = req_builder.body(#request_body)?;
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(Self { Ok(http_request)
metadata: input.parse()?, }
request: input.parse()?, }
response: input.parse()?,
error: input.parse().ok(), impl #ruma_api_import::IncomingRequest for #incoming_request_type {
}) type EndpointError = #error;
} type OutgoingResponse = Response;
#[doc = #metadata_doc]
const METADATA: #ruma_api_import::Metadata = self::METADATA;
}
#non_auth_endpoint_impls
})
} }
mod kw { mod kw {

View File

@ -13,22 +13,16 @@
#![allow(clippy::unknown_clippy_lints)] #![allow(clippy::unknown_clippy_lints)]
#![recursion_limit = "256"] #![recursion_limit = "256"]
use std::convert::TryFrom as _;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::ToTokens;
use syn::parse_macro_input; use syn::parse_macro_input;
use self::api::{Api, RawApi}; use self::api::Api;
mod api; mod api;
mod util; mod util;
#[proc_macro] #[proc_macro]
pub fn ruma_api(input: TokenStream) -> TokenStream { pub fn ruma_api(input: TokenStream) -> TokenStream {
let raw_api = parse_macro_input!(input as RawApi); let api = parse_macro_input!(input as Api);
match Api::try_from(raw_api) { api::expand_all(api).unwrap_or_else(|err| err.to_compile_error()).into()
Ok(api) => api.into_token_stream().into(),
Err(err) => err.to_compile_error().into(),
}
} }