api: Replace access_token Option with SendAccessToken enum
This commit is contained in:
parent
a3a756c339
commit
0ca5785ead
@ -131,23 +131,34 @@ impl Request {
|
||||
.collect();
|
||||
|
||||
for auth in &metadata.authentication {
|
||||
if auth.value == "AccessToken" {
|
||||
let attrs = &auth.attrs;
|
||||
header_kvs.extend(quote! {
|
||||
let attrs = &auth.attrs;
|
||||
|
||||
let (if_required_expr, none_expr) = if auth.value == "AccessToken" {
|
||||
(
|
||||
quote! { Some(access_token) },
|
||||
quote! { return Err(#ruma_api::error::IntoHttpError::NeedsAuthentication) },
|
||||
)
|
||||
} else {
|
||||
(quote! { None }, quote! { None })
|
||||
};
|
||||
|
||||
header_kvs.extend(quote! {{
|
||||
let access_token = match access_token {
|
||||
#ruma_api::SendAccessToken::IfRequired(access_token) => #if_required_expr,
|
||||
#ruma_api::SendAccessToken::Always(access_token) => Some(access_token),
|
||||
#ruma_api::SendAccessToken::None => #none_expr,
|
||||
};
|
||||
|
||||
if let Some(access_token) = access_token {
|
||||
#( #attrs )*
|
||||
req_headers.insert(
|
||||
#http::header::AUTHORIZATION,
|
||||
#http::header::HeaderValue::from_str(
|
||||
&::std::format!(
|
||||
"Bearer {}",
|
||||
access_token.ok_or(
|
||||
#ruma_api::error::IntoHttpError::NeedsAuthentication
|
||||
)?
|
||||
)
|
||||
&::std::format!("Bearer {}", access_token)
|
||||
)?
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}});
|
||||
}
|
||||
|
||||
let request_body = if let Some(field) = self.newtype_raw_body_field() {
|
||||
@ -200,7 +211,7 @@ impl Request {
|
||||
fn try_into_http_request(
|
||||
self,
|
||||
base_url: &::std::primitive::str,
|
||||
access_token: ::std::option::Option<&::std::primitive::str>,
|
||||
access_token: #ruma_api::SendAccessToken<'_>,
|
||||
) -> ::std::result::Result<
|
||||
#http::Request<::std::vec::Vec<::std::primitive::u8>>,
|
||||
#ruma_api::error::IntoHttpError,
|
||||
|
@ -212,6 +212,20 @@ pub mod exports {
|
||||
|
||||
use error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError};
|
||||
|
||||
/// An enum to control whether an access token should be added to outgoing requests
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SendAccessToken<'a> {
|
||||
/// Add the given access token to the request only if the `METADATA` on the request requires it
|
||||
IfRequired(&'a str),
|
||||
|
||||
/// Always add the access token
|
||||
Always(&'a str),
|
||||
|
||||
/// Don't add an access token. This will lead to an error if the request endpoint requires
|
||||
/// authentication
|
||||
None,
|
||||
}
|
||||
|
||||
/// A request type for a Matrix API endpoint, used for sending requests.
|
||||
pub trait OutgoingRequest: Sized {
|
||||
/// A type capturing the expected error conditions the server can return.
|
||||
@ -234,7 +248,7 @@ pub trait OutgoingRequest: Sized {
|
||||
fn try_into_http_request(
|
||||
self,
|
||||
base_url: &str,
|
||||
access_token: Option<&str>,
|
||||
access_token: SendAccessToken<'_>,
|
||||
) -> Result<http::Request<Vec<u8>>, IntoHttpError>;
|
||||
}
|
||||
|
||||
@ -258,7 +272,7 @@ pub trait OutgoingRequestAppserviceExt: OutgoingRequest {
|
||||
fn try_into_http_request_with_user_id(
|
||||
self,
|
||||
base_url: &str,
|
||||
access_token: Option<&str>,
|
||||
access_token: SendAccessToken<'_>,
|
||||
user_id: UserId,
|
||||
) -> Result<http::Request<Vec<u8>>, IntoHttpError> {
|
||||
let mut http_request = self.try_into_http_request(base_url, access_token)?;
|
||||
|
@ -1,5 +1,6 @@
|
||||
use ruma_api::{
|
||||
ruma_api, IncomingRequest as _, OutgoingRequest as _, OutgoingRequestAppserviceExt as _,
|
||||
SendAccessToken,
|
||||
};
|
||||
use ruma_identifiers::{user_id, UserId};
|
||||
|
||||
@ -47,7 +48,8 @@ fn request_serde() {
|
||||
user: user_id!("@bazme:ruma.io"),
|
||||
};
|
||||
|
||||
let http_req = req.clone().try_into_http_request("https://homeserver.tld", None).unwrap();
|
||||
let http_req =
|
||||
req.clone().try_into_http_request("https://homeserver.tld", SendAccessToken::None).unwrap();
|
||||
let req2 = Request::try_from_http_request(http_req).unwrap();
|
||||
|
||||
assert_eq!(req.hello, req2.hello);
|
||||
@ -70,8 +72,13 @@ fn request_with_user_id_serde() {
|
||||
};
|
||||
|
||||
let user_id = user_id!("@_virtual_:ruma.io");
|
||||
let http_req =
|
||||
req.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id).unwrap();
|
||||
let http_req = req
|
||||
.try_into_http_request_with_user_id(
|
||||
"https://homeserver.tld",
|
||||
SendAccessToken::None,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query = http_req.uri().query().unwrap();
|
||||
|
||||
@ -124,7 +131,11 @@ mod without_query {
|
||||
|
||||
let user_id = user_id!("@_virtual_:ruma.io");
|
||||
let http_req = req
|
||||
.try_into_http_request_with_user_id("https://homeserver.tld", None, user_id)
|
||||
.try_into_http_request_with_user_id(
|
||||
"https://homeserver.tld",
|
||||
SendAccessToken::None,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query = http_req.uri().query().unwrap();
|
||||
|
@ -1,5 +1,5 @@
|
||||
use http::header::{Entry, CONTENT_TYPE};
|
||||
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _};
|
||||
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _, SendAccessToken};
|
||||
|
||||
ruma_api! {
|
||||
metadata: {
|
||||
@ -45,7 +45,8 @@ fn response_content_type_override() {
|
||||
#[test]
|
||||
fn request_content_type_override() {
|
||||
let req = Request { location: None, stuff: "magic".into() };
|
||||
let mut http_req = req.try_into_http_request("https://homeserver.tld", None).unwrap();
|
||||
let mut http_req =
|
||||
req.try_into_http_request("https://homeserver.tld", SendAccessToken::None).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
match http_req.headers_mut().entry(CONTENT_TYPE) {
|
||||
|
@ -6,7 +6,7 @@ use http::{header::CONTENT_TYPE, method::Method};
|
||||
use ruma_api::{
|
||||
error::{FromHttpRequestError, FromHttpResponseError, IntoHttpError, ServerError, Void},
|
||||
AuthScheme, EndpointError, IncomingRequest, IncomingResponse, Metadata, OutgoingRequest,
|
||||
OutgoingResponse,
|
||||
OutgoingResponse, SendAccessToken,
|
||||
};
|
||||
use ruma_identifiers::{RoomAliasId, RoomId};
|
||||
use ruma_serde::Outgoing;
|
||||
@ -41,7 +41,7 @@ impl OutgoingRequest for Request {
|
||||
fn try_into_http_request(
|
||||
self,
|
||||
base_url: &str,
|
||||
_access_token: Option<&str>,
|
||||
_access_token: SendAccessToken<'_>,
|
||||
) -> Result<http::Request<Vec<u8>>, IntoHttpError> {
|
||||
let url = (base_url.to_owned() + METADATA.path)
|
||||
.replace(":room_alias", &self.room_alias.to_string());
|
||||
|
@ -1,4 +1,4 @@
|
||||
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _};
|
||||
use ruma_api::{ruma_api, OutgoingRequest as _, OutgoingResponse as _, SendAccessToken};
|
||||
|
||||
ruma_api! {
|
||||
metadata: {
|
||||
@ -17,7 +17,8 @@ ruma_api! {
|
||||
#[test]
|
||||
fn empty_request_http_repr() {
|
||||
let req = Request {};
|
||||
let http_req = req.try_into_http_request("https://homeserver.tld", None).unwrap();
|
||||
let http_req =
|
||||
req.try_into_http_request("https://homeserver.tld", SendAccessToken::None).unwrap();
|
||||
|
||||
assert!(http_req.body().is_empty());
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ mod helper_tests {
|
||||
#[cfg(feature = "server")]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruma_api::{exports::http, OutgoingRequest};
|
||||
use ruma_api::{exports::http, OutgoingRequest, SendAccessToken};
|
||||
use ruma_events::AnyEvent;
|
||||
use ruma_serde::Raw;
|
||||
use serde_json::json;
|
||||
@ -165,7 +165,10 @@ mod tests {
|
||||
let events = vec![dummy_event];
|
||||
|
||||
let req: http::Request<Vec<u8>> = Request { events: &events, txn_id: "any_txn_id" }
|
||||
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
|
||||
.try_into_http_request(
|
||||
"https://homeserver.tld",
|
||||
SendAccessToken::IfRequired("auth_tok"),
|
||||
)
|
||||
.unwrap();
|
||||
let json_body: serde_json::Value = serde_json::from_slice(&req.body()).unwrap();
|
||||
|
||||
|
@ -76,7 +76,7 @@ mod tests {
|
||||
#[cfg(feature = "client")]
|
||||
#[test]
|
||||
fn construct_request_from_refs() {
|
||||
use ruma_api::OutgoingRequest as _;
|
||||
use ruma_api::{OutgoingRequest as _, SendAccessToken};
|
||||
use ruma_identifiers::server_name;
|
||||
|
||||
let req = super::Request {
|
||||
@ -84,7 +84,7 @@ mod tests {
|
||||
since: Some("hello"),
|
||||
server: Some(&server_name!("test.tld")),
|
||||
}
|
||||
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
|
||||
.try_into_http_request("https://homeserver.tld", SendAccessToken::IfRequired("auth_tok"))
|
||||
.unwrap();
|
||||
|
||||
let uri = req.uri();
|
||||
|
@ -134,7 +134,7 @@ pub enum Direction {
|
||||
#[cfg(all(test, feature = "client"))]
|
||||
mod tests {
|
||||
use js_int::uint;
|
||||
use ruma_api::OutgoingRequest;
|
||||
use ruma_api::{OutgoingRequest, SendAccessToken};
|
||||
use ruma_identifiers::room_id;
|
||||
|
||||
use super::{Direction, Request};
|
||||
@ -160,8 +160,12 @@ mod tests {
|
||||
filter: Some(filter),
|
||||
};
|
||||
|
||||
let request: http::Request<Vec<u8>> =
|
||||
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
|
||||
let request: http::Request<Vec<u8>> = req
|
||||
.try_into_http_request(
|
||||
"https://homeserver.tld",
|
||||
SendAccessToken::IfRequired("auth_tok"),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
"from=token\
|
||||
&to=token2\
|
||||
@ -186,8 +190,12 @@ mod tests {
|
||||
filter: None,
|
||||
};
|
||||
|
||||
let request =
|
||||
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
|
||||
let request = req
|
||||
.try_into_http_request(
|
||||
"https://homeserver.tld",
|
||||
SendAccessToken::IfRequired("auth_tok"),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!("from=token&to=token2&dir=b&limit=0", request.uri().query().unwrap(),);
|
||||
}
|
||||
|
||||
@ -203,8 +211,12 @@ mod tests {
|
||||
filter: Some(RoomEventFilter::default()),
|
||||
};
|
||||
|
||||
let request: http::Request<Vec<u8>> =
|
||||
req.try_into_http_request("https://homeserver.tld", Some("auth_tok")).unwrap();
|
||||
let request: http::Request<Vec<u8>> = req
|
||||
.try_into_http_request(
|
||||
"https://homeserver.tld",
|
||||
SendAccessToken::IfRequired("auth_tok"),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
"from=token&to=token2&dir=b&limit=0&filter=%7B%7D",
|
||||
request.uri().query().unwrap(),
|
||||
|
@ -201,7 +201,7 @@ mod tests {
|
||||
#[test]
|
||||
#[cfg(feature = "client")]
|
||||
fn serialize_login_request_body() {
|
||||
use ruma_api::OutgoingRequest;
|
||||
use ruma_api::{OutgoingRequest, SendAccessToken};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use super::{LoginInfo, Medium, Request, UserIdentifier};
|
||||
@ -211,7 +211,7 @@ mod tests {
|
||||
device_id: None,
|
||||
initial_device_display_name: Some("test"),
|
||||
}
|
||||
.try_into_http_request("https://homeserver.tld", None)
|
||||
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
|
||||
.unwrap();
|
||||
|
||||
let req_body_value: JsonValue = serde_json::from_slice(req.body()).unwrap();
|
||||
@ -235,7 +235,7 @@ mod tests {
|
||||
device_id: None,
|
||||
initial_device_display_name: Some("test"),
|
||||
}
|
||||
.try_into_http_request("https://homeserver.tld", None)
|
||||
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
|
||||
.unwrap();
|
||||
|
||||
let req_body_value: JsonValue = serde_json::from_slice(req.body()).unwrap();
|
||||
|
@ -46,14 +46,14 @@ impl Response {
|
||||
|
||||
#[cfg(all(test, feature = "client"))]
|
||||
mod tests {
|
||||
use ruma_api::OutgoingRequest;
|
||||
use ruma_api::{OutgoingRequest, SendAccessToken};
|
||||
|
||||
use super::Request;
|
||||
|
||||
#[test]
|
||||
fn serialize_sso_login_request_uri() {
|
||||
let req: http::Request<Vec<u8>> = Request { redirect_url: "https://example.com/sso" }
|
||||
.try_into_http_request("https://homeserver.tld", None)
|
||||
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -53,14 +53,14 @@ impl Response {
|
||||
|
||||
#[cfg(all(test, feature = "client"))]
|
||||
mod tests {
|
||||
use ruma_api::OutgoingRequest as _;
|
||||
use ruma_api::{OutgoingRequest as _, SendAccessToken};
|
||||
|
||||
use super::Request;
|
||||
|
||||
#[test]
|
||||
fn serialize_sso_login_with_provider_request_uri() {
|
||||
let req = Request { idp_id: "provider", redirect_url: "https://example.com/sso" }
|
||||
.try_into_http_request("https://homeserver.tld", None)
|
||||
.try_into_http_request("https://homeserver.tld", SendAccessToken::None)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! [GET /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey)
|
||||
|
||||
use ruma_api::ruma_api;
|
||||
use ruma_api::{ruma_api, SendAccessToken};
|
||||
use ruma_events::{AnyStateEventContent, EventType};
|
||||
use ruma_identifiers::RoomId;
|
||||
use ruma_serde::{Outgoing, Raw};
|
||||
@ -68,7 +68,7 @@ impl<'a> ruma_api::OutgoingRequest for Request<'a> {
|
||||
fn try_into_http_request(
|
||||
self,
|
||||
base_url: &str,
|
||||
access_token: Option<&str>,
|
||||
access_token: SendAccessToken<'_>,
|
||||
) -> Result<http::Request<Vec<u8>>, ruma_api::error::IntoHttpError> {
|
||||
use std::borrow::Cow;
|
||||
|
||||
@ -87,17 +87,20 @@ impl<'a> ruma_api::OutgoingRequest for Request<'a> {
|
||||
url.push_str(&Cow::from(utf8_percent_encode(&self.state_key, NON_ALPHANUMERIC)));
|
||||
}
|
||||
|
||||
let access_token = match access_token {
|
||||
SendAccessToken::IfRequired(access_token) | SendAccessToken::Always(access_token) => {
|
||||
access_token
|
||||
}
|
||||
SendAccessToken::None => {
|
||||
return Err(ruma_api::error::IntoHttpError::NeedsAuthentication)
|
||||
}
|
||||
};
|
||||
|
||||
http::Request::builder()
|
||||
.method(http::Method::GET)
|
||||
.uri(url)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.header(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!(
|
||||
"Bearer {}",
|
||||
access_token.ok_or(ruma_api::error::IntoHttpError::NeedsAuthentication)?
|
||||
))?,
|
||||
)
|
||||
.header(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", access_token))?)
|
||||
.body(Vec::new())
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
@ -777,7 +777,7 @@ mod tests {
|
||||
mod client_tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use ruma_api::OutgoingRequest as _;
|
||||
use ruma_api::{OutgoingRequest as _, SendAccessToken};
|
||||
|
||||
use super::{Filter, PresenceState, Request};
|
||||
|
||||
@ -790,7 +790,7 @@ mod client_tests {
|
||||
set_presence: &PresenceState::Offline,
|
||||
timeout: Some(Duration::from_millis(30000)),
|
||||
}
|
||||
.try_into_http_request("https://homeserver.tld", Some("auth_tok"))
|
||||
.try_into_http_request("https://homeserver.tld", SendAccessToken::IfRequired("auth_tok"))
|
||||
.unwrap();
|
||||
|
||||
let uri = req.uri();
|
||||
|
@ -113,7 +113,7 @@ use async_stream::try_stream;
|
||||
use futures_core::stream::Stream;
|
||||
use http::{uri::Uri, Response as HttpResponse};
|
||||
use hyper::client::{Client as HyperClient, HttpConnector};
|
||||
use ruma_api::{AuthScheme, OutgoingRequest};
|
||||
use ruma_api::{AuthScheme, OutgoingRequest, SendAccessToken};
|
||||
use ruma_client_api::r0::sync::sync_events::{
|
||||
Filter as SyncFilter, Request as SyncRequest, Response as SyncResponse,
|
||||
};
|
||||
@ -349,12 +349,12 @@ impl Client {
|
||||
let access_token = if Request::METADATA.authentication == AuthScheme::AccessToken {
|
||||
session = client.session.lock().unwrap();
|
||||
if let Some(s) = &*session {
|
||||
Some(s.access_token.as_str())
|
||||
SendAccessToken::IfRequired(s.access_token.as_str())
|
||||
} else {
|
||||
return Err(Error::AuthenticationRequired);
|
||||
}
|
||||
} else {
|
||||
None
|
||||
SendAccessToken::None
|
||||
};
|
||||
|
||||
request.try_into_http_request(&client.homeserver_url.to_string(), access_token)?
|
||||
|
Loading…
x
Reference in New Issue
Block a user