Only use regex for checking valid characters in user localparts.

This commit is contained in:
Jimmy Cuadra 2016-07-22 21:55:28 -07:00
parent 9ad1e0fc69
commit 2a03702976

View File

@ -13,20 +13,36 @@ use std::fmt::{Display, Formatter, Result as FmtResult};
use regex::Regex;
use url::{Host, ParseError, Url};
/// All events must be 255 bytes or less.
const MAX_BYTES: usize = 255;
/// The minimum number of characters an ID can be.
///
/// This is an optimization and not required by the spec. The shortest possible valid ID is a sigil
/// + a single character local ID + a colon + a single character hostname.
const MIN_CHARS: usize = 4;
/// The number of bytes in a valid sigil.
const SIGIL_BYTES: usize = 1;
lazy_static! {
static ref USER_ID_PATTERN: Regex =
Regex::new(r"\A@(?P<localpart>[a-z0-9._=-]+):(?P<host>.+)\z")
.expect("Failed to compile user ID regex.");
static ref USER_LOCALPART_PATTERN: Regex =
Regex::new(r"\A[a-z0-9._=-]+\z").expect("Failed to create user localpart regex.");
}
/// An error encountered when trying to parse an invalid user ID string.
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum Error {
/// The user ID string did not match the "@<localpart>:<domain>" format, or used invalid
/// characters in its localpart.
InvalidFormat,
/// The ID's localpart contains invalid characters.
InvalidCharacters,
/// The domain part of the user ID string was not a valid IP address or DNS name.
InvalidHost(ParseError),
InvalidHost,
/// The ID exceeds 255 bytes.
MaximumLengthExceeded,
/// The ID is less than 4 characters.
MinimumLengthNotSatisfied,
/// The ID is missing the colon delimiter between localpart and server name.
MissingDelimiter,
/// The ID is missing the leading sigil.
MissingSigil,
}
/// A Matrix user ID.
@ -44,37 +60,59 @@ pub struct UserId {
port: u16,
}
fn parse_id<'a>(required_sigil: char, id: &'a str) -> Result<(&'a str, Host, u16), Error> {
if id.len() > MAX_BYTES {
return Err(Error::MaximumLengthExceeded);
}
let mut chars = id.chars();
if id.len() < MIN_CHARS {
return Err(Error::MinimumLengthNotSatisfied);
}
let sigil = chars.nth(0).expect("ID missing first character.");
if sigil != required_sigil {
return Err(Error::MissingSigil);
}
let delimiter_index = match chars.position(|c| c == ':') {
Some(index) => index + 1,
None => return Err(Error::MissingDelimiter),
};
let localpart = &id[1..delimiter_index];
let raw_host = &id[delimiter_index + SIGIL_BYTES..];
let url_string = format!("https://{}", raw_host);
let url = try!(Url::parse(&url_string));
let host = match url.host() {
Some(host) => host.to_owned(),
None => return Err(Error::InvalidHost),
};
let port = url.port().unwrap_or(443);
Ok((localpart, host, port))
}
impl UserId {
/// Create a new Matrix user ID from a string representation.
///
/// The string must include the leading @ sigil, the localpart, a literal colon, and a valid
/// server name.
pub fn new(user_id: &str) -> Result<UserId, Error> {
let captures = match USER_ID_PATTERN.captures(user_id) {
Some(captures) => captures,
None => return Err(Error::InvalidFormat),
};
let (localpart, host, port) = try!(parse_id('@', user_id));
let raw_host = captures.name("host").expect("Failed to extract hostname from regex.");
let url_string = format!("https://{}", raw_host);
let url = try!(Url::parse(&url_string));
let host = match url.host() {
Some(host) => host,
None => return Err(Error::InvalidFormat),
};
let port = url.port().unwrap_or(443);
if !USER_LOCALPART_PATTERN.is_match(localpart) {
return Err(Error::InvalidCharacters);
}
Ok(UserId {
hostname: host.to_owned(),
hostname: host,
port: port,
localpart: captures
.name("localpart")
.expect("Failed to extract localpart from regex.")
.to_string(),
localpart: localpart.to_owned(),
})
}
@ -98,8 +136,8 @@ impl UserId {
}
impl From<ParseError> for Error {
fn from(error: ParseError) -> Error {
Error::InvalidHost(error)
fn from(_: ParseError) -> Error {
Error::InvalidHost
}
}
@ -115,7 +153,7 @@ impl Display for UserId {
#[cfg(test)]
mod tests {
use super::UserId;
use super::{Error, UserId};
#[test]
fn valid_user_id() {
@ -149,26 +187,41 @@ mod tests {
#[test]
fn invalid_characters_in_localpart() {
assert!(UserId::new("@CARL:example.com").is_err());
assert_eq!(
UserId::new("@CARL:example.com").err().unwrap(),
Error::InvalidCharacters
);
}
#[test]
fn missing_sigil() {
assert!(UserId::new("carl:example.com").is_err());
assert_eq!(
UserId::new("carl:example.com").err().unwrap(),
Error::MissingSigil
);
}
#[test]
fn missing_domain() {
assert!(UserId::new("carl").is_err());
fn missing_delimiter() {
assert_eq!(
UserId::new("@carl").err().unwrap(),
Error::MissingDelimiter
);
}
#[test]
fn invalid_host() {
assert!(UserId::new("@carl:-").is_err());
assert_eq!(
UserId::new("@carl:-").err().unwrap(),
Error::InvalidHost
);
}
#[test]
fn invalid_port() {
assert!(UserId::new("@carl:example.com:notaport").is_err());
assert_eq!(
UserId::new("@carl:example.com:notaport").err().unwrap(),
Error::InvalidHost
);
}
}