diff --git a/ruma-api/src/lib.rs b/ruma-api/src/lib.rs index 5b1c0565..f9940611 100644 --- a/ruma-api/src/lib.rs +++ b/ruma-api/src/lib.rs @@ -25,7 +25,8 @@ use std::{ error::Error as StdError, }; -use http::Method; +use http::{uri::PathAndQuery, Method}; +use ruma_identifiers::UserId; /// Generates a `ruma_api::Endpoint` from a concise definition. /// @@ -226,7 +227,7 @@ pub trait EndpointError: StdError + Sized + 'static { } /// A request type for a Matrix API endpoint, used for sending requests. -pub trait OutgoingRequest { +pub trait OutgoingRequest: Sized { /// A type capturing the expected error conditions the server can return. type EndpointError: EndpointError; @@ -254,6 +255,44 @@ pub trait OutgoingRequest { ) -> Result>, IntoHttpError>; } +/// An extension to `OutgoingRequest` which provides Appservice specific methods +pub trait OutgoingRequestAppserviceExt: OutgoingRequest { + /// Tries to convert this request into an `http::Request` and appends a virtual `user_id` to + /// [assert Appservice identity][id_assert]. + /// + /// [id_assert]: https://matrix.org/docs/spec/application_service/r0.1.2#identity-assertion + fn try_into_http_request_with_user_id( + self, + base_url: &str, + access_token: Option<&str>, + user_id: UserId, + ) -> Result>, IntoHttpError> { + let mut http_request = self.try_into_http_request(base_url, access_token)?; + let user_id_query = + ruma_serde::urlencoded::to_string(&[("user_id", &user_id.into_string())])?; + + let uri = http_request.uri().to_owned(); + let mut parts = uri.into_parts(); + + let path_and_query_with_user_id = match &parts.path_and_query { + Some(path_and_query) => match path_and_query.query() { + Some(_) => format!("{}&{}", path_and_query, user_id_query), + None => format!("{}?{}", path_and_query, user_id_query), + }, + None => format!("/?{}", user_id_query), + }; + + parts.path_and_query = + Some(PathAndQuery::try_from(path_and_query_with_user_id).map_err(http::Error::from)?); + + *http_request.uri_mut() = parts.try_into().map_err(http::Error::from)?; + + Ok(http_request) + } +} + +impl OutgoingRequestAppserviceExt for T {} + /// A request type for a Matrix API endpoint, used for receiving requests. pub trait IncomingRequest: Sized { /// A type capturing the error conditions that can be returned in the response. diff --git a/ruma-api/tests/conversions.rs b/ruma-api/tests/conversions.rs index 69e6307f..2c6875eb 100644 --- a/ruma-api/tests/conversions.rs +++ b/ruma-api/tests/conversions.rs @@ -1,4 +1,6 @@ -use ruma_api::{ruma_api, IncomingRequest as _, OutgoingRequest as _}; +use ruma_api::{ + ruma_api, IncomingRequest as _, OutgoingRequest as _, OutgoingRequestAppserviceExt as _, +}; use ruma_identifiers::{user_id, UserId}; ruma_api! { @@ -57,3 +59,85 @@ fn request_serde() -> Result<(), Box> { Ok(()) } + +#[test] +fn request_with_user_id_serde() -> Result<(), Box> { + let req = Request { + hello: "hi".to_owned(), + world: "test".to_owned(), + q1: "query_param_special_chars %/&@!".to_owned(), + q2: 55, + bar: "barVal".to_owned(), + baz: user_id!("@bazme:ruma.io"), + }; + + let user_id = user_id!("@_virtual_:ruma.io"); + let http_req = + req.clone().try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)?; + + let query = http_req.uri().query().unwrap(); + + assert_eq!( + query, + "q1=query_param_special_chars+%25%2F%26%40%21&q2=55&user_id=%40_virtual_%3Aruma.io" + ); + + Ok(()) +} + +mod without_query { + use super::*; + + ruma_api! { + metadata: { + description: "Does something without query.", + method: POST, + name: "my_endpoint", + path: "/_matrix/foo/:bar/:baz", + rate_limited: false, + authentication: None, + } + + request: { + pub hello: String, + #[ruma_api(header = CONTENT_TYPE)] + pub world: String, + #[ruma_api(path)] + pub bar: String, + #[ruma_api(path)] + pub baz: UserId, + } + + response: { + pub hello: String, + #[ruma_api(header = CONTENT_TYPE)] + pub world: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub optional_flag: Option, + } + } + + #[test] + fn request_without_query_with_user_id_serde() -> Result<(), Box> + { + let req = Request { + hello: "hi".to_owned(), + world: "test".to_owned(), + bar: "barVal".to_owned(), + baz: user_id!("@bazme:ruma.io"), + }; + + let user_id = user_id!("@_virtual_:ruma.io"); + let http_req = req.clone().try_into_http_request_with_user_id( + "https://homeserver.tld", + None, + user_id, + )?; + + let query = http_req.uri().query().unwrap(); + + assert_eq!(query, "user_id=%40_virtual_%3Aruma.io"); + + Ok(()) + } +}