diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index 0d2f0a6c..ba72bff7 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -1,7 +1,5 @@ //! Details of the `ruma_api` procedural macro. -use std::convert::TryFrom; - use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ @@ -36,29 +34,23 @@ pub struct Api { response: Response, /// The `error` section of the macro. - error: TokenStream, + error_ty: TokenStream, } -impl TryFrom for Api { - type Error = syn::Error; - - fn try_from(raw_api: RawApi) -> syn::Result { +impl Parse for Api { + fn parse(input: ParseStream<'_>) -> syn::Result { let import_path = util::import_ruma_api(); - let res = Self { - metadata: raw_api.metadata, - request: raw_api.request, - response: raw_api.response, - error: match raw_api.error { - Some(raw_err) => raw_err.ty.to_token_stream(), - None => quote! { #import_path::error::Void }, - }, + let metadata: Metadata = input.parse()?; + let request: Request = input.parse()?; + let response: Response = input.parse()?; + let error_ty = match input.parse::() { + Ok(err) => err.ty.to_token_stream(), + Err(_) => quote! { #import_path::error::Void }, }; - let newtype_body_field = res.request.newtype_body_field(); - if res.metadata.method == "GET" - && (res.request.has_body_fields() || newtype_body_field.is_some()) - { + let newtype_body_field = request.newtype_body_field(); + if metadata.method == "GET" && (request.has_body_fields() || newtype_body_field.is_some()) { let mut combined_error: Option = None; let mut add_error = |field| { let error = syn::Error::new_spanned(field, "GET endpoints can't have body fields"); @@ -69,7 +61,7 @@ impl TryFrom for Api { } }; - for field in res.request.body_fields() { + for field in request.body_fields() { add_error(field); } @@ -77,95 +69,90 @@ impl TryFrom for Api { add_error(field); } - Err(combined_error.unwrap()) - } else { - Ok(res) + return Err(combined_error.unwrap()); } + + Ok(Self { metadata, request, response, error_ty }) } } -impl ToTokens for Api { - fn to_tokens(&self, tokens: &mut TokenStream) { - // Guarantee `ruma_api` is available and named something we can refer to. - let ruma_api_import = util::import_ruma_api(); +pub fn expand_all(api: Api) -> syn::Result { + // Guarantee `ruma_api` is available and named something we can refer to. + let ruma_api_import = util::import_ruma_api(); - let description = &self.metadata.description; - let method = &self.metadata.method; - // We don't (currently) use this literal as a literal in the generated code. Instead we just - // put it into doc comments, for which the span information is irrelevant. So we can work - // with only the literal's value from here on. - let name = &self.metadata.name.value(); - let path = &self.metadata.path; - let rate_limited = &self.metadata.rate_limited; - let authentication = &self.metadata.authentication; + let description = &api.metadata.description; + let method = &api.metadata.method; + // We don't (currently) use this literal as a literal in the generated code. Instead we just + // put it into doc comments, for which the span information is irrelevant. So we can work + // with only the literal's value from here on. + let name = &api.metadata.name.value(); + let path = &api.metadata.path; + let rate_limited = &api.metadata.rate_limited; + let authentication = &api.metadata.authentication; - let request_type = &self.request; - let response_type = &self.response; + let request_type = &api.request; + let response_type = &api.response; - let incoming_request_type = if self.request.contains_lifetimes() { - quote!(IncomingRequest) - } else { - quote!(Request) - }; + let incoming_request_type = + if api.request.contains_lifetimes() { quote!(IncomingRequest) } else { quote!(Request) }; - let extract_request_path = if self.request.has_path_fields() { - quote! { - let path_segments: ::std::vec::Vec<&::std::primitive::str> = - request.uri().path()[1..].split('/').collect(); - } - } else { - TokenStream::new() - }; - - let (request_path_string, parse_request_path) = - util::request_path_string_and_parse(&self.request, &self.metadata, &ruma_api_import); - - let request_query_string = util::build_query_string(&self.request, &ruma_api_import); - - let extract_request_query = util::extract_request_query(&self.request, &ruma_api_import); - - let parse_request_query = if let Some(field) = self.request.query_map_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - - quote! { - #field_name: request_query, - } - } else { - self.request.request_init_query_fields() - }; - - let mut header_kvs = self.request.append_header_kvs(); - if authentication == "AccessToken" { - header_kvs.extend(quote! { - req_builder = req_builder.header( - #ruma_api_import::exports::http::header::AUTHORIZATION, - #ruma_api_import::exports::http::header::HeaderValue::from_str( - &::std::format!( - "Bearer {}", - access_token.ok_or( - #ruma_api_import::error::IntoHttpError::NeedsAuthentication - )? - ) - )? - ); - }); + let extract_request_path = if api.request.has_path_fields() { + quote! { + let path_segments: ::std::vec::Vec<&::std::primitive::str> = + request.uri().path()[1..].split('/').collect(); } + } else { + TokenStream::new() + }; - let extract_request_headers = if self.request.has_header_fields() { - quote! { - let headers = request.headers(); - } - } else { - TokenStream::new() - }; + let (request_path_string, parse_request_path) = + util::request_path_string_and_parse(&api.request, &api.metadata, &ruma_api_import); - let extract_request_body = if self.request.has_body_fields() - || self.request.newtype_body_field().is_some() - { - let body_lifetimes = if self.request.has_body_lifetimes() { + let request_query_string = util::build_query_string(&api.request, &ruma_api_import); + + let extract_request_query = util::extract_request_query(&api.request, &ruma_api_import); + + let parse_request_query = if let Some(field) = api.request.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + + quote! { + #field_name: request_query, + } + } else { + api.request.request_init_query_fields() + }; + + let mut header_kvs = api.request.append_header_kvs(); + if authentication == "AccessToken" { + header_kvs.extend(quote! { + req_builder = req_builder.header( + #ruma_api_import::exports::http::header::AUTHORIZATION, + #ruma_api_import::exports::http::header::HeaderValue::from_str( + &::std::format!( + "Bearer {}", + access_token.ok_or( + #ruma_api_import::error::IntoHttpError::NeedsAuthentication + )? + ) + )? + ); + }); + } + + let extract_request_headers = if api.request.has_header_fields() { + quote! { + let headers = request.headers(); + } + } else { + TokenStream::new() + }; + + let extract_request_body = + if api.request.has_body_fields() || api.request.newtype_body_field().is_some() { + let body_lifetimes = if api.request.has_body_lifetimes() { // duplicate the anonymous lifetime as many times as needed let lifetimes = - std::iter::repeat(quote! { '_ }).take(self.request.body_lifetime_count()); + std::iter::repeat(quote! { '_ }).take(api.request.body_lifetime_count()); quote! { < #( #lifetimes, )* >} } else { TokenStream::new() @@ -184,239 +171,207 @@ impl ToTokens for Api { TokenStream::new() }; - let parse_request_headers = if self.request.has_header_fields() { - self.request.parse_headers_from_request() - } else { - TokenStream::new() - }; + let parse_request_headers = if api.request.has_header_fields() { + api.request.parse_headers_from_request() + } else { + TokenStream::new() + }; - let request_body = util::build_request_body(&self.request, &ruma_api_import); + let request_body = util::build_request_body(&api.request, &ruma_api_import); - let parse_request_body = util::parse_request_body(&self.request); + let parse_request_body = util::parse_request_body(&api.request); - let extract_response_headers = if self.response.has_header_fields() { - quote! { - let mut headers = response.headers().clone(); - } - } else { - TokenStream::new() - }; + let extract_response_headers = if api.response.has_header_fields() { + quote! { + let mut headers = response.headers().clone(); + } + } else { + TokenStream::new() + }; - let typed_response_body_decl = if self.response.has_body_fields() - || self.response.newtype_body_field().is_some() + let typed_response_body_decl = if api.response.has_body_fields() + || api.response.newtype_body_field().is_some() + { + quote! { + let response_body: < + ResponseBody + as #ruma_api_import::exports::ruma_common::Outgoing + >::Incoming = + #ruma_api_import::try_deserialize!( + response, + #ruma_api_import::exports::serde_json::from_slice(response.body().as_slice()), + ); + } + } else { + TokenStream::new() + }; + + let response_init_fields = api.response.init_fields(); + + let serialize_response_headers = api.response.apply_header_fields(); + + let body = api.response.to_body(); + + let metadata_doc = format!("Metadata for the `{}` API endpoint.", name); + let request_doc = + format!("Data for a request to the `{}` API endpoint.\n\n{}", name, description.value()); + let response_doc = format!("Data in the response from the `{}` API endpoint.", name); + + let error = &api.error_ty; + + let request_lifetimes = api.request.combine_lifetimes(); + + let non_auth_endpoint_impls = if authentication != "None" { + TokenStream::new() + } else { + quote! { + impl #request_lifetimes #ruma_api_import::OutgoingNonAuthRequest + for Request #request_lifetimes + {} + + impl #ruma_api_import::IncomingNonAuthRequest for #incoming_request_type {} + } + }; + + Ok(quote! { + #[doc = #request_doc] + #request_type + + impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Request>> + for #incoming_request_type { - quote! { - let response_body: < - ResponseBody - as #ruma_api_import::exports::ruma_common::Outgoing - >::Incoming = - #ruma_api_import::try_deserialize!( - response, - #ruma_api_import::exports::serde_json::from_slice(response.body().as_slice()), - ); + type Error = #ruma_api_import::error::FromHttpRequestError; + + #[allow(unused_variables)] + fn try_from( + request: #ruma_api_import::exports::http::Request> + ) -> ::std::result::Result { + #extract_request_path + #extract_request_query + #extract_request_headers + #extract_request_body + + Ok(Self { + #parse_request_path + #parse_request_query + #parse_request_headers + #parse_request_body + }) } - } else { - TokenStream::new() - }; + } - let response_init_fields = self.response.init_fields(); + #[doc = #response_doc] + #response_type - let serialize_response_headers = self.response.apply_header_fields(); + impl ::std::convert::TryFrom for #ruma_api_import::exports::http::Response> { + type Error = #ruma_api_import::error::IntoHttpError; - let body = self.response.to_body(); + #[allow(unused_variables)] + fn try_from(response: Response) -> ::std::result::Result { + let mut resp_builder = #ruma_api_import::exports::http::Response::builder() + .header(#ruma_api_import::exports::http::header::CONTENT_TYPE, "application/json"); - let metadata_doc = format!("Metadata for the `{}` API endpoint.", name); - let request_doc = format!( - "Data for a request to the `{}` API endpoint.\n\n{}", - name, - description.value() - ); - let response_doc = format!("Data in the response from the `{}` API endpoint.", name); + let mut headers = + resp_builder.headers_mut().expect("`http::ResponseBuilder` is in unusable state"); + #serialize_response_headers - let error = &self.error; - - let request_lifetimes = self.request.combine_lifetimes(); - - let non_auth_endpoint_impls = if authentication != "None" { - TokenStream::new() - } else { - quote! { - impl #request_lifetimes #ruma_api_import::OutgoingNonAuthRequest - for Request #request_lifetimes - {} - - impl #ruma_api_import::IncomingNonAuthRequest for #incoming_request_type {} + // This cannot fail because we parse each header value + // checking for errors as each value is inserted and + // we only allow keys from the `http::header` module. + let response = resp_builder.body(#body).unwrap(); + Ok(response) } - }; + } - let api = quote! { - #[doc = #request_doc] - #request_type + impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Response>> for Response { + type Error = #ruma_api_import::error::FromHttpResponseError<#error>; - impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Request>> - for #incoming_request_type - { - type Error = #ruma_api_import::error::FromHttpRequestError; + #[allow(unused_variables)] + fn try_from( + response: #ruma_api_import::exports::http::Response>, + ) -> ::std::result::Result { + if response.status().as_u16() < 400 { + #extract_response_headers - #[allow(unused_variables)] - fn try_from( - request: #ruma_api_import::exports::http::Request> - ) -> ::std::result::Result { - #extract_request_path - #extract_request_query - #extract_request_headers - #extract_request_body + #typed_response_body_decl Ok(Self { - #parse_request_path - #parse_request_query - #parse_request_headers - #parse_request_body + #response_init_fields }) - } - } - - #[doc = #response_doc] - #response_type - - impl ::std::convert::TryFrom for #ruma_api_import::exports::http::Response> { - type Error = #ruma_api_import::error::IntoHttpError; - - #[allow(unused_variables)] - fn try_from(response: Response) -> ::std::result::Result { - let mut resp_builder = #ruma_api_import::exports::http::Response::builder() - .header(#ruma_api_import::exports::http::header::CONTENT_TYPE, "application/json"); - - let mut headers = - resp_builder.headers_mut().expect("`http::ResponseBuilder` is in unusable state"); - #serialize_response_headers - - // This cannot fail because we parse each header value - // checking for errors as each value is inserted and - // we only allow keys from the `http::header` module. - let response = resp_builder.body(#body).unwrap(); - Ok(response) - } - } - - impl ::std::convert::TryFrom<#ruma_api_import::exports::http::Response>> for Response { - type Error = #ruma_api_import::error::FromHttpResponseError<#error>; - - #[allow(unused_variables)] - fn try_from( - response: #ruma_api_import::exports::http::Response>, - ) -> ::std::result::Result { - if response.status().as_u16() < 400 { - #extract_response_headers - - #typed_response_body_decl - - Ok(Self { - #response_init_fields - }) - } else { - match <#error as #ruma_api_import::EndpointError>::try_from_response(response) { - Ok(err) => Err(#ruma_api_import::error::ServerError::Known(err).into()), - Err(response_err) => { - Err(#ruma_api_import::error::ServerError::Unknown(response_err).into()) - } + } else { + match <#error as #ruma_api_import::EndpointError>::try_from_response(response) { + Ok(err) => Err(#ruma_api_import::error::ServerError::Known(err).into()), + Err(response_err) => { + Err(#ruma_api_import::error::ServerError::Unknown(response_err).into()) } } } } + } - #[doc = #metadata_doc] - pub const METADATA: #ruma_api_import::Metadata = #ruma_api_import::Metadata { - description: #description, - method: #ruma_api_import::exports::http::Method::#method, - name: #name, - path: #path, - rate_limited: #rate_limited, - authentication: #ruma_api_import::AuthScheme::#authentication, - }; - - impl #request_lifetimes #ruma_api_import::OutgoingRequest - for Request #request_lifetimes - { - type EndpointError = #error; - type IncomingResponse = - ::Incoming; - - #[doc = #metadata_doc] - const METADATA: #ruma_api_import::Metadata = self::METADATA; - - #[allow(unused_mut, unused_variables)] - fn try_into_http_request( - self, - base_url: &::std::primitive::str, - access_token: ::std::option::Option<&str>, - ) -> ::std::result::Result< - #ruma_api_import::exports::http::Request>, - #ruma_api_import::error::IntoHttpError, - > { - let metadata = self::METADATA; - - let mut req_builder = #ruma_api_import::exports::http::Request::builder() - .method(#ruma_api_import::exports::http::Method::#method) - .uri(::std::format!( - "{}{}{}", - // FIXME: Once MSRV is >= 1.45.0, switch to - // base_url.strip_suffix('/').unwrap_or(base_url), - match base_url.as_bytes().last() { - Some(b'/') => &base_url[..base_url.len() - 1], - _ => base_url, - }, - #request_path_string, - #request_query_string, - )); - - #header_kvs - - let http_request = req_builder.body(#request_body)?; - - Ok(http_request) - } - } - - impl #ruma_api_import::IncomingRequest for #incoming_request_type { - type EndpointError = #error; - type OutgoingResponse = Response; - - #[doc = #metadata_doc] - const METADATA: #ruma_api_import::Metadata = self::METADATA; - } - - #non_auth_endpoint_impls + #[doc = #metadata_doc] + pub const METADATA: #ruma_api_import::Metadata = #ruma_api_import::Metadata { + description: #description, + method: #ruma_api_import::exports::http::Method::#method, + name: #name, + path: #path, + rate_limited: #rate_limited, + authentication: #ruma_api_import::AuthScheme::#authentication, }; - api.to_tokens(tokens); - } -} + impl #request_lifetimes #ruma_api_import::OutgoingRequest + for Request #request_lifetimes + { + type EndpointError = #error; + type IncomingResponse = + ::Incoming; -/// The entire `ruma_api!` macro structure directly as it appears in the source code.. -pub struct RawApi { - /// The `metadata` section of the macro. - pub metadata: Metadata, + #[doc = #metadata_doc] + const METADATA: #ruma_api_import::Metadata = self::METADATA; - /// The `request` section of the macro. - pub request: Request, + #[allow(unused_mut, unused_variables)] + fn try_into_http_request( + self, + base_url: &::std::primitive::str, + access_token: ::std::option::Option<&str>, + ) -> ::std::result::Result< + #ruma_api_import::exports::http::Request>, + #ruma_api_import::error::IntoHttpError, + > { + let metadata = self::METADATA; - /// The `response` section of the macro. - pub response: Response, + let mut req_builder = #ruma_api_import::exports::http::Request::builder() + .method(#ruma_api_import::exports::http::Method::#method) + .uri(::std::format!( + "{}{}{}", + // FIXME: Once MSRV is >= 1.45.0, switch to + // base_url.strip_suffix('/').unwrap_or(base_url), + match base_url.as_bytes().last() { + Some(b'/') => &base_url[..base_url.len() - 1], + _ => base_url, + }, + #request_path_string, + #request_query_string, + )); - /// The `error` section of the macro. - pub error: Option, -} + #header_kvs -impl Parse for RawApi { - fn parse(input: ParseStream<'_>) -> syn::Result { - Ok(Self { - metadata: input.parse()?, - request: input.parse()?, - response: input.parse()?, - error: input.parse().ok(), - }) - } + let http_request = req_builder.body(#request_body)?; + + Ok(http_request) + } + } + + impl #ruma_api_import::IncomingRequest for #incoming_request_type { + type EndpointError = #error; + type OutgoingResponse = Response; + + #[doc = #metadata_doc] + const METADATA: #ruma_api_import::Metadata = self::METADATA; + } + + #non_auth_endpoint_impls + }) } mod kw { diff --git a/ruma-api-macros/src/lib.rs b/ruma-api-macros/src/lib.rs index ff0886b8..baec0a4f 100644 --- a/ruma-api-macros/src/lib.rs +++ b/ruma-api-macros/src/lib.rs @@ -13,22 +13,16 @@ #![allow(clippy::unknown_clippy_lints)] #![recursion_limit = "256"] -use std::convert::TryFrom as _; - use proc_macro::TokenStream; -use quote::ToTokens; use syn::parse_macro_input; -use self::api::{Api, RawApi}; +use self::api::Api; mod api; mod util; #[proc_macro] pub fn ruma_api(input: TokenStream) -> TokenStream { - let raw_api = parse_macro_input!(input as RawApi); - match Api::try_from(raw_api) { - Ok(api) => api.into_token_stream().into(), - Err(err) => err.to_compile_error().into(), - } + let api = parse_macro_input!(input as Api); + api::expand_all(api).unwrap_or_else(|err| err.to_compile_error()).into() }