diff --git a/crates/ruma-macros/src/api.rs b/crates/ruma-macros/src/api.rs index 4e7ae788..e8210f08 100644 --- a/crates/ruma-macros/src/api.rs +++ b/crates/ruma-macros/src/api.rs @@ -74,8 +74,8 @@ impl Api { |err_ty| quote! { #err_ty }, ); - let request = self.request.map(|req| req.expand(metadata, &ruma_common)); - let response = self.response.map(|res| res.expand(metadata, &ruma_common)); + let request = self.request.map(|req| req.expand(metadata, &error_ty, &ruma_common)); + let response = self.response.map(|res| res.expand(metadata, &error_ty, &ruma_common)); let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); @@ -93,11 +93,11 @@ impl Api { history: #history, }; - #[allow(unused)] - type EndpointError = #error_ty; - #request #response + + #[cfg(not(any(feature = "client", feature = "server")))] + type _SilenceUnusedError = #error_ty; } } diff --git a/crates/ruma-macros/src/api/api_request.rs b/crates/ruma-macros/src/api/api_request.rs index 8618c5b9..42b8e166 100644 --- a/crates/ruma-macros/src/api/api_request.rs +++ b/crates/ruma-macros/src/api/api_request.rs @@ -67,7 +67,12 @@ impl Request { lifetimes } - pub(super) fn expand(&self, metadata: &Metadata, ruma_common: &TokenStream) -> TokenStream { + pub(super) fn expand( + &self, + metadata: &Metadata, + error_ty: &TokenStream, + ruma_common: &TokenStream, + ) -> TokenStream { let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let docs = format!( @@ -93,6 +98,7 @@ impl Request { )] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] + #[ruma_api(error_ty = #error_ty)] #( #struct_attributes )* pub struct #request_ident < #(#lifetimes),* > { #fields diff --git a/crates/ruma-macros/src/api/api_response.rs b/crates/ruma-macros/src/api/api_response.rs index ab1d4670..d881aef7 100644 --- a/crates/ruma-macros/src/api/api_response.rs +++ b/crates/ruma-macros/src/api/api_response.rs @@ -19,7 +19,12 @@ pub(crate) struct Response { } impl Response { - pub(super) fn expand(&self, metadata: &Metadata, ruma_common: &TokenStream) -> TokenStream { + pub(super) fn expand( + &self, + metadata: &Metadata, + error_ty: &TokenStream, + ruma_common: &TokenStream, + ) -> TokenStream { let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let docs = @@ -39,6 +44,7 @@ impl Response { )] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] + #[ruma_api(error_ty = #error_ty)] #( #struct_attributes )* pub struct #response_ident { #fields diff --git a/crates/ruma-macros/src/api/attribute.rs b/crates/ruma-macros/src/api/attribute.rs index d2d8a4fb..2b649468 100644 --- a/crates/ruma-macros/src/api/attribute.rs +++ b/crates/ruma-macros/src/api/attribute.rs @@ -2,7 +2,7 @@ use syn::{ parse::{Parse, ParseStream}, - Ident, Token, + Ident, Token, Type, }; mod kw { @@ -12,6 +12,8 @@ mod kw { syn::custom_keyword!(query); syn::custom_keyword!(query_map); syn::custom_keyword!(header); + syn::custom_keyword!(error_ty); + syn::custom_keyword!(manual_body_serde); } pub enum RequestMeta { @@ -51,6 +53,23 @@ impl Parse for RequestMeta { } } +pub enum DeriveRequestMeta { + ErrorTy(Type), +} + +impl Parse for DeriveRequestMeta { + fn parse(input: ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::error_ty) { + let _: kw::error_ty = input.parse()?; + let _: Token![=] = input.parse()?; + input.parse().map(Self::ErrorTy) + } else { + Err(lookahead.error()) + } + } +} + pub enum ResponseMeta { NewtypeBody, RawBody, @@ -75,3 +94,25 @@ impl Parse for ResponseMeta { } } } + +#[allow(clippy::large_enum_variant)] +pub enum DeriveResponseMeta { + ManualBodySerde, + ErrorTy(Type), +} + +impl Parse for DeriveResponseMeta { + fn parse(input: ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::manual_body_serde) { + let _: kw::manual_body_serde = input.parse()?; + Ok(Self::ManualBodySerde) + } else if lookahead.peek(kw::error_ty) { + let _: kw::error_ty = input.parse()?; + let _: Token![=] = input.parse()?; + input.parse().map(Self::ErrorTy) + } else { + Err(lookahead.error()) + } + } +} diff --git a/crates/ruma-macros/src/api/request.rs b/crates/ruma-macros/src/api/request.rs index 96b87b7d..c90a409f 100644 --- a/crates/ruma-macros/src/api/request.rs +++ b/crates/ruma-macros/src/api/request.rs @@ -4,10 +4,14 @@ use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, - DeriveInput, Field, Generics, Ident, Lifetime, + punctuated::Punctuated, + DeriveInput, Field, Generics, Ident, Lifetime, Token, Type, }; -use super::{attribute::RequestMeta, util::collect_lifetime_idents}; +use super::{ + attribute::{DeriveRequestMeta, RequestMeta}, + util::collect_lifetime_idents, +}; use crate::util::import_ruma_common; mod incoming; @@ -40,7 +44,29 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { }) .collect::>()?; - let request = Request { ident: input.ident, generics: input.generics, fields, lifetimes }; + let mut error_ty = None; + + for attr in input.attrs { + if !attr.path.is_ident("ruma_api") { + continue; + } + + let metas = + attr.parse_args_with(Punctuated::::parse_terminated)?; + for meta in metas { + match meta { + DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t), + } + } + } + + let request = Request { + ident: input.ident, + generics: input.generics, + fields, + lifetimes, + error_ty: error_ty.expect("missing error_ty attribute"), + }; let ruma_common = import_ruma_common(); let test = request.check(&ruma_common)?; @@ -65,6 +91,8 @@ struct Request { generics: Generics, lifetimes: RequestLifetimes, fields: Vec, + + error_ty: Type, } impl Request { diff --git a/crates/ruma-macros/src/api/request/incoming.rs b/crates/ruma-macros/src/api/request/incoming.rs index 50cefb29..ee909034 100644 --- a/crates/ruma-macros/src/api/request/incoming.rs +++ b/crates/ruma-macros/src/api/request/incoming.rs @@ -10,6 +10,8 @@ impl Request { let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; + let error_ty = &self.error_ty; + let incoming_request_type = if self.has_lifetimes() { quote! { IncomingRequest } } else { @@ -178,7 +180,7 @@ impl Request { #[automatically_derived] #[cfg(feature = "server")] impl #ruma_common::api::IncomingRequest for #incoming_request_type { - type EndpointError = self::EndpointError; + type EndpointError = #error_ty; type OutgoingResponse = Response; const METADATA: #ruma_common::api::Metadata = METADATA; diff --git a/crates/ruma-macros/src/api/request/outgoing.rs b/crates/ruma-macros/src/api/request/outgoing.rs index 0e5361ea..fe79e558 100644 --- a/crates/ruma-macros/src/api/request/outgoing.rs +++ b/crates/ruma-macros/src/api/request/outgoing.rs @@ -9,6 +9,8 @@ impl Request { let bytes = quote! { #ruma_common::exports::bytes }; let http = quote! { #ruma_common::exports::http }; + let error_ty = &self.error_ty; + let path_fields = self.path_fields().map(|f| f.ident.as_ref().expect("path fields have a name")); @@ -122,7 +124,7 @@ impl Request { #[automatically_derived] #[cfg(feature = "client")] impl #impl_generics #ruma_common::api::OutgoingRequest for Request #ty_generics #where_clause { - type EndpointError = self::EndpointError; + type EndpointError = #error_ty; type IncomingResponse = Response; const METADATA: #ruma_common::api::Metadata = METADATA; diff --git a/crates/ruma-macros/src/api/response.rs b/crates/ruma-macros/src/api/response.rs index 33602c44..01eeb5fb 100644 --- a/crates/ruma-macros/src/api/response.rs +++ b/crates/ruma-macros/src/api/response.rs @@ -4,20 +4,17 @@ use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, + punctuated::Punctuated, visit::Visit, - DeriveInput, Field, Generics, Ident, Lifetime, Type, + DeriveInput, Field, Generics, Ident, Lifetime, Token, Type, }; -use super::attribute::ResponseMeta; +use super::attribute::{DeriveResponseMeta, ResponseMeta}; use crate::util::import_ruma_common; mod incoming; mod outgoing; -mod kw { - syn::custom_keyword!(manual_body_serde); -} - pub fn expand_derive_response(input: DeriveInput) -> syn::Result { let fields = match input.data { syn::Data::Struct(s) => s.fields, @@ -26,18 +23,29 @@ pub fn expand_derive_response(input: DeriveInput) -> syn::Result { let fields = fields.into_iter().map(ResponseField::try_from).collect::>()?; let mut manual_body_serde = false; + let mut error_ty = None; for attr in input.attrs { if !attr.path.is_ident("ruma_api") { continue; } - let _ = attr.parse_args::()?; - - manual_body_serde = true; + let metas = + attr.parse_args_with(Punctuated::::parse_terminated)?; + for meta in metas { + match meta { + DeriveResponseMeta::ManualBodySerde => manual_body_serde = true, + DeriveResponseMeta::ErrorTy(t) => error_ty = Some(t), + } + } } - let response = - Response { ident: input.ident, generics: input.generics, fields, manual_body_serde }; + let response = Response { + ident: input.ident, + generics: input.generics, + fields, + manual_body_serde, + error_ty: error_ty.unwrap(), + }; response.check()?; Ok(response.expand_all()) @@ -48,6 +56,7 @@ struct Response { generics: Generics, fields: Vec, manual_body_serde: bool, + error_ty: Type, } impl Response { @@ -100,7 +109,7 @@ impl Response { }); let outgoing_response_impl = self.expand_outgoing(&ruma_common); - let incoming_response_impl = self.expand_incoming(&ruma_common); + let incoming_response_impl = self.expand_incoming(&self.error_ty, &ruma_common); quote! { #response_body_struct diff --git a/crates/ruma-macros/src/api/response/incoming.rs b/crates/ruma-macros/src/api/response/incoming.rs index 168865a8..59260c53 100644 --- a/crates/ruma-macros/src/api/response/incoming.rs +++ b/crates/ruma-macros/src/api/response/incoming.rs @@ -1,10 +1,11 @@ use proc_macro2::TokenStream; use quote::quote; +use syn::Type; use super::{Response, ResponseFieldKind}; impl Response { - pub fn expand_incoming(&self, ruma_common: &TokenStream) -> TokenStream { + pub fn expand_incoming(&self, error_ty: &Type, ruma_common: &TokenStream) -> TokenStream { let http = quote! { #ruma_common::exports::http }; let serde_json = quote! { #ruma_common::exports::serde_json }; @@ -103,41 +104,37 @@ impl Response { } }; - let method_body = quote! { - if response.status().as_u16() < 400 { - #extract_response_headers - #typed_response_body_decl - - ::std::result::Result::Ok(Self { - #response_init_fields - }) - } else { - match ::try_from_http_response( - response - ) { - ::std::result::Result::Ok(err) => { - Err(#ruma_common::api::error::ServerError::Known(err).into()) - } - ::std::result::Result::Err(response_err) => { - Err(#ruma_common::api::error::ServerError::Unknown(response_err).into()) - } - } - } - }; - quote! { #[automatically_derived] #[cfg(feature = "client")] impl #ruma_common::api::IncomingResponse for Response { - type EndpointError = EndpointError; + type EndpointError = #error_ty; fn try_from_http_response>( response: #http::Response, ) -> ::std::result::Result< Self, - #ruma_common::api::error::FromHttpResponseError, + #ruma_common::api::error::FromHttpResponseError<#error_ty>, > { - #method_body + if response.status().as_u16() < 400 { + #extract_response_headers + #typed_response_body_decl + + ::std::result::Result::Ok(Self { + #response_init_fields + }) + } else { + match <#error_ty as #ruma_common::api::EndpointError>::try_from_http_response( + response + ) { + ::std::result::Result::Ok(err) => { + Err(#ruma_common::api::error::ServerError::Known(err).into()) + } + ::std::result::Result::Err(response_err) => { + Err(#ruma_common::api::error::ServerError::Unknown(response_err).into()) + } + } + } } } }