diff --git a/ruma-api-macros/Cargo.toml b/ruma-api-macros/Cargo.toml index f4d21feb..e18b4a6d 100644 --- a/ruma-api-macros/Cargo.toml +++ b/ruma-api-macros/Cargo.toml @@ -17,7 +17,7 @@ edition = "2018" [dependencies] proc-macro2 = "1.0.24" quote = "1.0.8" -syn = { version = "1.0.57", features = ["full", "extra-traits"] } +syn = { version = "1.0.57", features = ["full", "extra-traits", "visit"] } proc-macro-crate = "1.0.0" [lib] diff --git a/ruma-api-macros/src/api/request.rs b/ruma-api-macros/src/api/request.rs index 793e45fd..42761c8f 100644 --- a/ruma-api-macros/src/api/request.rs +++ b/ruma-api-macros/src/api/request.rs @@ -324,10 +324,10 @@ impl Request { }; let (request_path_string, parse_request_path) = - path_string_and_parse(self, metadata, &ruma_api); + self.path_string_and_parse(metadata, &ruma_api); - let request_query_string = build_query_string(self, &ruma_api); - let extract_request_query = extract_request_query(self, &ruma_api); + let request_query_string = self.build_query_string(&ruma_api); + let extract_request_query = self.extract_request_query(&ruma_api); let parse_request_query = if let Some(field) = self.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); @@ -403,8 +403,8 @@ impl Request { TokenStream::new() }; - let request_body = build_request_body(self, &ruma_api); - let parse_request_body = parse_request_body(self); + let request_body = self.build_request_body(&ruma_api); + let parse_request_body = self.parse_request_body(); let request_generics = self.combine_lifetimes(); @@ -517,86 +517,300 @@ impl Request { .collect(); quote! { - #[doc = #docs] - #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] - #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] - #[incoming_derive(!Deserialize)] - #( #struct_attributes )* - pub struct Request #request_generics #request_def + #[doc = #docs] + #[derive(Debug, Clone, #ruma_serde::Outgoing, #ruma_serde::_FakeDeriveSerde)] + #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] + #[incoming_derive(!Deserialize)] + #( #struct_attributes )* + pub struct Request #request_generics #request_def - #non_auth_endpoint_impls + #non_auth_endpoint_impls - #request_body_struct - #request_query_struct + #request_body_struct + #request_query_struct - #[automatically_derived] - #[cfg(feature = "client")] - impl #request_lifetimes #ruma_api::OutgoingRequest for Request #request_lifetimes { - type EndpointError = #error_ty; - type IncomingResponse = ::Incoming; + #[automatically_derived] + #[cfg(feature = "client")] + impl #request_lifetimes #ruma_api::OutgoingRequest for Request #request_lifetimes { + type EndpointError = #error_ty; + type IncomingResponse = ::Incoming; - const METADATA: #ruma_api::Metadata = self::METADATA; + const METADATA: #ruma_api::Metadata = self::METADATA; - fn try_into_http_request( - self, - base_url: &::std::primitive::str, - access_token: ::std::option::Option<&str>, - ) -> ::std::result::Result<#http::Request>, #ruma_api::error::IntoHttpError> { - let metadata = self::METADATA; + fn try_into_http_request( + self, + base_url: &::std::primitive::str, + access_token: ::std::option::Option<&str>, + ) -> ::std::result::Result<#http::Request>, #ruma_api::error::IntoHttpError> { + let metadata = self::METADATA; - let mut req_builder = #http::Request::builder() - .method(#http::Method::#method) - .uri(::std::format!( - "{}{}{}", - base_url.strip_suffix('/').unwrap_or(base_url), - #request_path_string, - #request_query_string, - )) - .header(#ruma_api::exports::http::header::CONTENT_TYPE, "application/json"); + let mut req_builder = #http::Request::builder() + .method(#http::Method::#method) + .uri(::std::format!( + "{}{}{}", + base_url.strip_suffix('/').unwrap_or(base_url), + #request_path_string, + #request_query_string, + )) + .header( + #ruma_api::exports::http::header::CONTENT_TYPE, + "application/json", + ); - let mut req_headers = req_builder - .headers_mut() - .expect("`http::RequestBuilder` is in unusable state"); + let mut req_headers = req_builder + .headers_mut() + .expect("`http::RequestBuilder` is in unusable state"); - #header_kvs + #header_kvs - let http_request = req_builder.body(#request_body)?; + let http_request = req_builder.body(#request_body)?; - Ok(http_request) + Ok(http_request) } } - #[automatically_derived] - #[cfg(feature = "server")] - impl #ruma_api::IncomingRequest for #incoming_request_type { - type EndpointError = #error_ty; - type OutgoingResponse = Response; + #[automatically_derived] + #[cfg(feature = "server")] + impl #ruma_api::IncomingRequest for #incoming_request_type { + type EndpointError = #error_ty; + type OutgoingResponse = Response; - const METADATA: #ruma_api::Metadata = self::METADATA; + const METADATA: #ruma_api::Metadata = self::METADATA; - fn try_from_http_request( - request: #http::Request> - ) -> ::std::result::Result { - if request.method() != #http::Method::#method { - return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch { - expected: #http::Method::#method, - received: request.method().clone(), - }); - } - #extract_request_path - #extract_request_query - #extract_request_headers - #extract_request_body + fn try_from_http_request( + request: #http::Request> + ) -> ::std::result::Result { + if request.method() != #http::Method::#method { + return Err(#ruma_api::error::FromHttpRequestError::MethodMismatch { + expected: #http::Method::#method, + received: request.method().clone(), + }); + } - Ok(Self { - #parse_request_path - #parse_request_query - #parse_request_headers - #parse_request_body - }) + #extract_request_path + #extract_request_query + #extract_request_headers + #extract_request_body + + Ok(Self { + #parse_request_path + #parse_request_query + #parse_request_headers + #parse_request_body + }) + } + } } + } + + /// Deserialize the query string. + fn extract_request_query(&self, ruma_api: &TokenStream) -> TokenStream { + let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; + + if self.query_map_field().is_some() { + quote! { + let request_query = #ruma_api::try_deserialize!( + request, + #ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or("") + ), + ); + } + } else if self.has_query_fields() { + quote! { + let request_query: ::Incoming = + #ruma_api::try_deserialize!( + request, + #ruma_serde::urlencoded::from_str( + &request.uri().query().unwrap_or("") + ), + ); + } + } else { + TokenStream::new() + } + } + + /// Generates the code to initialize a `Request`. + /// + /// Used to construct an `http::Request`s body. + fn build_request_body(&self, ruma_api: &TokenStream) -> TokenStream { + let serde_json = quote! { #ruma_api::exports::serde_json }; + + if let Some(field) = self.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote!(self.#field_name) + } else if self.has_body_fields() || self.newtype_body_field().is_some() { + let request_body_initializers = if let Some(field) = self.newtype_body_field() { + let field_name = + field.ident.as_ref().expect("expected field to have an identifier"); + quote! { (self.#field_name) } + } else { + let initializers = self.request_body_init_fields(); + quote! { { #initializers } } + }; + + quote! { + { + let request_body = RequestBody #request_body_initializers; + #serde_json::to_vec(&request_body)? + } + } + } else { + quote!(Vec::new()) + } + } + + fn parse_request_body(&self) -> TokenStream { + if let Some(field) = self.newtype_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request_body.0, + } + } else if let Some(field) = self.newtype_raw_body_field() { + let field_name = field.ident.as_ref().expect("expected field to have an identifier"); + quote! { + #field_name: request.into_body(), + } + } else { + self.request_init_body_fields() + } + } + + /// The function determines the type of query string that needs to be built + /// and then builds it using `ruma_serde::urlencoded::to_string`. + fn build_query_string(&self, ruma_api: &TokenStream) -> TokenStream { + let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; + + if let Some(field) = self.query_map_field() { + let field_name = field.ident.as_ref().expect("expected field to have identifier"); + + quote!({ + // This function exists so that the compiler will throw an + // error when the type of the field with the query_map + // attribute doesn't implement IntoIterator + // + // This is necessary because the ruma_serde::urlencoded::to_string + // call will result in a runtime error when the type cannot be + // encoded as a list key-value pairs (?key1=value1&key2=value2) + // + // By asserting that it implements the iterator trait, we can + // ensure that it won't fail. + fn assert_trait_impl(_: &T) + where + T: ::std::iter::IntoIterator< + Item = (::std::string::String, ::std::string::String) + >, + {} + + let request_query = RequestQuery(self.#field_name); + assert_trait_impl(&request_query.0); + + format_args!( + "?{}", + #ruma_serde::urlencoded::to_string(request_query)? + ) + }) + } else if self.has_query_fields() { + let request_query_init_fields = self.request_query_init_fields(); + + quote!({ + let request_query = RequestQuery { + #request_query_init_fields + }; + + format_args!( + "?{}", + #ruma_serde::urlencoded::to_string(request_query)? + ) + }) + } else { + quote! { "" } + } + } + + /// The first item in the tuple generates code for the request path from + /// the `Metadata` and `Request` structs. The second item in the returned tuple + /// is the code to generate a Request struct field created from any segments + /// of the path that start with ":". + /// + /// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is + /// used for implementing `TryFrom>>`, from path strings deserialized to Ruma + /// types. + pub(crate) fn path_string_and_parse( + &self, + metadata: &Metadata, + ruma_api: &TokenStream, + ) -> (TokenStream, TokenStream) { + let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; + + if self.has_path_fields() { + let path_string = metadata.path.value(); + + assert!(path_string.starts_with('/'), "path needs to start with '/'"); + assert!( + path_string.chars().filter(|c| *c == ':').count() == self.path_field_count(), + "number of declared path parameters needs to match amount of placeholders in path" + ); + + let format_call = { + let mut format_string = path_string.clone(); + let mut format_args = Vec::new(); + + while let Some(start_of_segment) = format_string.find(':') { + // ':' should only ever appear at the start of a segment + assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); + + let end_of_segment = match format_string[start_of_segment..].find('/') { + Some(rel_pos) => start_of_segment + rel_pos, + None => format_string.len(), + }; + + let path_var = Ident::new( + &format_string[start_of_segment + 1..end_of_segment], + Span::call_site(), + ); + format_args.push(quote! { + #percent_encoding::utf8_percent_encode( + &self.#path_var.to_string(), + #percent_encoding::NON_ALPHANUMERIC, + ) + }); + format_string.replace_range(start_of_segment..end_of_segment, "{}"); + } + + quote! { + format_args!(#format_string, #(#format_args),*) + } + }; + + let path_fields = + path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( + |(i, segment)| { + let path_var = &segment[1..]; + let path_var_ident = Ident::new(path_var, Span::call_site()); + quote! { + #path_var_ident: { + let segment = path_segments[#i].as_bytes(); + let decoded = #ruma_api::try_deserialize!( + request, + #percent_encoding::percent_decode(segment) + .decode_utf8(), + ); + + #ruma_api::try_deserialize!( + request, + ::std::convert::TryFrom::try_from(&*decoded), + ) } } + }, + ); + + (format_call, quote! { #(#path_fields,)* }) + } else { + (quote! { metadata.path.to_owned() }, TokenStream::new()) + } } } @@ -750,210 +964,3 @@ pub(crate) enum RequestFieldKind { /// See the similarly named variant of `RequestField`. QueryMap, } - -/// The first item in the tuple generates code for the request path from -/// the `Metadata` and `Request` structs. The second item in the returned tuple -/// is the code to generate a Request struct field created from any segments -/// of the path that start with ":". -/// -/// The first `TokenStream` returned is the constructed url path. The second `TokenStream` is -/// used for implementing `TryFrom>>`, from path strings deserialized to Ruma -/// types. -pub(crate) fn path_string_and_parse( - request: &Request, - metadata: &Metadata, - ruma_api: &TokenStream, -) -> (TokenStream, TokenStream) { - let percent_encoding = quote! { #ruma_api::exports::percent_encoding }; - - if request.has_path_fields() { - let path_string = metadata.path.value(); - - assert!(path_string.starts_with('/'), "path needs to start with '/'"); - assert!( - path_string.chars().filter(|c| *c == ':').count() == request.path_field_count(), - "number of declared path parameters needs to match amount of placeholders in path" - ); - - let format_call = { - let mut format_string = path_string.clone(); - let mut format_args = Vec::new(); - - while let Some(start_of_segment) = format_string.find(':') { - // ':' should only ever appear at the start of a segment - assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); - - let end_of_segment = match format_string[start_of_segment..].find('/') { - Some(rel_pos) => start_of_segment + rel_pos, - None => format_string.len(), - }; - - let path_var = Ident::new( - &format_string[start_of_segment + 1..end_of_segment], - Span::call_site(), - ); - format_args.push(quote! { - #percent_encoding::utf8_percent_encode( - &self.#path_var.to_string(), - #percent_encoding::NON_ALPHANUMERIC, - ) - }); - format_string.replace_range(start_of_segment..end_of_segment, "{}"); - } - - quote! { - format_args!(#format_string, #(#format_args),*) - } - }; - - let path_fields = - path_string[1..].split('/').enumerate().filter(|(_, s)| s.starts_with(':')).map( - |(i, segment)| { - let path_var = &segment[1..]; - let path_var_ident = Ident::new(path_var, Span::call_site()); - quote! { - #path_var_ident: { - let segment = path_segments[#i].as_bytes(); - let decoded = #ruma_api::try_deserialize!( - request, - #percent_encoding::percent_decode(segment) - .decode_utf8(), - ); - - #ruma_api::try_deserialize!( - request, - ::std::convert::TryFrom::try_from(&*decoded), - ) - } - } - }, - ); - - (format_call, quote! { #(#path_fields,)* }) - } else { - (quote! { metadata.path.to_owned() }, TokenStream::new()) - } -} - -/// The function determines the type of query string that needs to be built -/// and then builds it using `ruma_serde::urlencoded::to_string`. -fn build_query_string(request: &Request, ruma_api: &TokenStream) -> TokenStream { - let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - - if let Some(field) = request.query_map_field() { - let field_name = field.ident.as_ref().expect("expected field to have identifier"); - - quote!({ - // This function exists so that the compiler will throw an - // error when the type of the field with the query_map - // attribute doesn't implement IntoIterator - // - // This is necessary because the ruma_serde::urlencoded::to_string - // call will result in a runtime error when the type cannot be - // encoded as a list key-value pairs (?key1=value1&key2=value2) - // - // By asserting that it implements the iterator trait, we can - // ensure that it won't fail. - fn assert_trait_impl(_: &T) - where - T: ::std::iter::IntoIterator, - {} - - let request_query = RequestQuery(self.#field_name); - assert_trait_impl(&request_query.0); - - format_args!( - "?{}", - #ruma_serde::urlencoded::to_string(request_query)? - ) - }) - } else if request.has_query_fields() { - let request_query_init_fields = request.request_query_init_fields(); - - quote!({ - let request_query = RequestQuery { - #request_query_init_fields - }; - - format_args!( - "?{}", - #ruma_serde::urlencoded::to_string(request_query)? - ) - }) - } else { - quote! { "" } - } -} - -/// Deserialize the query string. -fn extract_request_query(request: &Request, ruma_api: &TokenStream) -> TokenStream { - let ruma_serde = quote! { #ruma_api::exports::ruma_serde }; - - if request.query_map_field().is_some() { - quote! { - let request_query = #ruma_api::try_deserialize!( - request, - #ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or("") - ), - ); - } - } else if request.has_query_fields() { - quote! { - let request_query: ::Incoming = - #ruma_api::try_deserialize!( - request, - #ruma_serde::urlencoded::from_str( - &request.uri().query().unwrap_or("") - ), - ); - } - } else { - TokenStream::new() - } -} - -/// Generates the code to initialize a `Request`. -/// -/// Used to construct an `http::Request`s body. -fn build_request_body(request: &Request, ruma_api: &TokenStream) -> TokenStream { - let serde_json = quote! { #ruma_api::exports::serde_json }; - - if let Some(field) = request.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote!(self.#field_name) - } else if request.has_body_fields() || request.newtype_body_field().is_some() { - let request_body_initializers = if let Some(field) = request.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { (self.#field_name) } - } else { - let initializers = request.request_body_init_fields(); - quote! { { #initializers } } - }; - - quote! { - { - let request_body = RequestBody #request_body_initializers; - #serde_json::to_vec(&request_body)? - } - } - } else { - quote!(Vec::new()) - } -} - -fn parse_request_body(request: &Request) -> TokenStream { - if let Some(field) = request.newtype_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request_body.0, - } - } else if let Some(field) = request.newtype_raw_body_field() { - let field_name = field.ident.as_ref().expect("expected field to have an identifier"); - quote! { - #field_name: request.into_body(), - } - } else { - request.request_init_body_fields() - } -}