From c15382ca41262058302959eac4029ab4a1ea5889 Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Sun, 1 Nov 2020 05:04:34 -0800 Subject: [PATCH] api-macros: Make Response header fields override any defaults --- ruma-api-macros/src/api.rs | 7 +++-- ruma-api-macros/src/api/response.rs | 18 +++++++------ ruma-api/tests/header_override.rs | 42 +++++++++++++++++++++++++++++ ruma-api/tests/optional_headers.rs | 1 + 4 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 ruma-api/tests/header_override.rs diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index d23271c3..eec40244 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -289,10 +289,13 @@ impl ToTokens for Api { let mut resp_builder = #ruma_api_import::exports::http::Response::builder() .header(#ruma_api_import::exports::http::header::CONTENT_TYPE, "application/json"); + let mut headers = + resp_builder.headers_mut().expect("`http::ResponseBuilder` is in unusable state"); #serialize_response_headers - // Since we require header names to come from the `http::header` module, - // this cannot fail. + // This cannot fail because we parse each header value + // checking for errors as each value is inserted and + // we only allow keys from the `http::header` module. let response = resp_builder.body(#body).unwrap(); Ok(response) } diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index 4bab4868..4daa20eb 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -123,18 +123,20 @@ impl Response { { quote! { if let Some(header) = response.#field_name { - resp_builder = resp_builder.header( - #import_path::exports::http::header::#header_name, - header, - ); + headers + .insert( + #import_path::exports::http::header::#header_name, + header.parse()?, + ); } } } _ => quote! { - resp_builder = resp_builder.header( - #import_path::exports::http::header::#header_name, - response.#field_name, - ); + headers + .insert( + #import_path::exports::http::header::#header_name, + response.#field_name.parse()?, + ); }, }; diff --git a/ruma-api/tests/header_override.rs b/ruma-api/tests/header_override.rs new file mode 100644 index 00000000..db7c4a80 --- /dev/null +++ b/ruma-api/tests/header_override.rs @@ -0,0 +1,42 @@ +use std::convert::TryFrom; + +use http::header::{Entry, CONTENT_TYPE}; +use ruma_api::ruma_api; + +ruma_api! { + metadata: { + description: "Does something.", + method: GET, + name: "no_fields", + path: "/_matrix/my/endpoint", + rate_limited: false, + authentication: None, + } + + request: { + #[ruma_api(header = LOCATION)] + pub location: Option, + } + + response: { + #[ruma_api(header = CONTENT_TYPE)] + pub stuff: String, + } +} + +#[test] +fn content_type_override() { + let res = Response { stuff: "magic".into() }; + let mut http_res = http::Response::>::try_from(res).unwrap(); + + // Test that we correctly replaced the default content type, + // not adding another content-type header. + assert_eq!( + match http_res.headers_mut().entry(CONTENT_TYPE) { + Entry::Occupied(occ) => occ.iter().count(), + _ => 0, + }, + 1 + ); + assert_eq!(http_res.headers().get("content-type").unwrap(), "magic"); +} diff --git a/ruma-api/tests/optional_headers.rs b/ruma-api/tests/optional_headers.rs index 1a5adb26..84ca68d1 100644 --- a/ruma-api/tests/optional_headers.rs +++ b/ruma-api/tests/optional_headers.rs @@ -14,6 +14,7 @@ ruma_api! { #[ruma_api(header = LOCATION)] pub location: Option, } + response: { #[ruma_api(header = LOCATION)] pub stuff: Option,