reduce excessive cloning for verify_json

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-10-10 00:58:48 +00:00
parent eb93c641ab
commit 90fb81eabe
2 changed files with 27 additions and 39 deletions

View File

@ -12,7 +12,7 @@ use ruma_common::{
serde::{base64::Standard, Base64}, serde::{base64::Standard, Base64},
CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedServerName, RoomVersionId, UserId, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedServerName, RoomVersionId, UserId,
}; };
use serde_json::{from_str as from_json_str, to_string as to_json_string}; use serde_json::to_string as to_json_string;
use sha2::{digest::Digest, Sha256}; use sha2::{digest::Digest, Sha256};
use crate::{ use crate::{
@ -154,7 +154,7 @@ where
/// ///
/// assert_eq!(canonical, r#"{"日":1,"本":2}"#); /// assert_eq!(canonical, r#"{"日":1,"本":2}"#);
/// ``` /// ```
pub fn canonical_json(object: &CanonicalJsonObject) -> Result<String, Error> { pub fn canonical_json(object: CanonicalJsonObject) -> Result<String, Error> {
canonical_json_with_fields_to_remove(object, CANONICAL_JSON_FIELDS_TO_REMOVE) canonical_json_with_fields_to_remove(object, CANONICAL_JSON_FIELDS_TO_REMOVE)
} }
@ -207,28 +207,27 @@ pub fn canonical_json(object: &CanonicalJsonObject) -> Result<String, Error> {
/// ``` /// ```
pub fn verify_json( pub fn verify_json(
public_key_map: &PublicKeyMap, public_key_map: &PublicKeyMap,
object: &CanonicalJsonObject, mut object: CanonicalJsonObject,
) -> Result<(), Error> { ) -> Result<(), Error> {
let signature_map = match object.get("signatures") { let signature_map = match object.remove("signatures") {
Some(CanonicalJsonValue::Object(signatures)) => signatures.clone(), Some(CanonicalJsonValue::Object(signatures)) => signatures,
Some(_) => return Err(JsonError::not_of_type("signatures", JsonType::Object)), Some(_) => return Err(JsonError::not_of_type("signatures", JsonType::Object)),
None => return Err(JsonError::field_missing_from_object("signatures")), None => return Err(JsonError::field_missing_from_object("signatures")),
}; };
for (entity_id, signature_set) in signature_map { let object = canonical_json(object)?;
for (entity_id, signature_set) in &signature_map {
let signature_set = match signature_set { let signature_set = match signature_set {
CanonicalJsonValue::Object(set) => set, CanonicalJsonValue::Object(set) => set,
_ => return Err(JsonError::not_multiples_of_type("signature sets", JsonType::Object)), _ => return Err(JsonError::not_multiples_of_type("signature sets", JsonType::Object)),
}; };
let public_keys = match public_key_map.get(&entity_id) { let public_keys = match public_key_map.get(entity_id) {
Some(keys) => keys, Some(keys) => keys,
None => { None => return Err(JsonError::key_missing("public_key_map", "public_keys", entity_id)),
return Err(JsonError::key_missing("public_key_map", "public_keys", &entity_id))
}
}; };
for (key_id, signature) in &signature_set { for (key_id, signature) in signature_set {
let signature = match signature { let signature = match signature {
CanonicalJsonValue::String(s) => s, CanonicalJsonValue::String(s) => s,
_ => return Err(JsonError::not_of_type("signature", JsonType::String)), _ => return Err(JsonError::not_of_type("signature", JsonType::String)),
@ -245,12 +244,7 @@ pub fn verify_json(
let signature = Base64::<Standard>::parse(signature) let signature = Base64::<Standard>::parse(signature)
.map_err(|e| ParseError::base64("signature", signature, e))?; .map_err(|e| ParseError::base64("signature", signature, e))?;
verify_json_with( verify_json_with(&Ed25519Verifier, &public_key, &signature, &object)?;
&Ed25519Verifier,
public_key.as_bytes(),
signature.as_bytes(),
object,
)?;
} }
} }
@ -271,14 +265,14 @@ pub fn verify_json(
/// Returns an error if verification fails. /// Returns an error if verification fails.
fn verify_json_with<V>( fn verify_json_with<V>(
verifier: &V, verifier: &V,
public_key: &[u8], public_key: &Base64,
signature: &[u8], signature: &Base64,
object: &CanonicalJsonObject, object: &str,
) -> Result<(), Error> ) -> Result<(), Error>
where where
V: Verifier, V: Verifier,
{ {
verifier.verify_json(public_key, signature, canonical_json(object)?.as_bytes()) verifier.verify_json(public_key.as_bytes(), signature.as_bytes(), object.as_bytes())
} }
/// Creates a *content hash* for an event. /// Creates a *content hash* for an event.
@ -294,7 +288,7 @@ where
/// ///
/// Returns an error if the event is too large. /// Returns an error if the event is too large.
pub fn content_hash(object: &CanonicalJsonObject) -> Result<Base64<Standard, [u8; 32]>, Error> { pub fn content_hash(object: &CanonicalJsonObject) -> Result<Base64<Standard, [u8; 32]>, Error> {
let json = canonical_json_with_fields_to_remove(object, CONTENT_HASH_FIELDS_TO_REMOVE)?; let json = canonical_json_with_fields_to_remove(object.clone(), CONTENT_HASH_FIELDS_TO_REMOVE)?;
if json.len() > MAX_PDU_BYTES { if json.len() > MAX_PDU_BYTES {
return Err(Error::PduSize); return Err(Error::PduSize);
} }
@ -326,7 +320,8 @@ pub fn reference_hash(
let redacted_value = redact(value.clone(), version, None)?; let redacted_value = redact(value.clone(), version, None)?;
let json = let json =
canonical_json_with_fields_to_remove(&redacted_value, REFERENCE_HASH_FIELDS_TO_REMOVE)?; canonical_json_with_fields_to_remove(redacted_value, REFERENCE_HASH_FIELDS_TO_REMOVE)?;
if json.len() > MAX_PDU_BYTES { if json.len() > MAX_PDU_BYTES {
return Err(Error::PduSize); return Err(Error::PduSize);
} }
@ -567,7 +562,7 @@ pub fn verify_event(
}; };
let servers_to_check = servers_to_check_signatures(object, version)?; let servers_to_check = servers_to_check_signatures(object, version)?;
let canonical_json = from_json_str(&canonical_json(&redacted)?).map_err(JsonError::from)?; let canonical_json = canonical_json(redacted)?;
for entity_id in servers_to_check { for entity_id in servers_to_check {
let signature_set = match signature_map.get(entity_id.as_str()) { let signature_set = match signature_map.get(entity_id.as_str()) {
@ -603,12 +598,7 @@ pub fn verify_event(
let signature = Base64::<Standard>::parse(signature) let signature = Base64::<Standard>::parse(signature)
.map_err(|e| ParseError::base64("signature", signature, e))?; .map_err(|e| ParseError::base64("signature", signature, e))?;
verify_json_with( verify_json_with(&Ed25519Verifier, &public_key, &signature, &canonical_json)?;
&Ed25519Verifier,
public_key.as_bytes(),
signature.as_bytes(),
&canonical_json,
)?;
checked = true; checked = true;
} }
@ -632,11 +622,9 @@ pub fn verify_event(
/// ///
/// Allows customization of the fields that will be removed before serializing. /// Allows customization of the fields that will be removed before serializing.
fn canonical_json_with_fields_to_remove( fn canonical_json_with_fields_to_remove(
object: &CanonicalJsonObject, mut owned_object: CanonicalJsonObject,
fields: &[&str], fields: &[&str],
) -> Result<String, Error> { ) -> Result<String, Error> {
let mut owned_object = object.clone();
for field in fields { for field in fields {
owned_object.remove(*field); owned_object.remove(*field);
} }
@ -766,7 +754,7 @@ mod tests {
_ => unreachable!(), _ => unreachable!(),
}; };
assert_eq!(canonical_json(&object).unwrap(), canonical); assert_eq!(canonical_json(object).unwrap(), canonical);
} }
#[test] #[test]

View File

@ -137,7 +137,7 @@ mod tests {
/// Convenience for converting a string of JSON into its canonical form. /// Convenience for converting a string of JSON into its canonical form.
fn test_canonical_json(input: &str) -> String { fn test_canonical_json(input: &str) -> String {
let object = from_json_str(input).unwrap(); let object = from_json_str(input).unwrap();
canonical_json(&object).unwrap() canonical_json(object).unwrap()
} }
#[test] #[test]
@ -253,7 +253,7 @@ mod tests {
let mut public_key_map = BTreeMap::new(); let mut public_key_map = BTreeMap::new();
public_key_map.insert("domain".into(), signature_set); public_key_map.insert("domain".into(), signature_set);
verify_json(&public_key_map, &value).unwrap(); verify_json(&public_key_map, value).unwrap();
} }
#[test] #[test]
@ -290,13 +290,13 @@ mod tests {
let mut public_key_map = BTreeMap::new(); let mut public_key_map = BTreeMap::new();
public_key_map.insert("domain".into(), signature_set); public_key_map.insert("domain".into(), signature_set);
verify_json(&public_key_map, &value).unwrap(); verify_json(&public_key_map, value).unwrap();
let reverse_value = from_json_str( let reverse_value = from_json_str(
r#"{"two":"Two","signatures":{"domain":{"ed25519:1":"t6Ehmh6XTDz7qNWI0QI5tNPSliWLPQP/+Fzz3LpdCS7q1k2G2/5b5Embs2j4uG3ZeivejrzqSVoBcdocRpa+AQ"}},"one":1}"# r#"{"two":"Two","signatures":{"domain":{"ed25519:1":"t6Ehmh6XTDz7qNWI0QI5tNPSliWLPQP/+Fzz3LpdCS7q1k2G2/5b5Embs2j4uG3ZeivejrzqSVoBcdocRpa+AQ"}},"one":1}"#
).unwrap(); ).unwrap();
verify_json(&public_key_map, &reverse_value).unwrap(); verify_json(&public_key_map, reverse_value).unwrap();
} }
#[test] #[test]
@ -309,7 +309,7 @@ mod tests {
let mut public_key_map = BTreeMap::new(); let mut public_key_map = BTreeMap::new();
public_key_map.insert("domain".into(), signature_set); public_key_map.insert("domain".into(), signature_set);
verify_json(&public_key_map, &value).unwrap_err(); verify_json(&public_key_map, value).unwrap_err();
} }
#[test] #[test]