api: Enforce consistent path field order
This commit is contained in:
		
							parent
							
								
									6ec01bfdb4
								
							
						
					
					
						commit
						764e96a254
					
				| @ -5,7 +5,6 @@ mod header_override; | ||||
| mod manual_endpoint_impl; | ||||
| mod no_fields; | ||||
| mod optional_headers; | ||||
| mod path_arg_ordering; | ||||
| mod ruma_api; | ||||
| mod ruma_api_lifetime; | ||||
| mod ruma_api_macros; | ||||
|  | ||||
| @ -1,42 +0,0 @@ | ||||
| use ruma_common::api::{ruma_api, IncomingRequest as _}; | ||||
| 
 | ||||
| ruma_api! { | ||||
|     metadata: { | ||||
|         description: "Does something.", | ||||
|         method: GET, | ||||
|         name: "some_path_args", | ||||
|         unstable_path: "/_matrix/:one/a/:two/b/:three/c", | ||||
|         rate_limited: false, | ||||
|         authentication: None, | ||||
|     } | ||||
| 
 | ||||
|     request: { | ||||
|         #[ruma_api(path)] | ||||
|         pub three: String, | ||||
| 
 | ||||
|         #[ruma_api(path)] | ||||
|         pub one: String, | ||||
| 
 | ||||
|         #[ruma_api(path)] | ||||
|         pub two: String, | ||||
|     } | ||||
| 
 | ||||
|     response: {} | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn path_ordering_is_correct() { | ||||
|     let request = http::Request::builder() | ||||
|         .method("GET") | ||||
|         // This explicitly puts wrong values in the URI, as now we rely on the side-supplied
 | ||||
|         // path_args slice, so this is just to ensure it *is* using that slice.
 | ||||
|         .uri("https://www.rust-lang.org/_matrix/non/a/non/b/non/c") | ||||
|         .body("") | ||||
|         .unwrap(); | ||||
| 
 | ||||
|     let resp = Request::try_from_http_request(request, &["1", "2", "3"]).unwrap(); | ||||
| 
 | ||||
|     assert_eq!(resp.one, "1"); | ||||
|     assert_eq!(resp.two, "2"); | ||||
|     assert_eq!(resp.three, "3"); | ||||
| } | ||||
| @ -12,7 +12,12 @@ use syn::{ | ||||
|     Attribute, Field, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| use self::{api_metadata::Metadata, api_request::Request, api_response::Response}; | ||||
| use self::{ | ||||
|     api_metadata::Metadata, | ||||
|     api_request::Request, | ||||
|     api_response::Response, | ||||
|     request::{RequestField, RequestFieldKind}, | ||||
| }; | ||||
| use crate::util::import_ruma_common; | ||||
| 
 | ||||
| mod api_metadata; | ||||
| @ -50,7 +55,8 @@ pub struct Api { | ||||
| 
 | ||||
| impl Api { | ||||
|     pub fn expand_all(self) -> TokenStream { | ||||
|         let maybe_error = ensure_feature_presence().map(syn::Error::to_compile_error); | ||||
|         let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error); | ||||
|         let maybe_path_error = self.check_paths().err().map(syn::Error::into_compile_error); | ||||
| 
 | ||||
|         let ruma_common = import_ruma_common(); | ||||
|         let http = quote! { #ruma_common::exports::http }; | ||||
| @ -79,7 +85,8 @@ impl Api { | ||||
|         let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); | ||||
| 
 | ||||
|         quote! { | ||||
|             #maybe_error | ||||
|             #maybe_feature_error | ||||
|             #maybe_path_error | ||||
| 
 | ||||
|             #[doc = #metadata_doc] | ||||
|             pub const METADATA: #ruma_common::api::Metadata = #ruma_common::api::Metadata { | ||||
| @ -103,6 +110,54 @@ impl Api { | ||||
|             type _SilenceUnusedError = #error_ty; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn check_paths(&self) -> syn::Result<()> { | ||||
|         let mut path_iter = self | ||||
|             .metadata | ||||
|             .unstable_path | ||||
|             .iter() | ||||
|             .chain(&self.metadata.r0_path) | ||||
|             .chain(&self.metadata.stable_path); | ||||
| 
 | ||||
|         let path = path_iter.next().ok_or_else(|| { | ||||
|             syn::Error::new(Span::call_site(), "at least one path metadata field must be set") | ||||
|         })?; | ||||
|         let path_args = get_path_args(&path.value()); | ||||
| 
 | ||||
|         for extra_path in path_iter { | ||||
|             let extra_path_args = get_path_args(&extra_path.value()); | ||||
|             if extra_path_args != path_args { | ||||
|                 return Err(syn::Error::new( | ||||
|                     Span::call_site(), | ||||
|                     "paths have different path parameters", | ||||
|                 )); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if let Some(req) = &self.request { | ||||
|             let path_field_names: Vec<_> = req | ||||
|                 .fields | ||||
|                 .iter() | ||||
|                 .cloned() | ||||
|                 .filter_map(|f| match RequestField::try_from(f) { | ||||
|                     Ok(RequestField { kind: RequestFieldKind::Path, inner }) => { | ||||
|                         Some(Ok(inner.ident.unwrap().to_string())) | ||||
|                     } | ||||
|                     Ok(_) => None, | ||||
|                     Err(e) => Some(Err(e)), | ||||
|                 }) | ||||
|                 .collect::<syn::Result<_>>()?; | ||||
| 
 | ||||
|             if path_args != path_field_names { | ||||
|                 return Err(syn::Error::new_spanned( | ||||
|                     req.request_kw, | ||||
|                     "path fields must be in the same order as they appear in the path segments", | ||||
|                 )); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Parse for Api { | ||||
| @ -215,3 +270,7 @@ fn ensure_feature_presence() -> Option<&'static syn::Error> { | ||||
| 
 | ||||
|     RESULT.as_ref().err() | ||||
| } | ||||
| 
 | ||||
| fn get_path_args(path: &str) -> Vec<String> { | ||||
|     path.split('/').filter_map(|s| s.strip_prefix(':').map(ToOwned::to_owned)).collect() | ||||
| } | ||||
|  | ||||
| @ -84,9 +84,6 @@ impl Request { | ||||
| 
 | ||||
|         let method = &metadata.method; | ||||
|         let authentication = &metadata.authentication; | ||||
|         let unstable_attr = metadata.unstable_path.as_ref().map(|p| quote! { unstable = #p, }); | ||||
|         let r0_attr = metadata.r0_path.as_ref().map(|p| quote! { r0 = #p, }); | ||||
|         let stable_attr = metadata.stable_path.as_ref().map(|p| quote! { stable = #p, }); | ||||
| 
 | ||||
|         let request_ident = Ident::new("Request", self.request_kw.span()); | ||||
|         let lifetimes = self.all_lifetimes(); | ||||
| @ -107,9 +104,6 @@ impl Request { | ||||
|             #[ruma_api(
 | ||||
|                 method = #method, | ||||
|                 authentication = #authentication, | ||||
|                 #unstable_attr | ||||
|                 #r0_attr | ||||
|                 #stable_attr | ||||
|                 error_ty = #error_ty, | ||||
|             )] | ||||
|             #( #struct_attributes )* | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| 
 | ||||
| use syn::{ | ||||
|     parse::{Parse, ParseStream}, | ||||
|     Ident, LitStr, Token, Type, | ||||
|     Ident, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| mod kw { | ||||
| @ -62,9 +62,6 @@ pub enum DeriveRequestMeta { | ||||
|     Authentication(Type), | ||||
|     Method(Type), | ||||
|     ErrorTy(Type), | ||||
|     UnstablePath(LitStr), | ||||
|     R0Path(LitStr), | ||||
|     StablePath(LitStr), | ||||
| } | ||||
| 
 | ||||
| impl Parse for DeriveRequestMeta { | ||||
| @ -82,18 +79,6 @@ impl Parse for DeriveRequestMeta { | ||||
|             let _: kw::error_ty = input.parse()?; | ||||
|             let _: Token![=] = input.parse()?; | ||||
|             input.parse().map(Self::ErrorTy) | ||||
|         } else if lookahead.peek(kw::unstable) { | ||||
|             let _: kw::unstable = input.parse()?; | ||||
|             let _: Token![=] = input.parse()?; | ||||
|             input.parse().map(Self::UnstablePath) | ||||
|         } else if lookahead.peek(kw::r0) { | ||||
|             let _: kw::r0 = input.parse()?; | ||||
|             let _: Token![=] = input.parse()?; | ||||
|             input.parse().map(Self::R0Path) | ||||
|         } else if lookahead.peek(kw::stable) { | ||||
|             let _: kw::stable = input.parse()?; | ||||
|             let _: Token![=] = input.parse()?; | ||||
|             input.parse().map(Self::StablePath) | ||||
|         } else { | ||||
|             Err(lookahead.error()) | ||||
|         } | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| use std::collections::{BTreeMap, BTreeSet}; | ||||
| use std::collections::BTreeSet; | ||||
| 
 | ||||
| use proc_macro2::TokenStream; | ||||
| use quote::{quote, ToTokens}; | ||||
| @ -6,7 +6,7 @@ use syn::{ | ||||
|     parse::{Parse, ParseStream}, | ||||
|     parse_quote, | ||||
|     punctuated::Punctuated, | ||||
|     DeriveInput, Field, Generics, Ident, Lifetime, LitStr, Token, Type, | ||||
|     DeriveInput, Field, Generics, Ident, Lifetime, Token, Type, | ||||
| }; | ||||
| 
 | ||||
| use super::{ | ||||
| @ -49,9 +49,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> { | ||||
|     let mut authentication = None; | ||||
|     let mut error_ty = None; | ||||
|     let mut method = None; | ||||
|     let mut unstable_path = None; | ||||
|     let mut r0_path = None; | ||||
|     let mut stable_path = None; | ||||
| 
 | ||||
|     for attr in input.attrs { | ||||
|         if !attr.path.is_ident("ruma_api") { | ||||
| @ -65,9 +62,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> { | ||||
|                 DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)), | ||||
|                 DeriveRequestMeta::Method(t) => method = Some(parse_quote!(#t)), | ||||
|                 DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t), | ||||
|                 DeriveRequestMeta::UnstablePath(s) => unstable_path = Some(s), | ||||
|                 DeriveRequestMeta::R0Path(s) => r0_path = Some(s), | ||||
|                 DeriveRequestMeta::StablePath(s) => stable_path = Some(s), | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @ -79,9 +73,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> { | ||||
|         lifetimes, | ||||
|         authentication: authentication.expect("missing authentication attribute"), | ||||
|         method: method.expect("missing method attribute"), | ||||
|         unstable_path, | ||||
|         r0_path, | ||||
|         stable_path, | ||||
|         error_ty: error_ty.expect("missing error_ty attribute"), | ||||
|     }; | ||||
| 
 | ||||
| @ -105,9 +96,6 @@ struct Request { | ||||
| 
 | ||||
|     authentication: AuthScheme, | ||||
|     method: Ident, | ||||
|     unstable_path: Option<LitStr>, | ||||
|     r0_path: Option<LitStr>, | ||||
|     stable_path: Option<LitStr>, | ||||
|     error_ty: Type, | ||||
| } | ||||
| 
 | ||||
| @ -149,27 +137,8 @@ impl Request { | ||||
|         self.fields.iter().filter_map(RequestField::as_header_field) | ||||
|     } | ||||
| 
 | ||||
|     fn path_fields_ordered(&self) -> impl Iterator<Item = &Field> { | ||||
|         let map: BTreeMap<String, &Field> = self | ||||
|             .fields | ||||
|             .iter() | ||||
|             .filter_map(RequestField::as_path_field) | ||||
|             .map(|f| (f.ident.as_ref().unwrap().to_string(), f)) | ||||
|             .collect(); | ||||
| 
 | ||||
|         self.stable_path | ||||
|             .as_ref() | ||||
|             .or(self.r0_path.as_ref()) | ||||
|             .or(self.unstable_path.as_ref()) | ||||
|             .expect("one of the paths to be defined") | ||||
|             .value() | ||||
|             .split('/') | ||||
|             .filter_map(|s| { | ||||
|                 s.strip_prefix(':') | ||||
|                     .map(|s| *map.get(s).expect("path args have already been checked")) | ||||
|             }) | ||||
|             .collect::<Vec<_>>() | ||||
|             .into_iter() | ||||
|     fn path_fields(&self) -> impl Iterator<Item = &Field> { | ||||
|         self.fields.iter().filter_map(RequestField::as_path_field) | ||||
|     } | ||||
| 
 | ||||
|     fn raw_body_field(&self) -> Option<&Field> { | ||||
| @ -252,13 +221,6 @@ impl Request { | ||||
|     pub(super) fn check(&self) -> syn::Result<()> { | ||||
|         // TODO: highlight problematic fields
 | ||||
| 
 | ||||
|         let path_fields: Vec<_> = | ||||
|             self.fields.iter().filter_map(RequestField::as_path_field).collect(); | ||||
| 
 | ||||
|         self.check_path(&path_fields, self.unstable_path.as_ref())?; | ||||
|         self.check_path(&path_fields, self.r0_path.as_ref())?; | ||||
|         self.check_path(&path_fields, self.stable_path.as_ref())?; | ||||
| 
 | ||||
|         let newtype_body_fields = self.fields.iter().filter(|f| { | ||||
|             matches!(&f.kind, RequestFieldKind::NewtypeBody | RequestFieldKind::RawBody) | ||||
|         }); | ||||
| @ -322,58 +284,16 @@ impl Request { | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     fn check_path(&self, fields: &[&Field], path: Option<&LitStr>) -> syn::Result<()> { | ||||
|         let path = if let Some(lit) = path { lit } else { return Ok(()) }; | ||||
| 
 | ||||
|         let path_args: Vec<_> = path | ||||
|             .value() | ||||
|             .split('/') | ||||
|             .filter_map(|s| s.strip_prefix(':').map(str::to_string)) | ||||
|             .collect(); | ||||
| 
 | ||||
|         let field_map: BTreeMap<_, _> = | ||||
|             fields.iter().map(|&f| (f.ident.as_ref().unwrap().to_string(), f)).collect(); | ||||
| 
 | ||||
|         // test if all macro fields exist in the path
 | ||||
|         for (name, field) in field_map.iter() { | ||||
|             if !path_args.contains(name) { | ||||
|                 return Err({ | ||||
|                     let mut err = syn::Error::new_spanned( | ||||
|                         field, | ||||
|                         "this path argument field is not defined in...", | ||||
|                     ); | ||||
|                     err.combine(syn::Error::new_spanned(path, "...this path.")); | ||||
|                     err | ||||
|                 }); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // test if all path fields exists in macro fields
 | ||||
|         for arg in &path_args { | ||||
|             if !field_map.contains_key(arg) { | ||||
|                 return Err(syn::Error::new_spanned( | ||||
|                     path, | ||||
|                     format!( | ||||
|                         "a corresponding request path argument field for \"{}\" does not exist", | ||||
|                         arg | ||||
|                     ), | ||||
|                 )); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// A field of the request struct.
 | ||||
| struct RequestField { | ||||
|     inner: Field, | ||||
|     kind: RequestFieldKind, | ||||
| pub(super) struct RequestField { | ||||
|     pub(super) inner: Field, | ||||
|     pub(super) kind: RequestFieldKind, | ||||
| } | ||||
| 
 | ||||
| /// The kind of a request field.
 | ||||
| enum RequestFieldKind { | ||||
| pub(super) enum RequestFieldKind { | ||||
|     /// JSON data in the body of the request.
 | ||||
|     Body, | ||||
| 
 | ||||
|  | ||||
| @ -24,8 +24,7 @@ impl Request { | ||||
|         // except this one. If we get errors about missing fields in IncomingRequest for
 | ||||
|         // a path field look here.
 | ||||
|         let (parse_request_path, path_vars) = if self.has_path_fields() { | ||||
|             let path_vars: Vec<_> = | ||||
|                 self.path_fields_ordered().filter_map(|f| f.ident.as_ref()).collect(); | ||||
|             let path_vars: Vec<_> = self.path_fields().filter_map(|f| f.ident.as_ref()).collect(); | ||||
| 
 | ||||
|             let parse_request_path = quote! { | ||||
|                 let (#(#path_vars,)*) = #serde::Deserialize::deserialize( | ||||
|  | ||||
| @ -14,7 +14,7 @@ impl Request { | ||||
|         let error_ty = &self.error_ty; | ||||
| 
 | ||||
|         let path_fields = | ||||
|             self.path_fields_ordered().map(|f| f.ident.as_ref().expect("path fields have a name")); | ||||
|             self.path_fields().map(|f| f.ident.as_ref().expect("path fields have a name")); | ||||
| 
 | ||||
|         let request_query_string = if let Some(field) = self.query_map_field() { | ||||
|             let field_name = field.ident.as_ref().expect("expected field to have identifier"); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user