api: Enforce consistent path field order

This commit is contained in:
Jonas Platte 2022-10-20 20:39:13 +02:00
parent 6ec01bfdb4
commit 764e96a254
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
8 changed files with 73 additions and 159 deletions

View File

@ -5,7 +5,6 @@ mod header_override;
mod manual_endpoint_impl;
mod no_fields;
mod optional_headers;
mod path_arg_ordering;
mod ruma_api;
mod ruma_api_lifetime;
mod ruma_api_macros;

View File

@ -1,42 +0,0 @@
use ruma_common::api::{ruma_api, IncomingRequest as _};
ruma_api! {
metadata: {
description: "Does something.",
method: GET,
name: "some_path_args",
unstable_path: "/_matrix/:one/a/:two/b/:three/c",
rate_limited: false,
authentication: None,
}
request: {
#[ruma_api(path)]
pub three: String,
#[ruma_api(path)]
pub one: String,
#[ruma_api(path)]
pub two: String,
}
response: {}
}
#[test]
fn path_ordering_is_correct() {
let request = http::Request::builder()
.method("GET")
// This explicitly puts wrong values in the URI, as now we rely on the side-supplied
// path_args slice, so this is just to ensure it *is* using that slice.
.uri("https://www.rust-lang.org/_matrix/non/a/non/b/non/c")
.body("")
.unwrap();
let resp = Request::try_from_http_request(request, &["1", "2", "3"]).unwrap();
assert_eq!(resp.one, "1");
assert_eq!(resp.two, "2");
assert_eq!(resp.three, "3");
}

View File

