Use a dedicated method for conversion from http::Request to Ruma request types

This commit is contained in:
Jonas Platte 2020-12-13 12:59:21 +01:00
parent ac4446ab5b
commit 95d21552e0
No known key found for this signature in database
GPG Key ID: CC154DE0E30B7C67
9 changed files with 179 additions and 184 deletions

View File

@ -243,27 +243,6 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
#[doc = #request_doc]
#request_type
impl ::std::convert::TryFrom<#http::Request<Vec<u8>>> for #incoming_request_type {
type Error = #ruma_api::error::FromHttpRequestError;
#[allow(unused_variables)]
fn try_from(
request: #http::Request<Vec<u8>>
) -> ::std::result::Result<Self, Self::Error> {
#extract_request_path
#extract_request_query
#extract_request_headers
#extract_request_body
Ok(Self {
#parse_request_path
#parse_request_query
#parse_request_headers
#parse_request_body
})
}
}
#[doc = #response_doc]
#response_type
@ -367,6 +346,23 @@ pub fn expand_all(api: Api) -> syn::Result<TokenStream> {
#[doc = #metadata_doc]
const METADATA: #ruma_api::Metadata = self::METADATA;
#[allow(unused_variables)]
fn try_from_http_request(
request: #http::Request<Vec<u8>>
) -> ::std::result::Result<Self, #ruma_api::error::FromHttpRequestError> {
#extract_request_path
#extract_request_query
#extract_request_headers
#extract_request_body
Ok(Self {
#parse_request_path
#parse_request_query
#parse_request_headers
#parse_request_body
})
}
}
#non_auth_endpoint_impls

View File

