diff --git a/crates/ruma-common/tests/api/mod.rs b/crates/ruma-common/tests/api/mod.rs index e61dcaf6..a01f92eb 100644 --- a/crates/ruma-common/tests/api/mod.rs +++ b/crates/ruma-common/tests/api/mod.rs @@ -5,7 +5,6 @@ mod header_override; mod manual_endpoint_impl; mod no_fields; mod optional_headers; -mod path_arg_ordering; mod ruma_api; mod ruma_api_lifetime; mod ruma_api_macros; diff --git a/crates/ruma-common/tests/api/path_arg_ordering.rs b/crates/ruma-common/tests/api/path_arg_ordering.rs deleted file mode 100644 index 8cff98dc..00000000 --- a/crates/ruma-common/tests/api/path_arg_ordering.rs +++ /dev/null @@ -1,42 +0,0 @@ -use ruma_common::api::{ruma_api, IncomingRequest as _}; - -ruma_api! { - metadata: { - description: "Does something.", - method: GET, - name: "some_path_args", - unstable_path: "/_matrix/:one/a/:two/b/:three/c", - rate_limited: false, - authentication: None, - } - - request: { - #[ruma_api(path)] - pub three: String, - - #[ruma_api(path)] - pub one: String, - - #[ruma_api(path)] - pub two: String, - } - - response: {} -} - -#[test] -fn path_ordering_is_correct() { - let request = http::Request::builder() - .method("GET") - // This explicitly puts wrong values in the URI, as now we rely on the side-supplied - // path_args slice, so this is just to ensure it *is* using that slice. - .uri("https://www.rust-lang.org/_matrix/non/a/non/b/non/c") - .body("") - .unwrap(); - - let resp = Request::try_from_http_request(request, &["1", "2", "3"]).unwrap(); - - assert_eq!(resp.one, "1"); - assert_eq!(resp.two, "2"); - assert_eq!(resp.three, "3"); -} diff --git a/crates/ruma-macros/src/api.rs b/crates/ruma-macros/src/api.rs index 41dfb797..7ea57710 100644 --- a/crates/ruma-macros/src/api.rs +++ b/crates/ruma-macros/src/api.rs @@ -12,7 +12,12 @@ use syn::{ Attribute, Field, Token, Type, }; -use self::{api_metadata::Metadata, api_request::Request, api_response::Response}; +use self::{ + api_metadata::Metadata, + api_request::Request, + api_response::Response, + request::{RequestField, RequestFieldKind}, +}; use crate::util::import_ruma_common; mod api_metadata; @@ -50,7 +55,8 @@ pub struct Api { impl Api { pub fn expand_all(self) -> TokenStream { - let maybe_error = ensure_feature_presence().map(syn::Error::to_compile_error); + let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error); + let maybe_path_error = self.check_paths().err().map(syn::Error::into_compile_error); let ruma_common = import_ruma_common(); let http = quote! { #ruma_common::exports::http }; @@ -79,7 +85,8 @@ impl Api { let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); quote! { - #maybe_error + #maybe_feature_error + #maybe_path_error #[doc = #metadata_doc] pub const METADATA: #ruma_common::api::Metadata = #ruma_common::api::Metadata { @@ -103,6 +110,54 @@ impl Api { type _SilenceUnusedError = #error_ty; } } + + fn check_paths(&self) -> syn::Result<()> { + let mut path_iter = self + .metadata + .unstable_path + .iter() + .chain(&self.metadata.r0_path) + .chain(&self.metadata.stable_path); + + let path = path_iter.next().ok_or_else(|| { + syn::Error::new(Span::call_site(), "at least one path metadata field must be set") + })?; + let path_args = get_path_args(&path.value()); + + for extra_path in path_iter { + let extra_path_args = get_path_args(&extra_path.value()); + if extra_path_args != path_args { + return Err(syn::Error::new( + Span::call_site(), + "paths have different path parameters", + )); + } + } + + if let Some(req) = &self.request { + let path_field_names: Vec<_> = req + .fields + .iter() + .cloned() + .filter_map(|f| match RequestField::try_from(f) { + Ok(RequestField { kind: RequestFieldKind::Path, inner }) => { + Some(Ok(inner.ident.unwrap().to_string())) + } + Ok(_) => None, + Err(e) => Some(Err(e)), + }) + .collect::>()?; + + if path_args != path_field_names { + return Err(syn::Error::new_spanned( + req.request_kw, + "path fields must be in the same order as they appear in the path segments", + )); + } + } + + Ok(()) + } } impl Parse for Api { @@ -215,3 +270,7 @@ fn ensure_feature_presence() -> Option<&'static syn::Error> { RESULT.as_ref().err() } + +fn get_path_args(path: &str) -> Vec { + path.split('/').filter_map(|s| s.strip_prefix(':').map(ToOwned::to_owned)).collect() +} diff --git a/crates/ruma-macros/src/api/api_request.rs b/crates/ruma-macros/src/api/api_request.rs index d3aa957c..88aed2bb 100644 --- a/crates/ruma-macros/src/api/api_request.rs +++ b/crates/ruma-macros/src/api/api_request.rs @@ -84,9 +84,6 @@ impl Request { let method = &metadata.method; let authentication = &metadata.authentication; - let unstable_attr = metadata.unstable_path.as_ref().map(|p| quote! { unstable = #p, }); - let r0_attr = metadata.r0_path.as_ref().map(|p| quote! { r0 = #p, }); - let stable_attr = metadata.stable_path.as_ref().map(|p| quote! { stable = #p, }); let request_ident = Ident::new("Request", self.request_kw.span()); let lifetimes = self.all_lifetimes(); @@ -107,9 +104,6 @@ impl Request { #[ruma_api( method = #method, authentication = #authentication, - #unstable_attr - #r0_attr - #stable_attr error_ty = #error_ty, )] #( #struct_attributes )* diff --git a/crates/ruma-macros/src/api/attribute.rs b/crates/ruma-macros/src/api/attribute.rs index d0a1d50b..14e9b426 100644 --- a/crates/ruma-macros/src/api/attribute.rs +++ b/crates/ruma-macros/src/api/attribute.rs @@ -2,7 +2,7 @@ use syn::{ parse::{Parse, ParseStream}, - Ident, LitStr, Token, Type, + Ident, Token, Type, }; mod kw { @@ -62,9 +62,6 @@ pub enum DeriveRequestMeta { Authentication(Type), Method(Type), ErrorTy(Type), - UnstablePath(LitStr), - R0Path(LitStr), - StablePath(LitStr), } impl Parse for DeriveRequestMeta { @@ -82,18 +79,6 @@ impl Parse for DeriveRequestMeta { let _: kw::error_ty = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::ErrorTy) - } else if lookahead.peek(kw::unstable) { - let _: kw::unstable = input.parse()?; - let _: Token![=] = input.parse()?; - input.parse().map(Self::UnstablePath) - } else if lookahead.peek(kw::r0) { - let _: kw::r0 = input.parse()?; - let _: Token![=] = input.parse()?; - input.parse().map(Self::R0Path) - } else if lookahead.peek(kw::stable) { - let _: kw::stable = input.parse()?; - let _: Token![=] = input.parse()?; - input.parse().map(Self::StablePath) } else { Err(lookahead.error()) } diff --git a/crates/ruma-macros/src/api/request.rs b/crates/ruma-macros/src/api/request.rs index 865fe8f3..bccd3c23 100644 --- a/crates/ruma-macros/src/api/request.rs +++ b/crates/ruma-macros/src/api/request.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; @@ -6,7 +6,7 @@ use syn::{ parse::{Parse, ParseStream}, parse_quote, punctuated::Punctuated, - DeriveInput, Field, Generics, Ident, Lifetime, LitStr, Token, Type, + DeriveInput, Field, Generics, Ident, Lifetime, Token, Type, }; use super::{ @@ -49,9 +49,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { let mut authentication = None; let mut error_ty = None; let mut method = None; - let mut unstable_path = None; - let mut r0_path = None; - let mut stable_path = None; for attr in input.attrs { if !attr.path.is_ident("ruma_api") { @@ -65,9 +62,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)), DeriveRequestMeta::Method(t) => method = Some(parse_quote!(#t)), DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t), - DeriveRequestMeta::UnstablePath(s) => unstable_path = Some(s), - DeriveRequestMeta::R0Path(s) => r0_path = Some(s), - DeriveRequestMeta::StablePath(s) => stable_path = Some(s), } } } @@ -79,9 +73,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result { lifetimes, authentication: authentication.expect("missing authentication attribute"), method: method.expect("missing method attribute"), - unstable_path, - r0_path, - stable_path, error_ty: error_ty.expect("missing error_ty attribute"), }; @@ -105,9 +96,6 @@ struct Request { authentication: AuthScheme, method: Ident, - unstable_path: Option, - r0_path: Option, - stable_path: Option, error_ty: Type, } @@ -149,27 +137,8 @@ impl Request { self.fields.iter().filter_map(RequestField::as_header_field) } - fn path_fields_ordered(&self) -> impl Iterator { - let map: BTreeMap = self - .fields - .iter() - .filter_map(RequestField::as_path_field) - .map(|f| (f.ident.as_ref().unwrap().to_string(), f)) - .collect(); - - self.stable_path - .as_ref() - .or(self.r0_path.as_ref()) - .or(self.unstable_path.as_ref()) - .expect("one of the paths to be defined") - .value() - .split('/') - .filter_map(|s| { - s.strip_prefix(':') - .map(|s| *map.get(s).expect("path args have already been checked")) - }) - .collect::>() - .into_iter() + fn path_fields(&self) -> impl Iterator { + self.fields.iter().filter_map(RequestField::as_path_field) } fn raw_body_field(&self) -> Option<&Field> { @@ -252,13 +221,6 @@ impl Request { pub(super) fn check(&self) -> syn::Result<()> { // TODO: highlight problematic fields - let path_fields: Vec<_> = - self.fields.iter().filter_map(RequestField::as_path_field).collect(); - - self.check_path(&path_fields, self.unstable_path.as_ref())?; - self.check_path(&path_fields, self.r0_path.as_ref())?; - self.check_path(&path_fields, self.stable_path.as_ref())?; - let newtype_body_fields = self.fields.iter().filter(|f| { matches!(&f.kind, RequestFieldKind::NewtypeBody | RequestFieldKind::RawBody) }); @@ -322,58 +284,16 @@ impl Request { Ok(()) } - - fn check_path(&self, fields: &[&Field], path: Option<&LitStr>) -> syn::Result<()> { - let path = if let Some(lit) = path { lit } else { return Ok(()) }; - - let path_args: Vec<_> = path - .value() - .split('/') - .filter_map(|s| s.strip_prefix(':').map(str::to_string)) - .collect(); - - let field_map: BTreeMap<_, _> = - fields.iter().map(|&f| (f.ident.as_ref().unwrap().to_string(), f)).collect(); - - // test if all macro fields exist in the path - for (name, field) in field_map.iter() { - if !path_args.contains(name) { - return Err({ - let mut err = syn::Error::new_spanned( - field, - "this path argument field is not defined in...", - ); - err.combine(syn::Error::new_spanned(path, "...this path.")); - err - }); - } - } - - // test if all path fields exists in macro fields - for arg in &path_args { - if !field_map.contains_key(arg) { - return Err(syn::Error::new_spanned( - path, - format!( - "a corresponding request path argument field for \"{}\" does not exist", - arg - ), - )); - } - } - - Ok(()) - } } /// A field of the request struct. -struct RequestField { - inner: Field, - kind: RequestFieldKind, +pub(super) struct RequestField { + pub(super) inner: Field, + pub(super) kind: RequestFieldKind, } /// The kind of a request field. -enum RequestFieldKind { +pub(super) enum RequestFieldKind { /// JSON data in the body of the request. Body, diff --git a/crates/ruma-macros/src/api/request/incoming.rs b/crates/ruma-macros/src/api/request/incoming.rs index 71a56c8d..80f2fba4 100644 --- a/crates/ruma-macros/src/api/request/incoming.rs +++ b/crates/ruma-macros/src/api/request/incoming.rs @@ -24,8 +24,7 @@ impl Request { // except this one. If we get errors about missing fields in IncomingRequest for // a path field look here. let (parse_request_path, path_vars) = if self.has_path_fields() { - let path_vars: Vec<_> = - self.path_fields_ordered().filter_map(|f| f.ident.as_ref()).collect(); + let path_vars: Vec<_> = self.path_fields().filter_map(|f| f.ident.as_ref()).collect(); let parse_request_path = quote! { let (#(#path_vars,)*) = #serde::Deserialize::deserialize( diff --git a/crates/ruma-macros/src/api/request/outgoing.rs b/crates/ruma-macros/src/api/request/outgoing.rs index 6b89c2f2..d954d672 100644 --- a/crates/ruma-macros/src/api/request/outgoing.rs +++ b/crates/ruma-macros/src/api/request/outgoing.rs @@ -14,7 +14,7 @@ impl Request { let error_ty = &self.error_ty; let path_fields = - self.path_fields_ordered().map(|f| f.ident.as_ref().expect("path fields have a name")); + self.path_fields().map(|f| f.ident.as_ref().expect("path fields have a name")); let request_query_string = if let Some(field) = self.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have identifier");