@ -12,7 +12,12 @@ use syn::{
Attribute, Field, Token, Type,
};
use self::{api_metadata::Metadata, api_request::Request, api_response::Response};
use self::{
api_metadata::Metadata,
api_request::Request,
api_response::Response,
request::{RequestField, RequestFieldKind},
};
use crate::util::import_ruma_common;
mod api_metadata;
@ -50,7 +55,8 @@ pub struct Api {
impl Api {
pub fn expand_all(self) -> TokenStream {
let maybe_error = ensure_feature_presence().map(syn::Error::to_compile_error);
let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);
let maybe_path_error = self.check_paths().err().map(syn::Error::into_compile_error);
let ruma_common = import_ruma_common();
let http = quote! { #ruma_common::exports::http };
@ -79,7 +85,8 @@ impl Api {
let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value());
quote! {
#maybe_error
#maybe_feature_error
#maybe_path_error
#[doc = #metadata_doc]
pub const METADATA: #ruma_common::api::Metadata = #ruma_common::api::Metadata {
@ -103,6 +110,54 @@ impl Api {
type _SilenceUnusedError = #error_ty;
}
}
fn check_paths(&self) -> syn::Result<()> {
let mut path_iter = self
.metadata
.unstable_path
.iter()
.chain(&self.metadata.r0_path)
.chain(&self.metadata.stable_path);
let path = path_iter.next().ok_or_else(|| {
syn::Error::new(Span::call_site(), "at least one path metadata field must be set")
})?;
let path_args = get_path_args(&path.value());
for extra_path in path_iter {
let extra_path_args = get_path_args(&extra_path.value());
if extra_path_args != path_args {
return Err(syn::Error::new(
Span::call_site(),
"paths have different path parameters",
));
}
}
if let Some(req) = &self.request {
let path_field_names: Vec<_> = req
.fields
.iter()
.cloned()
.filter_map(|f| match RequestField::try_from(f) {
Ok(RequestField { kind: RequestFieldKind::Path, inner }) => {
Some(Ok(inner.ident.unwrap().to_string()))
}
Ok(_) => None,
Err(e) => Some(Err(e)),
})
.collect::<syn::Result<_>>()?;
if path_args != path_field_names {
return Err(syn::Error::new_spanned(
req.request_kw,
"path fields must be in the same order as they appear in the path segments",
));
}
}
Ok(())
}
}
impl Parse for Api {
@ -215,3 +270,7 @@ fn ensure_feature_presence() -> Option<&'static syn::Error> {
RESULT.as_ref().err()
}
fn get_path_args(path: &str) -> Vec<String> {
path.split('/').filter_map(|s| s.strip_prefix(':').map(ToOwned::to_owned)).collect()
}

View File

@ -84,9 +84,6 @@ impl Request {
let method = &metadata.method;
let authentication = &metadata.authentication;
let unstable_attr = metadata.unstable_path.as_ref().map(|p| quote! { unstable = #p, });
let r0_attr = metadata.r0_path.as_ref().map(|p| quote! { r0 = #p, });
let stable_attr = metadata.stable_path.as_ref().map(|p| quote! { stable = #p, });
let request_ident = Ident::new("Request", self.request_kw.span());
let lifetimes = self.all_lifetimes();
@ -107,9 +104,6 @@ impl Request {
#[ruma_api(
method = #method,
authentication = #authentication,
#unstable_attr
#r0_attr
#stable_attr
error_ty = #error_ty,
)]
#( #struct_attributes )*

View File

@ -2,7 +2,7 @@
use syn::{
parse::{Parse, ParseStream},
Ident, LitStr, Token, Type,
Ident, Token, Type,
};
mod kw {
@ -62,9 +62,6 @@ pub enum DeriveRequestMeta {
Authentication(Type),
Method(Type),
ErrorTy(Type),
UnstablePath(LitStr),
R0Path(LitStr),
StablePath(LitStr),
}
impl Parse for DeriveRequestMeta {
@ -82,18 +79,6 @@ impl Parse for DeriveRequestMeta {
let _: kw::error_ty = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(Self::ErrorTy)
} else if lookahead.peek(kw::unstable) {
let _: kw::unstable = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(Self::UnstablePath)
} else if lookahead.peek(kw::r0) {
let _: kw::r0 = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(Self::R0Path)
} else if lookahead.peek(kw::stable) {
let _: kw::stable = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(Self::StablePath)
} else {
Err(lookahead.error())
}

View File

@ -1,4 +1,4 @@
use std::collections::{BTreeMap, BTreeSet};
use std::collections::BTreeSet;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
@ -6,7 +6,7 @@ use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
DeriveInput, Field, Generics, Ident, Lifetime, LitStr, Token, Type,
DeriveInput, Field, Generics, Ident, Lifetime, Token, Type,
};
use super::{
@ -49,9 +49,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
let mut authentication = None;
let mut error_ty = None;
let mut method = None;
let mut unstable_path = None;
let mut r0_path = None;
let mut stable_path = None;
for attr in input.attrs {
if !attr.path.is_ident("ruma_api") {
@ -65,9 +62,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)),
DeriveRequestMeta::Method(t) => method = Some(parse_quote!(#t)),
DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t),
DeriveRequestMeta::UnstablePath(s) => unstable_path = Some(s),
DeriveRequestMeta::R0Path(s) => r0_path = Some(s),
DeriveRequestMeta::StablePath(s) => stable_path = Some(s),
}
}
}
@ -79,9 +73,6 @@ pub fn expand_derive_request(input: DeriveInput) -> syn::Result<TokenStream> {
lifetimes,
authentication: authentication.expect("missing authentication attribute"),
method: method.expect("missing method attribute"),
unstable_path,
r0_path,
stable_path,
error_ty: error_ty.expect("missing error_ty attribute"),
};
@ -105,9 +96,6 @@ struct Request {
authentication: AuthScheme,
method: Ident,
unstable_path: Option<LitStr>,
r0_path: Option<LitStr>,
stable_path: Option<LitStr>,
error_ty: Type,
}
@ -149,27 +137,8 @@ impl Request {
self.fields.iter().filter_map(RequestField::as_header_field)
}
fn path_fields_ordered(&self) -> impl Iterator<Item = &Field> {
let map: BTreeMap<String, &Field> = self
.fields
.iter()
.filter_map(RequestField::as_path_field)
.map(|f| (f.ident.as_ref().unwrap().to_string(), f))
.collect();
self.stable_path
.as_ref()
.or(self.r0_path.as_ref())
.or(self.unstable_path.as_ref())
.expect("one of the paths to be defined")
.value()
.split('/')
.filter_map(|s| {
s.strip_prefix(':')
.map(|s| *map.get(s).expect("path args have already been checked"))
})
.collect::<Vec<_>>()
.into_iter()
fn path_fields(&self) -> impl Iterator<Item = &Field> {
self.fields.iter().filter_map(RequestField::as_path_field)
}
fn raw_body_field(&self) -> Option<&Field> {
@ -252,13 +221,6 @@ impl Request {
pub(super) fn check(&self) -> syn::Result<()> {
// TODO: highlight problematic fields
let path_fields: Vec<_> =
self.fields.iter().filter_map(RequestField::as_path_field).collect();
self.check_path(&path_fields, self.unstable_path.as_ref())?;
self.check_path(&path_fields, self.r0_path.as_ref())?;
self.check_path(&path_fields, self.stable_path.as_ref())?;
let newtype_body_fields = self.fields.iter().filter(|f| {
matches!(&f.kind, RequestFieldKind::NewtypeBody | RequestFieldKind::RawBody)
});
@ -322,58 +284,16 @@ impl Request {
Ok(())
}
fn check_path(&self, fields: &[&Field], path: Option<&LitStr>) -> syn::Result<()> {
let path = if let Some(lit) = path { lit } else { return Ok(()) };
let path_args: Vec<_> = path
.value()
.split('/')
.filter_map(|s| s.strip_prefix(':').map(str::to_string))
.collect();
let field_map: BTreeMap<_, _> =
fields.iter().map(|&f| (f.ident.as_ref().unwrap().to_string(), f)).collect();
// test if all macro fields exist in the path
for (name, field) in field_map.iter() {
if !path_args.contains(name) {
return Err({
let mut err = syn::Error::new_spanned(
field,
"this path argument field is not defined in...",
);
err.combine(syn::Error::new_spanned(path, "...this path."));
err
});
}
}
// test if all path fields exists in macro fields
for arg in &path_args {
if !field_map.contains_key(arg) {
return Err(syn::Error::new_spanned(
path,
format!(
"a corresponding request path argument field for \"{}\" does not exist",
arg
),
));
}
}
Ok(())
}
}
/// A field of the request struct.
struct RequestField {
inner: Field,
kind: RequestFieldKind,
pub(super) struct RequestField {
pub(super) inner: Field,
pub(super) kind: RequestFieldKind,
}
/// The kind of a request field.
enum RequestFieldKind {
pub(super) enum RequestFieldKind {
/// JSON data in the body of the request.
Body,

View File

@ -24,8 +24,7 @@ impl Request {
// except this one. If we get errors about missing fields in IncomingRequest for
// a path field look here.
let (parse_request_path, path_vars) = if self.has_path_fields() {
let path_vars: Vec<_> =
self.path_fields_ordered().filter_map(|f| f.ident.as_ref()).collect();
let path_vars: Vec<_> = self.path_fields().filter_map(|f| f.ident.as_ref()).collect();
let parse_request_path = quote! {
let (#(#path_vars,)*) = #serde::Deserialize::deserialize(

View File

@ -14,7 +14,7 @@ impl Request {
let error_ty = &self.error_ty;
let path_fields =
self.path_fields_ordered().map(|f| f.ident.as_ref().expect("path fields have a name"));
self.path_fields().map(|f| f.ident.as_ref().expect("path fields have a name"));
let request_query_string = if let Some(field) = self.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have identifier");