diff --git a/ruma-api-macros/src/api.rs b/ruma-api-macros/src/api.rs index ef8890d4..dfb92a1d 100644 --- a/ruma-api-macros/src/api.rs +++ b/ruma-api-macros/src/api.rs @@ -136,15 +136,17 @@ impl ToTokens for Api { let mut header_kvs = self.request.append_header_kvs(); if requires_authentication.value { header_kvs.push(quote! { - #ruma_api_import::exports::http::header::AUTHORIZATION, - #ruma_api_import::exports::http::header::HeaderValue::from_str( - &::std::format!( - "Bearer {}", - access_token.ok_or( - #ruma_api_import::error::IntoHttpError::NeedsAuthentication - )? - ) - )? + let req_builder = req_builder.header( + #ruma_api_import::exports::http::header::AUTHORIZATION, + #ruma_api_import::exports::http::header::HeaderValue::from_str( + &::std::format!( + "Bearer {}", + access_token.ok_or( + #ruma_api_import::error::IntoHttpError::NeedsAuthentication + )? + ) + )? + ); }); } @@ -274,13 +276,14 @@ impl ToTokens for Api { #[allow(unused_variables)] fn try_from(response: Response) -> ::std::result::Result { - let response = #ruma_api_import::exports::http::Response::builder() - .header(#ruma_api_import::exports::http::header::CONTENT_TYPE, "application/json") - #serialize_response_headers - .body(#body) - // Since we require header names to come from the `http::header` module, - // this cannot fail. - .unwrap(); + let mut resp_builder = #ruma_api_import::exports::http::Response::builder() + .header(#ruma_api_import::exports::http::header::CONTENT_TYPE, "application/json"); + + #serialize_response_headers + + // Since we require header names to come from the `http::header` module, + // this cannot fail. + let response = resp_builder.body(#body).unwrap(); Ok(response) } } @@ -341,16 +344,18 @@ impl ToTokens for Api { > { let metadata = ::METADATA; - let http_request = #ruma_api_import::exports::http::Request::builder() + let mut req_builder = #ruma_api_import::exports::http::Request::builder() .method(#ruma_api_import::exports::http::Method::#method) .uri(::std::format!( "{}{}{}", base_url.strip_suffix("/").unwrap_or(base_url), #request_path_string, #request_query_string, - )) - #( .header(#header_kvs) )* - .body(#request_body)?; + )); + + #( #header_kvs )* + + let http_request = req_builder.body(#request_body)?; Ok(http_request) } diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 545b4058..c56d5c9a 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -49,9 +49,25 @@ impl Request { let field_name = &field.ident; - quote! { - #import_path::exports::http::header::#header_name, - #import_path::exports::http::header::HeaderValue::from_str(self.#field_name.as_ref())? + match &field.ty { + syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) + if segments.last().unwrap().ident == "Option" => + { + quote! { + if let Some(header_val) = self.#field_name.as_ref() { + req_builder = req_builder.header( + #import_path::exports::http::header::#header_name, + #import_path::exports::http::header::HeaderValue::from_str(header_val)?, + ); + } + } + } + _ => quote! { + req_builder = req_builder.header( + #import_path::exports::http::header::#header_name, + #import_path::exports::http::header::HeaderValue::from_str(self.#field_name.as_ref())?, + ); + }, } }).collect() } @@ -68,13 +84,15 @@ impl Request { let field_name = &field.ident; let header_name_string = header_name.to_string(); - quote! { - #field_name: match headers - .get(#import_path::exports::http::header::#header_name) - .and_then(|v| v.to_str().ok()) // FIXME: Should have a distinct error message + let (some_case, none_case) = match &field.ty { + syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) + if segments.last().unwrap().ident == "Option" => { - Some(header) => header.to_owned(), - None => { + (quote! { Some(header.to_owned()) }, quote! { None }) + } + _ => ( + quote! { header.to_owned() }, + quote! {{ use #import_path::exports::serde::de::Error as _; // FIXME: Not a missing json field, a missing header! @@ -85,7 +103,17 @@ impl Request { request, ) .into()); - } + }}, + ), + }; + + quote! { + #field_name: match headers + .get(#import_path::exports::http::header::#header_name) + .and_then(|v| v.to_str().ok()) // FIXME: Should have a distinct error message + { + Some(header) => #some_case, + None => #none_case, } } }); diff --git a/ruma-api-macros/src/api/response.rs b/ruma-api-macros/src/api/response.rs index 99568ab8..8ae75913 100644 --- a/ruma-api-macros/src/api/response.rs +++ b/ruma-api-macros/src/api/response.rs @@ -58,15 +58,30 @@ impl Response { } } ResponseField::Header(_, header_name) => { - quote_spanned! {span=> - #field_name: #import_path::try_deserialize!( - response, - headers.remove(#import_path::exports::http::header::#header_name) - .expect("response missing expected header") - .to_str() + let optional_header = match &field.ty { + syn::Type::Path(syn::TypePath { + path: syn::Path { segments, .. }, .. + }) if segments.last().unwrap().ident == "Option" => { + quote! { + #field_name: #import_path::try_deserialize!( + response, + headers.remove(#import_path::exports::http::header::#header_name) + .map(|h| h.to_str().map(|s| s.to_owned())) + .transpose() + ) + } + } + _ => quote! { + #field_name: #import_path::try_deserialize!( + response, + headers.remove(#import_path::exports::http::header::#header_name) + .expect("response missing expected header") + .to_str() ) .to_owned() - } + }, + }; + quote_spanned! {span=> #optional_header } } ResponseField::NewtypeBody(_) => { quote_spanned! {span=> @@ -102,8 +117,29 @@ impl Response { field.ident.as_ref().expect("expected field to have an identifier"); let span = field.span(); + let optional_header = match &field.ty { + syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) + if segments.last().unwrap().ident == "Option" => + { + quote! { + if let Some(header) = response.#field_name { + resp_builder = resp_builder.header( + #import_path::exports::http::header::#header_name, + header, + ); + } + } + } + _ => quote! { + resp_builder = resp_builder.header( + #import_path::exports::http::header::#header_name, + response.#field_name, + ); + }, + }; + Some(quote_spanned! {span=> - .header(#import_path::exports::http::header::#header_name, response.#field_name) + #optional_header }) } else { None diff --git a/ruma-api/tests/optional_headers.rs b/ruma-api/tests/optional_headers.rs new file mode 100644 index 00000000..58061e0f --- /dev/null +++ b/ruma-api/tests/optional_headers.rs @@ -0,0 +1,21 @@ +use ruma_api::ruma_api; + +ruma_api! { + metadata: { + description: "Does something.", + method: GET, + name: "no_fields", + path: "/_matrix/my/endpoint", + rate_limited: false, + requires_authentication: false, + } + + request: { + #[ruma_api(header = LOCATION)] + location: Option, + } + response: { + #[ruma_api(header = LOCATION)] + stuff: Option, + } +}