From abd702cfbcc1c8cbf439e3817ab80baa4528d253 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 2 Feb 2022 11:57:29 +0100 Subject: [PATCH] api: Don't extract request path arguments in IncomingRequest impls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … instead requiring callers to pass them as a list of strings. Parsing is still done within the trait implementations though. --- .../ruma-api-macros/src/request/incoming.rs | 49 +++++++---------- crates/ruma-api/src/lib.rs | 15 ++++-- crates/ruma-api/tests/conversions.rs | 2 +- crates/ruma-api/tests/manual_endpoint_impl.rs | 22 ++++---- .../src/r0/filter/create_filter.rs | 1 + .../src/r0/membership/get_member_events.rs | 1 + .../src/r0/profile/set_avatar_url.rs | 2 + .../src/r0/state/get_state_events_for_key.rs | 54 +++++++++---------- .../src/r0/state/send_state_event.rs | 50 ++++++++--------- .../src/r0/sync/sync_events.rs | 3 ++ 10 files changed, 101 insertions(+), 98 deletions(-) diff --git a/crates/ruma-api-macros/src/request/incoming.rs b/crates/ruma-api-macros/src/request/incoming.rs index 3190ad60..b0bc76be 100644 --- a/crates/ruma-api-macros/src/request/incoming.rs +++ b/crates/ruma-api-macros/src/request/incoming.rs @@ -8,8 +8,8 @@ use crate::auth_scheme::AuthScheme; impl Request { pub fn expand_incoming(&self, ruma_api: &TokenStream) -> TokenStream { let http = quote! { #ruma_api::exports::http }; - let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; + let serde = quote! { #ruma_api::exports::serde }; let serde_json = quote! { #ruma_api::exports::serde_json }; let method = &self.method; @@ -33,36 +33,22 @@ impl Request { "number of declared path parameters needs to match amount of placeholders in path" ); - let path_var_decls = path_string[1..] - .split('/') - .enumerate() - .filter(|(_, seg)| seg.starts_with(':')) - .map(|(i, seg)| { - let path_var = Ident::new(&seg[1..], Span::call_site()); - quote! { - let #path_var = { - let segment = path_segments[#i].as_bytes(); - let decoded = - #percent_encoding::percent_decode(segment).decode_utf8()?; - - ::std::convert::TryFrom::try_from(&*decoded)? - }; - } - }); - - let parse_request_path = quote! { - let path_segments: ::std::vec::Vec<&::std::primitive::str> = - request.uri().path()[1..].split('/').collect(); - - #(#path_var_decls)* - }; - let path_vars = path_string[1..] .split('/') .filter(|seg| seg.starts_with(':')) .map(|seg| Ident::new(&seg[1..], Span::call_site())); - (parse_request_path, quote! { #(#path_vars,)* }) + let vars = path_vars.clone(); + + let parse_request_path = quote! { + let (#(#path_vars,)*) = #serde::Deserialize::deserialize( + #serde::de::value::SeqDeserializer::<_, #serde::de::value::Error>::new( + path_args.iter().map(::std::convert::AsRef::as_ref) + ) + )?; + }; + + (parse_request_path, quote! { #(#vars,)* }) } else { (TokenStream::new(), TokenStream::new()) }; @@ -226,9 +212,14 @@ impl Request { const METADATA: #ruma_api::Metadata = self::METADATA; - fn try_from_http_request>( - request: #http::Request, - ) -> ::std::result::Result { + fn try_from_http_request( + request: #http::Request, + path_args: &[S], + ) -> ::std::result::Result + where + B: ::std::convert::AsRef<[::std::primitive::u8]>, + S: ::std::convert::AsRef<::std::primitive::str>, + { if request.method() != #http::Method::#method { return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch { expected: #http::Method::#method, diff --git a/crates/ruma-api/src/lib.rs b/crates/ruma-api/src/lib.rs index 2b69ac03..72142db5 100644 --- a/crates/ruma-api/src/lib.rs +++ b/crates/ruma-api/src/lib.rs @@ -336,10 +336,17 @@ pub trait IncomingRequest: Sized { /// Metadata about the endpoint. const METADATA: Metadata; - /// Tries to turn the given `http::Request` into this request type. - fn try_from_http_request>( - req: http::Request, - ) -> Result; + /// Tries to turn the given `http::Request` into this request type, + /// together with the corresponding path arguments. + /// + /// Note: The strings in path_args need to be percent-decoded. + fn try_from_http_request( + req: http::Request, + path_args: &[S], + ) -> Result + where + B: AsRef<[u8]>, + S: AsRef; } /// A request type for a Matrix API endpoint, used for sending responses. diff --git a/crates/ruma-api/tests/conversions.rs b/crates/ruma-api/tests/conversions.rs index 91d2b230..61a89113 100644 --- a/crates/ruma-api/tests/conversions.rs +++ b/crates/ruma-api/tests/conversions.rs @@ -54,7 +54,7 @@ fn request_serde() { .clone() .try_into_http_request::>("https://homeserver.tld", SendAccessToken::None) .unwrap(); - let req2 = Request::try_from_http_request(http_req).unwrap(); + let req2 = Request::try_from_http_request(http_req, &["barVal", "@bazme:ruma.io"]).unwrap(); assert_eq!(req.hello, req2.hello); assert_eq!(req.world, req2.world); diff --git a/crates/ruma-api/tests/manual_endpoint_impl.rs b/crates/ruma-api/tests/manual_endpoint_impl.rs index 87edb4e7..24580f63 100644 --- a/crates/ruma-api/tests/manual_endpoint_impl.rs +++ b/crates/ruma-api/tests/manual_endpoint_impl.rs @@ -2,8 +2,6 @@ #![allow(clippy::exhaustive_structs)] -use std::convert::TryFrom; - use bytes::BufMut; use http::{header::CONTENT_TYPE, method::Method}; use ruma_api::{ @@ -69,16 +67,16 @@ impl IncomingRequest for Request { const METADATA: Metadata = METADATA; - fn try_from_http_request>( - request: http::Request, - ) -> Result { - let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); - let room_alias = { - let decoded = - percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?; - - TryFrom::try_from(&*decoded)? - }; + fn try_from_http_request( + request: http::Request, + path_args: &[S], + ) -> Result where B: AsRef<[u8]>, S: AsRef { + let (room_alias,) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::< + _, + serde::de::value::Error, + >::new( + path_args.iter().map(::std::convert::AsRef::as_ref), + ))?; let request_body: RequestBody = serde_json::from_slice(request.body().as_ref())?; diff --git a/crates/ruma-client-api/src/r0/filter/create_filter.rs b/crates/ruma-client-api/src/r0/filter/create_filter.rs index 43b1ce44..a4e740df 100644 --- a/crates/ruma-client-api/src/r0/filter/create_filter.rs +++ b/crates/ruma-client-api/src/r0/filter/create_filter.rs @@ -67,6 +67,7 @@ mod tests { .uri("https://matrix.org/_matrix/client/r0/user/@foo:bar.com/filter") .body(b"{}" as &[u8]) .unwrap(), + &["@foo:bar.com"] ), Ok(IncomingRequest { user_id, filter }) if user_id == "@foo:bar.com" && filter.is_empty() diff --git a/crates/ruma-client-api/src/r0/membership/get_member_events.rs b/crates/ruma-client-api/src/r0/membership/get_member_events.rs index 6814c839..3c9ea3a8 100644 --- a/crates/ruma-client-api/src/r0/membership/get_member_events.rs +++ b/crates/ruma-client-api/src/r0/membership/get_member_events.rs @@ -122,6 +122,7 @@ mod tests { let req = IncomingRequest::try_from_http_request( http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), + &["!dummy:example.org"], ); assert_matches!( diff --git a/crates/ruma-client-api/src/r0/profile/set_avatar_url.rs b/crates/ruma-client-api/src/r0/profile/set_avatar_url.rs index 9e9ea395..f0b93c03 100644 --- a/crates/ruma-client-api/src/r0/profile/set_avatar_url.rs +++ b/crates/ruma-client-api/src/r0/profile/set_avatar_url.rs @@ -87,6 +87,7 @@ mod tests { .method("PUT") .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") .body(&[] as &[u8]).unwrap(), + &["@foo:bar.org"], ).unwrap(), IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org" ); @@ -99,6 +100,7 @@ mod tests { .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") .body(serde_json::to_vec(&serde_json::json!({ "avatar_url": "" })).unwrap()) .unwrap(), + &["@foo:bar.org"], ).unwrap(), IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org" ); diff --git a/crates/ruma-client-api/src/r0/state/get_state_events_for_key.rs b/crates/ruma-client-api/src/r0/state/get_state_events_for_key.rs index 44be8af9..bdfc582d 100644 --- a/crates/ruma-client-api/src/r0/state/get_state_events_for_key.rs +++ b/crates/ruma-client-api/src/r0/state/get_state_events_for_key.rs @@ -112,35 +112,33 @@ impl ruma_api::IncomingRequest for IncomingRequest { const METADATA: ruma_api::Metadata = METADATA; - fn try_from_http_request>( - request: http::Request, - ) -> Result { - use std::convert::TryFrom; + fn try_from_http_request( + _request: http::Request, + path_args: &[S], + ) -> Result + where + B: AsRef<[u8]>, + S: AsRef, + { + // FIXME: find a way to make this if-else collapse with serde recognizing trailing Option + let (room_id, event_type, state_key): (Box, EventType, String) = + if path_args.len() == 3 { + serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::< + _, + serde::de::value::Error, + >::new( + path_args.iter().map(::std::convert::AsRef::as_ref), + ))? + } else { + let (a, b) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::< + _, + serde::de::value::Error, + >::new( + path_args.iter().map(::std::convert::AsRef::as_ref), + ))?; - let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); - - let room_id = { - let decoded = - percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8()?; - - Box::::try_from(&*decoded)? - }; - - let event_type = { - let decoded = - percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?; - - EventType::try_from(&*decoded)? - }; - - let state_key = match path_segments.get(7) { - Some(segment) => { - let decoded = percent_encoding::percent_decode(segment.as_bytes()).decode_utf8()?; - - String::try_from(&*decoded)? - } - None => "".into(), - }; + (a, b, "".into()) + }; Ok(Self { room_id, event_type, state_key }) } diff --git a/crates/ruma-client-api/src/r0/state/send_state_event.rs b/crates/ruma-client-api/src/r0/state/send_state_event.rs index af5c7943..61f7bc5e 100644 --- a/crates/ruma-client-api/src/r0/state/send_state_event.rs +++ b/crates/ruma-client-api/src/r0/state/send_state_event.rs @@ -138,31 +138,33 @@ impl ruma_api::IncomingRequest for IncomingRequest { const METADATA: ruma_api::Metadata = METADATA; - fn try_from_http_request>( - request: http::Request, - ) -> Result { - use std::{borrow::Cow, convert::TryFrom}; + fn try_from_http_request( + request: http::Request, + path_args: &[S], + ) -> Result + where + B: AsRef<[u8]>, + S: AsRef, + { + // FIXME: find a way to make this if-else collapse with serde recognizing trailing Option + let (room_id, event_type, state_key): (Box, String, String) = + if path_args.len() == 3 { + serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::< + _, + serde::de::value::Error, + >::new( + path_args.iter().map(::std::convert::AsRef::as_ref), + ))? + } else { + let (a, b) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::< + _, + serde::de::value::Error, + >::new( + path_args.iter().map(::std::convert::AsRef::as_ref), + ))?; - let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); - - let room_id = { - let decoded = - percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8()?; - - Box::::try_from(&*decoded)? - }; - - let event_type = percent_encoding::percent_decode(path_segments[6].as_bytes()) - .decode_utf8()? - .into_owned(); - - let state_key = path_segments - .get(7) - .map(|segment| percent_encoding::percent_decode(segment.as_bytes()).decode_utf8()) - .transpose()? - // Last URL segment is optional, but not present is the same semantically as empty - .unwrap_or(Cow::Borrowed("")) - .into_owned(); + (a, b, "".into()) + }; let body = serde_json::from_slice(request.body().as_ref())?; diff --git a/crates/ruma-client-api/src/r0/sync/sync_events.rs b/crates/ruma-client-api/src/r0/sync/sync_events.rs index ddc6ee0e..df1af961 100644 --- a/crates/ruma-client-api/src/r0/sync/sync_events.rs +++ b/crates/ruma-client-api/src/r0/sync/sync_events.rs @@ -672,6 +672,7 @@ mod server_tests { let req = IncomingRequest::try_from_http_request( http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), + &[] as &[String], ) .unwrap(); @@ -693,6 +694,7 @@ mod server_tests { let req = IncomingRequest::try_from_http_request( http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), + &[] as &[String], ) .unwrap(); @@ -718,6 +720,7 @@ mod server_tests { let req = IncomingRequest::try_from_http_request( http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), + &[] as &[String], ) .unwrap();