Update ruma-identifiers validation logic

* Allow empty localparts
* Simplify some code
This commit is contained in:
Jonas Platte 2020-09-21 22:34:56 +02:00
parent 85e3df7c76
commit 22ec1710b5
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
10 changed files with 37 additions and 44 deletions

View File

@ -5,27 +5,25 @@ use std::fmt::{self, Display, Formatter};
/// An error encountered when trying to parse an invalid ID string.
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
pub enum Error {
/// The room version ID is empty.
EmptyRoomVersionId,
/// The ID's localpart contains invalid characters.
///
/// Only relevant for user IDs.
InvalidCharacters,
/// The key version contains outside of [a-zA-Z0-9_].
InvalidKeyVersion,
/// The localpart of the ID string is not valid (because it is empty).
InvalidLocalPart,
/// The server name part of the the ID string is not a valid server name.
InvalidServerName,
/// The ID exceeds 255 bytes (or 32 codepoints for a room version ID).
MaximumLengthExceeded,
/// The ID is less than 4 characters (or is an empty room version ID).
MinimumLengthNotSatisfied,
/// The ID is missing the colon delimiter between localpart and server name.
MissingDelimiter,
/// The ID is missing the colon delimiter between key algorithm and device ID.
MissingDeviceKeyDelimiter,
/// The ID is missing the colon delimiter between key algorithm and version.
MissingServerKeyDelimiter,
/// The ID is missing the leading sigil.
/// The ID is missing the correct leading sigil.
MissingSigil,
/// The key algorithm is not recognized.
UnknownKeyAlgorithm,
@ -34,16 +32,15 @@ pub enum Error {
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let message = match self {
Error::EmptyRoomVersionId => "room version ID is empty",
Error::InvalidCharacters => "localpart contains invalid characters",
Error::InvalidKeyVersion => "key id version contains invalid characters",
Error::InvalidLocalPart => "localpart is empty",
Error::InvalidKeyVersion => "key ID version contains invalid characters",
Error::InvalidServerName => "server name is not a valid IP address or domain name",
Error::MaximumLengthExceeded => "ID exceeds 255 bytes",
Error::MinimumLengthNotSatisfied => "ID must be at least 4 characters",
Error::MissingDelimiter => "colon is required between localpart and server name",
Error::MissingDeviceKeyDelimiter => "colon is required between algorithm and device ID",
Error::MissingServerKeyDelimiter => "colon is required between algorithm and version",
Error::MissingSigil => "leading sigil is missing",
Error::MissingSigil => "leading sigil is incorrect or missing",
Error::UnknownKeyAlgorithm => "unknown key algorithm specified",
};

View File

@ -1,12 +1,15 @@
use std::num::NonZeroU8;
use crate::{parse_id, validate_id, Error};
use crate::{parse_id, Error};
pub fn validate(s: &str) -> Result<Option<NonZeroU8>, Error> {
Ok(match s.contains(':') {
true => Some(parse_id(s, &['$'])?),
false => {
validate_id(s, &['$'])?;
if !s.starts_with('$') {
return Err(Error::MissingSigil);
}
None
}
})

View File

@ -17,23 +17,13 @@ pub use error::Error;
/// All identifiers must be 255 bytes or less.
const MAX_BYTES: usize = 255;
/// The minimum number of characters an ID can be.
///
/// This is an optimization and not required by the spec. The shortest possible valid ID is a sigil
/// + a single character local ID + a colon + a single character hostname.
const MIN_CHARS: usize = 4;
/// Checks if an identifier is valid.
fn validate_id(id: &str, valid_sigils: &[char]) -> Result<(), Error> {
if id.len() > MAX_BYTES {
return Err(Error::MaximumLengthExceeded);
}
if id.len() < MIN_CHARS {
return Err(Error::MinimumLengthNotSatisfied);
}
if !valid_sigils.contains(&id.chars().next().unwrap()) {
if !id.starts_with(valid_sigils) {
return Err(Error::MissingSigil);
}
@ -44,13 +34,7 @@ fn validate_id(id: &str, valid_sigils: &[char]) -> Result<(), Error> {
/// and returns the index of the colon that separates the two.
fn parse_id(id: &str, valid_sigils: &[char]) -> Result<NonZeroU8, Error> {
validate_id(id, valid_sigils)?;
let colon_idx = id.find(':').ok_or(Error::MissingDelimiter)?;
if colon_idx < 2 {
return Err(Error::InvalidLocalPart);
}
server_name::validate(&id[colon_idx + 1..])?;
Ok(NonZeroU8::new(colon_idx as u8).unwrap())
}

View File

@ -5,7 +5,7 @@ const MAX_CODE_POINTS: usize = 32;
pub fn validate(s: &str) -> Result<(), Error> {
if s.is_empty() {
Err(Error::MinimumLengthNotSatisfied)
Err(Error::EmptyRoomVersionId)
} else if s.chars().count() > MAX_CODE_POINTS {
Err(Error::MaximumLengthExceeded)
} else {

View File

@ -14,7 +14,7 @@ pub fn validate(s: &str) -> Result<NonZeroU8, Error> {
fn validate_version(version: &str) -> Result<(), Error> {
if version.is_empty() {
return Err(Error::MinimumLengthNotSatisfied);
return Err(Error::EmptyRoomVersionId);
} else if !version.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(Error::InvalidCharacters);
}

View File

@ -15,10 +15,6 @@ pub fn validate(s: &str) -> Result<(NonZeroU8, bool), Error> {
/// Returns an `Err` for invalid user ID localparts, `Ok(false)` for historical user ID localparts
/// and `Ok(true)` for fully conforming user ID localparts.
pub fn localpart_is_fully_comforming(localpart: &str) -> Result<bool, Error> {
if localpart.is_empty() {
return Err(Error::InvalidLocalPart);
}
// See https://matrix.org/docs/spec/appendices#user-identifiers
let is_fully_conforming = localpart
.bytes()

View File

@ -68,6 +68,16 @@ mod tests {
);
}
#[test]
fn empty_localpart() {
assert_eq!(
RoomAliasId::try_from("#:myhomeserver.io")
.expect("Failed to create RoomAliasId.")
.as_ref(),
"#:myhomeserver.io"
);
}
#[cfg(feature = "serde")]
#[test]
fn serialize_valid_room_alias_id() {
@ -129,13 +139,13 @@ mod tests {
}
#[test]
fn missing_localpart() {
assert_eq!(RoomAliasId::try_from("#:example.com").unwrap_err(), Error::InvalidLocalPart);
fn missing_room_alias_id_delimiter() {
assert_eq!(RoomAliasId::try_from("#ruma").unwrap_err(), Error::MissingDelimiter);
}
#[test]
fn missing_room_alias_id_delimiter() {
assert_eq!(RoomAliasId::try_from("#ruma").unwrap_err(), Error::MissingDelimiter);
fn invalid_leading_sigil() {
assert_eq!(RoomAliasId::try_from("!room_id:foo.bar").unwrap_err(), Error::MissingSigil);
}
#[test]

View File

@ -84,6 +84,14 @@ mod tests {
);
}
#[test]
fn empty_localpart() {
assert_eq!(
RoomId::try_from("!:example.com").expect("Failed to create RoomId.").as_ref(),
"!:example.com"
);
}
#[cfg(feature = "rand")]
#[test]
fn generate_random_valid_room_id() {

View File

@ -376,7 +376,7 @@ mod tests {
#[test]
fn empty_room_version_id() {
assert_eq!(RoomVersionId::try_from(""), Err(Error::MinimumLengthNotSatisfied));
assert_eq!(RoomVersionId::try_from(""), Err(Error::EmptyRoomVersionId));
}
#[test]

View File

@ -241,11 +241,6 @@ mod tests {
assert_eq!(UserId::try_from("carl:example.com").unwrap_err(), Error::MissingSigil);
}
#[test]
fn missing_localpart() {
assert_eq!(UserId::try_from("@:example.com").unwrap_err(), Error::InvalidLocalPart);
}
#[test]
fn missing_user_id_delimiter() {
assert_eq!(UserId::try_from("@carl").unwrap_err(), Error::MissingDelimiter);