Only use regex for checking valid characters in user localparts.

This commit is contained in:
Jimmy Cuadra 2016-07-22 21:55:28 -07:00
parent 9ad1e0fc69
commit 2a03702976

View File

@ -13,20 +13,36 @@ use std::fmt::{Display, Formatter, Result as FmtResult};
use regex::Regex; use regex::Regex;
use url::{Host, ParseError, Url}; use url::{Host, ParseError, Url};
/// All events must be 255 bytes or less.
const MAX_BYTES: usize = 255;
/// The minimum number of characters an ID can be.
///
/// This is an optimization and not required by the spec. The shortest possible valid ID is a sigil
/// + a single character local ID + a colon + a single character hostname.
const MIN_CHARS: usize = 4;
/// The number of bytes in a valid sigil.
const SIGIL_BYTES: usize = 1;
lazy_static! { lazy_static! {
static ref USER_ID_PATTERN: Regex = static ref USER_LOCALPART_PATTERN: Regex =
Regex::new(r"\A@(?P<localpart>[a-z0-9._=-]+):(?P<host>.+)\z") Regex::new(r"\A[a-z0-9._=-]+\z").expect("Failed to create user localpart regex.");
.expect("Failed to compile user ID regex.");
} }
/// An error encountered when trying to parse an invalid user ID string. /// An error encountered when trying to parse an invalid user ID string.
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub enum Error { pub enum Error {
/// The user ID string did not match the "@<localpart>:<domain>" format, or used invalid /// The ID's localpart contains invalid characters.
/// characters in its localpart. InvalidCharacters,
InvalidFormat,
/// The domain part of the user ID string was not a valid IP address or DNS name. /// The domain part of the user ID string was not a valid IP address or DNS name.
InvalidHost(ParseError), InvalidHost,
/// The ID exceeds 255 bytes.
MaximumLengthExceeded,
/// The ID is less than 4 characters.
MinimumLengthNotSatisfied,
/// The ID is missing the colon delimiter between localpart and server name.
MissingDelimiter,
/// The ID is missing the leading sigil.
MissingSigil,
} }
/// A Matrix user ID. /// A Matrix user ID.
@ -44,37 +60,59 @@ pub struct UserId {
port: u16, port: u16,
} }
fn parse_id<'a>(required_sigil: char, id: &'a str) -> Result<(&'a str, Host, u16), Error> {
if id.len() > MAX_BYTES {
return Err(Error::MaximumLengthExceeded);
}
let mut chars = id.chars();
if id.len() < MIN_CHARS {
return Err(Error::MinimumLengthNotSatisfied);
}
let sigil = chars.nth(0).expect("ID missing first character.");
if sigil != required_sigil {
return Err(Error::MissingSigil);
}
let delimiter_index = match chars.position(|c| c == ':') {
Some(index) => index + 1,
None => return Err(Error::MissingDelimiter),
};
let localpart = &id[1..delimiter_index];
let raw_host = &id[delimiter_index + SIGIL_BYTES..];
let url_string = format!("https://{}", raw_host);
let url = try!(Url::parse(&url_string));
let host = match url.host() {
Some(host) => host.to_owned(),
None => return Err(Error::InvalidHost),
};
let port = url.port().unwrap_or(443);
Ok((localpart, host, port))
}
impl UserId { impl UserId {
/// Create a new Matrix user ID from a string representation. /// Create a new Matrix user ID from a string representation.
/// ///
/// The string must include the leading @ sigil, the localpart, a literal colon, and a valid /// The string must include the leading @ sigil, the localpart, a literal colon, and a valid
/// server name. /// server name.
pub fn new(user_id: &str) -> Result<UserId, Error> { pub fn new(user_id: &str) -> Result<UserId, Error> {
let captures = match USER_ID_PATTERN.captures(user_id) { let (localpart, host, port) = try!(parse_id('@', user_id));
Some(captures) => captures,
None => return Err(Error::InvalidFormat),
};
let raw_host = captures.name("host").expect("Failed to extract hostname from regex."); if !USER_LOCALPART_PATTERN.is_match(localpart) {
return Err(Error::InvalidCharacters);
let url_string = format!("https://{}", raw_host); }
let url = try!(Url::parse(&url_string));
let host = match url.host() {
Some(host) => host,
None => return Err(Error::InvalidFormat),
};
let port = url.port().unwrap_or(443);
Ok(UserId { Ok(UserId {
hostname: host.to_owned(), hostname: host,
port: port, port: port,
localpart: captures localpart: localpart.to_owned(),
.name("localpart")
.expect("Failed to extract localpart from regex.")
.to_string(),
}) })
} }
@ -98,8 +136,8 @@ impl UserId {
} }
impl From<ParseError> for Error { impl From<ParseError> for Error {
fn from(error: ParseError) -> Error { fn from(_: ParseError) -> Error {
Error::InvalidHost(error) Error::InvalidHost
} }
} }
@ -115,7 +153,7 @@ impl Display for UserId {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::UserId; use super::{Error, UserId};
#[test] #[test]
fn valid_user_id() { fn valid_user_id() {
@ -149,26 +187,41 @@ mod tests {
#[test] #[test]
fn invalid_characters_in_localpart() { fn invalid_characters_in_localpart() {
assert!(UserId::new("@CARL:example.com").is_err()); assert_eq!(
UserId::new("@CARL:example.com").err().unwrap(),
Error::InvalidCharacters
);
} }
#[test] #[test]
fn missing_sigil() { fn missing_sigil() {
assert!(UserId::new("carl:example.com").is_err()); assert_eq!(
UserId::new("carl:example.com").err().unwrap(),
Error::MissingSigil
);
} }
#[test] #[test]
fn missing_domain() { fn missing_delimiter() {
assert!(UserId::new("carl").is_err()); assert_eq!(
UserId::new("@carl").err().unwrap(),
Error::MissingDelimiter
);
} }
#[test] #[test]
fn invalid_host() { fn invalid_host() {
assert!(UserId::new("@carl:-").is_err()); assert_eq!(
UserId::new("@carl:-").err().unwrap(),
Error::InvalidHost
);
} }
#[test] #[test]
fn invalid_port() { fn invalid_port() {
assert!(UserId::new("@carl:example.com:notaport").is_err()); assert_eq!(
UserId::new("@carl:example.com:notaport").err().unwrap(),
Error::InvalidHost
);
} }
} }