diff --git a/crates/ruma-client-api/src/state/get_state_events_for_key.rs b/crates/ruma-client-api/src/state/get_state_events_for_key.rs index 25b9249a..1ca51e87 100644 --- a/crates/ruma-client-api/src/state/get_state_events_for_key.rs +++ b/crates/ruma-client-api/src/state/get_state_events_for_key.rs @@ -85,25 +85,13 @@ pub mod v3 { use http::header; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; - let room_id_percent = utf8_percent_encode(self.room_id.as_str(), NON_ALPHANUMERIC); - let event_type = self.event_type.to_string(); - let event_type_percent = utf8_percent_encode(&event_type, NON_ALPHANUMERIC); - - let mut url = format!( - "{}{}", - base_url.strip_suffix('/').unwrap_or(base_url), - ruma_common::api::select_path( - considering_versions, - &METADATA, - None, - Some(format_args!( - "/_matrix/client/r0/rooms/{room_id_percent}/state/{event_type_percent}", - )), - Some(format_args!( - "/_matrix/client/v3/rooms/{room_id_percent}/state/{event_type_percent}", - )), - )? - ); + let mut url = ruma_common::api::make_endpoint_url( + &METADATA, + considering_versions, + base_url, + &[&self.room_id, &self.event_type], + None, + )?; if !self.state_key.is_empty() { url.push('/'); diff --git a/crates/ruma-client-api/src/state/send_state_event.rs b/crates/ruma-client-api/src/state/send_state_event.rs index f43f9f15..e93e445b 100644 --- a/crates/ruma-client-api/src/state/send_state_event.rs +++ b/crates/ruma-client-api/src/state/send_state_event.rs @@ -127,25 +127,13 @@ pub mod v3 { use http::header::{self, HeaderValue}; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; - let room_id_percent = utf8_percent_encode(self.room_id.as_str(), NON_ALPHANUMERIC); - let event_type = self.event_type.to_string(); - let event_type_percent = utf8_percent_encode(&event_type, NON_ALPHANUMERIC); - - let mut url = format!( - "{}{}", - base_url.strip_suffix('/').unwrap_or(base_url), - ruma_common::api::select_path( - considering_versions, - &METADATA, - None, - Some(format_args!( - "/_matrix/client/r0/rooms/{room_id_percent}/state/{event_type_percent}", - )), - Some(format_args!( - "/_matrix/client/v3/rooms/{room_id_percent}/state/{event_type_percent}", - )), - )? - ); + let mut url = ruma_common::api::make_endpoint_url( + &METADATA, + considering_versions, + base_url, + &[&self.room_id, &self.event_type], + None, + )?; // Last URL segment is optional, that is why this trait impl is not generated. if !self.state_key.is_empty() { diff --git a/crates/ruma-common/src/api.rs b/crates/ruma-common/src/api.rs index 5b3cb229..b56fea70 100644 --- a/crates/ruma-common/src/api.rs +++ b/crates/ruma-common/src/api.rs @@ -12,9 +12,14 @@ //! //! [apis]: https://spec.matrix.org/v1.2/#matrix-apis -use std::{convert::TryInto as _, error::Error as StdError, fmt}; +use std::{ + convert::TryInto as _, + error::Error as StdError, + fmt::{Display, Write}, +}; use bytes::BufMut; +use percent_encoding::utf8_percent_encode; use tracing::warn; use crate::UserId; @@ -402,28 +407,50 @@ pub enum AuthScheme { // This function needs to be public, yet hidden, as it is used the code generated by `ruma_api!`. #[doc(hidden)] pub fn make_endpoint_url( + metadata: &Metadata, + versions: &[MatrixVersion], base_url: &str, - path: fmt::Arguments<'_>, + path_args: &[&dyn Display], query_string: Option<&str>, -) -> String { - let base = base_url.strip_suffix('/').unwrap_or(base_url); - match query_string { - Some(query) => format!("{base}{path}?{query}"), - None => format!("{base}{path}"), +) -> Result { + let path_with_placeholders = select_path(metadata, versions)?; + + let mut res = base_url.strip_suffix('/').unwrap_or(base_url).to_owned(); + let mut segments = path_with_placeholders.split('/'); + let mut path_args = path_args.iter(); + + let first_segment = segments.next().expect("split iterator is never empty"); + assert!(first_segment.is_empty(), "endpoint paths must start with '/'"); + + for segment in segments { + if segment.starts_with(':') { + let arg = path_args + .next() + .expect("number of placeholders must match number of arguments") + .to_string(); + let arg = utf8_percent_encode(&arg, percent_encoding::NON_ALPHANUMERIC); + + write!(res, "/{arg}").expect("writing to a String using fmt::Write can't fail"); + } else { + res.reserve(segment.len() + 1); + res.push('/'); + res.push_str(segment); + } } + + if let Some(query) = query_string { + res.push('?'); + res.push_str(query); + } + + Ok(res) } // This function helps picks the right path (or an error) from a set of matrix versions. -// -// This function needs to be public, yet hidden, as it is used the code generated by `ruma_api!`. -#[doc(hidden)] -pub fn select_path<'a>( - versions: &'_ [MatrixVersion], - metadata: &'_ Metadata, - unstable: Option>, - r0: Option>, - stable: Option>, -) -> Result, IntoHttpError> { +fn select_path<'a>( + metadata: &'a Metadata, + versions: &[MatrixVersion], +) -> Result<&'a str, IntoHttpError> { match metadata.versioning_decision_for(versions) { VersioningDecision::Removed => Err(IntoHttpError::EndpointRemoved( metadata.removed.expect("VersioningDecision::Removed implies metadata.removed"), @@ -457,15 +484,95 @@ pub fn select_path<'a>( ); } - if let Some(r0) = r0 { + if let Some(r0) = metadata.r0_path { if versions.iter().all(|&v| v == MatrixVersion::V1_0) { // Endpoint was added in 1.0, we return the r0 variant. return Ok(r0); } } - Ok(stable.expect("metadata.added enforces the stable path to exist")) + Ok(metadata.stable_path.expect("metadata.added enforces the stable path to exist")) } - VersioningDecision::Unstable => unstable.ok_or(IntoHttpError::NoUnstablePath), + VersioningDecision::Unstable => metadata.unstable_path.ok_or(IntoHttpError::NoUnstablePath), + } +} + +#[cfg(test)] +mod tests { + use super::{ + error::IntoHttpError, + select_path, AuthScheme, + MatrixVersion::{V1_0, V1_1, V1_2}, + Metadata, + }; + use assert_matches::assert_matches; + use http::Method; + + const BASE: Metadata = Metadata { + description: "", + method: Method::GET, + name: "test_endpoint", + unstable_path: None, + r0_path: None, + stable_path: None, + rate_limited: false, + authentication: AuthScheme::None, + added: None, + deprecated: None, + removed: None, + }; + + // TODO add test that can hook into tracing and verify the deprecation warning is emitted + + #[test] + fn select_stable() { + let meta = Metadata { added: Some(V1_1), stable_path: Some("s"), ..BASE }; + assert_matches!(select_path(&meta, &[V1_0, V1_1]), Ok("s")); + } + + #[test] + fn select_unstable() { + let meta = Metadata { unstable_path: Some("u"), ..BASE }; + assert_matches!(select_path(&meta, &[V1_0]), Ok("u")); + } + + #[test] + fn select_r0() { + let meta = Metadata { added: Some(V1_0), r0_path: Some("r"), ..BASE }; + assert_matches!(select_path(&meta, &[V1_0]), Ok("r")); + } + + #[test] + fn select_removed_err() { + let meta = Metadata { + added: Some(V1_0), + deprecated: Some(V1_1), + removed: Some(V1_2), + unstable_path: Some("u"), + r0_path: Some("r"), + stable_path: Some("s"), + ..BASE + }; + assert_matches!(select_path(&meta, &[V1_2]), Err(IntoHttpError::EndpointRemoved(V1_2))); + } + + #[test] + fn partially_removed_but_stable() { + let meta = Metadata { + added: Some(V1_0), + deprecated: Some(V1_1), + removed: Some(V1_2), + r0_path: Some("r"), + stable_path: Some("s"), + ..BASE + }; + assert_matches!(select_path(&meta, &[V1_1]), Ok("s")); + } + + #[test] + fn no_unstable() { + let meta = + Metadata { added: Some(V1_1), r0_path: Some("r"), stable_path: Some("s"), ..BASE }; + assert_matches!(select_path(&meta, &[V1_0]), Err(IntoHttpError::NoUnstablePath)); } } diff --git a/crates/ruma-common/src/lib.rs b/crates/ruma-common/src/lib.rs index 8c9f23a2..4ef84a42 100644 --- a/crates/ruma-common/src/lib.rs +++ b/crates/ruma-common/src/lib.rs @@ -72,7 +72,6 @@ pub mod exports { pub use bytes; #[cfg(feature = "api")] pub use http; - pub use percent_encoding; pub use ruma_macros; pub use serde; pub use serde_json; diff --git a/crates/ruma-common/tests/api/manual_endpoint_impl.rs b/crates/ruma-common/tests/api/manual_endpoint_impl.rs index 6faa13c0..5c1965ef 100644 --- a/crates/ruma-common/tests/api/manual_endpoint_impl.rs +++ b/crates/ruma-common/tests/api/manual_endpoint_impl.rs @@ -49,17 +49,13 @@ impl OutgoingRequest for Request { _access_token: SendAccessToken<'_>, considering_versions: &'_ [MatrixVersion], ) -> Result, IntoHttpError> { - let url = format!( - "{}{}", + let url = ruma_common::api::make_endpoint_url( + &METADATA, + considering_versions, base_url, - ruma_common::api::select_path( - considering_versions, - &METADATA, - Some(format_args!("/_matrix/client/unstable/directory/room/{}", self.room_alias)), - Some(format_args!("/_matrix/client/r0/directory/room/{}", self.room_alias)), - Some(format_args!("/_matrix/client/v3/directory/room/{}", self.room_alias)), - )? - ); + &[&self.room_alias], + None, + )?; let request_body = RequestBody { room_id: self.room_id }; diff --git a/crates/ruma-common/tests/api/mod.rs b/crates/ruma-common/tests/api/mod.rs index 6fdca848..e61dcaf6 100644 --- a/crates/ruma-common/tests/api/mod.rs +++ b/crates/ruma-common/tests/api/mod.rs @@ -9,4 +9,3 @@ mod path_arg_ordering; mod ruma_api; mod ruma_api_lifetime; mod ruma_api_macros; -mod select_path; diff --git a/crates/ruma-common/tests/api/select_path.rs b/crates/ruma-common/tests/api/select_path.rs deleted file mode 100644 index 8a384cb2..00000000 --- a/crates/ruma-common/tests/api/select_path.rs +++ /dev/null @@ -1,100 +0,0 @@ -use assert_matches::assert_matches; -use http::Method; -use ruma_common::api::{ - error::IntoHttpError, - select_path, - MatrixVersion::{V1_0, V1_1, V1_2}, - Metadata, -}; - -const BASE: Metadata = Metadata { - description: "", - method: Method::GET, - name: "test_endpoint", - unstable_path: Some("/unstable/path"), - r0_path: Some("/r0/path"), - stable_path: Some("/stable/path"), - rate_limited: false, - authentication: ruma_common::api::AuthScheme::None, - added: None, - deprecated: None, - removed: None, -}; - -const U: &str = "u"; -const S: &str = "s"; -const R: &str = "r"; - -// TODO add test that can hook into tracing and verify the deprecation warning is emitted - -#[test] -fn select_stable() { - let meta = Metadata { added: Some(V1_1), ..BASE }; - - let res = select_path(&[V1_0, V1_1], &meta, None, None, Some(format_args!("{S}"))) - .unwrap() - .to_string(); - - assert_eq!(res, S); -} - -#[test] -fn select_unstable() { - let meta = BASE; - - let res = - select_path(&[V1_0], &meta, Some(format_args!("{U}")), None, None).unwrap().to_string(); - - assert_eq!(res, U); -} - -#[test] -fn select_r0() { - let meta = Metadata { added: Some(V1_0), ..BASE }; - - let res = - select_path(&[V1_0], &meta, None, Some(format_args!("{R}")), Some(format_args!("{S}"))) - .unwrap() - .to_string(); - - assert_eq!(res, R); -} - -#[test] -fn select_removed_err() { - let meta = Metadata { added: Some(V1_0), deprecated: Some(V1_1), removed: Some(V1_2), ..BASE }; - - let res = select_path( - &[V1_2], - &meta, - Some(format_args!("{U}")), - Some(format_args!("{R}")), - Some(format_args!("{S}")), - ) - .unwrap_err(); - - assert_matches!(res, IntoHttpError::EndpointRemoved(V1_2)); -} - -#[test] -fn partially_removed_but_stable() { - let meta = Metadata { added: Some(V1_0), deprecated: Some(V1_1), removed: Some(V1_2), ..BASE }; - - let res = - select_path(&[V1_1], &meta, None, Some(format_args!("{R}")), Some(format_args!("{S}"))) - .unwrap() - .to_string(); - - assert_eq!(res, S); -} - -#[test] -fn no_unstable() { - let meta = Metadata { added: Some(V1_1), ..BASE }; - - let res = - select_path(&[V1_0], &meta, None, Some(format_args!("{R}")), Some(format_args!("{S}"))) - .unwrap_err(); - - assert_matches!(res, IntoHttpError::NoUnstablePath); -} diff --git a/crates/ruma-macros/src/api/request/outgoing.rs b/crates/ruma-macros/src/api/request/outgoing.rs index 47ecfc8f..2d7884fc 100644 --- a/crates/ruma-macros/src/api/request/outgoing.rs +++ b/crates/ruma-macros/src/api/request/outgoing.rs @@ -1,40 +1,20 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{Field, LitStr}; +use syn::Field; use super::{Request, RequestField}; -use crate::api::{auth_scheme::AuthScheme, util}; +use crate::api::auth_scheme::AuthScheme; impl Request { pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream { let bytes = quote! { #ruma_common::exports::bytes }; let http = quote! { #ruma_common::exports::http }; - let percent_encoding = quote! { #ruma_common::exports::percent_encoding }; let method = &self.method; let error_ty = &self.error_ty; - let (unstable_path, r0_path, stable_path) = if self.has_path_fields() { - let path_format_args_call_with_percent_encoding = |s: &LitStr| -> TokenStream { - util::path_format_args_call(s.value(), &percent_encoding) - }; - - ( - self.unstable_path.as_ref().map(path_format_args_call_with_percent_encoding), - self.r0_path.as_ref().map(path_format_args_call_with_percent_encoding), - self.stable_path.as_ref().map(path_format_args_call_with_percent_encoding), - ) - } else { - ( - self.unstable_path.as_ref().map(|path| quote! { format_args!(#path) }), - self.r0_path.as_ref().map(|path| quote! { format_args!(#path) }), - self.stable_path.as_ref().map(|path| quote! { format_args!(#path) }), - ) - }; - - let unstable_path = util::map_option_literal(&unstable_path); - let r0_path = util::map_option_literal(&r0_path); - let stable_path = util::map_option_literal(&stable_path); + let path_fields = + self.path_fields_ordered().map(|f| f.ident.as_ref().expect("path fields have a name")); let request_query_string = if let Some(field) = self.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have identifier"); @@ -194,16 +174,12 @@ impl Request { let mut req_builder = #http::Request::builder() .method(#http::Method::#method) .uri(#ruma_common::api::make_endpoint_url( + &metadata, + considering_versions, base_url, - #ruma_common::api::select_path( - considering_versions, - &metadata, - #unstable_path, - #r0_path, - #stable_path, - )?, + &[ #( &self.#path_fields ),* ], #request_query_string, - )); + )?); if let Some(mut req_headers) = req_builder.headers_mut() { #header_kvs diff --git a/crates/ruma-macros/src/api/util.rs b/crates/ruma-macros/src/api/util.rs index 5449d3f3..dfc001a1 100644 --- a/crates/ruma-macros/src/api/util.rs +++ b/crates/ruma-macros/src/api/util.rs @@ -2,7 +2,7 @@ use std::collections::BTreeSet; -use proc_macro2::{Ident, Span, TokenStream}; +use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{parse_quote, visit::Visit, Attribute, Lifetime, NestedMeta, Type}; @@ -54,34 +54,3 @@ pub fn extract_cfg(attr: &Attribute) -> Option { Some(list.nested.pop().unwrap().into_value()) } - -pub fn path_format_args_call( - mut format_string: String, - percent_encoding: &TokenStream, -) -> TokenStream { - let mut format_args = Vec::new(); - - while let Some(start_of_segment) = format_string.find(':') { - // ':' should only ever appear at the start of a segment - assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); - - let end_of_segment = match format_string[start_of_segment..].find('/') { - Some(rel_pos) => start_of_segment + rel_pos, - None => format_string.len(), - }; - - let path_var = - Ident::new(&format_string[start_of_segment + 1..end_of_segment], Span::call_site()); - format_args.push(quote! { - #percent_encoding::utf8_percent_encode( - &::std::string::ToString::to_string(&self.#path_var), - #percent_encoding::NON_ALPHANUMERIC, - ) - }); - format_string.replace_range(start_of_segment..end_of_segment, "{}"); - } - - quote! { - format_args!(#format_string, #(#format_args),*) - } -}