@ -257,7 +257,7 @@ pub trait OutgoingRequest {
}
/// A request type for a Matrix API endpoint. (trait used for receiving requests)
pub trait IncomingRequest: TryFrom<http::Request<Vec<u8>>, Error = FromHttpRequestError> {
pub trait IncomingRequest: Sized {
/// A type capturing the error conditions that can be returned in the response.
type EndpointError: EndpointError;
@ -266,6 +266,9 @@ pub trait IncomingRequest: TryFrom<http::Request<Vec<u8>>, Error = FromHttpReque
/// Metadata about the endpoint.
const METADATA: Metadata;
/// Tries to turn the given `http::Request` into this request type.
fn try_from_http_request(req: http::Request<Vec<u8>>) -> Result<Self, FromHttpRequestError>;
}
/// Marker trait for requests that don't require authentication. (for the client side)

View File

@ -1,6 +1,4 @@
use std::convert::TryFrom;
use ruma_api::{ruma_api, OutgoingRequest as _};
use ruma_api::{ruma_api, IncomingRequest as _, OutgoingRequest as _};
use ruma_identifiers::{user_id, UserId};
ruma_api! {
@ -48,7 +46,7 @@ fn request_serde() -> Result<(), Box<dyn std::error::Error + 'static>> {
};
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?;
let req2 = Request::try_from(http_req)?;
let req2 = Request::try_from_http_request(http_req)?;
assert_eq!(req.hello, req2.hello);
assert_eq!(req.world, req2.world);

View File

@ -67,12 +67,10 @@ impl IncomingRequest for Request {
type OutgoingResponse = Response;
const METADATA: Metadata = METADATA;
}
impl TryFrom<http::Request<Vec<u8>>> for Request {
type Error = FromHttpRequestError;
fn try_from(request: http::Request<Vec<u8>>) -> Result<Self, Self::Error> {
fn try_from_http_request(
request: http::Request<Vec<u8>>,
) -> Result<Self, FromHttpRequestError> {
let request_body: RequestBody = match serde_json::from_slice(request.body().as_slice()) {
Ok(body) => body,
Err(err) => {

View File

@ -84,9 +84,8 @@ pub enum MembershipEventFilter {
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use matches::assert_matches;
use ruma_api::IncomingRequest as _;
use super::{IncomingRequest, MembershipEventFilter};
@ -103,8 +102,9 @@ mod tests {
.build()
.unwrap();
let req: Result<IncomingRequest, _> =
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap().try_into();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
);
assert_matches!(
req,

View File

@ -68,55 +68,6 @@ const METADATA: Metadata = Metadata {
authentication: AuthScheme::AccessToken,
};
impl TryFrom<http::Request<Vec<u8>>> for IncomingRequest {
type Error = FromHttpRequestError;
fn try_from(request: http::Request<Vec<u8>>) -> Result<Self, Self::Error> {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
match percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
match RoomId::try_from(&*decoded) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
let txn_id =
match percent_encoding::percent_decode(path_segments[7].as_bytes()).decode_utf8() {
Ok(val) => val.into_owned(),
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let content = {
let request_body: Box<RawJsonValue> =
match serde_json::from_slice(request.body().as_slice()) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let event_type = {
match percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
match AnyMessageEventContent::from_parts(&event_type, request_body) {
Ok(content) => content,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
Ok(Self { room_id, txn_id, content })
}
}
/// Data in the response body.
#[derive(Debug, Deserialize, Serialize)]
struct ResponseBody {
@ -206,4 +157,51 @@ impl ruma_api::IncomingRequest for IncomingRequest {
/// Metadata for the `send_message_event` endpoint.
const METADATA: Metadata = METADATA;
fn try_from_http_request(
request: http::Request<Vec<u8>>,
) -> Result<Self, FromHttpRequestError> {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
match percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
match RoomId::try_from(&*decoded) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
let txn_id =
match percent_encoding::percent_decode(path_segments[7].as_bytes()).decode_utf8() {
Ok(val) => val.into_owned(),
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let content = {
let request_body: Box<RawJsonValue> =
match serde_json::from_slice(request.body().as_slice()) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let event_type = {
match percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
match AnyMessageEventContent::from_parts(&event_type, request_body) {
Ok(content) => content,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
Ok(Self { room_id, txn_id, content })
}
}

View File

@ -61,49 +61,6 @@ const METADATA: Metadata = Metadata {
authentication: AuthScheme::AccessToken,
};
impl TryFrom<http::Request<Vec<u8>>> for IncomingRequest {
type Error = FromHttpRequestError;
fn try_from(request: http::Request<Vec<u8>>) -> Result<Self, Self::Error> {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
match percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
match RoomId::try_from(&*decoded) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
let content = {
let request_body: Box<RawJsonValue> =
match serde_json::from_slice(request.body().as_slice()) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let event_type = {
match percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
match AnyStateEventContent::from_parts(&event_type, request_body) {
Ok(content) => content,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
Ok(Self { room_id, content })
}
}
/// Data in the response body.
#[derive(Debug, Deserialize, Serialize)]
struct ResponseBody {
@ -192,4 +149,45 @@ impl ruma_api::IncomingRequest for IncomingRequest {
/// Metadata for the `send_message_event` endpoint.
const METADATA: Metadata = METADATA;
fn try_from_http_request(
request: http::Request<Vec<u8>>,
) -> Result<Self, FromHttpRequestError> {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
match percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
match RoomId::try_from(&*decoded) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
let content = {
let request_body: Box<RawJsonValue> =
match serde_json::from_slice(request.body().as_slice()) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let event_type = {
match percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
match AnyStateEventContent::from_parts(&event_type, request_body) {
Ok(content) => content,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
Ok(Self { room_id, content })
}
}

View File

@ -64,55 +64,6 @@ const METADATA: Metadata = Metadata {
authentication: AuthScheme::AccessToken,
};
impl TryFrom<http::Request<Vec<u8>>> for IncomingRequest {
type Error = FromHttpRequestError;
fn try_from(request: http::Request<Vec<u8>>) -> Result<Self, Self::Error> {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
match percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
match RoomId::try_from(&*decoded) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
let state_key =
match percent_encoding::percent_decode(path_segments[7].as_bytes()).decode_utf8() {
Ok(val) => val.into_owned(),
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let content = {
let request_body: Box<RawJsonValue> =
match serde_json::from_slice(request.body().as_slice()) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let event_type = {
match percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
match AnyStateEventContent::from_parts(&event_type, request_body) {
Ok(content) => content,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
Ok(Self { room_id, state_key, content })
}
}
/// Data in the response body.
#[derive(Debug, Deserialize, Serialize)]
struct ResponseBody {
@ -202,4 +153,51 @@ impl ruma_api::IncomingRequest for IncomingRequest {
/// Metadata for the `send_message_event` endpoint.
const METADATA: Metadata = METADATA;
fn try_from_http_request(
request: http::Request<Vec<u8>>,
) -> Result<Self, FromHttpRequestError> {
let path_segments: Vec<&str> = request.uri().path()[1..].split('/').collect();
let room_id = {
let decoded =
match percent_encoding::percent_decode(path_segments[4].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
match RoomId::try_from(&*decoded) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
let state_key =
match percent_encoding::percent_decode(path_segments[7].as_bytes()).decode_utf8() {
Ok(val) => val.into_owned(),
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let content = {
let request_body: Box<RawJsonValue> =
match serde_json::from_slice(request.body().as_slice()) {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
};
let event_type = {
match percent_encoding::percent_decode(path_segments[6].as_bytes()).decode_utf8() {
Ok(val) => val,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
match AnyStateEventContent::from_parts(&event_type, request_body) {
Ok(content) => content,
Err(err) => return Err(RequestDeserializationError::new(err, request).into()),
}
};
Ok(Self { room_id, state_key, content })
}
}

View File

@ -530,9 +530,9 @@ impl DeviceLists {
#[cfg(test)]
mod tests {
use std::{convert::TryInto, time::Duration};
use std::time::Duration;
use ruma_api::OutgoingRequest;
use ruma_api::{IncomingRequest as _, OutgoingRequest as _};
use serde_json::{from_value as from_json_value, json, to_value as to_json_value};
use matches::assert_matches;
@ -578,8 +578,10 @@ mod tests {
.build()
.unwrap();
let req: IncomingRequest =
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap().try_into().unwrap();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
)
.unwrap();
assert_matches!(req.filter, Some(IncomingFilter::FilterId(id)) if id == "myfilter");
assert_eq!(req.since, Some("myts".into()));
@ -597,8 +599,10 @@ mod tests {
.build()
.unwrap();
let req: IncomingRequest =
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap().try_into().unwrap();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
)
.unwrap();
assert_matches!(req.filter, None);
assert_eq!(req.since, None);
@ -620,8 +624,10 @@ mod tests {
.build()
.unwrap();
let req: IncomingRequest =
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap().try_into().unwrap();
let req = IncomingRequest::try_from_http_request(
http::Request::builder().uri(uri).body(Vec::<u8>::new()).unwrap(),
)
.unwrap();
assert_matches!(req.filter, Some(IncomingFilter::FilterId(id)) if id == "EOKFFmdZYF");
assert_eq!(req.since, None);