api: Move the majority of endpoint URL building out of macro code

This commit is contained in:
Jonas Platte 2022-09-29 12:40:14 +02:00 committed by Jonas Platte
parent a6e23d731e
commit 8290d712f2
9 changed files with 156 additions and 234 deletions

View File

@ -85,25 +85,13 @@ pub mod v3 {
use http::header; use http::header;
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
let room_id_percent = utf8_percent_encode(self.room_id.as_str(), NON_ALPHANUMERIC); let mut url = ruma_common::api::make_endpoint_url(
let event_type = self.event_type.to_string();
let event_type_percent = utf8_percent_encode(&event_type, NON_ALPHANUMERIC);
let mut url = format!(
"{}{}",
base_url.strip_suffix('/').unwrap_or(base_url),
ruma_common::api::select_path(
considering_versions,
&METADATA, &METADATA,
considering_versions,
base_url,
&[&self.room_id, &self.event_type],
None, None,
Some(format_args!( )?;
"/_matrix/client/r0/rooms/{room_id_percent}/state/{event_type_percent}",
)),
Some(format_args!(
"/_matrix/client/v3/rooms/{room_id_percent}/state/{event_type_percent}",
)),
)?
);
if !self.state_key.is_empty() { if !self.state_key.is_empty() {
url.push('/'); url.push('/');

View File

@ -127,25 +127,13 @@ pub mod v3 {
use http::header::{self, HeaderValue}; use http::header::{self, HeaderValue};
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
let room_id_percent = utf8_percent_encode(self.room_id.as_str(), NON_ALPHANUMERIC); let mut url = ruma_common::api::make_endpoint_url(
let event_type = self.event_type.to_string();
let event_type_percent = utf8_percent_encode(&event_type, NON_ALPHANUMERIC);
let mut url = format!(
"{}{}",
base_url.strip_suffix('/').unwrap_or(base_url),
ruma_common::api::select_path(
considering_versions,
&METADATA, &METADATA,
considering_versions,
base_url,
&[&self.room_id, &self.event_type],
None, None,
Some(format_args!( )?;
"/_matrix/client/r0/rooms/{room_id_percent}/state/{event_type_percent}",
)),
Some(format_args!(
"/_matrix/client/v3/rooms/{room_id_percent}/state/{event_type_percent}",
)),
)?
);
// Last URL segment is optional, that is why this trait impl is not generated. // Last URL segment is optional, that is why this trait impl is not generated.
if !self.state_key.is_empty() { if !self.state_key.is_empty() {

View File

@ -12,9 +12,14 @@
//! //!
//! [apis]: https://spec.matrix.org/v1.2/#matrix-apis //! [apis]: https://spec.matrix.org/v1.2/#matrix-apis
use std::{convert::TryInto as _, error::Error as StdError, fmt}; use std::{
convert::TryInto as _,
error::Error as StdError,
fmt::{Display, Write},
};
use bytes::BufMut; use bytes::BufMut;
use percent_encoding::utf8_percent_encode;
use tracing::warn; use tracing::warn;
use crate::UserId; use crate::UserId;
@ -402,28 +407,50 @@ pub enum AuthScheme {
// This function needs to be public, yet hidden, as it is used the code generated by `ruma_api!`. // This function needs to be public, yet hidden, as it is used the code generated by `ruma_api!`.
#[doc(hidden)] #[doc(hidden)]
pub fn make_endpoint_url( pub fn make_endpoint_url(
metadata: &Metadata,
versions: &[MatrixVersion],
base_url: &str, base_url: &str,
path: fmt::Arguments<'_>, path_args: &[&dyn Display],
query_string: Option<&str>, query_string: Option<&str>,
) -> String { ) -> Result<String, IntoHttpError> {
let base = base_url.strip_suffix('/').unwrap_or(base_url); let path_with_placeholders = select_path(metadata, versions)?;
match query_string {
Some(query) => format!("{base}{path}?{query}"), let mut res = base_url.strip_suffix('/').unwrap_or(base_url).to_owned();
None => format!("{base}{path}"), let mut segments = path_with_placeholders.split('/');
let mut path_args = path_args.iter();
let first_segment = segments.next().expect("split iterator is never empty");
assert!(first_segment.is_empty(), "endpoint paths must start with '/'");
for segment in segments {
if segment.starts_with(':') {
let arg = path_args
.next()
.expect("number of placeholders must match number of arguments")
.to_string();
let arg = utf8_percent_encode(&arg, percent_encoding::NON_ALPHANUMERIC);
write!(res, "/{arg}").expect("writing to a String using fmt::Write can't fail");
} else {
res.reserve(segment.len() + 1);
res.push('/');
res.push_str(segment);
} }
}
if let Some(query) = query_string {
res.push('?');
res.push_str(query);
}
Ok(res)
} }
// This function helps picks the right path (or an error) from a set of matrix versions. // This function helps picks the right path (or an error) from a set of matrix versions.
// fn select_path<'a>(
// This function needs to be public, yet hidden, as it is used the code generated by `ruma_api!`. metadata: &'a Metadata,
#[doc(hidden)] versions: &[MatrixVersion],
pub fn select_path<'a>( ) -> Result<&'a str, IntoHttpError> {
versions: &'_ [MatrixVersion],
metadata: &'_ Metadata,
unstable: Option<fmt::Arguments<'a>>,
r0: Option<fmt::Arguments<'a>>,
stable: Option<fmt::Arguments<'a>>,
) -> Result<fmt::Arguments<'a>, IntoHttpError> {
match metadata.versioning_decision_for(versions) { match metadata.versioning_decision_for(versions) {
VersioningDecision::Removed => Err(IntoHttpError::EndpointRemoved( VersioningDecision::Removed => Err(IntoHttpError::EndpointRemoved(
metadata.removed.expect("VersioningDecision::Removed implies metadata.removed"), metadata.removed.expect("VersioningDecision::Removed implies metadata.removed"),
@ -457,15 +484,95 @@ pub fn select_path<'a>(
); );
} }
if let Some(r0) = r0 { if let Some(r0) = metadata.r0_path {
if versions.iter().all(|&v| v == MatrixVersion::V1_0) { if versions.iter().all(|&v| v == MatrixVersion::V1_0) {
// Endpoint was added in 1.0, we return the r0 variant. // Endpoint was added in 1.0, we return the r0 variant.
return Ok(r0); return Ok(r0);
} }
} }
Ok(stable.expect("metadata.added enforces the stable path to exist")) Ok(metadata.stable_path.expect("metadata.added enforces the stable path to exist"))
} }
VersioningDecision::Unstable => unstable.ok_or(IntoHttpError::NoUnstablePath), VersioningDecision::Unstable => metadata.unstable_path.ok_or(IntoHttpError::NoUnstablePath),
}
}
#[cfg(test)]
mod tests {
use super::{
error::IntoHttpError,
select_path, AuthScheme,
MatrixVersion::{V1_0, V1_1, V1_2},
Metadata,
};
use assert_matches::assert_matches;
use http::Method;
const BASE: Metadata = Metadata {
description: "",
method: Method::GET,
name: "test_endpoint",
unstable_path: None,
r0_path: None,
stable_path: None,
rate_limited: false,
authentication: AuthScheme::None,
added: None,
deprecated: None,
removed: None,
};
// TODO add test that can hook into tracing and verify the deprecation warning is emitted
#[test]
fn select_stable() {
let meta = Metadata { added: Some(V1_1), stable_path: Some("s"), ..BASE };
assert_matches!(select_path(&meta, &[V1_0, V1_1]), Ok("s"));
}
#[test]
fn select_unstable() {
let meta = Metadata { unstable_path: Some("u"), ..BASE };
assert_matches!(select_path(&meta, &[V1_0]), Ok("u"));
}
#[test]
fn select_r0() {
let meta = Metadata { added: Some(V1_0), r0_path: Some("r"), ..BASE };
assert_matches!(select_path(&meta, &[V1_0]), Ok("r"));
}
#[test]
fn select_removed_err() {
let meta = Metadata {
added: Some(V1_0),
deprecated: Some(V1_1),
removed: Some(V1_2),
unstable_path: Some("u"),
r0_path: Some("r"),
stable_path: Some("s"),
..BASE
};
assert_matches!(select_path(&meta, &[V1_2]), Err(IntoHttpError::EndpointRemoved(V1_2)));
}
#[test]
fn partially_removed_but_stable() {
let meta = Metadata {
added: Some(V1_0),
deprecated: Some(V1_1),
removed: Some(V1_2),
r0_path: Some("r"),
stable_path: Some("s"),
..BASE
};
assert_matches!(select_path(&meta, &[V1_1]), Ok("s"));
}
#[test]
fn no_unstable() {
let meta =
Metadata { added: Some(V1_1), r0_path: Some("r"), stable_path: Some("s"), ..BASE };
assert_matches!(select_path(&meta, &[V1_0]), Err(IntoHttpError::NoUnstablePath));
} }
} }

View File

@ -72,7 +72,6 @@ pub mod exports {
pub use bytes; pub use bytes;
#[cfg(feature = "api")] #[cfg(feature = "api")]
pub use http; pub use http;
pub use percent_encoding;
pub use ruma_macros; pub use ruma_macros;
pub use serde; pub use serde;
pub use serde_json; pub use serde_json;

View File

@ -49,17 +49,13 @@ impl OutgoingRequest for Request {
_access_token: SendAccessToken<'_>, _access_token: SendAccessToken<'_>,
considering_versions: &'_ [MatrixVersion], considering_versions: &'_ [MatrixVersion],
) -> Result<http::Request<T>, IntoHttpError> { ) -> Result<http::Request<T>, IntoHttpError> {
let url = format!( let url = ruma_common::api::make_endpoint_url(
"{}{}",
base_url,
ruma_common::api::select_path(
considering_versions,
&METADATA, &METADATA,
Some(format_args!("/_matrix/client/unstable/directory/room/{}", self.room_alias)), considering_versions,
Some(format_args!("/_matrix/client/r0/directory/room/{}", self.room_alias)), base_url,
Some(format_args!("/_matrix/client/v3/directory/room/{}", self.room_alias)), &[&self.room_alias],
)? None,
); )?;
let request_body = RequestBody { room_id: self.room_id }; let request_body = RequestBody { room_id: self.room_id };

View File

@ -9,4 +9,3 @@ mod path_arg_ordering;
mod ruma_api; mod ruma_api;
mod ruma_api_lifetime; mod ruma_api_lifetime;
mod ruma_api_macros; mod ruma_api_macros;
mod select_path;

View File

@ -1,100 +0,0 @@
use assert_matches::assert_matches;
use http::Method;
use ruma_common::api::{
error::IntoHttpError,
select_path,
MatrixVersion::{V1_0, V1_1, V1_2},
Metadata,
};
const BASE: Metadata = Metadata {
description: "",
method: Method::GET,
name: "test_endpoint",
unstable_path: Some("/unstable/path"),
r0_path: Some("/r0/path"),
stable_path: Some("/stable/path"),
rate_limited: false,
authentication: ruma_common::api::AuthScheme::None,
added: None,
deprecated: None,
removed: None,
};
const U: &str = "u";
const S: &str = "s";
const R: &str = "r";
// TODO add test that can hook into tracing and verify the deprecation warning is emitted
#[test]
fn select_stable() {
let meta = Metadata { added: Some(V1_1), ..BASE };
let res = select_path(&[V1_0, V1_1], &meta, None, None, Some(format_args!("{S}")))
.unwrap()
.to_string();
assert_eq!(res, S);
}
#[test]
fn select_unstable() {
let meta = BASE;
let res =
select_path(&[V1_0], &meta, Some(format_args!("{U}")), None, None).unwrap().to_string();
assert_eq!(res, U);
}
#[test]
fn select_r0() {
let meta = Metadata { added: Some(V1_0), ..BASE };
let res =
select_path(&[V1_0], &meta, None, Some(format_args!("{R}")), Some(format_args!("{S}")))
.unwrap()
.to_string();
assert_eq!(res, R);
}
#[test]
fn select_removed_err() {
let meta = Metadata { added: Some(V1_0), deprecated: Some(V1_1), removed: Some(V1_2), ..BASE };
let res = select_path(
&[V1_2],
&meta,
Some(format_args!("{U}")),
Some(format_args!("{R}")),
Some(format_args!("{S}")),
)
.unwrap_err();
assert_matches!(res, IntoHttpError::EndpointRemoved(V1_2));
}
#[test]
fn partially_removed_but_stable() {
let meta = Metadata { added: Some(V1_0), deprecated: Some(V1_1), removed: Some(V1_2), ..BASE };
let res =
select_path(&[V1_1], &meta, None, Some(format_args!("{R}")), Some(format_args!("{S}")))
.unwrap()
.to_string();
assert_eq!(res, S);
}
#[test]
fn no_unstable() {
let meta = Metadata { added: Some(V1_1), ..BASE };
let res =
select_path(&[V1_0], &meta, None, Some(format_args!("{R}")), Some(format_args!("{S}")))
.unwrap_err();
assert_matches!(res, IntoHttpError::NoUnstablePath);
}

View File

@ -1,40 +1,20 @@
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::quote; use quote::quote;
use syn::{Field, LitStr}; use syn::Field;
use super::{Request, RequestField}; use super::{Request, RequestField};
use crate::api::{auth_scheme::AuthScheme, util}; use crate::api::auth_scheme::AuthScheme;
impl Request { impl Request {
pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream { pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream {
let bytes = quote! { #ruma_common::exports::bytes }; let bytes = quote! { #ruma_common::exports::bytes };
let http = quote! { #ruma_common::exports::http }; let http = quote! { #ruma_common::exports::http };
let percent_encoding = quote! { #ruma_common::exports::percent_encoding };
let method = &self.method; let method = &self.method;
let error_ty = &self.error_ty; let error_ty = &self.error_ty;
let (unstable_path, r0_path, stable_path) = if self.has_path_fields() { let path_fields =
let path_format_args_call_with_percent_encoding = |s: &LitStr| -> TokenStream { self.path_fields_ordered().map(|f| f.ident.as_ref().expect("path fields have a name"));
util::path_format_args_call(s.value(), &percent_encoding)
};
(
self.unstable_path.as_ref().map(path_format_args_call_with_percent_encoding),
self.r0_path.as_ref().map(path_format_args_call_with_percent_encoding),
self.stable_path.as_ref().map(path_format_args_call_with_percent_encoding),
)
} else {
(
self.unstable_path.as_ref().map(|path| quote! { format_args!(#path) }),
self.r0_path.as_ref().map(|path| quote! { format_args!(#path) }),
self.stable_path.as_ref().map(|path| quote! { format_args!(#path) }),
)
};
let unstable_path = util::map_option_literal(&unstable_path);
let r0_path = util::map_option_literal(&r0_path);
let stable_path = util::map_option_literal(&stable_path);
let request_query_string = if let Some(field) = self.query_map_field() { let request_query_string = if let Some(field) = self.query_map_field() {
let field_name = field.ident.as_ref().expect("expected field to have identifier"); let field_name = field.ident.as_ref().expect("expected field to have identifier");
@ -194,16 +174,12 @@ impl Request {
let mut req_builder = #http::Request::builder() let mut req_builder = #http::Request::builder()
.method(#http::Method::#method) .method(#http::Method::#method)
.uri(#ruma_common::api::make_endpoint_url( .uri(#ruma_common::api::make_endpoint_url(
base_url,
#ruma_common::api::select_path(
considering_versions,
&metadata, &metadata,
#unstable_path, considering_versions,
#r0_path, base_url,
#stable_path, &[ #( &self.#path_fields ),* ],
)?,
#request_query_string, #request_query_string,
)); )?);
if let Some(mut req_headers) = req_builder.headers_mut() { if let Some(mut req_headers) = req_builder.headers_mut() {
#header_kvs #header_kvs

View File

@ -2,7 +2,7 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::TokenStream;
use quote::{quote, ToTokens}; use quote::{quote, ToTokens};
use syn::{parse_quote, visit::Visit, Attribute, Lifetime, NestedMeta, Type}; use syn::{parse_quote, visit::Visit, Attribute, Lifetime, NestedMeta, Type};
@ -54,34 +54,3 @@ pub fn extract_cfg(attr: &Attribute) -> Option<NestedMeta> {
Some(list.nested.pop().unwrap().into_value()) Some(list.nested.pop().unwrap().into_value())
} }
pub fn path_format_args_call(
mut format_string: String,
percent_encoding: &TokenStream,
) -> TokenStream {
let mut format_args = Vec::new();
while let Some(start_of_segment) = format_string.find(':') {
// ':' should only ever appear at the start of a segment
assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/");
let end_of_segment = match format_string[start_of_segment..].find('/') {
Some(rel_pos) => start_of_segment + rel_pos,
None => format_string.len(),
};
let path_var =
Ident::new(&format_string[start_of_segment + 1..end_of_segment], Span::call_site());
format_args.push(quote! {
#percent_encoding::utf8_percent_encode(
&::std::string::ToString::to_string(&self.#path_var),
#percent_encoding::NON_ALPHANUMERIC,
)
});
format_string.replace_range(start_of_segment..end_of_segment, "{}");
}
quote! {
format_args!(#format_string, #(#format_args),*)
}
}