diff --git a/CHANGELOG.md b/CHANGELOG.md index 757acd5f..6320241b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ Breaking changes: * Note that hashes are generally only guaranteed consistent in the lifetime of the program though, so do not persist them! * The `hostname` methods have been updated to return string slices instead of `&url::Host` +* `Error::InvalidHost` has been renamed to `Error::InvalidServerName`, because it also covers errors + in the port, not just the host part section of the server name Improvements: diff --git a/src/error.rs b/src/error.rs index 087a8200..0ea68c2f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,8 +11,8 @@ pub enum Error { InvalidCharacters, /// The localpart of the ID string is not valid (because it is empty). InvalidLocalPart, - /// The domain part of the the ID string is not a valid IP address or DNS name. - InvalidHost, + /// The server name part of the the ID string is not a valid server name. + InvalidServerName, /// The ID exceeds 255 bytes (or 32 codepoints for a room version ID.) MaximumLengthExceeded, /// The ID is less than 4 characters (or is an empty room version ID.) @@ -28,7 +28,7 @@ impl Display for Error { let message = match self { Error::InvalidCharacters => "localpart contains invalid characters", Error::InvalidLocalPart => "localpart is empty", - Error::InvalidHost => "server name is not a valid IP address or domain name", + Error::InvalidServerName => "server name is not a valid IP address or domain name", Error::MaximumLengthExceeded => "ID exceeds 255 bytes", Error::MinimumLengthNotSatisfied => "ID must be at least 4 characters", Error::MissingDelimiter => "colon is required between localpart and server name", diff --git a/src/event_id.rs b/src/event_id.rs index 97c1545a..cd536fd1 100644 --- a/src/event_id.rs +++ b/src/event_id.rs @@ -162,10 +162,10 @@ mod tests { assert_eq!(id_str.len(), 31); } - /*#[test] + #[test] fn generate_random_invalid_event_id() { assert!(EventId::new("").is_err()); - }*/ + } #[test] fn serialize_valid_original_event_id() { @@ -275,11 +275,11 @@ mod tests { ); } - /*#[test] + #[test] fn invalid_event_id_host() { assert_eq!( EventId::try_from("$39hvsi03hlne:/").unwrap_err(), - Error::InvalidHost + Error::InvalidServerName ); } @@ -287,7 +287,7 @@ mod tests { fn invalid_event_id_port() { assert_eq!( EventId::try_from("$39hvsi03hlne:example.com:notaport").unwrap_err(), - Error::InvalidHost + Error::InvalidServerName ); - }*/ + } } diff --git a/src/lib.rs b/src/lib.rs index 666799e0..6c5d75ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,8 @@ use serde::de::{self, Deserialize as _, Deserializer, Unexpected}; pub use crate::device_id::DeviceId; pub use crate::{ error::Error, event_id::EventId, room_alias_id::RoomAliasId, room_id::RoomId, - room_id_or_room_alias_id::RoomIdOrAliasId, room_version_id::RoomVersionId, user_id::UserId, + room_id_or_room_alias_id::RoomIdOrAliasId, room_version_id::RoomVersionId, + server_name::is_valid_server_name, user_id::UserId, }; #[macro_use] @@ -38,6 +39,7 @@ mod room_alias_id; mod room_id; mod room_id_or_room_alias_id; mod room_version_id; +mod server_name; mod user_id; /// All identifiers must be 255 bytes or less. @@ -79,8 +81,8 @@ fn parse_id(id: &str, valid_sigils: &[char]) -> Result { validate_id(id, valid_sigils)?; let colon_idx = id.find(':').ok_or(Error::MissingDelimiter)?; - if colon_idx == id.len() - 1 { - return Err(Error::InvalidHost); + if !is_valid_server_name(&id[colon_idx + 1..]) { + return Err(Error::InvalidServerName); } match NonZeroU8::new(colon_idx as u8) { diff --git a/src/room_alias_id.rs b/src/room_alias_id.rs index 54c05a37..8b3445b3 100644 --- a/src/room_alias_id.rs +++ b/src/room_alias_id.rs @@ -144,11 +144,11 @@ mod tests { ); } - /*#[test] + #[test] fn invalid_room_alias_id_host() { assert_eq!( RoomAliasId::try_from("#ruma:/").unwrap_err(), - Error::InvalidHost + Error::InvalidServerName ); } @@ -156,7 +156,7 @@ mod tests { fn invalid_room_alias_id_port() { assert_eq!( RoomAliasId::try_from("#ruma:example.com:notaport").unwrap_err(), - Error::InvalidHost + Error::InvalidServerName ); - }*/ + } } diff --git a/src/room_id.rs b/src/room_id.rs index c80d3c54..75aa62c9 100644 --- a/src/room_id.rs +++ b/src/room_id.rs @@ -101,10 +101,10 @@ mod tests { assert_eq!(id_str.len(), 31); } - /*#[test] + #[test] fn generate_random_invalid_room_id() { assert!(RoomId::new("").is_err()); - }*/ + } #[test] fn serialize_valid_room_id() { @@ -162,11 +162,11 @@ mod tests { ); } - /*#[test] + #[test] fn invalid_room_id_host() { assert_eq!( RoomId::try_from("!29fhd83h92h0:/").unwrap_err(), - Error::InvalidHost + Error::InvalidServerName ); } @@ -174,7 +174,7 @@ mod tests { fn invalid_room_id_port() { assert_eq!( RoomId::try_from("!29fhd83h92h0:example.com:notaport").unwrap_err(), - Error::InvalidHost + Error::InvalidServerName ); - }*/ + } } diff --git a/src/server_name.rs b/src/server_name.rs new file mode 100644 index 00000000..b0b2d15e --- /dev/null +++ b/src/server_name.rs @@ -0,0 +1,96 @@ +/// Check whether a given string is a valid server name according to [the specification][]. +/// +/// [the specification]: https://matrix.org/docs/spec/appendices#server-name +pub fn is_valid_server_name(name: &str) -> bool { + use std::net::Ipv6Addr; + + let end_of_host = if name.starts_with('[') { + let end_of_ipv6 = match name.find(']') { + Some(idx) => idx, + None => return false, + }; + + if name[1..end_of_ipv6].parse::().is_err() { + return false; + } + + end_of_ipv6 + 1 + } else { + let end_of_host = name.find(':').unwrap_or_else(|| name.len()); + + if name[..end_of_host] + .bytes() + .any(|byte| !(byte.is_ascii_alphanumeric() || byte == b'-' || byte == b'.')) + { + return false; + } + + end_of_host + }; + + if name.len() == end_of_host { + true + } else if name.as_bytes()[end_of_host] != b':' { + // hostname is followed by something other than ":port" + false + } else { + // are the remaining characters after ':' a valid port? + name[end_of_host + 1..].parse::().is_ok() + } +} + +#[cfg(test)] +mod tests { + use super::is_valid_server_name; + + #[test] + fn ipv4_host() { + assert!(is_valid_server_name("127.0.0.1")); + } + + #[test] + fn ipv4_host_and_port() { + assert!(is_valid_server_name("1.1.1.1:12000")); + } + + #[test] + fn ipv6() { + assert!(is_valid_server_name("[::1]")); + } + + #[test] + fn ipv6_with_port() { + assert!(is_valid_server_name("[1234:5678::abcd]:5678")); + } + + #[test] + fn dns_name() { + assert!(is_valid_server_name("example.com")); + } + + #[test] + fn dns_name_with_port() { + assert!(is_valid_server_name("ruma.io:8080")); + } + + #[test] + fn invalid_ipv6() { + assert!(!is_valid_server_name("[test::1]")); + } + + #[test] + fn ipv4_with_invalid_port() { + assert!(!is_valid_server_name("127.0.0.1:")); + } + + #[test] + fn ipv6_with_invalid_port() { + assert!(!is_valid_server_name("[fe80::1]:100000")); + assert!(!is_valid_server_name("[fe80::1]!")); + } + + #[test] + fn dns_name_with_invalid_port() { + assert!(!is_valid_server_name("matrix.org:hello")); + } +} diff --git a/src/user_id.rs b/src/user_id.rs index d0fb1ea5..5b3fd7d2 100644 --- a/src/user_id.rs +++ b/src/user_id.rs @@ -209,16 +209,19 @@ mod tests { ); } - /*#[test] + #[test] fn invalid_user_id_host() { - assert_eq!(UserId::try_from("@carl:/").unwrap_err(), Error::InvalidHost); + assert_eq!( + UserId::try_from("@carl:/").unwrap_err(), + Error::InvalidServerName + ); } #[test] fn invalid_user_id_port() { assert_eq!( UserId::try_from("@carl:example.com:notaport").unwrap_err(), - Error::InvalidHost + Error::InvalidServerName ); - }*/ + } }