api: Enforce consistent path field order
This commit is contained in:
parent
6ec01bfdb4
commit
764e96a254
@ -5,7 +5,6 @@ mod header_override;
|
||||
mod manual_endpoint_impl;
|
||||
mod no_fields;
|
||||
mod optional_headers;
|
||||
mod path_arg_ordering;
|
||||
mod ruma_api;
|
||||
mod ruma_api_lifetime;
|
||||
mod ruma_api_macros;
|
||||
|
@ -1,42 +0,0 @@
|
||||
use ruma_common::api::{ruma_api, IncomingRequest as _};
|
||||
|
||||
ruma_api! {
|
||||
metadata: {
|
||||
description: "Does something.",
|
||||
method: GET,
|
||||
name: "some_path_args",
|
||||
unstable_path: "/_matrix/:one/a/:two/b/:three/c",
|
||||
rate_limited: false,
|
||||
authentication: None,
|
||||
}
|
||||
|
||||
request: {
|
||||
#[ruma_api(path)]
|
||||
pub three: String,
|
||||
|
||||
#[ruma_api(path)]
|
||||
pub one: String,
|
||||
|
||||
#[ruma_api(path)]
|
||||
pub two: String,
|
||||
}
|
||||
|
||||
response: {}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn path_ordering_is_correct() {
|
||||
let request = http::Request::builder()
|
||||
.method("GET")
|
||||
// This explicitly puts wrong values in the URI, as now we rely on the side-supplied
|
||||
// path_args slice, so this is just to ensure it *is* using that slice.
|
||||
.uri("https://www.rust-lang.org/_matrix/non/a/non/b/non/c")
|
||||
.body("")
|
||||
.unwrap();
|
||||
|
||||
let resp = Request::try_from_http_request(request, &["1", "2", "3"]).unwrap();
|
||||
|
||||
assert_eq!(resp.one, "1");
|
||||
assert_eq!(resp.two, "2");
|
||||
assert_eq!(resp.three, "3");
|
||||
}
|
@ -12,7 +12,12 @@ use syn::{
|
||||
Attribute, Field, Token, Type,
|
||||
};
|
||||
|
||||
use self::{api_metadata::Metadata, api_request::Request, api_response::Response};
|
||||
use self::{
|
||||
api_metadata::Metadata,
|
||||
api_request::Request,
|
||||
api_response::Response,
|
||||
request::{RequestField, RequestFieldKind},
|
||||
};
|
||||
use crate::util::import_ruma_common;
|
||||
|
||||
mod api_metadata;
|
||||
@ -50,7 +55,8 @@ pub struct Api {
|
||||
|
||||
impl Api {
|
||||
pub fn expand_all(self) -> TokenStream {
|
||||
let maybe_error = ensure_feature_presence().map(syn::Error::to_compile_error);
|
||||
let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);
|
||||
let maybe_path_error = self.check_paths().err().map(syn::Error::into_compile_error);
|
||||
|
||||
let ruma_common = import_ruma_common();
|
||||
let http = quote! { #ruma_common::exports::http };
|
||||
@ -79,7 +85,8 @@ impl Api {
|
||||
let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value());
|
||||
|
||||
quote! {
|
||||
#maybe_error
|
||||
#maybe_feature_error
|
||||
#maybe_path_error
|
||||
|
||||
#[doc = #metadata_doc]
|
||||
pub const METADATA: #ruma_common::api::Metadata = #ruma_common::api::Metadata {
|
||||
@ -103,6 +110,54 @@ impl Api {
|
||||
type _SilenceUnusedError = #error_ty;
|
||||
}
|
||||
}
|
||||
|
||||
fn check_paths(&self) -> syn::Result<()> {
|
||||
let mut path_iter = self
|
||||
.metadata
|
||||
.unstable_path
|
||||
.iter()
|
||||
.chain(&self.metadata.r0_path)
|
||||
.chain(&self.metadata.stable_path);
|
||||
|
||||
let path = path_iter.next().ok_or_else(|| {
|
||||
syn::Error::new(Span::call_site(), "at least one path metadata field must be set")
|
||||
})?;
|
||||
let path_args = get_path_args(&path.value());
|
||||
|
||||
for extra_path in path_iter {
|
||||
let extra_path_args = get_path_args(&extra_path.value());
|
||||
if extra_path_args != path_args {
|
||||
return Err(syn::Error::new(
|
||||
Span::call_site(),
|
||||
"paths have different path parameters",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(req) = &self.request {
|
||||
let path_field_names: Vec<_> = req
|
||||
.fields
|
||||
.iter()
|
||||
.cloned()
|
||||
.filter_map(|f| match RequestField::try_from(f) {
|
||||
Ok(RequestField { kind: RequestFieldKind::Path, inner }) => {
|
||||
Some(Ok(inner.ident.unwrap().to_string()))
|
||||
}
|
||||
Ok(_) => None,
|
||||
Err(e) => Some(Err(e)),
|
||||
})
|
||||
.collect::<syn::Result<_>>()?;
|
||||
|
||||
if path_args != path_field_names {
|
||||
return Err(syn::Error::new_spanned(
|
||||
req.request_kw,
|
||||
"path fields must be in the same order as they appear in the path segments",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for Api {
|
||||
@ -215,3 +270,7 @@ fn ensure_feature_presence() -> Option<&'static syn::Error> {
|
||||
|
||||
RESULT.as_ref().err()
|
||||
}
|
||||
|
||||
fn get_path_args(path: &str) -> Vec<String> {
|
||||
path.split('/').filter_map(|s| s.strip_prefix(':').map(ToOwned::to_owned)).collect()
|
||||
}
|
||||
|
@ -84,9 +84,6 @@ impl Request {
|
||||
|
||||
let method = &metadata.method;
|
||||
let authentication = &metadata.authentication;
|
||||
let unstable_attr = metadata.unstable_path.as_ref().map(|p| quote! { unstable = #p, });
|
||||
let r0_attr = metadata.r0_path.as_ref().map(|p| quote! { r0 = #p, });
|
||||
let stable_attr = metadata.stable_path.as_ref().map(|p| quote! { stable = #p, });
|
||||
|
||||
let request_ident = Ident::new("Request", self.request_kw.span());
|
||||
let lifetimes = self.all_lifetimes();
|
||||
@ -107,9 +104,6 @@ impl Request {
|
||||
#[ruma_api(
|
||||
method = #method,
|
||||
authentication = #authentication,
|
||||
#unstable_attr
|
||||
#r0_attr
|
||||
#stable_attr
|
||||
error_ty = #error_ty,
|
||||
)]
|
||||
#( #struct_attributes )*
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
use syn::{
|
||||
parse::{Parse, ParseStream},
|
||||
Ident, LitStr, Token, Type,
|
||||
Ident, Token, Type,
|
||||
};
|
||||
|
||||
mod kw {
|
||||
@ -62,9 +62,6 @@ pub enum DeriveRequestMeta {
|
||||
Authentication(Type),
|
||||
Method(Type),
|
||||
ErrorTy(Type),
|
||||
UnstablePath(LitStr),
|
||||
R0Path(LitStr),
|
||||
StablePath(LitStr),
|
||||
}
|
||||
|
||||
impl Parse for DeriveRequestMeta {
|
||||
@ -82,18 +79,6 @@ impl Parse for DeriveRequestMeta {
|
||||
let _: kw::error_ty = input.parse()?;
|
||||
let _: Token![=] = input.parse()?;
|
||||
input.parse().map(Self::ErrorTy)
|
||||
} else if lookahead.peek(kw::unstable) {
|
||||
let _: kw::unstable = input.parse()?;
|
||||
let _: Token![=] = input.parse()?;
|
||||
input.parse().map(Self::UnstablePath)
|
||||
} else if lookahead.peek(kw::r0) {
|
||||
let _: kw::r0 = input.parse()?;
|
||||
let _: Token![=] = input.parse()?;
|
||||
input.parse().map(Self::R0Path)
|
||||
} else if lookahead.peek(kw::stable) {
|
||||
let _: kw::stable = input.parse()?;
|
||||
let _: Token![=] = input.parse()?;
|
||||
input.parse().map(Self::StablePath)
|
||||
} else {
|
||||
Err(lookahead.error())
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::{quote, ToTokens};
|
||||
@ -6,7 +6,7 @@ use syn::{
|
||||
parse::{Parse, ParseStream},
|
||||
parse_quote,
|
||||
punctuated::Punctuated,
|
||||
DeriveInput, Field, Generics, Ident, Lifetime, LitStr, Token, Type,
|
||||
DeriveInput, Field, Generics, Ident, Lifetime, Token, Type,
|
||||
};
|
||||
|
||||
use super::{
|
||||
@ -49,9 +49,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
|
||||
let mut authentication = None;
|
||||
let mut error_ty = None;
|
||||
let mut method = None;
|
||||
let mut unstable_path = None;
|
||||
let mut r0_path = None;
|
||||
let mut stable_path = None;
|
||||
|
||||
for attr in input.attrs {
|
||||
if !attr.path.is_ident("ruma_api") {
|
||||
@ -65,9 +62,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
|
||||
DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)),
|
||||
DeriveRequestMeta::Method(t) => method = Some(parse_quote!(#t)),
|
||||
DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t),
|
||||
DeriveRequestMeta::UnstablePath(s) => unstable_path = Some(s),
|
||||
DeriveRequestMeta::R0Path(s) => r0_path = Some(s),
|
||||
DeriveRequestMeta::StablePath(s) => stable_path = Some(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -79,9 +73,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
|
||||
lifetimes,
|
||||
authentication: authentication.expect("missing authentication attribute"),
|
||||
method: method.expect("missing method attribute"),
|
||||
unstable_path,
|
||||
r0_path,
|
||||
stable_path,
|
||||
error_ty: error_ty.expect("missing error_ty attribute"),
|
||||
};
|
||||
|
||||
@ -105,9 +96,6 @@ struct Request {
|
||||
|
||||
authentication: AuthScheme,
|
||||
method: Ident,
|
||||
unstable_path: Option<LitStr>,
|
||||
r0_path: Option<LitStr>,
|
||||
stable_path: Option<LitStr>,
|
||||
error_ty: Type,
|
||||
}
|
||||
|
||||
@ -149,27 +137,8 @@ impl Request {
|
||||
self.fields.iter().filter_map(RequestField::as_header_field)
|
||||
}
|
||||
|
||||
fn path_fields_ordered(&self) -> impl Iterator<Item = &Field> {
|
||||
let map: BTreeMap<String, &Field> = self
|
||||
.fields
|
||||
.iter()
|
||||
.filter_map(RequestField::as_path_field)
|
||||
.map(|f| (f.ident.as_ref().unwrap().to_string(), f))
|
||||
.collect();
|
||||
|
||||
self.stable_path
|
||||
.as_ref()
|
||||
.or(self.r0_path.as_ref())
|
||||
.or(self.unstable_path.as_ref())
|
||||
.expect("one of the paths to be defined")
|
||||
.value()
|
||||
.split('/')
|
||||
.filter_map(|s| {
|
||||
s.strip_prefix(':')
|
||||
.map(|s| *map.get(s).expect("path args have already been checked"))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
fn path_fields(&self) -> impl Iterator<Item = &Field> {
|
||||
self.fields.iter().filter_map(RequestField::as_path_field)
|
||||
}
|
||||
|
||||
fn raw_body_field(&self) -> Option<&Field> {
|
||||
@ -252,13 +221,6 @@ impl Request {
|
||||
pub(super) fn check(&self) -> syn::Result<()> {
|
||||
// TODO: highlight problematic fields
|
||||
|
||||
let path_fields: Vec<_> =
|
||||
self.fields.iter().filter_map(RequestField::as_path_field).collect();
|
||||
|
||||
self.check_path(&path_fields, self.unstable_path.as_ref())?;
|
||||
self.check_path(&path_fields, self.r0_path.as_ref())?;
|
||||
self.check_path(&path_fields, self.stable_path.as_ref())?;
|
||||
|
||||
let newtype_body_fields = self.fields.iter().filter(|f| {
|
||||
matches!(&f.kind, RequestFieldKind::NewtypeBody | RequestFieldKind::RawBody)
|
||||
});
|
||||
@ -322,58 +284,16 @@ impl Request {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_path(&self, fields: &[&Field], path: Option<&LitStr>) -> syn::Result<()> {
|
||||
let path = if let Some(lit) = path { lit } else { return Ok(()) };
|
||||
|
||||
let path_args: Vec<_> = path
|
||||
.value()
|
||||
.split('/')
|
||||
.filter_map(|s| s.strip_prefix(':').map(str::to_string))
|
||||
.collect();
|
||||
|
||||
let field_map: BTreeMap<_, _> =
|
||||
fields.iter().map(|&f| (f.ident.as_ref().unwrap().to_string(), f)).collect();
|
||||
|
||||
// test if all macro fields exist in the path
|
||||
for (name, field) in field_map.iter() {
|
||||
if !path_args.contains(name) {
|
||||
return Err({
|
||||
let mut err = syn::Error::new_spanned(
|
||||
field,
|
||||
"this path argument field is not defined in...",
|
||||
);
|
||||
err.combine(syn::Error::new_spanned(path, "...this path."));
|
||||
err
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// test if all path fields exists in macro fields
|
||||
for arg in &path_args {
|
||||
if !field_map.contains_key(arg) {
|
||||
return Err(syn::Error::new_spanned(
|
||||
path,
|
||||
format!(
|
||||
"a corresponding request path argument field for \"{}\" does not exist",
|
||||
arg
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A field of the request struct.
|
||||
struct RequestField {
|
||||
inner: Field,
|
||||
kind: RequestFieldKind,
|
||||
pub(super) struct RequestField {
|
||||
pub(super) inner: Field,
|
||||
pub(super) kind: RequestFieldKind,
|
||||
}
|
||||
|
||||
/// The kind of a request field.
|
||||
enum RequestFieldKind {
|
||||
pub(super) enum RequestFieldKind {
|
||||
/// JSON data in the body of the request.
|
||||
Body,
|
||||
|
||||
|
@ -24,8 +24,7 @@ impl Request {
|
||||
// except this one. If we get errors about missing fields in IncomingRequest for
|
||||
// a path field look here.
|
||||
let (parse_request_path, path_vars) = if self.has_path_fields() {
|
||||
let path_vars: Vec<_> =
|
||||
self.path_fields_ordered().filter_map(|f| f.ident.as_ref()).collect();
|
||||
let path_vars: Vec<_> = self.path_fields().filter_map(|f| f.ident.as_ref()).collect();
|
||||
|
||||
let parse_request_path = quote! {
|
||||
let (#(#path_vars,)*) = #serde::Deserialize::deserialize(
|
||||
|
@ -14,7 +14,7 @@ impl Request {
|
||||
let error_ty = &self.error_ty;
|
||||
|
||||
let path_fields =
|
||||
self.path_fields_ordered().map(|f| f.ident.as_ref().expect("path fields have a name"));
|
||||
self.path_fields().map(|f| f.ident.as_ref().expect("path fields have a name"));
|
||||
|
||||
let request_query_string = if let Some(field) = self.query_map_field() {
|
||||
let field_name = field.ident.as_ref().expect("expected field to have identifier");
|
||||
|
Loading…
x
Reference in New Issue
Block a user