api: Replace access_token Option with SendAccessToken enum

This commit is contained in:
Johannes Becker 2021-04-22 16:12:15 +02:00 committed by Jonas Platte
parent a3a756c339
commit 0ca5785ead
15 changed files with 112 additions and 56 deletions

View File

@ -131,23 +131,34 @@ impl Request {
.collect();
for auth in &metadata.authentication {
if auth.value == "AccessToken" {
let attrs = &auth.attrs;
header_kvs.extend(quote! {
let attrs = &auth.attrs;
let (if_required_expr, none_expr) = if auth.value == "AccessToken" {
(
quote! { Some(access_token) },
quote! { return Err(#ruma_api::error::IntoHttpError::NeedsAuthentication) },
)
} else {
(quote! { None }, quote! { None })
};
header_kvs.extend(quote! {{
let access_token = match access_token {
#ruma_api::SendAccessToken::IfRequired(access_token) => #if_required_expr,
#ruma_api::SendAccessToken::Always(access_token) => Some(access_token),
#ruma_api::SendAccessToken::None => #none_expr,
};
if let Some(access_token) = access_token {
#( #attrs )*
req_headers.insert(
#http::header::AUTHORIZATION,
#http::header::HeaderValue::from_str(
&::std::format!(
"Bearer {}",
access_token.ok_or(
#ruma_api::error::IntoHttpError::NeedsAuthentication
)?
)
&::std::format!("Bearer {}", access_token)
)?
);
});
}
}
}});
}
let request_body = if let Some(field) = self.newtype_raw_body_field() {
@ -200,7 +211,7 @@ impl Request {
fn try_into_http_request(
self,
base_url: &::std::primitive::str,
access_token: ::std::option::Option<&::std::primitive::str>,
access_token: #ruma_api::SendAccessToken<'_>,
) -> ::std::result::Result<
#http::Request<::std::vec::Vec<::std::primitive::u8>>,
#ruma_api::error::IntoHttpError,

View File

@ -212,6 +212,20 @@ pub mod exports {
use error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError};
/// An enum to control whether an access token should be added to outgoing requests
#[derive(Debug, Clone)]
pub enum SendAccessToken<'a> {
/// Add the given access token to the request only if the `METADATA` on the request requires it
IfRequired(&'a str),
/// Always add the access token
Always(&'a str),
/// Don't add an access token. This will lead to an error if the request endpoint requires
/// authentication
None,
}
/// A request type for a Matrix API endpoint, used for sending requests.
pub trait OutgoingRequest: Sized {
/// A type capturing the expected error conditions the server can return.
@ -234,7 +248,7 @@ pub trait OutgoingRequest: Sized {
fn try_into_http_request(
self,
base_url: &str,
access_token: Option<&str>,
access_token: SendAccessToken<'_>,
) -> Result<http::Request<Vec<u8>>, IntoHttpError>;
}
@ -258,7 +272,7 @@ pub trait OutgoingRequestAppserviceExt: OutgoingRequest {
fn try_into_http_request_with_user_id(
self,
base_url: &str,
access_token: Option<&str>,
access_token: SendAccessToken<'_>,
user_id: UserId,
) -> Result<http::Request<Vec<u8>>, IntoHttpError> {
let mut http_request = self.try_into_http_request(base_url, access_token)?;

View File

@ -1,5 +1,6 @@
use ruma_api::{
ruma_api, IncomingRequest as _, OutgoingRequest as _, OutgoingRequestAppserviceExt as _,
SendAccessToken,
};
use ruma_identifiers::{user_id, UserId};
@ -47,7 +48,8 @@ fn request_serde() {
user: user_id!("@bazme:ruma.io"),
};
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None).unwrap();
let http_req =
req.clone().try_into_http_request("https://homeserver.tld", SendAccessToken::None).unwrap();
let req2 = Request::try_from_http_request(http_req).unwrap();
assert_eq!(req.hello, req2.hello);
@ -70,8 +72,13 @@ fn request_with_user_id_serde() {
};
let user_id = user_id!("@_virtual_:ruma.io");
let http_req =
req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id).unwrap();
let http_req = req
.try_into_http_request_with_user_id(
"https://homeserver.tld",
SendAccessToken::None,
user_id,
)
.unwrap();
let query = http_req.uri().query().unwrap();
@ -124,7 +131,11 @@ mod without_query {
let user_id = user_id!("@_virtual_:ruma.io");
let http_req = req
.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)
.try_into_http_request_with_user_id(
"https://homeserver.tld",
SendAccessToken::None,
user_id,
)
.unwrap();
let query = http_req.uri().query().unwrap();

View File

@ -1,5 +1,5 @@
use http::header::{Entry, CONTENT_TYPE};
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _};
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _, SendAccessToken};
ruma_api! {
metadata: {
@ -45,7 +45,8 @@ fn response_content_type_override() {
#[test]
fn request_content_type_override() {
let req = Request { location: None, stuff: "magic".into() };
let mut http_req = req.try_into_http_request("https://homeserver.tld", None).unwrap();
let mut http_req =
req.try_into_http_request("https://homeserver.tld", SendAccessToken::None).unwrap();
assert_eq!(
match http_req.headers_mut().entry(CONTENT_TYPE) {

View File

@ -6,7 +6,7 @@ use http::{header::CONTENT_TYPE, method::Method};
use ruma_api::{
error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError, ServerError, Void},
AuthScheme, EndpointError, IncomingRequest, IncomingResponse, Metadata, OutgoingRequest,
OutgoingResponse,
OutgoingResponse, SendAccessToken,
};
use ruma_identifiers::{RoomAliasId, RoomId};
use ruma_serde::Outgoing;
@ -41,7 +41,7 @@ impl OutgoingRequest for Request {
fn try_into_http_request(
self,
base_url: &str,
_access_token: Option<&str>,
_access_token: SendAccessToken<'_>,
) -> Result<http::Request<Vec<u8>>, IntoHttpError> {
let url = (base_url.to_owned() + METADATA.path)
.replace(":room_alias", &self.room_alias.to_string());

View File

@ -1,4 +1,4 @@
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _};
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _, SendAccessToken};
ruma_api! {
metadata: {
@ -17,7 +17,8 @@ ruma_api! {
#[test]
fn empty_request_http_repr() {
let req = Request {};
let http_req = req.try_into_http_request("https://homeserver.tld", None).unwrap();
let http_req =
req.try_into_http_request("https://homeserver.tld", SendAccessToken::None).unwrap();
assert!(http_req.body().is_empty());
}

View File

@ -147,7 +147,7 @@ mod helper_tests {
#[cfg(feature = "server")]
#[cfg(test)]
mod tests {
use ruma_api::{exports::http, OutgoingRequest};
use ruma_api::{exports::http, OutgoingRequest, SendAccessToken};
use ruma_events::AnyEvent;
use ruma_serde::Raw;
use serde_json::json;
@ -165,7 +165,10 @@ mod tests {
let events = vec![dummy_event];
let req: http::Request<Vec<u8>> = Request { events: &events, txn_id: "any_txn_id" }
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
.try_into_http_request(
"https://homeserver.tld",
SendAccessToken::IfRequired("auth_tok"),
)
.unwrap();
let json_body: serde_json::Value = serde_json::from_slice(&req.body()).unwrap();

View File

@ -76,7 +76,7 @@ mod tests {
#[cfg(feature = "client")]
#[test]
fn construct_request_from_refs() {
use ruma_api::OutgoingRequest as _;
use ruma_api::{OutgoingRequest as _, SendAccessToken};
use ruma_identifiers::server_name;
let req = super::Request {
@ -84,7 +84,7 @@ mod tests {
since: Some("hello"),
server: Some(&server_name!("test.tld")),
}
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
.try_into_http_request("https://homeserver.tld", SendAccessToken::IfRequired("auth_tok"))
.unwrap();
let uri = req.uri();

View File

@ -134,7 +134,7 @@ pub enum Direction {
#[cfg(all(test, feature = "client"))]
mod tests {
use js_int::uint;
use ruma_api::OutgoingRequest;
use ruma_api::{OutgoingRequest, SendAccessToken};
use ruma_identifiers::room_id;
use super::{Direction, Request};
@ -160,8 +160,12 @@ mod tests {
filter: Some(filter),
};
let request: http::Request<Vec<u8>> =
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
let request: http::Request<Vec<u8>> = req
.try_into_http_request(
"https://homeserver.tld",
SendAccessToken::IfRequired("auth_tok"),
)
.unwrap();
assert_eq!(
"from=token\
&to=token2\
@ -186,8 +190,12 @@ mod tests {
filter: None,
};
let request =
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
let request = req
.try_into_http_request(
"https://homeserver.tld",
SendAccessToken::IfRequired("auth_tok"),
)
.unwrap();
assert_eq!("from=token&to=token2&dir=b&limit=0", request.uri().query().unwrap(),);
}
@ -203,8 +211,12 @@ mod tests {
filter: Some(RoomEventFilter::default()),
};
let request: http::Request<Vec<u8>> =
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
let request: http::Request<Vec<u8>> = req
.try_into_http_request(
"https://homeserver.tld",
SendAccessToken::IfRequired("auth_tok"),
)
.unwrap();
assert_eq!(
"from=token&to=token2&dir=b&limit=0&filter=%7B%7D",
request.uri().query().unwrap(),

View File

@ -201,7 +201,7 @@ mod tests {
#[test]
#[cfg(feature = "client")]
fn serialize_login_request_body() {
use ruma_api::OutgoingRequest;
use ruma_api::{OutgoingRequest, SendAccessToken};
use serde_json::Value as JsonValue;
use super::{LoginInfo, Medium, Request, UserIdentifier};
@ -211,7 +211,7 @@ mod tests {
device_id: None,
initial_device_display_name: Some("test"),
}
.try_into_http_request("https://homeserver.tld", None)
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
.unwrap();
let req_body_value: JsonValue = serde_json::from_slice(req.body()).unwrap();
@ -235,7 +235,7 @@ mod tests {
device_id: None,
initial_device_display_name: Some("test"),
}
.try_into_http_request("https://homeserver.tld", None)
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
.unwrap();
let req_body_value: JsonValue = serde_json::from_slice(req.body()).unwrap();

View File

@ -46,14 +46,14 @@ impl Response {
#[cfg(all(test, feature = "client"))]
mod tests {
use ruma_api::OutgoingRequest;
use ruma_api::{OutgoingRequest, SendAccessToken};
use super::Request;
#[test]
fn serialize_sso_login_request_uri() {
let req: http::Request<Vec<u8>> = Request { redirect_url: "https://example.com/sso" }
.try_into_http_request("https://homeserver.tld", None)
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
.unwrap();
assert_eq!(

View File

@ -53,14 +53,14 @@ impl Response {
#[cfg(all(test, feature = "client"))]
mod tests {
use ruma_api::OutgoingRequest as _;
use ruma_api::{OutgoingRequest as _, SendAccessToken};
use super::Request;
#[test]
fn serialize_sso_login_with_provider_request_uri() {
let req = Request { idp_id: "provider", redirect_url: "https://example.com/sso" }
.try_into_http_request("https://homeserver.tld", None)
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
.unwrap();
assert_eq!(

View File

@ -1,6 +1,6 @@
//! [GET /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey)
use ruma_api::ruma_api;
use ruma_api::{ruma_api, SendAccessToken};
use ruma_events::{AnyStateEventContent, EventType};
use ruma_identifiers::RoomId;
use ruma_serde::{Outgoing, Raw};
@ -68,7 +68,7 @@ impl<'a> ruma_api::OutgoingRequest for Request<'a> {
fn try_into_http_request(
self,
base_url: &str,
access_token: Option<&str>,
access_token: SendAccessToken<'_>,
) -> Result<http::Request<Vec<u8>>, ruma_api::error::IntoHttpError> {
use std::borrow::Cow;
@ -87,17 +87,20 @@ impl<'a> ruma_api::OutgoingRequest for Request<'a> {
url.push_str(&Cow::from(utf8_percent_encode(&self.state_key, NON_ALPHANUMERIC)));
}
let access_token = match access_token {
SendAccessToken::IfRequired(access_token) | SendAccessToken::Always(access_token) => {
access_token
}
SendAccessToken::None => {
return Err(ruma_api::error::IntoHttpError::NeedsAuthentication)
}
};
http::Request::builder()
.method(http::Method::GET)
.uri(url)
.header(CONTENT_TYPE, "application/json")
.header(
AUTHORIZATION,
HeaderValue::from_str(&format!(
"Bearer {}",
access_token.ok_or(ruma_api::error::IntoHttpError::NeedsAuthentication)?
))?,
)
.header(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", access_token))?)
.body(Vec::new())
.map_err(Into::into)
}

View File

@ -777,7 +777,7 @@ mod tests {
mod client_tests {
use std::time::Duration;
use ruma_api::OutgoingRequest as _;
use ruma_api::{OutgoingRequest as _, SendAccessToken};
use super::{Filter, PresenceState, Request};
@ -790,7 +790,7 @@ mod client_tests {
set_presence: &PresenceState::Offline,
timeout: Some(Duration::from_millis(30000)),
}
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
.try_into_http_request("https://homeserver.tld", SendAccessToken::IfRequired("auth_tok"))
.unwrap();
let uri = req.uri();

View File

@ -113,7 +113,7 @@ use async_stream::try_stream;
use futures_core::stream::Stream;
use http::{uri::Uri, Response as HttpResponse};
use hyper::client::{Client as HyperClient, HttpConnector};
use ruma_api::{AuthScheme, OutgoingRequest};
use ruma_api::{AuthScheme, OutgoingRequest, SendAccessToken};
use ruma_client_api::r0::sync::sync_events::{
Filter as SyncFilter, Request as SyncRequest, Response as SyncResponse,
};
@ -349,12 +349,12 @@ impl Client {
let access_token = if Request::METADATA.authentication == AuthScheme::AccessToken {
session = client.session.lock().unwrap();
if let Some(s) = &*session {
Some(s.access_token.as_str())
SendAccessToken::IfRequired(s.access_token.as_str())
} else {
return Err(Error::AuthenticationRequired);
}
} else {
None
SendAccessToken::None
};
request.try_into_http_request(&client.homeserver_url.to_string(), access_token)?