436 lines
15 KiB
Rust
436 lines
15 KiB
Rust
//! Details of the `ruma_api` procedural macro.
|
|
|
|
use proc_macro2::TokenStream;
|
|
use quote::{quote, ToTokens};
|
|
use syn::{
|
|
parse::{Parse, ParseStream},
|
|
Token, Type,
|
|
};
|
|
|
|
pub(crate) mod attribute;
|
|
pub(crate) mod metadata;
|
|
pub(crate) mod request;
|
|
pub(crate) mod response;
|
|
|
|
use self::{metadata::Metadata, request::Request, response::Response};
|
|
use crate::util;
|
|
|
|
/// The result of processing the `ruma_api` macro, ready for output back to source code.
|
|
pub struct Api {
|
|
/// The `metadata` section of the macro.
|
|
metadata: Metadata,
|
|
|
|
/// The `request` section of the macro.
|
|
request: Request,
|
|
|
|
/// The `response` section of the macro.
|
|
response: Response,
|
|
|
|
/// The `error` section of the macro.
|
|
error_ty: TokenStream,
|
|
}
|
|
|
|
impl Parse for Api {
|
|
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
|
|
let ruma_api = util::import_ruma_api();
|
|
|
|
let metadata: Metadata = input.parse()?;
|
|
let request: Request = input.parse()?;
|
|
let response: Response = input.parse()?;
|
|
let error_ty = match input.parse::<ErrorType>() {
|
|
Ok(err) => err.ty.to_token_stream(),
|
|
Err(_) => quote! { #ruma_api::error::Void },
|
|
};
|
|
|
|
let newtype_body_field = request.newtype_body_field();
|
|
if metadata.method == "GET" && (request.has_body_fields() || newtype_body_field.is_some()) {
|
|
let mut combined_error: Option<syn::Error> = None;
|
|
let mut add_error = |field| {
|
|
let error = syn::Error::new_spanned(field, "GET endpoints can't have body fields");
|
|
if let Some(combined_error_ref) = &mut combined_error {
|
|
combined_error_ref.combine(error);
|
|
} else {
|
|
combined_error = Some(error);
|
|
}
|
|
};
|
|
|
|
for field in request.body_fields() {
|
|
add_error(field);
|
|
}
|
|
|
|
if let Some(field) = newtype_body_field {
|
|
add_error(field);
|
|
}
|
|
|
|
return Err(combined_error.unwrap());
|
|
}
|
|
|
|
Ok(Self { metadata, request, response, error_ty })
|
|
}
|
|
}
|
|
|
|
pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
|
|
// Guarantee `ruma_api` is available and named something we can refer to.
|
|
let ruma_api = util::import_ruma_api();
|
|
let http = quote! { #ruma_api::exports::http };
|
|
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
|
|
let serde_json = quote! { #ruma_api::exports::serde_json };
|
|
|
|
let description = &api.metadata.description;
|
|
let method = &api.metadata.method;
|
|
// We don't (currently) use this literal as a literal in the generated code. Instead we just
|
|
// put it into doc comments, for which the span information is irrelevant. So we can work
|
|
// with only the literal's value from here on.
|
|
let name = &api.metadata.name.value();
|
|
let path = &api.metadata.path;
|
|
let rate_limited: TokenStream = api
|
|
.metadata
|
|
.rate_limited
|
|
.iter()
|
|
.map(|r| {
|
|
let attrs = &r.attrs;
|
|
let value = &r.value;
|
|
quote! {
|
|
#( #attrs )*
|
|
rate_limited: #value,
|
|
}
|
|
})
|
|
.collect();
|
|
let authentication: TokenStream = api
|
|
.metadata
|
|
.authentication
|
|
.iter()
|
|
.map(|r| {
|
|
let attrs = &r.attrs;
|
|
let value = &r.value;
|
|
quote! {
|
|
#( #attrs )*
|
|
authentication: #ruma_api::AuthScheme::#value,
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let request_type = &api.request;
|
|
let response_type = &api.response;
|
|
|
|
let incoming_request_type =
|
|
if api.request.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) };
|
|
|
|
let extract_request_path = if api.request.has_path_fields() {
|
|
quote! {
|
|
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
|
|
request.uri().path()[1..].split('/').collect();
|
|
}
|
|
} else {
|
|
TokenStream::new()
|
|
};
|
|
|
|
let (request_path_string, parse_request_path) =
|
|
util::request_path_string_and_parse(&api.request, &api.metadata, &ruma_api);
|
|
|
|
let request_query_string = util::build_query_string(&api.request, &ruma_api);
|
|
let extract_request_query = util::extract_request_query(&api.request, &ruma_api);
|
|
|
|
let parse_request_query = if let Some(field) = api.request.query_map_field() {
|
|
let field_name = field.ident.as_ref().expect("expected field to have an identifier");
|
|
|
|
quote! {
|
|
#field_name: request_query,
|
|
}
|
|
} else {
|
|
api.request.request_init_query_fields()
|
|
};
|
|
|
|
let mut header_kvs = api.request.append_header_kvs();
|
|
for auth in &api.metadata.authentication {
|
|
if auth.value == "AccessToken" {
|
|
let attrs = &auth.attrs;
|
|
header_kvs.extend(quote! {
|
|
#( #attrs )*
|
|
req_headers.insert(
|
|
#http::header::AUTHORIZATION,
|
|
#http::header::HeaderValue::from_str(
|
|
&::std::format!(
|
|
"Bearer {}",
|
|
access_token.ok_or(
|
|
#ruma_api::error::IntoHttpError::NeedsAuthentication
|
|
)?
|
|
)
|
|
)?
|
|
);
|
|
});
|
|
}
|
|
}
|
|
|
|
let extract_request_headers = if api.request.has_header_fields() {
|
|
quote! {
|
|
let headers = request.headers();
|
|
}
|
|
} else {
|
|
TokenStream::new()
|
|
};
|
|
|
|
let extract_request_body =
|
|
if api.request.has_body_fields() || api.request.newtype_body_field().is_some() {
|
|
let body_lifetimes = if api.request.has_body_lifetimes() {
|
|
// duplicate the anonymous lifetime as many times as needed
|
|
let lifetimes =
|
|
std::iter::repeat(quote! { '_ }).take(api.request.body_lifetime_count());
|
|
quote! { < #( #lifetimes, )* >}
|
|
} else {
|
|
TokenStream::new()
|
|
};
|
|
quote! {
|
|
let request_body: <
|
|
RequestBody #body_lifetimes
|
|
as #ruma_serde::Outgoing
|
|
>::Incoming = {
|
|
// If the request body is completely empty, pretend it is an empty JSON object
|
|
// instead. This allows requests with only optional body parameters to be
|
|
// deserialized in that case.
|
|
let json = match request.body().as_slice() {
|
|
b"" => b"{}",
|
|
body => body,
|
|
};
|
|
|
|
#ruma_api::try_deserialize!(request, #serde_json::from_slice(json))
|
|
};
|
|
}
|
|
} else {
|
|
TokenStream::new()
|
|
};
|
|
|
|
let parse_request_headers = if api.request.has_header_fields() {
|
|
api.request.parse_headers_from_request()
|
|
} else {
|
|
TokenStream::new()
|
|
};
|
|
|
|
let request_body = util::build_request_body(&api.request, &ruma_api);
|
|
let parse_request_body = util::parse_request_body(&api.request);
|
|
|
|
let extract_response_headers = if api.response.has_header_fields() {
|
|
quote! {
|
|
let mut headers = response.headers().clone();
|
|
}
|
|
} else {
|
|
TokenStream::new()
|
|
};
|
|
|
|
let typed_response_body_decl =
|
|
if api.response.has_body_fields() || api.response.newtype_body_field().is_some() {
|
|
quote! {
|
|
let response_body: <
|
|
ResponseBody
|
|
as #ruma_serde::Outgoing
|
|
>::Incoming = {
|
|
// If the reponse body is completely empty, pretend it is an empty JSON object
|
|
// instead. This allows reponses with only optional body parameters to be
|
|
// deserialized in that case.
|
|
let json = match response.body().as_slice() {
|
|
b"" => b"{}",
|
|
body => body,
|
|
};
|
|
|
|
#ruma_api::try_deserialize!(
|
|
response,
|
|
#serde_json::from_slice(json),
|
|
)
|
|
};
|
|
}
|
|
} else {
|
|
TokenStream::new()
|
|
};
|
|
|
|
let response_init_fields = api.response.init_fields();
|
|
let serialize_response_headers = api.response.apply_header_fields();
|
|
|
|
let body = api.response.to_body();
|
|
|
|
let metadata_doc = format!("Metadata for the `{}` API endpoint.", name);
|
|
let request_doc =
|
|
format!("Data for a request to the `{}` API endpoint.\n\n{}", name, description.value());
|
|
let response_doc = format!("Data in the response from the `{}` API endpoint.", name);
|
|
|
|
let error = &api.error_ty;
|
|
let request_lifetimes = api.request.combine_lifetimes();
|
|
|
|
let non_auth_endpoint_impls: TokenStream = api
|
|
.metadata
|
|
.authentication
|
|
.iter()
|
|
.map(|auth| {
|
|
if auth.value != "None" {
|
|
TokenStream::new()
|
|
} else {
|
|
let attrs = &auth.attrs;
|
|
quote! {
|
|
#( #attrs )*
|
|
#[automatically_derived]
|
|
#[cfg(feature = "client")]
|
|
impl #request_lifetimes #ruma_api::OutgoingNonAuthRequest
|
|
for Request #request_lifetimes
|
|
{}
|
|
|
|
#( #attrs )*
|
|
#[automatically_derived]
|
|
#[cfg(feature = "server")]
|
|
impl #ruma_api::IncomingNonAuthRequest for #incoming_request_type {}
|
|
}
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
Ok(quote! {
|
|
#[doc = #request_doc]
|
|
#request_type
|
|
|
|
#[doc = #response_doc]
|
|
#response_type
|
|
|
|
#[automatically_derived]
|
|
#[cfg(feature = "server")]
|
|
impl ::std::convert::TryFrom<Response> for #http::Response<Vec<u8>> {
|
|
type Error = #ruma_api::error::IntoHttpError;
|
|
|
|
fn try_from(response: Response) -> ::std::result::Result<Self, Self::Error> {
|
|
let mut resp_builder = #http::Response::builder()
|
|
.header(#http::header::CONTENT_TYPE, "application/json");
|
|
|
|
let mut headers = resp_builder
|
|
.headers_mut()
|
|
.expect("`http::ResponseBuilder` is in unusable state");
|
|
#serialize_response_headers
|
|
|
|
// This cannot fail because we parse each header value
|
|
// checking for errors as each value is inserted and
|
|
// we only allow keys from the `http::header` module.
|
|
let response = resp_builder.body(#body).unwrap();
|
|
Ok(response)
|
|
}
|
|
}
|
|
|
|
#[automatically_derived]
|
|
#[cfg(feature = "client")]
|
|
impl ::std::convert::TryFrom<#http::Response<Vec<u8>>> for Response {
|
|
type Error = #ruma_api::error::FromHttpResponseError<#error>;
|
|
|
|
fn try_from(
|
|
response: #http::Response<Vec<u8>>,
|
|
) -> ::std::result::Result<Self, Self::Error> {
|
|
if response.status().as_u16() < 400 {
|
|
#extract_response_headers
|
|
|
|
#typed_response_body_decl
|
|
|
|
Ok(Self {
|
|
#response_init_fields
|
|
})
|
|
} else {
|
|
match <#error as #ruma_api::EndpointError>::try_from_response(response) {
|
|
Ok(err) => Err(#ruma_api::error::ServerError::Known(err).into()),
|
|
Err(response_err) => {
|
|
Err(#ruma_api::error::ServerError::Unknown(response_err).into())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[doc = #metadata_doc]
|
|
pub const METADATA: #ruma_api::Metadata = #ruma_api::Metadata {
|
|
description: #description,
|
|
method: #http::Method::#method,
|
|
name: #name,
|
|
path: #path,
|
|
#rate_limited
|
|
#authentication
|
|
};
|
|
|
|
#[automatically_derived]
|
|
#[cfg(feature = "client")]
|
|
impl #request_lifetimes #ruma_api::OutgoingRequest for Request #request_lifetimes {
|
|
type EndpointError = #error;
|
|
type IncomingResponse = <Response as #ruma_serde::Outgoing>::Incoming;
|
|
|
|
#[doc = #metadata_doc]
|
|
const METADATA: #ruma_api::Metadata = self::METADATA;
|
|
|
|
fn try_into_http_request(
|
|
self,
|
|
base_url: &::std::primitive::str,
|
|
access_token: ::std::option::Option<&str>,
|
|
) -> ::std::result::Result<#http::Request<Vec<u8>>, #ruma_api::error::IntoHttpError> {
|
|
let metadata = self::METADATA;
|
|
|
|
let mut req_builder = #http::Request::builder()
|
|
.method(#http::Method::#method)
|
|
.uri(::std::format!(
|
|
"{}{}{}",
|
|
base_url.strip_suffix('/').unwrap_or(base_url),
|
|
#request_path_string,
|
|
#request_query_string,
|
|
))
|
|
.header(#ruma_api::exports::http::header::CONTENT_TYPE, "application/json");
|
|
|
|
let mut req_headers = req_builder
|
|
.headers_mut()
|
|
.expect("`http::RequestBuilder` is in unusable state");
|
|
|
|
#header_kvs
|
|
|
|
let http_request = req_builder.body(#request_body)?;
|
|
|
|
Ok(http_request)
|
|
}
|
|
}
|
|
|
|
#[automatically_derived]
|
|
#[cfg(feature = "server")]
|
|
impl #ruma_api::IncomingRequest for #incoming_request_type {
|
|
type EndpointError = #error;
|
|
type OutgoingResponse = Response;
|
|
|
|
#[doc = #metadata_doc]
|
|
const METADATA: #ruma_api::Metadata = self::METADATA;
|
|
|
|
fn try_from_http_request(
|
|
request: #http::Request<Vec<u8>>
|
|
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> {
|
|
#extract_request_path
|
|
#extract_request_query
|
|
#extract_request_headers
|
|
#extract_request_body
|
|
|
|
Ok(Self {
|
|
#parse_request_path
|
|
#parse_request_query
|
|
#parse_request_headers
|
|
#parse_request_body
|
|
})
|
|
}
|
|
}
|
|
|
|
#non_auth_endpoint_impls
|
|
})
|
|
}
|
|
|
|
mod kw {
|
|
syn::custom_keyword!(error);
|
|
}
|
|
|
|
pub struct ErrorType {
|
|
pub error_kw: kw::error,
|
|
pub ty: Type,
|
|
}
|
|
|
|
impl Parse for ErrorType {
|
|
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
|
|
let error_kw = input.parse::<kw::error>()?;
|
|
input.parse::<Token![:]>()?;
|
|
let ty = input.parse()?;
|
|
|
|
Ok(Self { error_kw, ty })
|
|
}
|
|
}
|