api: Don't extract request path arguments in IncomingRequest impls

… instead requiring callers to pass them as a list of strings.
Parsing is still done within the trait implementations though.
This commit is contained in:
Jonathan de Jong 2022-02-02 11:57:29 +01:00 committed by GitHub
parent f7a10a7e47
commit abd702cfbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 101 additions and 98 deletions

View File

@ -8,8 +8,8 @@ use crate::auth_scheme::AuthScheme;
impl Request {
pub fn expand_incoming(&self, ruma_api: &TokenStream) -> TokenStream {
let http = quote! { #ruma_api::exports::http };
let percent_encoding = quote! { #ruma_api::exports::percent_encoding };
let ruma_serde = quote! { #ruma_api::exports::ruma_serde };
let serde = quote! { #ruma_api::exports::serde };
let serde_json = quote! { #ruma_api::exports::serde_json };
let method = &self.method;
@ -33,36 +33,22 @@ impl Request {
"number of declared path parameters needs to match amount of placeholders in path"
);
let path_var_decls = path_string[1..]
.split('/')
.enumerate()
.filter(|(_, seg)| seg.starts_with(':'))
.map(|(i, seg)| {
let path_var = Ident::new(&seg[1..], Span::call_site());
quote! {
let #path_var = {
let segment = path_segments[#i].as_bytes();
let decoded =
#percent_encoding::percent_decode(segment).decode_utf8()?;
::std::convert::TryFrom::try_from(&*decoded)?
};
}
});
let parse_request_path = quote! {
let path_segments: ::std::vec::Vec<&::std::primitive::str> =
request.uri().path()[1..].split('/').collect();
#(#path_var_decls)*
};
let path_vars = path_string[1..]
.split('/')
.filter(|seg| seg.starts_with(':'))
.map(|seg| Ident::new(&seg[1..], Span::call_site()));
(parse_request_path, quote! { #(#path_vars,)* })
let vars = path_vars.clone();
let parse_request_path = quote! {
let (#(#path_vars,)*) = #serde::Deserialize::deserialize(
#serde::de::value::SeqDeserializer::<_, #serde::de::value::Error>::new(
path_args.iter().map(::std::convert::AsRef::as_ref)
)
)?;
};
(parse_request_path, quote! { #(#vars,)* })
} else {
(TokenStream::new(), TokenStream::new())
};
@ -226,9 +212,14 @@ impl Request {
const METADATA: #ruma_api::Metadata = self::METADATA;
fn try_from_http_request<T: ::std::convert::AsRef<[::std::primitive::u8]>>(
request: #http::Request<T>,
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> {
fn try_from_http_request<B, S>(
request: #http::Request<B>,
path_args: &[S],
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError>
where
B: ::std::convert::AsRef<[::std::primitive::u8]>,
S: ::std::convert::AsRef<::std::primitive::str>,
{
if request.method() != #http::Method::#method {
return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch {
expected: #http::Method::#method,

View File

@ -336,10 +336,17 @@ pub trait IncomingRequest: Sized {
/// Metadata about the endpoint.
const METADATA: Metadata;
/// Tries to turn the given `http::Request` into this request type.
fn try_from_http_request<T: AsRef<[u8]>>(
req: http::Request<T>,
) -> Result<Self, FromHttpRequestError>;
/// Tries to turn the given `http::Request` into this request type,
/// together with the corresponding path arguments.
///
/// Note: The strings in path_args need to be percent-decoded.
fn try_from_http_request<B, S>(
req: http::Request<B>,
path_args: &[S],
) -> Result<Self, FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>;
}
/// A request type for a Matrix API endpoint, used for sending responses.

View File

@ -54,7 +54,7 @@ fn request_serde() {
.clone()
.try_into_http_request::<Vec<u8>>("https://homeserver.tld", SendAccessToken::None)
.unwrap();
let req2 = Request::try_from_http_request(http_req).unwrap();
let req2 = Request::try_from_http_request(http_req, &["barVal", "@bazme:ruma.io"]).unwrap();
assert_eq!(req.hello, req2.hello);
assert_eq!(req.world, req2.world);

View File

@ -2,8 +2,6 @@
#![allow(clippy::exhaustive_structs)]
use std::convert::TryFrom;
use bytes::BufMut;
use http::{header::CONTENT_TYPE, method::Method};
use ruma_api::{
@ -69,16 +67,16 @@ impl IncomingRequest for Request {
const METADATA: Metadata = METADATA;
fn try_from_http_request<T: AsRef<[u8]>>(
request: http::Request<T>,
) -> Result<Self, FromHttpRequestError> {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_alias = {
let decoded =
percent_encoding::percent_decode(path_segments[5].as_bytes()).decode_utf8()?;
TryFrom::try_from(&*decoded)?
};
fn try_from_http_request<B, S>(
request: http::Request<B>,
path_args: &[S],
) -> Result<Self, FromHttpRequestError> where B: AsRef<[u8]>, S: AsRef<str> {
let (room_alias,) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?;
let request_body: RequestBody = serde_json::from_slice(request.body().as_ref())?;

View File

@ -67,6 +67,7 @@ mod tests {
.uri("https://matrix.org/_matrix/client/r0/user/@foo:bar.com/filter")
.body(b"{}" as &[u8])
.unwrap(),
&["@foo:bar.com"]
),
Ok(IncomingRequest { user_id, filter })
if user_id == "@foo:bar.com" && filter.is_empty()

View File

@ -122,6 +122,7 @@ mod tests {
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&["!dummy:example.org"],
);
assert_matches!(

View File

@ -87,6 +87,7 @@ mod tests {
.method("PUT")
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(&[] as &[u8]).unwrap(),
&["@foo:bar.org"],
).unwrap(),
IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org"
);
@ -99,6 +100,7 @@ mod tests {
.uri("https://bar.org/_matrix/client/r0/profile/@foo:bar.org/avatar_url")
.body(serde_json::to_vec(&serde_json::json!({ "avatar_url": "" })).unwrap())
.unwrap(),
&["@foo:bar.org"],
).unwrap(),
IncomingRequest { user_id, avatar_url: None, .. } if user_id == "@foo:bar.org"
);

View File

@ -112,35 +112,33 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request<T: AsRef<[u8]>>(
request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::convert::TryFrom;
fn try_from_http_request<B, S>(
_request: http::Request<B>,
path_args: &[S],
) -> Result<Self, ruma_api::error::FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
// FIXME: find a way to make this if-else collapse with serde recognizing trailing Option
let (room_id, event_type, state_key): (Box<RoomId>, EventType, String) =
if path_args.len() == 3 {
serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?
} else {
let (a, b) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8()?;
Box::<RoomId>::try_from(&*decoded)?
};
let event_type = {
let decoded =
percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8()?;
EventType::try_from(&*decoded)?
};
let state_key = match path_segments.get(7) {
Some(segment) => {
let decoded = percent_encoding::percent_decode(segment.as_bytes()).decode_utf8()?;
String::try_from(&*decoded)?
}
None => "".into(),
};
(a, b, "".into())
};
Ok(Self { room_id, event_type, state_key })
}

View File

@ -138,31 +138,33 @@ impl ruma_api::IncomingRequest for IncomingRequest {
const METADATA: ruma_api::Metadata = METADATA;
fn try_from_http_request<T: AsRef<[u8]>>(
request: http::Request<T>,
) -> Result<Self, ruma_api::error::FromHttpRequestError> {
use std::{borrow::Cow, convert::TryFrom};
fn try_from_http_request<B, S>(
request: http::Request<B>,
path_args: &[S],
) -> Result<Self, ruma_api::error::FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
// FIXME: find a way to make this if-else collapse with serde recognizing trailing Option
let (room_id, event_type, state_key): (Box<RoomId>, String, String) =
if path_args.len() == 3 {
serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?
} else {
let (a, b) = serde::Deserialize::deserialize(serde::de::value::SeqDeserializer::<
_,
serde::de::value::Error,
>::new(
path_args.iter().map(::std::convert::AsRef::as_ref),
))?;
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8()?;
Box::<RoomId>::try_from(&*decoded)?
};
let event_type = percent_encoding::percent_decode(path_segments[6].as_bytes())
.decode_utf8()?
.into_owned();
let state_key = path_segments
.get(7)
.map(|segment| percent_encoding::percent_decode(segment.as_bytes()).decode_utf8())
.transpose()?
// Last URL segment is optional, but not present is the same semantically as empty
.unwrap_or(Cow::Borrowed(""))
.into_owned();
(a, b, "".into())
};
let body = serde_json::from_slice(request.body().as_ref())?;

View File

@ -672,6 +672,7 @@ mod server_tests {
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&[] as &[String],
)
.unwrap();
@ -693,6 +694,7 @@ mod server_tests {
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&[] as &[String],
)
.unwrap();
@ -718,6 +720,7 @@ mod server_tests {
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(&[] as &[u8]).unwrap(),
&[] as &[String],
)
.unwrap();