diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index db87d774..6058c088 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -139,15 +139,20 @@ impl ToTokens for Api { self.request.request_init_query_fields() }; - let add_headers_to_request = if self.request.has_header_fields() { - let add_headers = self.request.add_headers_to_request(); - quote! { - let headers = http_request.headers_mut(); - #add_headers - } - } else { - TokenStream::new() - }; + let mut header_kvs = self.request.append_header_kvs(); + if requires_authentication.value { + header_kvs.push(quote! { + ruma_api::exports::http::header::AUTHORIZATION, + ruma_api::exports::http::header::HeaderValue::from_str( + &format!( + "Bearer {}", + access_token.ok_or_else( + ruma_api::error::IntoHttpError::needs_authentication + )? + ) + )? + }); + } let extract_request_headers = if self.request.has_header_fields() { quote! { @@ -245,29 +250,6 @@ impl ToTokens for Api { } } - impl std::convert::TryFrom for ::ruma_api::exports::http::Request> { - type Error = ::ruma_api::error::IntoHttpError; - - #[allow(unused_mut, unused_variables)] - fn try_from(request: Request) -> Result { - let metadata = Request::METADATA; - let path_and_query = #request_path_string + &#request_query_string; - let mut http_request = ::ruma_api::exports::http::Request::new(#request_body); - - *http_request.method_mut() = ::ruma_api::exports::http::Method::#method; - *http_request.uri_mut() = ::ruma_api::exports::http::uri::Builder::new() - .path_and_query(path_and_query.as_str()) - .build() - // The ruma_api! macro guards against invalid path input, but if there are - // invalid (non ASCII) bytes in the fields with the query attribute this will panic. - .unwrap(); - - { #add_headers_to_request } - - Ok(http_request) - } - } - #[doc = #response_doc] #response_type @@ -324,6 +306,31 @@ impl ToTokens for Api { rate_limited: #rate_limited, requires_authentication: #requires_authentication, }; + + #[allow(unused_mut, unused_variables)] + fn try_into_http_request( + self, + base_url: &str, + access_token: ::std::option::Option<&str>, + ) -> ::std::result::Result< + ::ruma_api::exports::http::Request>, + ::ruma_api::error::IntoHttpError, + > { + let metadata = Request::METADATA; + + let http_request = ::ruma_api::exports::http::Request::builder() + .method(::ruma_api::exports::http::Method::#method) + .uri(::std::format!( + "{}{}{}", + base_url.strip_suffix("/").unwrap_or(base_url), + #request_path_string, + #request_query_string, + )) + #( .header(#header_kvs) )* + .body(#request_body)?; + + Ok(http_request) + } } #non_auth_endpoint_impl diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index c60c6cc0..f694a865 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -22,8 +22,8 @@ pub struct Request { impl Request { /// Produces code to add necessary HTTP headers to an `http::Request`. - pub fn add_headers_to_request(&self) -> TokenStream { - let append_stmts = self.header_fields().map(|request_field| { + pub fn append_header_kvs(&self) -> Vec { + self.header_fields().map(|request_field| { let (field, header_name) = match request_field { RequestField::Header(field, header_name) => (field, header_name), _ => unreachable!("expected request field to be header variant"), @@ -32,16 +32,10 @@ impl Request { let field_name = &field.ident; quote! { - headers.append( - ::ruma_api::exports::http::header::#header_name, - ::ruma_api::exports::http::header::HeaderValue::from_str(request.#field_name.as_ref())?, - ); + ::ruma_api::exports::http::header::#header_name, + ::ruma_api::exports::http::header::HeaderValue::from_str(self.#field_name.as_ref())? } - }); - - quote! { - #(#append_stmts)* - } + }).collect() } /// Produces code to extract fields from the HTTP headers in an `http::Request`. @@ -136,12 +130,12 @@ impl Request { /// Produces code for a struct initializer for body fields on a variable named `request`. pub fn request_body_init_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Body, quote!(request)) + self.struct_init_fields(RequestFieldKind::Body, quote!(self)) } /// Produces code for a struct initializer for query string fields on a variable named `request`. pub fn request_query_init_fields(&self) -> TokenStream { - self.struct_init_fields(RequestFieldKind::Query, quote!(request)) + self.struct_init_fields(RequestFieldKind::Query, quote!(self)) } /// Produces code for a struct initializer for body fields on a variable named `request_body`. diff --git a/ruma-api-macros/src/util.rs b/ruma-api-macros/src/util.rs index ba13cd47..1a7220b8 100644 --- a/ruma-api-macros/src/util.rs +++ b/ruma-api-macros/src/util.rs @@ -45,7 +45,7 @@ pub(crate) fn request_path_string_and_parse( ); format_args.push(quote! { ruma_api::exports::percent_encoding::utf8_percent_encode( - &request.#path_var.to_string(), + &self.#path_var.to_string(), ruma_api::exports::percent_encoding::NON_ALPHANUMERIC, ) }); @@ -53,7 +53,7 @@ pub(crate) fn request_path_string_and_parse( } quote! { - format!(#format_string, #(#format_args),*) + format_args!(#format_string, #(#format_args),*) } }; @@ -110,8 +110,11 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream { {} assert_trait_impl::<#field_type>(); - let request_query = RequestQuery(request.#field_name); - format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?) + let request_query = RequestQuery(self.#field_name); + format_args!( + "?{}", + ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)? + ) }) } else if request.has_query_fields() { let request_query_init_fields = request.request_query_init_fields(); @@ -121,12 +124,13 @@ pub(crate) fn build_query_string(request: &Request) -> TokenStream { #request_query_init_fields }; - format!("?{}", ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)?) + format_args!( + "?{}", + ruma_api::exports::ruma_serde::urlencoded::to_string(request_query)? + ) }) } else { - quote! { - String::new() - } + quote! { "" } } } @@ -161,11 +165,11 @@ pub(crate) fn extract_request_query(request: &Request) -> TokenStream { pub(crate) fn build_request_body(request: &Request) -> TokenStream { if let Some(field) = request.newtype_raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote!(request.#field_name) + quote!(self.#field_name) } else if request.has_body_fields() || request.newtype_body_field().is_some() { let request_body_initializers = if let Some(field) = request.newtype_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { (request.#field_name) } + quote! { (self.#field_name) } } else { let initializers = request.request_body_init_fields(); quote! { { #initializers } } diff --git a/ruma-api/src/error.rs b/ruma-api/src/error.rs index 2d7cf6ff..95223fe1 100644 --- a/ruma-api/src/error.rs +++ b/ruma-api/src/error.rs @@ -16,40 +16,62 @@ impl crate::EndpointError for Void { Err(ResponseDeserializationError::from_response(response)) } } + /// An error when converting one of ruma's endpoint-specific request or response /// types to the corresponding http type. #[derive(Debug)] -pub struct IntoHttpError(SerializationError); +pub struct IntoHttpError(InnerIntoHttpError); + +impl IntoHttpError { + // For usage in macros + #[doc(hidden)] + pub fn needs_authentication() -> Self { + Self(InnerIntoHttpError::NeedsAuthentication) + } +} #[doc(hidden)] impl From for IntoHttpError { fn from(err: serde_json::Error) -> Self { - Self(SerializationError::Json(err)) + Self(InnerIntoHttpError::Json(err)) } } #[doc(hidden)] impl From for IntoHttpError { fn from(err: ruma_serde::urlencoded::ser::Error) -> Self { - Self(SerializationError::Query(err)) + Self(InnerIntoHttpError::Query(err)) } } #[doc(hidden)] impl From for IntoHttpError { fn from(err: http::header::InvalidHeaderValue) -> Self { - Self(SerializationError::Header(err)) + Self(InnerIntoHttpError::Header(err)) + } +} + +#[doc(hidden)] +impl From for IntoHttpError { + fn from(err: http::Error) -> Self { + Self(InnerIntoHttpError::Http(err)) } } impl Display for IntoHttpError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match &self.0 { - SerializationError::Json(err) => write!(f, "JSON serialization failed: {}", err), - SerializationError::Query(err) => { + InnerIntoHttpError::NeedsAuthentication => write!( + f, + "This endpoint has to be converted to http::Request using \ + try_into_authenticated_http_request" + ), + InnerIntoHttpError::Json(err) => write!(f, "JSON serialization failed: {}", err), + InnerIntoHttpError::Query(err) => { write!(f, "Query parameter serialization failed: {}", err) } - SerializationError::Header(err) => write!(f, "Header serialization failed: {}", err), + InnerIntoHttpError::Header(err) => write!(f, "Header serialization failed: {}", err), + InnerIntoHttpError::Http(err) => write!(f, "HTTP request construction failed: {}", err), } } } @@ -201,10 +223,12 @@ impl Display for ServerError { impl std::error::Error for ServerError {} #[derive(Debug)] -enum SerializationError { +enum InnerIntoHttpError { + NeedsAuthentication, Json(serde_json::Error), Query(ruma_serde::urlencoded::ser::Error), Header(http::header::InvalidHeaderValue), + Http(http::Error), } /// This type is public so it is accessible from `ruma_api!` generated code. diff --git a/ruma-api/src/lib.rs b/ruma-api/src/lib.rs index aba96d46..67297aa7 100644 --- a/ruma-api/src/lib.rs +++ b/ruma-api/src/lib.rs @@ -41,7 +41,7 @@ use http::Method; /// // Struct fields for each piece of data expected /// // in the response from this API endpoint. /// } -/// +/// /// // The error returned when a response fails, defaults to `Void`. /// error: path::to::Error /// } @@ -245,7 +245,7 @@ pub trait EndpointError: Sized { /// A Matrix API endpoint. /// /// The type implementing this trait contains any data needed to make a request to the endpoint. -pub trait Endpoint: Outgoing + TryInto>, Error = IntoHttpError> +pub trait Endpoint: Outgoing where ::Incoming: TryFrom>, Error = FromHttpRequestError>, ::Incoming: TryFrom< @@ -260,6 +260,20 @@ where /// Metadata about the endpoint. const METADATA: Metadata; + + /// Tries to convert this request into an `http::Request`. + /// + /// This method should only fail when called on endpoints that require authentication. It may + /// also fail with a serialization error in case of bugs in Ruma though. + /// + /// The endpoints path will be appended to the given `base_url`, for example + /// `https://matrix.org`. Since all paths begin with a slash, it is not necessary for the + /// `base_url` to have a trailing slash. If it has one however, it will be ignored. + fn try_into_http_request( + self, + base_url: &str, + access_token: Option<&str>, + ) -> Result>, IntoHttpError>; } /// A Matrix API endpoint that doesn't require authentication. @@ -357,26 +371,24 @@ mod tests { name: "create_alias", path: "/_matrix/client/r0/directory/room/:room_alias", rate_limited: false, - requires_authentication: true, + requires_authentication: false, }; - } - impl TryFrom for http::Request> { - type Error = IntoHttpError; - - fn try_from(request: Request) -> Result>, Self::Error> { + fn try_into_http_request( + self, + base_url: &str, + _access_token: Option<&str>, + ) -> Result>, IntoHttpError> { let metadata = Request::METADATA; - let path = metadata - .path - .to_string() - .replace(":room_alias", &request.room_alias.to_string()); + let url = (base_url.to_owned() + metadata.path) + .replace(":room_alias", &self.room_alias.to_string()); - let request_body = RequestBody { room_id: request.room_id }; + let request_body = RequestBody { room_id: self.room_id }; let http_request = http::Request::builder() .method(metadata.method) - .uri(path) + .uri(url) .body(serde_json::to_vec(&request_body)?) // this cannot fail because we don't give user-supplied data to any of the // builder methods diff --git a/ruma-api/tests/conversions.rs b/ruma-api/tests/conversions.rs index b9094df9..c8c598b0 100644 --- a/ruma-api/tests/conversions.rs +++ b/ruma-api/tests/conversions.rs @@ -47,7 +47,7 @@ fn request_serde() -> Result<(), Box> { baz: UserId::try_from("@bazme:ruma.io")?, }; - let http_req = http::Request::>::try_from(req.clone())?; + let http_req = req.clone().try_into_http_request("https://homeserver.tld", None)?; let req2 = Request::try_from(http_req)?; assert_eq!(req.hello, req2.hello); diff --git a/ruma-api/tests/no_fields.rs b/ruma-api/tests/no_fields.rs index db131af2..38ce16d0 100644 --- a/ruma-api/tests/no_fields.rs +++ b/ruma-api/tests/no_fields.rs @@ -1,6 +1,6 @@ use std::convert::TryFrom; -use ruma_api::ruma_api; +use ruma_api::{ruma_api, Endpoint}; ruma_api! { metadata: { @@ -19,7 +19,7 @@ ruma_api! { #[test] fn empty_request_http_repr() { let req = Request {}; - let http_req = http::Request::>::try_from(req).unwrap(); + let http_req = req.try_into_http_request("https://homeserver.tld", None).unwrap(); assert!(http_req.body().is_empty()); } diff --git a/ruma-client-api/src/r0/message/get_message_events.rs b/ruma-client-api/src/r0/message/get_message_events.rs index e09cb9f2..bcae9b72 100644 --- a/ruma-client-api/src/r0/message/get_message_events.rs +++ b/ruma-client-api/src/r0/message/get_message_events.rs @@ -117,9 +117,10 @@ pub enum Direction { mod tests { use super::{Direction, Request}; - use std::convert::{TryFrom, TryInto}; + use std::convert::TryFrom; use js_int::uint; + use ruma_api::Endpoint; use ruma_identifiers::RoomId; use crate::r0::filter::{LazyLoadOptions, RoomEventFilter}; @@ -143,7 +144,8 @@ mod tests { filter: Some(filter), }; - let request: http::Request> = req.try_into().unwrap(); + let request: http::Request> = + req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap(); assert_eq!( "from=token&to=token2&dir=b&limit=0&filter=%7B%22not_types%22%3A%5B%22type%22%5D%2C%22not_rooms%22%3A%5B%22room%22%2C%22room2%22%2C%22room3%22%5D%2C%22rooms%22%3A%5B%22%21roomid%3Aexample.org%22%5D%2C%22lazy_load_members%22%3Atrue%2C%22include_redundant_members%22%3Atrue%7D", request.uri().query().unwrap() @@ -162,7 +164,8 @@ mod tests { filter: None, }; - let request: http::Request> = req.try_into().unwrap(); + let request = + req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap(); assert_eq!("from=token&to=token2&dir=b&limit=0", request.uri().query().unwrap(),); } @@ -178,7 +181,8 @@ mod tests { filter: Some(RoomEventFilter::default()), }; - let request: http::Request> = req.try_into().unwrap(); + let request: http::Request> = + req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap(); assert_eq!( "from=token&to=token2&dir=b&limit=0&filter=%7B%7D", request.uri().query().unwrap(), diff --git a/ruma-client-api/src/r0/session/login.rs b/ruma-client-api/src/r0/session/login.rs index 552dd4b0..623a5763 100644 --- a/ruma-client-api/src/r0/session/login.rs +++ b/ruma-client-api/src/r0/session/login.rs @@ -141,8 +141,7 @@ mod user_serde; #[cfg(test)] mod tests { - use std::convert::TryInto; - + use ruma_api::Endpoint; use serde_json::{from_value as from_json_value, json, Value as JsonValue}; use super::{LoginInfo, Medium, Request, UserInfo}; @@ -193,7 +192,7 @@ mod tests { device_id: None, initial_device_display_name: Some("test".into()), } - .try_into() + .try_into_http_request("https://homeserver.tld", None) .unwrap(); let req_body_value: JsonValue = serde_json::from_slice(req.body()).unwrap(); diff --git a/ruma-client-api/src/r0/sync/sync_events.rs b/ruma-client-api/src/r0/sync/sync_events.rs index fb389273..e0a28dfa 100644 --- a/ruma-client-api/src/r0/sync/sync_events.rs +++ b/ruma-client-api/src/r0/sync/sync_events.rs @@ -403,6 +403,7 @@ impl DeviceLists { mod tests { use std::{convert::TryInto, time::Duration}; + use ruma_api::Endpoint; use serde_json::{from_value as from_json_value, json, to_value as to_json_value}; use matches::assert_matches; @@ -418,7 +419,7 @@ mod tests { set_presence: PresenceState::Offline, timeout: Some(Duration::from_millis(30000)), } - .try_into() + .try_into_http_request("https://homeserver.tld", Some("auth_tok")) .unwrap(); let uri = req.uri();