diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index 90a874c7..f352cd5c 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -2,14 +2,13 @@ use std::convert::{TryFrom, TryInto as _}; -use proc_macro2::{Span, TokenStream}; -use proc_macro_crate::crate_name; +use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ braced, parse::{Parse, ParseStream}, spanned::Spanned, - Field, FieldValue, Ident, Token, Type, + Field, FieldValue, Token, Type, }; pub(crate) mod attribute; @@ -27,18 +26,6 @@ pub fn strip_serde_attrs(field: &Field) -> Field { field } -pub fn import_ruma_api() -> TokenStream { - if let Ok(possibly_renamed) = crate_name("ruma-api") { - let import = Ident::new(&possibly_renamed, Span::call_site()); - quote! { ::#import } - } else if let Ok(possibly_renamed) = crate_name("ruma") { - let import = Ident::new(&possibly_renamed, Span::call_site()); - quote! { ::#import::api } - } else { - quote! { ::ruma_api } - } -} - /// The result of processing the `ruma_api` macro, ready for output back to source code. pub struct Api { /// The `metadata` section of the macro. @@ -48,20 +35,23 @@ pub struct Api { /// The `response` section of the macro. response: Response, /// The `error` section of the macro. - error: Type, + error: TokenStream, } impl TryFrom for Api { type Error = syn::Error; fn try_from(raw_api: RawApi) -> syn::Result { + let import_path = util::import_ruma_api(); + let res = Self { metadata: raw_api.metadata.try_into()?, request: raw_api.request.try_into()?, response: raw_api.response.try_into()?, - error: raw_api - .error - .map_or(syn::parse_str::("ruma_api::error::Void").unwrap(), |err| err.ty), + error: match raw_api.error { + Some(raw_err) => raw_err.ty.to_token_stream(), + None => quote! { #import_path::error::Void }, + }, }; let newtype_body_field = res.request.newtype_body_field(); @@ -96,7 +86,7 @@ impl TryFrom for Api { impl ToTokens for Api { fn to_tokens(&self, tokens: &mut TokenStream) { // Guarantee `ruma_api` is available and named something we can refer to. - let ruma_api_import = import_ruma_api(); + let ruma_api_import = util::import_ruma_api(); let description = &self.metadata.description; let method = &self.metadata.method; @@ -127,11 +117,11 @@ impl ToTokens for Api { }; let (request_path_string, parse_request_path) = - util::request_path_string_and_parse(&self.request, &self.metadata); + util::request_path_string_and_parse(&self.request, &self.metadata, &ruma_api_import); - let request_query_string = util::build_query_string(&self.request); + let request_query_string = util::build_query_string(&self.request, &ruma_api_import); - let extract_request_query = util::extract_request_query(&self.request); + let extract_request_query = util::extract_request_query(&self.request, &ruma_api_import); let parse_request_query = if let Some(field) = self.request.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -194,7 +184,7 @@ impl ToTokens for Api { TokenStream::new() }; - let request_body = util::build_request_body(&self.request); + let request_body = util::build_request_body(&self.request, &ruma_api_import); let parse_request_body = util::parse_request_body(&self.request); diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index c2b576df..e41f25c3 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -9,7 +9,7 @@ use syn::{spanned::Spanned, Field, Ident, Lifetime}; use crate::{ api::{ attribute::{Meta, MetaNameValue}, - import_ruma_api, strip_serde_attrs, RawRequest, + strip_serde_attrs, RawRequest, }, util, }; @@ -37,7 +37,7 @@ pub struct Request { impl Request { /// Produces code to add necessary HTTP headers to an `http::Request`. pub fn append_header_kvs(&self) -> Vec { - let ruma_api = &self.ruma_api_import; + let import_path = &self.ruma_api_import; self.header_fields().map(|request_field| { let (field, header_name) = match request_field { RequestField::Header(field, header_name) => (field, header_name), @@ -47,15 +47,15 @@ impl Request { let field_name = &field.ident; quote! { - #ruma_api::exports::http::header::#header_name, - #ruma_api::exports::http::header::HeaderValue::from_str(self.#field_name.as_ref())? + #import_path::exports::http::header::#header_name, + #import_path::exports::http::header::HeaderValue::from_str(self.#field_name.as_ref())? } }).collect() } /// Produces code to extract fields from the HTTP headers in an `http::Request`. pub fn parse_headers_from_request(&self) -> TokenStream { - let ruma_api = &self.ruma_api_import; + let import_path = &self.ruma_api_import; let fields = self.header_fields().map(|request_field| { let (field, header_name) = match request_field { RequestField::Header(field, header_name) => (field, header_name), @@ -67,16 +67,16 @@ impl Request { quote! { #field_name: match headers - .get(#ruma_api::exports::http::header::#header_name) + .get(#import_path::exports::http::header::#header_name) .and_then(|v| v.to_str().ok()) // FIXME: Should have a distinct error message { Some(header) => header.to_owned(), None => { - use #ruma_api::exports::serde::de::Error as _; + use #import_path::exports::serde::de::Error as _; // FIXME: Not a missing json field, a missing header! - return Err(#ruma_api::error::RequestDeserializationError::new( - #ruma_api::exports::serde_json::Error::missing_field( + return Err(#import_path::error::RequestDeserializationError::new( + #import_path::exports::serde_json::Error::missing_field( #header_name_string ), request, @@ -382,13 +382,13 @@ impl TryFrom for Request { )); } - Ok(Self { fields, lifetimes, ruma_api_import: import_ruma_api() }) + Ok(Self { fields, lifetimes, ruma_api_import: util::import_ruma_api() }) } } impl ToTokens for Request { fn to_tokens(&self, tokens: &mut TokenStream) { - let ruma_api = &self.ruma_api_import; + let import_path = &self.ruma_api_import; let request_def = if self.fields.is_empty() { quote!(;) } else { @@ -409,7 +409,7 @@ impl ToTokens for Request { let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { (TokenStream::new(), self.body_lifetimes()) } else { - (quote!(#ruma_api::exports::serde::Deserialize), TokenStream::new()) + (quote!(#import_path::exports::serde::Deserialize), TokenStream::new()) }; Some((derive_deserialize, quote! { #lifetimes (#field); })) @@ -418,7 +418,7 @@ impl ToTokens for Request { let (derive_deserialize, lifetimes) = if self.has_body_lifetimes() { (TokenStream::new(), self.body_lifetimes()) } else { - (quote!(#ruma_api::exports::serde::Deserialize), TokenStream::new()) + (quote!(#import_path::exports::serde::Deserialize), TokenStream::new()) }; let fields = fields.map(RequestField::field); @@ -431,8 +431,8 @@ impl ToTokens for Request { /// Data in the request body. #[derive( Debug, - #ruma_api::Outgoing, - #ruma_api::exports::serde::Serialize, + #import_path::Outgoing, + #import_path::exports::serde::Serialize, #derive_deserialize )] struct RequestBody #def @@ -444,15 +444,15 @@ impl ToTokens for Request { let (derive_deserialize, lifetime) = if self.has_query_lifetimes() { (TokenStream::new(), self.query_lifetimes()) } else { - (quote!(#ruma_api::exports::serde::Deserialize), TokenStream::new()) + (quote!(#import_path::exports::serde::Deserialize), TokenStream::new()) }; quote! { /// Data in the request's query string. #[derive( Debug, - #ruma_api::Outgoing, - #ruma_api::exports::serde::Serialize, + #import_path::Outgoing, + #import_path::exports::serde::Serialize, #derive_deserialize )] struct RequestQuery #lifetime (#field); @@ -462,15 +462,15 @@ impl ToTokens for Request { let (derive_deserialize, lifetime) = if self.has_query_lifetimes() { (TokenStream::new(), self.query_lifetimes()) } else { - (quote!(#ruma_api::exports::serde::Deserialize), TokenStream::new()) + (quote!(#import_path::exports::serde::Deserialize), TokenStream::new()) }; quote! { /// Data in the request's query string. #[derive( Debug, - #ruma_api::Outgoing, - #ruma_api::exports::serde::Serialize, + #import_path::Outgoing, + #import_path::exports::serde::Serialize, #derive_deserialize )] struct RequestQuery #lifetime { @@ -482,7 +482,7 @@ impl ToTokens for Request { }; let request = quote! { - #[derive(Debug, Clone, #ruma_api::Outgoing)] + #[derive(Debug, Clone, #import_path::Outgoing)] #[incoming_no_deserialize] pub struct Request #request_generics #request_def diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index 597ae484..cd077d89 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -9,7 +9,7 @@ use syn::{spanned::Spanned, Field, Ident}; use crate::{ api::{ attribute::{Meta, MetaNameValue}, - import_ruma_api, strip_serde_attrs, RawResponse, + strip_serde_attrs, RawResponse, }, util, }; @@ -36,7 +36,7 @@ impl Response { /// Produces code for a response struct initializer. pub fn init_fields(&self) -> TokenStream { - let ruma_api = &self.ruma_api_import; + let import_path = &self.ruma_api_import; let mut fields = vec![]; let mut new_type_raw_body = None; @@ -56,9 +56,9 @@ impl Response { } ResponseField::Header(_, header_name) => { quote_spanned! {span=> - #field_name: #ruma_api::try_deserialize!( + #field_name: #import_path::try_deserialize!( response, - headers.remove(#ruma_api::exports::http::header::#header_name) + headers.remove(#import_path::exports::http::header::#header_name) .expect("response missing expected header") .to_str() ) @@ -91,7 +91,7 @@ impl Response { /// Produces code to add necessary HTTP headers to an `http::Response`. pub fn apply_header_fields(&self) -> TokenStream { - let ruma_api = &self.ruma_api_import; + let import_path = &self.ruma_api_import; let header_calls = self.fields.iter().filter_map(|response_field| { if let ResponseField::Header(ref field, ref header_name) = *response_field { @@ -100,7 +100,7 @@ impl Response { let span = field.span(); Some(quote_spanned! {span=> - .header(#ruma_api::exports::http::header::#header_name, response.#field_name) + .header(#import_path::exports::http::header::#header_name, response.#field_name) }) } else { None @@ -112,7 +112,7 @@ impl Response { /// Produces code to initialize the struct that will be used to create the response body. pub fn to_body(&self) -> TokenStream { - let ruma_api = &self.ruma_api_import; + let import_path = &self.ruma_api_import; if let Some(field) = self.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -147,7 +147,7 @@ impl Response { } }; - quote!(#ruma_api::exports::serde_json::to_vec(&#body)?) + quote!(#import_path::exports::serde_json::to_vec(&#body)?) } /// Gets the newtype body field, if this response has one. @@ -234,13 +234,13 @@ impl TryFrom for Response { )); } - Ok(Self { fields, ruma_api_import: import_ruma_api() }) + Ok(Self { fields, ruma_api_import: util::import_ruma_api() }) } } impl ToTokens for Response { fn to_tokens(&self, tokens: &mut TokenStream) { - let ruma_api = &self.ruma_api_import; + let import_path = &self.ruma_api_import; let response_def = if self.fields.is_empty() { quote!(;) @@ -269,15 +269,15 @@ impl ToTokens for Response { /// Data in the response body. #[derive( Debug, - #ruma_api::Outgoing, - #ruma_api::exports::serde::Deserialize, - #ruma_api::exports::serde::Serialize, + #import_path::Outgoing, + #import_path::exports::serde::Deserialize, + #import_path::exports::serde::Serialize, )] struct ResponseBody #def }; let response = quote! { - #[derive(Debug, Clone, #ruma_api::Outgoing)] + #[derive(Debug, Clone, #import_path::Outgoing)] #[incoming_no_deserialize] pub struct Response #response_def diff --git a/ruma-api-macros/src/derive_outgoing.rs b/ruma-api-macros/src/derive_outgoing.rs index 2a5dfc29..96006630 100644 --- a/ruma-api-macros/src/derive_outgoing.rs +++ b/ruma-api-macros/src/derive_outgoing.rs @@ -6,6 +6,8 @@ use syn::{ PathArguments, Type, TypeGenerics, TypePath, TypeReference, TypeSlice, Variant, }; +use crate::util::import_ruma_api; + enum StructKind { Struct, Tuple, @@ -18,10 +20,12 @@ enum DataKind { } pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { + let import_path = import_ruma_api(); + let derive_deserialize = if no_deserialize_in_attrs(&input.attrs) { TokenStream::new() } else { - quote!(::ruma_api::exports::serde::Deserialize) + quote! { #import_path::exports::serde::Deserialize } }; let input_attrs = @@ -42,7 +46,7 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { }; match data { - DataKind::Unit => Ok(impl_outgoing_with_incoming_self(&input)), + DataKind::Unit => Ok(impl_outgoing_with_incoming_self(&input, &import_path)), DataKind::Enum(mut vars) => { let mut found_lifetime = false; for var in &mut vars { @@ -57,7 +61,7 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl(); if !found_lifetime { - return Ok(impl_outgoing_with_incoming_self(&input)); + return Ok(impl_outgoing_with_incoming_self(&input, &import_path)); } let vis = input.vis; @@ -73,7 +77,7 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { #( #input_attrs )* #vis enum #incoming_ident #ty_gen { #( #vars, )* } - impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen { + impl #original_impl_gen #import_path::Outgoing for #original_ident #original_ty_gen { type Incoming = #incoming_ident #impl_gen; } }) @@ -90,7 +94,7 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { let (original_impl_gen, original_ty_gen, _) = input.generics.split_for_impl(); if !found_lifetime { - return Ok(impl_outgoing_with_incoming_self(&input)); + return Ok(impl_outgoing_with_incoming_self(&input, &import_path)); } let vis = input.vis; @@ -111,7 +115,7 @@ pub fn expand_derive_outgoing(input: DeriveInput) -> syn::Result { #( #input_attrs )* #vis struct #incoming_ident #ty_gen #struct_def - impl #original_impl_gen ::ruma_api::Outgoing for #original_ident #original_ty_gen { + impl #original_impl_gen #import_path::Outgoing for #original_ident #original_ty_gen { type Incoming = #incoming_ident #impl_gen; } }) @@ -129,12 +133,12 @@ fn no_deserialize_in_attrs(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| attr.path.is_ident("incoming_no_deserialize")) } -fn impl_outgoing_with_incoming_self(input: &DeriveInput) -> TokenStream { +fn impl_outgoing_with_incoming_self(input: &DeriveInput, import_path: &TokenStream) -> TokenStream { let ident = &input.ident; let (impl_gen, ty_gen, _) = input.generics.split_for_impl(); quote! { - impl #impl_gen ::ruma_api::Outgoing for #ident #ty_gen { + impl #impl_gen #import_path::Outgoing for #ident #ty_gen { type Incoming = Self; } } diff --git a/ruma-api-macros/src/util.rs b/ruma-api-macros/src/util.rs index bb649e82..e43d9ad3 100644 --- a/ruma-api-macros/src/util.rs +++ b/ruma-api-macros/src/util.rs @@ -1,6 +1,7 @@ //! Functions to aid the `Api::to_tokens` method. use proc_macro2::{Span, TokenStream}; +use proc_macro_crate::crate_name; use quote::quote; use std::collections::BTreeSet; use syn::{ @@ -150,6 +151,7 @@ pub fn has_lifetime(ty: &Type) -> bool { pub(crate) fn request_path_string_and_parse( request: &Request, metadata: &Metadata, + import_path: &TokenStream, ) -> (TokenStream, TokenStream) { if request.has_path_fields() { let path_string = metadata.path.value(); @@ -178,9 +180,9 @@ pub(crate) fn request_path_string_and_parse( Span::call_site(), ); format_args.push(quote! { - ::ruma_api::exports::percent_encoding::utf8_percent_encode( + #import_path::exports::percent_encoding::utf8_percent_encode( &self.#path_var.to_string(), - ::ruma_api::exports::percent_encoding::NON_ALPHANUMERIC, + #import_path::exports::percent_encoding::NON_ALPHANUMERIC, ) }); format_string.replace_range(start_of_segment..end_of_segment, "{}"); @@ -198,16 +200,16 @@ pub(crate) fn request_path_string_and_parse( let path_var_ident = Ident::new(path_var, Span::call_site()); quote! { #path_var_ident: { - use ::ruma_api::error::RequestDeserializationError; + use #import_path::error::RequestDeserializationError; let segment = path_segments.get(#i).unwrap().as_bytes(); - let decoded = ::ruma_api::try_deserialize!( + let decoded = #import_path::try_deserialize!( request, - ::ruma_api::exports::percent_encoding::percent_decode(segment) + #import_path::exports::percent_encoding::percent_decode(segment) .decode_utf8(), ); - ::ruma_api::try_deserialize!( + #import_path::try_deserialize!( request, ::std::convert::TryFrom::try_from(&*decoded), ) @@ -224,7 +226,7 @@ pub(crate) fn request_path_string_and_parse( /// The function determines the type of query string that needs to be built /// and then builds it using `ruma_serde::urlencoded::to_string`. -pub(crate) fn build_query_string(request: &Request) -> TokenStream { +pub(crate) fn build_query_string(request: &Request, import_path: &TokenStream) -> TokenStream { if let Some(field) = request.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have identifier"); @@ -249,7 +251,7 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream { format_args!( "?{}", - ::ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)? + #import_path::exports::ruma_serde::urlencoded::to_string(request_query)? ) }) } else if request.has_query_fields() { @@ -262,7 +264,7 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream { format_args!( "?{}", - ::ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)? + #import_path::exports::ruma_serde::urlencoded::to_string(request_query)? ) }) } else { @@ -271,22 +273,22 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream { } /// Deserialize the query string. -pub(crate) fn extract_request_query(request: &Request) -> TokenStream { +pub(crate) fn extract_request_query(request: &Request, import_path: &TokenStream) -> TokenStream { if request.query_map_field().is_some() { quote! { - let request_query = ::ruma_api::try_deserialize!( + let request_query = #import_path::try_deserialize!( request, - ::ruma_api::exports::ruma_serde::urlencoded::from_str( + #import_path::exports::ruma_serde::urlencoded::from_str( &request.uri().query().unwrap_or("") ), ); } } else if request.has_query_fields() { quote! { - let request_query: ::Incoming = - ::ruma_api::try_deserialize!( + let request_query: ::Incoming = + #import_path::try_deserialize!( request, - ::ruma_api::exports::ruma_serde::urlencoded::from_str( + #import_path::exports::ruma_serde::urlencoded::from_str( &request.uri().query().unwrap_or("") ), ); @@ -299,7 +301,7 @@ pub(crate) fn extract_request_query(request: &Request) -> TokenStream { /// Generates the code to initialize a `Request`. /// /// Used to construct an `http::Request`s body. -pub(crate) fn build_request_body(request: &Request) -> TokenStream { +pub(crate) fn build_request_body(request: &Request, import_path: &TokenStream) -> TokenStream { if let Some(field) = request.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote!(self.#field_name) @@ -315,7 +317,7 @@ pub(crate) fn build_request_body(request: &Request) -> TokenStream { quote! { { let request_body = RequestBody #request_body_initializers; - ::ruma_api::exports::serde_json::to_vec(&request_body)? + #import_path::exports::serde_json::to_vec(&request_body)? } } } else { @@ -380,3 +382,15 @@ pub(crate) fn req_res_name_value( pub(crate) fn is_valid_endpoint_path(string: &str) -> bool { string.as_bytes().iter().all(|b| (0x21..=0x7E).contains(b)) } + +pub fn import_ruma_api() -> TokenStream { + if let Ok(possibly_renamed) = crate_name("ruma-api") { + let import = Ident::new(&possibly_renamed, Span::call_site()); + quote! { ::#import } + } else if let Ok(possibly_renamed) = crate_name("ruma") { + let import = Ident::new(&possibly_renamed, Span::call_site()); + quote! { ::#import::api } + } else { + quote! { ::ruma_api } + } +} diff --git a/ruma-api/src/lib.rs b/ruma-api/src/lib.rs index 426ec064..6754dd4e 100644 --- a/ruma-api/src/lib.rs +++ b/ruma-api/src/lib.rs @@ -311,18 +311,18 @@ pub struct Metadata { #[macro_export] macro_rules! try_deserialize { ($kind:ident, $call:expr $(,)?) => { - ::ruma_api::try_deserialize!(@$kind, $kind, $call) + $crate::try_deserialize!(@$kind, $kind, $call) }; (@request, $kind:ident, $call:expr) => { match $call { Ok(val) => val, - Err(err) => return Err(::ruma_api::error::RequestDeserializationError::new(err, $kind).into()), + Err(err) => return Err($crate::error::RequestDeserializationError::new(err, $kind).into()), } }; (@response, $kind:ident, $call:expr) => { match $call { Ok(val) => val, - Err(err) => return Err(::ruma_api::error::ResponseDeserializationError::new(err, $kind).into()), + Err(err) => return Err($crate::error::ResponseDeserializationError::new(err, $kind).into()), } }; }