diff --git a/src/lib.rs b/src/lib.rs index d020fe9..9488a8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,10 +6,7 @@ extern crate alloc; #[cfg(any(feature = "alloc", feature = "std", test))] -use alloc::{ - string::{String, ToString}, - vec::Vec, -}; +use alloc::string::String; #[cfg(feature = "zeroize")] use zeroize::Zeroize; @@ -36,8 +33,8 @@ pub const BASE_64: GeneralPurpose = GeneralPurpose::new(&BCRYPT, NO_PAD); /// A bcrypt hash result before concatenating pub struct HashParts { cost: u32, - salt: String, - hash: String, + salt: [u8; 16], + hash: [u8; 23], } #[derive(Clone, Debug)] @@ -52,9 +49,28 @@ pub enum Version { #[cfg(any(feature = "alloc", feature = "std"))] impl HashParts { - /// Creates the bcrypt hash string from all its parts - fn format(&self) -> String { - self.format_for_version(Version::TwoB) + /// Creates the bcrypt hash string (version 2b) into a fixed-size stack buffer. + /// The full bcrypt hash string is always exactly 60 bytes. + fn format(&self) -> [u8; 60] { + struct ByteBuf { + buf: [u8; N], + pos: usize, + } + impl fmt::Write for ByteBuf { + fn write_str(&mut self, s: &str) -> fmt::Result { + let bytes = s.as_bytes(); + self.buf[self.pos..self.pos + bytes.len()].copy_from_slice(bytes); + self.pos += bytes.len(); + Ok(()) + } + } + let mut w = ByteBuf { + buf: [0u8; 60], + pos: 0, + }; + self.write_for_version(Version::TwoB, &mut w) + .expect("writing into a correctly sized buffer is infallible"); + w.buf } /// Get the bcrypt hash cost @@ -62,15 +78,43 @@ impl HashParts { self.cost } - /// Get the bcrypt hash salt + /// Get the bcrypt hash salt as a base64-encoded string pub fn get_salt(&self) -> String { - self.salt.clone() + BASE_64.encode(self.salt) + } + + /// Get the raw salt bytes + pub fn get_salt_raw(&self) -> [u8; 16] { + self.salt } - /// Creates the bcrypt hash string from all its part, allowing to customize the version. + /// Creates the bcrypt hash string from all its parts, allowing to customize the version. pub fn format_for_version(&self, version: Version) -> String { - // Cost need to have a length of 2 so padding with a 0 if cost < 10 - alloc::format!("${}${:02}${}{}", version, self.cost, self.salt, self.hash) + let mut s = String::with_capacity(60); + self.write_for_version(version, &mut s) + .expect("writing into a String is infallible"); + s + } + + /// Writes the bcrypt hash string into any `fmt::Write` sink without allocating. + /// Useful for writing into stack buffers (e.g. `arrayvec`, `heapless::String`). + pub fn write_for_version(&self, version: Version, w: &mut W) -> fmt::Result { + let mut salt_buf = [0u8; 22]; + let mut hash_buf = [0u8; 31]; + BASE_64 + .encode_slice(self.salt, &mut salt_buf) + .expect("salt encoding into correctly sized buffer is infallible"); + BASE_64 + .encode_slice(self.hash, &mut hash_buf) + .expect("hash encoding into correctly sized buffer is infallible"); + write!( + w, + "${}${:02}${}{}", + version, + self.cost, + core::str::from_utf8(&salt_buf).expect("base64 output is always valid UTF-8"), + core::str::from_utf8(&hash_buf).expect("base64 output is always valid UTF-8") + ) } } @@ -86,7 +130,7 @@ impl FromStr for HashParts { #[cfg(any(feature = "alloc", feature = "std"))] impl fmt::Display for HashParts { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.format()) + self.write_for_version(Version::TwoB, f) } } @@ -116,30 +160,29 @@ fn _hash_password( return Err(BcryptError::CostNotAllowed(cost)); } - // Passwords need to be null terminated - let mut vec = Vec::with_capacity(password.len() + 1); - vec.extend_from_slice(password); - vec.push(0); - // We only consider the first 72 chars; truncate if necessary. - // `bcrypt` below will panic if len > 72 - let truncated = if vec.len() > 72 { - if err_on_truncation { - return Err(BcryptError::Truncation(vec.len())); - } - &vec[..72] - } else { - &vec - }; + let password_len = password.len(); + if err_on_truncation && password_len >= 72 { + return Err(BcryptError::Truncation(password_len + 1)); + } + + // The bcrypt spec specifies that passwords should be null terminated + // strings, but if longer than 72 bytes, are truncated at 72 bytes (thereby + // losing the null byte at the end). + let copy_len = password_len.min(72); + let mut pass = [0u8; 72]; + pass[..copy_len].copy_from_slice(&password[..copy_len]); + let used = (copy_len + 1).min(72); + let truncated = &pass[..used]; let output = bcrypt::bcrypt(cost, salt, truncated); #[cfg(feature = "zeroize")] - vec.zeroize(); + pass.zeroize(); Ok(HashParts { cost, - salt: BASE_64.encode(salt), - hash: BASE_64.encode(&output[..23]), // remember to remove the last byte + salt, + hash: output[..23].try_into().unwrap(), // infallible: output is [u8; 24] }) } @@ -147,47 +190,55 @@ fn _hash_password( /// cost, salt and hash #[cfg(any(feature = "alloc", feature = "std"))] fn split_hash(hash: &str) -> BcryptResult { - let mut parts = HashParts { - cost: 0, - salt: "".to_string(), - hash: "".to_string(), - }; - - // Should be [prefix, cost, hash] - let raw_parts: Vec<_> = hash.split('$').filter(|s| !s.is_empty()).collect(); + // A valid bcrypt hash is always exactly 60 bytes: + if hash.len() != 60 { + return Err(BcryptError::InvalidHash( + "the hash format is malformed; expected 60 bytes", + )); + } - if raw_parts.len() != 3 { + let bytes = hash.as_bytes(); + if bytes[0] != b'$' || bytes[3] != b'$' || bytes[6] != b'$' { return Err(BcryptError::InvalidHash("the hash format is malformed")); } - if raw_parts[0] != "2y" && raw_parts[0] != "2b" && raw_parts[0] != "2a" && raw_parts[0] != "2x" - { + let version = &hash[1..3]; + if version != "2y" && version != "2b" && version != "2a" && version != "2x" { return Err(BcryptError::InvalidHash( "the hash prefix is not a bcrypt prefix", )); } - if let Ok(c) = raw_parts[1].parse::() { - parts.cost = c; - } else { - return Err(BcryptError::InvalidHash("the cost value is not a number")); - } + let cost = hash[4..6] + .parse::() + .map_err(|_| BcryptError::InvalidHash("the cost value is not a number"))?; - if raw_parts[2].len() == 53 && raw_parts[2].is_char_boundary(22) { - parts.salt = raw_parts[2][..22].chars().collect(); - parts.hash = raw_parts[2][22..].chars().collect(); - } else { - return Err(BcryptError::InvalidHash("the hash format is malformed")); - } + let salt_and_hash = &hash[7..]; + let mut salt = [0u8; 16]; + let mut hash_bytes = [0u8; 23]; + BASE_64 + .decode_slice(&salt_and_hash[..22], &mut salt) + .map_err(|_| BcryptError::InvalidHash("the salt part is not valid base64"))?; + BASE_64 + .decode_slice(&salt_and_hash[22..], &mut hash_bytes) + .map_err(|_| BcryptError::InvalidHash("the hash part is not valid base64"))?; - Ok(parts) + Ok(HashParts { + cost, + salt, + hash: hash_bytes, + }) } /// Generates a password hash using the cost given. /// The salt is generated randomly using the OS randomness #[cfg(any(feature = "alloc", feature = "std"))] pub fn hash>(password: P, cost: u32) -> BcryptResult { - hash_with_result(password, cost).map(|r| r.format()) + hash_with_result(password, cost).map(|r| { + String::from( + core::str::from_utf8(&r.format()).expect("base64 output is always valid UTF-8"), + ) + }) } /// Generates a password hash using the cost given. @@ -195,6 +246,27 @@ pub fn hash>(password: P, cost: u32) -> BcryptResult { /// Will return BcryptError::Truncation if password is longer than 72 bytes #[cfg(any(feature = "alloc", feature = "std"))] pub fn non_truncating_hash>(password: P, cost: u32) -> BcryptResult { + non_truncating_hash_with_result(password, cost).map(|r| { + String::from( + core::str::from_utf8(&r.format()).expect("base64 output is always valid UTF-8"), + ) + }) +} + +/// Generates a password hash using the cost given, returning a fixed-size stack buffer. +/// The salt is generated randomly using the OS randomness. +/// The returned buffer is always exactly 60 bytes of valid UTF-8 (version 2b format). +#[cfg(any(feature = "alloc", feature = "std"))] +pub fn hash_bytes>(password: P, cost: u32) -> BcryptResult<[u8; 60]> { + hash_with_result(password, cost).map(|r| r.format()) +} + +/// Generates a password hash using the cost given, returning a fixed-size stack buffer. +/// The salt is generated randomly using the OS randomness. +/// The returned buffer is always exactly 60 bytes of valid UTF-8 (version 2b format). +/// Will return BcryptError::Truncation if password is longer than 72 bytes +#[cfg(any(feature = "alloc", feature = "std"))] +pub fn non_truncating_hash_bytes>(password: P, cost: u32) -> BcryptResult<[u8; 60]> { non_truncating_hash_with_result(password, cost).map(|r| r.format()) } @@ -239,6 +311,17 @@ pub fn hash_with_salt>( _hash_password(password.as_ref(), cost, salt, false) } +/// Generates a password given a hash and a cost, returning a fixed-size stack buffer. +/// The returned buffer is always exactly 60 bytes of valid UTF-8 (version 2b format). +#[cfg(any(feature = "alloc", feature = "std"))] +pub fn hash_with_salt_bytes>( + password: P, + cost: u32, + salt: [u8; 16], +) -> BcryptResult<[u8; 60]> { + _hash_password(password.as_ref(), cost, salt, false).map(|r| r.format()) +} + /// Generates a password given a hash and a cost. /// The function returns a result structure and allows to format the hash in different versions. /// Will return BcryptError::Truncation if password is longer than 72 bytes @@ -251,6 +334,18 @@ pub fn non_truncating_hash_with_salt>( _hash_password(password.as_ref(), cost, salt, true) } +/// Generates a password given a hash and a cost, returning a fixed-size stack buffer. +/// The returned buffer is always exactly 60 bytes of valid UTF-8 (version 2b format). +/// Will return BcryptError::Truncation if password is longer than 72 bytes +#[cfg(any(feature = "alloc", feature = "std"))] +pub fn non_truncating_hash_with_salt_bytes>( + password: P, + cost: u32, + salt: [u8; 16], +) -> BcryptResult<[u8; 60]> { + _hash_password(password.as_ref(), cost, salt, true).map(|r| r.format()) +} + /// Verify the password against the hash by extracting the salt from the hash and recomputing the /// hash from the password. If `err_on_truncation` is set to true, then this method will return /// `BcryptError::Truncation`. @@ -259,24 +354,9 @@ fn _verify>(password: P, hash: &str, err_on_truncation: bool) -> use subtle::ConstantTimeEq; let parts = split_hash(hash)?; - let salt = BASE_64 - .decode(&parts.salt) - .map_err(|_| BcryptError::InvalidHash("the salt part is not valid base64"))?; - let generated = _hash_password( - password.as_ref(), - parts.cost, - salt.try_into() - .map_err(|_| BcryptError::InvalidHash("the salt length is not 16 bytes"))?, - err_on_truncation, - )?; - let source_decoded = BASE_64 - .decode(parts.hash) - .map_err(|_| BcryptError::InvalidHash("the hash to verify against is not valid base64"))?; - let generated_decoded = BASE_64.decode(generated.hash).map_err(|_| { - BcryptError::InvalidHash("the generated hash for the password is not valid base64") - })?; - - Ok(source_decoded.ct_eq(&generated_decoded).into()) + let generated = _hash_password(password.as_ref(), parts.cost, parts.salt, err_on_truncation)?; + + Ok(parts.hash.ct_eq(&generated.hash).into()) } /// Verify that a password is equivalent to the hash provided @@ -304,8 +384,10 @@ mod tests { vec, vec::Vec, }, - hash, hash_with_salt, non_truncating_verify, split_hash, verify, + hash, hash_bytes, hash_with_salt, hash_with_salt_bytes, non_truncating_hash_bytes, + non_truncating_hash_with_salt_bytes, non_truncating_verify, split_hash, verify, }; + use base64::Engine as _; use core::convert::TryInto; use core::iter; use core::str::FromStr; @@ -315,12 +397,12 @@ mod tests { fn can_split_hash() { let hash = "$2y$12$L6Bc/AlTQHyd9liGgGEZyOFLPHNgyxeEPfgYfBCVxJ7JIlwxyVU3u"; let output = split_hash(hash).unwrap(); - let expected = HashParts { - cost: 12, - salt: "L6Bc/AlTQHyd9liGgGEZyO".to_string(), - hash: "FLPHNgyxeEPfgYfBCVxJ7JIlwxyVU3u".to_string(), - }; - assert_eq!(output, expected); + assert_eq!(output.get_cost(), 12); + assert_eq!(output.get_salt(), "L6Bc/AlTQHyd9liGgGEZyO"); + assert_eq!( + output.format_for_version(Version::TwoY), + "$2y$12$L6Bc/AlTQHyd9liGgGEZyOFLPHNgyxeEPfgYfBCVxJ7JIlwxyVU3u" + ); } #[test] @@ -331,6 +413,38 @@ mod tests { assert_eq!(parsed.get_salt(), "L6Bc/AlTQHyd9liGgGEZyO".to_string()); } + #[test] + fn can_get_raw_salt_from_parsed_hash() { + let hash = "$2y$12$L6Bc/AlTQHyd9liGgGEZyOFLPHNgyxeEPfgYfBCVxJ7JIlwxyVU3u"; + let parsed = HashParts::from_str(hash).unwrap(); + // Raw salt must round-trip back to the same base64 string + assert_eq!( + super::BASE_64.encode(parsed.get_salt_raw()), + "L6Bc/AlTQHyd9liGgGEZyO" + ); + } + + #[test] + fn can_write_hash_for_version_without_allocating() { + let hash = "$2y$12$L6Bc/AlTQHyd9liGgGEZyOFLPHNgyxeEPfgYfBCVxJ7JIlwxyVU3u"; + let parsed = HashParts::from_str(hash).unwrap(); + let mut buf = String::new(); + parsed.write_for_version(Version::TwoY, &mut buf).unwrap(); + assert_eq!(buf, hash); + } + + #[test] + fn write_for_version_matches_format_for_version() { + let salt = [0u8; 16]; + let result = _hash_password("hunter2".as_bytes(), DEFAULT_COST, salt, false).unwrap(); + let formatted = result.format_for_version(Version::TwoA); + let mut written = String::new(); + result + .write_for_version(Version::TwoA, &mut written) + .unwrap(); + assert_eq!(formatted, written); + } + #[test] fn returns_an_error_if_a_parsed_hash_is_baddly_formated() { let hash1 = "$2y$12$L6Bc/AlTQHyd9lGEZyOFLPHNgyxeEPfgYfBCVxJ7JIlwxyVU3u"; @@ -545,6 +659,66 @@ mod tests { ); } + #[test] + fn hash_bytes_returns_valid_utf8_bcrypt_string() { + let result = hash_bytes("hunter2", 4).unwrap(); + let s = core::str::from_utf8(&result).unwrap(); + assert!(s.starts_with("$2b$04$")); + assert_eq!(s.len(), 60); + assert!(verify("hunter2", s).unwrap()); + } + + #[test] + fn non_truncating_hash_bytes_returns_valid_utf8_bcrypt_string() { + let result = non_truncating_hash_bytes("hunter2", 4).unwrap(); + let s = core::str::from_utf8(&result).unwrap(); + assert!(s.starts_with("$2b$04$")); + assert_eq!(s.len(), 60); + assert!(verify("hunter2", s).unwrap()); + } + + #[test] + fn non_truncating_hash_bytes_errors_on_long_password() { + use core::iter; + let result = non_truncating_hash_bytes(iter::repeat("x").take(72).collect::(), 4); + assert!(matches!(result, Err(BcryptError::Truncation(73)))); + } + + #[test] + fn hash_with_salt_bytes_matches_hash_with_salt() { + let salt = [ + 38, 113, 212, 141, 108, 213, 195, 166, 201, 38, 20, 13, 47, 40, 104, 18, + ]; + let expected = hash_with_salt("My S3cre7 P@55w0rd!", 5, salt) + .unwrap() + .to_string(); + let result = hash_with_salt_bytes("My S3cre7 P@55w0rd!", 5, salt).unwrap(); + let s = core::str::from_utf8(&result).unwrap(); + assert_eq!(expected, s); + } + + #[test] + fn non_truncating_hash_with_salt_bytes_errors_on_long_password() { + use core::iter; + let salt = [0u8; 16]; + let result = non_truncating_hash_with_salt_bytes( + iter::repeat("x").take(72).collect::(), + 4, + salt, + ); + assert!(matches!(result, Err(BcryptError::Truncation(73)))); + } + + #[test] + fn hash_bytes_matches_hash_string() { + let salt = [0u8; 16]; + let result_parts = _hash_password("hunter2".as_bytes(), 4, salt, false).unwrap(); + let from_parts = result_parts.format_for_version(Version::TwoB); + let bytes_result = hash_with_salt_bytes("hunter2", 4, salt).unwrap(); + let from_bytes = core::str::from_utf8(&bytes_result).unwrap(); + assert_eq!(from_parts, from_bytes); + } + quickcheck! { fn can_verify_arbitrary_own_generated(pass: Vec) -> BcryptResult { let mut pass = pass;