api: Don't extract request path arguments in IncomingRequest impls

… instead requiring callers to pass them as a list of strings.
Parsing is still done within the trait implementations though.
This commit is contained in:
Jonathan de Jong 2022-02-02 11:57:29 +01:00 committed by GitHub
parent f7a10a7e47
commit abd702cfbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 101 additions and 98 deletions

View File

@ -8,8 +8,8 @@ use crate::auth_scheme::AuthScheme;
impl Request { impl Request {
pub fn expand_incoming(&self, ruma_api: &TokenStream) -> TokenStream { pub fn expand_incoming(&self, ruma_api: &TokenStream) -> TokenStream {
let http = quote! { #ruma_api::exports::http }; let http = quote! { #ruma_api::exports::http };
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
let serde = quote! { #ruma_api::exports::serde };
let serde_json = quote! { #ruma_api::exports::serde_json }; let serde_json = quote! { #ruma_api::exports::serde_json };
let method = &self.method; let method = &self.method;
@ -33,36 +33,22 @@ impl Request {
"number of declared path parameters needs to match amount of placeholders in path" "number of declared path parameters needs to match amount of placeholders in path"
); );
let path_var_decls = path_string[1..]
.split('/')
.enumerate()
.filter(|(_, seg)| seg.starts_with(':'))
.map(|(i, seg)| {
let path_var = Ident::new(&seg[1..], Span::call_site());
quote! {
let #path_var = {
let segment = path_segments[#i].as_bytes();
let decoded =
#percent_encoding::percent_decode(segment).decode_utf8()?;
::std::convert::TryFrom::try_from(&*decoded)?
};
}
});
let parse_request_path = quote! {
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
request.uri().path()[1..].split('/').collect();
#(#path_var_decls)*
};
let path_vars = path_string[1..] let path_vars = path_string[1..]
.split('/') .split('/')
.filter(|seg| seg.starts_with(':')) .filter(|seg| seg.starts_with(':'))
.map(|seg| Ident::new(&seg[1..], Span::call_site())); .map(|seg| Ident::new(&seg[1..], Span::call_site()));
(parse_request_path, quote! { #(#path_vars,)* }) let vars = path_vars.clone();
let parse_request_path = quote! {
let (#(#path_vars,)*) = #serde::Deserialize::deserialize(
#serde::de::value::SeqDeserializer::<_, #serde::de::value::Error>::new(
path_args.iter().map(::std::convert::AsRef::as_ref)
)
)?;
};
(parse_request_path, quote! { #(#vars,)* })
} else { } else {
(TokenStream::new(), TokenStream::new()) (TokenStream::new(), TokenStream::new())
}; };
@ -226,9 +212,14 @@ impl Request {
const METADATA: #ruma_api::Metadata = self::METADATA; const METADATA: #ruma_api::Metadata = self::METADATA;
fn try_from_http_request<T: ::std::convert::AsRef<[::std::primitive::u8]>>( fn try_from_http_request<B, S>(
request: #http::Request<T>, request: #http::Request<B>,
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> { path_args: &[S],
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError>
where
B: ::std::convert::AsRef<[::std::primitive::u8]>,
S: ::std::convert::AsRef<::std::primitive::str>,
{
if request.method() != #http::Method::#method { if request.method() != #http::Method::#method {
return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch { return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch {
expected: #http::Method::#method, expected: #http::Method::#method,

View File

@ -336,10 +336,17 @@ pub trait IncomingRequest: Sized {
/// Metadata about the endpoint. /// Metadata about the endpoint.
const METADATA: Metadata; const METADATA: Metadata;
/// Tries to turn the given `http::Request` into this request type. /// Tries to turn the given `http::Request` into this request type,
fn try_from_http_request<T: AsRef<[u8]>>( /// together with the corresponding path arguments.
req: http::Request<T>, ///
) -> Result<Self, FromHttpRequestError>; /// Note: The strings in path_args need to be percent-decoded.
fn try_from_http_request<B, S>(
req: http::Request<B>,
path_args: &[S],
) -> Result<Self, FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>;
} }
/// A request type for a Matrix API endpoint, used for sending responses. /// A request type for a Matrix API endpoint, used for sending responses.

View File

@ -54,7 +54,7 @@ fn request_serde() {
.clone() .clone()
.try_into_http_request::<Vec<u8>>("https://homeserver.tld", SendAccessToken::None) .try_into_http_request::<Vec<u8>>("https://homeserver.tld", SendAccessToken::None)
.unwrap(); .unwrap();
let req2 = Request::try_from_http_request(http_req).unwrap(); let req2 = Request::try_from_http_request(http_req, &["barVal", "@bazme:ruma.io"]).unwrap();
assert_eq!(req.hello, req2.hello); assert_eq!(req.hello, req2.hello);
assert_eq!(req.world, req2.world); assert_eq!(req.world, req2.world);

View File

@ -2,8 +2,6 @@
#![allow(clippy::exhaustive_structs)] #![allow(clippy::exhaustive_structs)]
use std::convert::TryFrom;
use bytes::BufMut; use bytes::BufMut;
use http::{header::CONTENT_TYPE, method::Method}; use http::{header::CONTENT_TYPE, method::Method};
use ruma_api::{ use ruma_api::{
@ -69,16 +67,16 @@ impl IncomingRequest for Request {
const METADATA: Metadata = METADATA; const METADATA: Metadata = METADATA;
fn try_from_http_request<T: AsRef<[u8]>>( fn try_from_http_request<B, S>(
request: http::Request<T>, request: http::Request<B>,
) -> Result<Self, FromHttpRequestError> { path_args: &[S],
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); ) -> Result<Self, FromHttpRequestError> where B: AsRef<[u8]>, S: AsRef<str> {
let room_alias = { let (room_alias,) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
let decoded = _,
percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?; serde::de::value::Error,
>::new(
TryFrom::try_from(&*decoded)? path_args.iter().map(::std::convert::AsRef::as_ref),
}; ))?;
let request_body: RequestBody = serde_json::from_slice(request.body().as_ref())?; let request_body: RequestBody = serde_json::from_slice(request.body().as_ref())?;

View File

@ -67,6 +67,7 @@ mod tests {
.uri("https://matrix.org/_matrix/client/r0/user/@foo:bar.com/filter") .uri("https://matrix.org/_matrix/client/r0/user/@foo:bar.com/filter")
.body(b"{}" as &[u8]) .body(b"{}" as &[u8])
.unwrap(), .unwrap(),
&["@foo:bar.com"]
), ),
Ok(IncomingRequest { user_id, filter }) Ok(IncomingRequest { user_id, filter })
if user_id == "@foo:bar.com" && filter.is_empty() if user_id == "@foo:bar.com" && filter.is_empty()

View File

@ -122,6 +122,7 @@ mod tests {
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&["!dummy:example.org"],
); );
assert_matches!( assert_matches!(

View File

@ -87,6 +87,7 @@ mod tests {
.method("PUT") .method("PUT")
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(&[] as &[u8]).unwrap(), .body(&[] as &[u8]).unwrap(),
&["@foo:bar.org"],
).unwrap(), ).unwrap(),
IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org" IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org"
); );
@ -99,6 +100,7 @@ mod tests {
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url") .uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(serde_json::to_vec(&serde_json::json!({ "avatar_url": "" })).unwrap()) .body(serde_json::to_vec(&serde_json::json!({ "avatar_url": "" })).unwrap())
.unwrap(), .unwrap(),
&["@foo:bar.org"],
).unwrap(), ).unwrap(),
IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org" IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org"
); );

View File

@ -112,35 +112,33 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA; const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request<T: AsRef<[u8]>>( fn try_from_http_request<B, S>(
request: http::Request<T>, _request: http::Request<B>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> { path_args: &[S],
use std::convert::TryFrom; ) -> Result<Self, ruma_api::error::FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
// FIXME: find a way to make this if-else collapse with serde recognizing trailing Option
let (room_id, event_type, state_key): (Box<RoomId>, EventType, String) =
if path_args.len() == 3 {
serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?
} else {
let (a, b) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); (a, b, "".into())
};
let room_id = {
let decoded =
percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8()?;
Box::<RoomId>::try_from(&*decoded)?
};
let event_type = {
let decoded =
percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?;
EventType::try_from(&*decoded)?
};
let state_key = match path_segments.get(7) {
Some(segment) => {
let decoded = percent_encoding::percent_decode(segment.as_bytes()).decode_utf8()?;
String::try_from(&*decoded)?
}
None => "".into(),
};
Ok(Self { room_id, event_type, state_key }) Ok(Self { room_id, event_type, state_key })
} }

View File

@ -138,31 +138,33 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA; const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request<T: AsRef<[u8]>>( fn try_from_http_request<B, S>(
request: http::Request<T>, request: http::Request<B>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> { path_args: &[S],
use std::{borrow::Cow, convert::TryFrom}; ) -> Result<Self, ruma_api::error::FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
// FIXME: find a way to make this if-else collapse with serde recognizing trailing Option
let (room_id, event_type, state_key): (Box<RoomId>, String, String) =
if path_args.len() == 3 {
serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?
} else {
let (a, b) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect(); (a, b, "".into())
};
let room_id = {
let decoded =
percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8()?;
Box::<RoomId>::try_from(&*decoded)?
};
let event_type = percent_encoding::percent_decode(path_segments[6].as_bytes())
.decode_utf8()?
.into_owned();
let state_key = path_segments
.get(7)
.map(|segment| percent_encoding::percent_decode(segment.as_bytes()).decode_utf8())
.transpose()?
// Last URL segment is optional, but not present is the same semantically as empty
.unwrap_or(Cow::Borrowed(""))
.into_owned();
let body = serde_json::from_slice(request.body().as_ref())?; let body = serde_json::from_slice(request.body().as_ref())?;

View File

@ -672,6 +672,7 @@ mod server_tests {
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&[] as &[String],
) )
.unwrap(); .unwrap();
@ -693,6 +694,7 @@ mod server_tests {
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&[] as &[String],
) )
.unwrap(); .unwrap();
@ -718,6 +720,7 @@ mod server_tests {
let req = IncomingRequest::try_from_http_request( let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(), http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&[] as &[String],
) )
.unwrap(); .unwrap();