From f9f76f318bf81a072ec628fe28c712148dc1e512 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:39:13 +0300 Subject: [PATCH 01/18] types: make hash algorithm a newtype --- src/dtls12/message/mod.rs | 2 +- src/types.rs | 155 ++++++++++++++++++++++++++------------ 2 files changed, 108 insertions(+), 49 deletions(-) diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index 6eb8ca24..53a97712 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -200,7 +200,7 @@ impl Dtls12CipherSuite { Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 | Dtls12CipherSuite::PSK_AES128_CCM_8 => HashAlgorithm::SHA256, - Dtls12CipherSuite::Unknown(_) => HashAlgorithm::Unknown(0), + Dtls12CipherSuite::Unknown(_) => HashAlgorithm::UNKNOWN_DERIVED, } } diff --git a/src/types.rs b/src/types.rs index 48a9c2f1..8f6aecf7 100644 --- a/src/types.rs +++ b/src/types.rs @@ -274,60 +274,56 @@ impl NamedGroup { /// /// Specifies the hash algorithm to be used in digital signatures, /// PRF/HKDF operations, and transcript hashing. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum HashAlgorithm { +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct HashAlgorithm(u8); + +impl Default for HashAlgorithm { + fn default() -> Self { + Self::NONE + } +} + +impl HashAlgorithm { /// No hash (not typically used). - None, + pub const NONE: Self = Self(0); /// MD5 hash (deprecated, not supported). - MD5, + pub const MD5: Self = Self(1); /// SHA-1 hash (deprecated, not supported). - SHA1, + pub const SHA1: Self = Self(2); /// SHA-224 hash. - SHA224, + pub const SHA224: Self = Self(3); /// SHA-256 hash (supported by dimpl). - SHA256, + pub const SHA256: Self = Self(4); /// SHA-384 hash (supported by dimpl). - SHA384, + pub const SHA384: Self = Self(5); /// SHA-512 hash. - SHA512, - /// Unknown or unsupported hash algorithm. - Unknown(u8), -} + pub const SHA512: Self = Self(6); -impl Default for HashAlgorithm { - fn default() -> Self { - Self::Unknown(0) - } -} + pub(crate) const UNKNOWN_DERIVED: Self = Self(u8::MAX); -impl HashAlgorithm { /// Convert a wire format u8 value to a `HashAlgorithm`. - pub fn from_u8(value: u8) -> Self { - match value { - 0 => HashAlgorithm::None, - 1 => HashAlgorithm::MD5, - 2 => HashAlgorithm::SHA1, - 3 => HashAlgorithm::SHA224, - 4 => HashAlgorithm::SHA256, - 5 => HashAlgorithm::SHA384, - 6 => HashAlgorithm::SHA512, - _ => HashAlgorithm::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `HashAlgorithm` to its wire format u8 value. - pub fn as_u8(&self) -> u8 { - match self { - HashAlgorithm::None => 0, - HashAlgorithm::MD5 => 1, - HashAlgorithm::SHA1 => 2, - HashAlgorithm::SHA224 => 3, - HashAlgorithm::SHA256 => 4, - HashAlgorithm::SHA384 => 5, - HashAlgorithm::SHA512 => 6, - HashAlgorithm::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known DTLS hash algorithm wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + HashAlgorithm::NONE + | HashAlgorithm::MD5 + | HashAlgorithm::SHA1 + | HashAlgorithm::SHA224 + | HashAlgorithm::SHA256 + | HashAlgorithm::SHA384 + | HashAlgorithm::SHA512 + ) } /// Parse a `HashAlgorithm` from wire format. @@ -337,16 +333,31 @@ impl HashAlgorithm { } /// Returns the output length in bytes for this hash algorithm. - pub fn output_len(&self) -> usize { - match self { - HashAlgorithm::None => 0, + pub const fn output_len(&self) -> usize { + match *self { + HashAlgorithm::NONE => 0, HashAlgorithm::MD5 => 16, HashAlgorithm::SHA1 => 20, HashAlgorithm::SHA224 => 28, HashAlgorithm::SHA256 => 32, HashAlgorithm::SHA384 => 48, HashAlgorithm::SHA512 => 64, - HashAlgorithm::Unknown(_) => 0, + _ => 0, + } + } +} + +impl fmt::Debug for HashAlgorithm { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + HashAlgorithm::NONE => f.write_str("None"), + HashAlgorithm::MD5 => f.write_str("MD5"), + HashAlgorithm::SHA1 => f.write_str("SHA1"), + HashAlgorithm::SHA224 => f.write_str("SHA224"), + HashAlgorithm::SHA256 => f.write_str("SHA256"), + HashAlgorithm::SHA384 => f.write_str("SHA384"), + HashAlgorithm::SHA512 => f.write_str("SHA512"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), } } } @@ -686,8 +697,8 @@ impl SignatureScheme { | SignatureScheme::RSA_PSS_PSS_SHA512 | SignatureScheme::RSA_PKCS1_SHA512 => HashAlgorithm::SHA512, // Ed25519 and Ed448 have intrinsic hash algorithms - SignatureScheme::ED25519 | SignatureScheme::ED448 => HashAlgorithm::None, - SignatureScheme::Unknown(_) => HashAlgorithm::Unknown(0), + SignatureScheme::ED25519 | SignatureScheme::ED448 => HashAlgorithm::NONE, + SignatureScheme::Unknown(_) => HashAlgorithm::UNKNOWN_DERIVED, } } } @@ -757,7 +768,7 @@ impl Dtls13CipherSuite { | Dtls13CipherSuite::AES_128_CCM_SHA256 | Dtls13CipherSuite::AES_128_CCM_8_SHA256 => HashAlgorithm::SHA256, Dtls13CipherSuite::AES_256_GCM_SHA384 => HashAlgorithm::SHA384, - Dtls13CipherSuite::Unknown(_) => HashAlgorithm::Unknown(0), + Dtls13CipherSuite::Unknown(_) => HashAlgorithm::UNKNOWN_DERIVED, } } @@ -919,6 +930,54 @@ impl CompressionMethod { mod tests { use super::*; + #[test] + fn hash_algorithm_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(HashAlgorithm::default(), HashAlgorithm::NONE); + } + + #[test] + fn hash_algorithm_wire_roundtrip() { + let known = [ + (0, HashAlgorithm::NONE), + (1, HashAlgorithm::MD5), + (2, HashAlgorithm::SHA1), + (3, HashAlgorithm::SHA224), + (4, HashAlgorithm::SHA256), + (5, HashAlgorithm::SHA384), + (6, HashAlgorithm::SHA512), + ]; + + for (wire, algorithm) in known { + assert_eq!(HashAlgorithm::from_u8(wire), algorithm); + assert_eq!(algorithm.as_u8(), wire); + assert!(!algorithm.is_unknown()); + } + + let unknown = HashAlgorithm::from_u8(7); + assert_eq!(unknown.as_u8(), 7); + assert!(unknown.is_unknown()); + } + + #[test] + fn hash_algorithm_output_len() { + assert_eq!(HashAlgorithm::NONE.output_len(), 0); + assert_eq!(HashAlgorithm::MD5.output_len(), 16); + assert_eq!(HashAlgorithm::SHA1.output_len(), 20); + assert_eq!(HashAlgorithm::SHA224.output_len(), 28); + assert_eq!(HashAlgorithm::SHA256.output_len(), 32); + assert_eq!(HashAlgorithm::SHA384.output_len(), 48); + assert_eq!(HashAlgorithm::SHA512.output_len(), 64); + assert_eq!(HashAlgorithm::from_u8(7).output_len(), 0); + } + + #[test] + fn hash_algorithm_debug_stays_enum_like() { + assert_eq!(format!("{:?}", HashAlgorithm::NONE), "None"); + assert_eq!(format!("{:?}", HashAlgorithm::SHA256), "SHA256"); + assert_eq!(format!("{:?}", HashAlgorithm::from_u8(7)), "Unknown(7)"); + } + #[test] fn random_parse() { let data = [ From 526f9d714fd650e46715e4b802c9d0caf4dfc0ec Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:41:22 +0300 Subject: [PATCH 02/18] types: make signature algorithm a newtype --- src/dtls12/message/mod.rs | 2 +- src/types.rs | 109 +++++++++++++++++++++++++++----------- 2 files changed, 79 insertions(+), 32 deletions(-) diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index 53a97712..6e83119d 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -215,7 +215,7 @@ impl Dtls12CipherSuite { Some(SignatureAlgorithm::ECDSA) } Dtls12CipherSuite::PSK_AES128_CCM_8 => None, - Dtls12CipherSuite::Unknown(_) => Some(SignatureAlgorithm::Unknown(0)), + Dtls12CipherSuite::Unknown(_) => Some(SignatureAlgorithm::UNKNOWN_DERIVED), } } diff --git a/src/types.rs b/src/types.rs index 8f6aecf7..b4d5b67f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -370,48 +370,47 @@ impl fmt::Debug for HashAlgorithm { /// /// Represents the underlying signature primitive (RSA, ECDSA, etc.). /// Used internally for signing operations across both DTLS versions. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum SignatureAlgorithm { - /// Anonymous (no certificate). - Anonymous, - /// RSA signatures. - RSA, - /// DSA signatures. - DSA, - /// ECDSA signatures. - ECDSA, - /// Unknown or unsupported signature algorithm. - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct SignatureAlgorithm(u8); impl Default for SignatureAlgorithm { fn default() -> Self { - Self::Unknown(0) + Self::ANONYMOUS } } impl SignatureAlgorithm { + /// Anonymous (no certificate). + pub const ANONYMOUS: Self = Self(0); + /// RSA signatures. + pub const RSA: Self = Self(1); + /// DSA signatures. + pub const DSA: Self = Self(2); + /// ECDSA signatures. + pub const ECDSA: Self = Self(3); + + pub(crate) const UNKNOWN_DERIVED: Self = Self(u8::MAX); + /// Convert an 8-bit value into a `SignatureAlgorithm`. - pub fn from_u8(value: u8) -> Self { - match value { - 0 => SignatureAlgorithm::Anonymous, - 1 => SignatureAlgorithm::RSA, - 2 => SignatureAlgorithm::DSA, - 3 => SignatureAlgorithm::ECDSA, - _ => SignatureAlgorithm::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `SignatureAlgorithm` into its 8-bit representation. - pub fn as_u8(&self) -> u8 { - match self { - SignatureAlgorithm::Anonymous => 0, - SignatureAlgorithm::RSA => 1, - SignatureAlgorithm::DSA => 2, - SignatureAlgorithm::ECDSA => 3, - SignatureAlgorithm::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known DTLS signature algorithm wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + SignatureAlgorithm::ANONYMOUS + | SignatureAlgorithm::RSA + | SignatureAlgorithm::DSA + | SignatureAlgorithm::ECDSA + ) } /// Parse a `SignatureAlgorithm` from network bytes. @@ -421,6 +420,18 @@ impl SignatureAlgorithm { } } +impl fmt::Debug for SignatureAlgorithm { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + SignatureAlgorithm::ANONYMOUS => f.write_str("Anonymous"), + SignatureAlgorithm::RSA => f.write_str("RSA"), + SignatureAlgorithm::DSA => f.write_str("DSA"), + SignatureAlgorithm::ECDSA => f.write_str("ECDSA"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Content Type // ============================================================================ @@ -978,6 +989,42 @@ mod tests { assert_eq!(format!("{:?}", HashAlgorithm::from_u8(7)), "Unknown(7)"); } + #[test] + fn signature_algorithm_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(SignatureAlgorithm::default(), SignatureAlgorithm::ANONYMOUS); + } + + #[test] + fn signature_algorithm_wire_roundtrip() { + let known = [ + (0, SignatureAlgorithm::ANONYMOUS), + (1, SignatureAlgorithm::RSA), + (2, SignatureAlgorithm::DSA), + (3, SignatureAlgorithm::ECDSA), + ]; + + for (wire, algorithm) in known { + assert_eq!(SignatureAlgorithm::from_u8(wire), algorithm); + assert_eq!(algorithm.as_u8(), wire); + assert!(!algorithm.is_unknown()); + } + + let unknown = SignatureAlgorithm::from_u8(4); + assert_eq!(unknown.as_u8(), 4); + assert!(unknown.is_unknown()); + } + + #[test] + fn signature_algorithm_debug_stays_enum_like() { + assert_eq!(format!("{:?}", SignatureAlgorithm::ANONYMOUS), "Anonymous"); + assert_eq!(format!("{:?}", SignatureAlgorithm::ECDSA), "ECDSA"); + assert_eq!( + format!("{:?}", SignatureAlgorithm::from_u8(4)), + "Unknown(4)" + ); + } + #[test] fn random_parse() { let data = [ From b9ee25463739d8b233acfd6f656ffe8eb41cb7e3 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:42:36 +0300 Subject: [PATCH 03/18] types: make compression method a newtype --- src/dtls12/message/client_hello.rs | 4 +- src/dtls12/message/handshake.rs | 6 +- src/dtls12/message/server_hello.rs | 2 +- src/dtls13/client.rs | 6 +- src/dtls13/message/client_hello.rs | 2 +- src/dtls13/message/handshake.rs | 2 +- src/dtls13/message/server_hello.rs | 4 +- src/dtls13/server.rs | 6 +- src/types.rs | 90 +++++++++++++++++++++--------- 9 files changed, 81 insertions(+), 41 deletions(-) diff --git a/src/dtls12/message/client_hello.rs b/src/dtls12/message/client_hello.rs index 2747a1fa..0067dddc 100644 --- a/src/dtls12/message/client_hello.rs +++ b/src/dtls12/message/client_hello.rs @@ -274,7 +274,7 @@ mod tests { 0xC0, 0x2B, // Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 0xC0, 0x2C, // Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 0x01, // CompressionMethods length - 0x00, // CompressionMethod::Null + 0x00, // CompressionMethod::NULL ]; #[test] @@ -286,7 +286,7 @@ mod tests { cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256); cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::Null); + compression_methods.push(CompressionMethod::NULL); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, diff --git a/src/dtls12/message/handshake.rs b/src/dtls12/message/handshake.rs index 65521159..4b9f8f46 100644 --- a/src/dtls12/message/handshake.rs +++ b/src/dtls12/message/handshake.rs @@ -532,7 +532,7 @@ mod tests { 0xC0, 0x2B, // Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 0xC0, 0x2C, // Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 0x01, // CompressionMethods length - 0x00, // CompressionMethod::Null + 0x00, // CompressionMethod::NULL ]; #[test] @@ -564,7 +564,7 @@ mod tests { cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256); cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::Null); + compression_methods.push(CompressionMethod::NULL); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, @@ -607,7 +607,7 @@ mod tests { cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256); cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::Null); + compression_methods.push(CompressionMethod::NULL); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, diff --git a/src/dtls12/message/server_hello.rs b/src/dtls12/message/server_hello.rs index 5ae0c6cf..b6496547 100644 --- a/src/dtls12/message/server_hello.rs +++ b/src/dtls12/message/server_hello.rs @@ -200,7 +200,7 @@ mod test { 0x01, // SessionId length 0xAA, // SessionId 0xC0, 0x2B, // Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 - 0x00, // CompressionMethod::Null + 0x00, // CompressionMethod::NULL 0x00, 0x0C, // Extensions length (12 bytes total: 2 type + 2 length + 8 data) 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension data length (8 bytes) diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 092be3bb..66bc40ad 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -551,7 +551,7 @@ impl State { } // Validate legacy_compression_method (must be null) - if server_hello.legacy_compression_method != CompressionMethod::Null { + if server_hello.legacy_compression_method != CompressionMethod::NULL { return Err((Error::SecurityError( crate::SecurityError::ServerHelloCompressionMustBeNull, )) @@ -1196,7 +1196,7 @@ fn handshake_create_client_hello( ); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::Null); + compression_methods.push(CompressionMethod::NULL); // Build extensions let mut extensions: ArrayVec = ArrayVec::new(); @@ -1636,7 +1636,7 @@ mod tests { body.extend_from_slice(&[7; 32]); body.push(0); // legacy_session_id body.extend_from_slice(&Dtls13CipherSuite::AES_128_GCM_SHA256.as_u16().to_be_bytes()); - body.push(CompressionMethod::Null.as_u8()); + body.push(CompressionMethod::NULL.as_u8()); body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); body.extend_from_slice(&extensions); body diff --git a/src/dtls13/message/client_hello.rs b/src/dtls13/message/client_hello.rs index b6ceedad..56d37f07 100644 --- a/src/dtls13/message/client_hello.rs +++ b/src/dtls13/message/client_hello.rs @@ -214,7 +214,7 @@ mod tests { cipher_suites.push(Dtls13CipherSuite::AES_128_GCM_SHA256); cipher_suites.push(Dtls13CipherSuite::AES_256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::Null); + compression_methods.push(CompressionMethod::NULL); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index 612e195d..006fae64 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -539,7 +539,7 @@ mod tests { cipher_suites.push(Dtls13CipherSuite::AES_128_GCM_SHA256); cipher_suites.push(Dtls13CipherSuite::AES_256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::Null); + compression_methods.push(CompressionMethod::NULL); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, diff --git a/src/dtls13/message/server_hello.rs b/src/dtls13/message/server_hello.rs index de705117..24d796dc 100644 --- a/src/dtls13/message/server_hello.rs +++ b/src/dtls13/message/server_hello.rs @@ -156,7 +156,7 @@ mod test { 0x01, // SessionId length 0xAA, // SessionId 0x13, 0x01, // Dtls13CipherSuite::AES_128_GCM_SHA256 - 0x00, // CompressionMethod::Null + 0x00, // CompressionMethod::NULL 0x00, 0x0C, // Extensions length (12 bytes) 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension data length (8 bytes) @@ -194,7 +194,7 @@ mod test { hrr_random, SessionId::empty(), Dtls13CipherSuite::AES_128_GCM_SHA256, - CompressionMethod::Null, + CompressionMethod::NULL, None, ); diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 4711c8d4..3b7040e2 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -428,7 +428,7 @@ impl State { // Validate null compression is offered let has_null_compression = client_hello .legacy_compression_methods - .contains(&CompressionMethod::Null); + .contains(&CompressionMethod::NULL); if !has_null_compression { return Err(Error::SecurityError( crate::SecurityError::ClientHelloMustOfferNullCompression, @@ -1322,7 +1322,7 @@ fn send_hello_retry_request( hrr_random, client_session_id, cipher_suite, - CompressionMethod::Null, + CompressionMethod::NULL, Some(extensions), ); @@ -1384,7 +1384,7 @@ fn handshake_create_server_hello( random, client_session_id, cipher_suite, - CompressionMethod::Null, + CompressionMethod::NULL, Some(extensions), ); diff --git a/src/types.rs b/src/types.rs index b4d5b67f..0dd3277c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -876,30 +876,25 @@ impl ProtocolVersion { /// /// Used in ClientHello/ServerHello for both DTLS 1.2 and 1.3. /// TLS 1.3 only uses Null compression but includes it for compatibility. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CompressionMethod { - /// No compression. - Null, - /// DEFLATE compression. - Deflate, - /// Unknown compression method. - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct CompressionMethod(u8); impl Default for CompressionMethod { fn default() -> Self { - Self::Unknown(0) + Self::NULL } } impl CompressionMethod { + /// No compression. + pub const NULL: Self = Self(0x00); + /// DEFLATE compression. + pub const DEFLATE: Self = Self(0x01); + /// Convert a u8 value to a `CompressionMethod`. - pub fn from_u8(value: u8) -> Self { - match value { - 0x00 => CompressionMethod::Null, - 0x01 => CompressionMethod::Deflate, - _ => CompressionMethod::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Returns true if this compression method is supported by this implementation. @@ -909,7 +904,7 @@ impl CompressionMethod { /// All recognized compression methods (every non-`Unknown` variant). pub const fn all() -> &'static [CompressionMethod; 2] { - &[CompressionMethod::Null, CompressionMethod::Deflate] + &[CompressionMethod::NULL, CompressionMethod::DEFLATE] } /// Supported compression methods. @@ -918,16 +913,17 @@ impl CompressionMethod { /// §4.1.2) mandates exactly one compression method (null). DEFLATE /// is recognized by parsing but not accepted. pub const fn supported() -> &'static [CompressionMethod; 1] { - &[CompressionMethod::Null] + &[CompressionMethod::NULL] } /// Convert this `CompressionMethod` to its u8 value. - pub fn as_u8(&self) -> u8 { - match self { - CompressionMethod::Null => 0x00, - CompressionMethod::Deflate => 0x01, - CompressionMethod::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known TLS compression method wire value. + pub const fn is_unknown(&self) -> bool { + !matches!(*self, CompressionMethod::NULL | CompressionMethod::DEFLATE) } /// Parse a `CompressionMethod` from wire format. @@ -937,6 +933,16 @@ impl CompressionMethod { } } +impl fmt::Debug for CompressionMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + CompressionMethod::NULL => f.write_str("Null"), + CompressionMethod::DEFLATE => f.write_str("Deflate"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -1025,6 +1031,40 @@ mod tests { ); } + #[test] + fn compression_method_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(CompressionMethod::default(), CompressionMethod::NULL); + } + + #[test] + fn compression_method_wire_roundtrip() { + let known = [ + (0x00, CompressionMethod::NULL), + (0x01, CompressionMethod::DEFLATE), + ]; + + for (wire, method) in known { + assert_eq!(CompressionMethod::from_u8(wire), method); + assert_eq!(method.as_u8(), wire); + assert!(!method.is_unknown()); + } + + let unknown = CompressionMethod::from_u8(0x02); + assert_eq!(unknown.as_u8(), 0x02); + assert!(unknown.is_unknown()); + } + + #[test] + fn compression_method_debug_stays_enum_like() { + assert_eq!(format!("{:?}", CompressionMethod::NULL), "Null"); + assert_eq!(format!("{:?}", CompressionMethod::DEFLATE), "Deflate"); + assert_eq!( + format!("{:?}", CompressionMethod::from_u8(0x02)), + "Unknown(2)" + ); + } + #[test] fn random_parse() { let data = [ @@ -1060,7 +1100,7 @@ mod tests { let supported = CompressionMethod::supported(); assert_eq!( supported, - &[CompressionMethod::Null], + &[CompressionMethod::NULL], "Only Null compression should be supported" ); } From 142ce29983dca089a9ddc59982458e3c93bff93f Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:45:38 +0300 Subject: [PATCH 04/18] types: make content type a newtype --- src/dtls12/client.rs | 18 +++--- src/dtls12/engine.rs | 20 +++---- src/dtls12/incoming.rs | 10 ++-- src/dtls12/message/record.rs | 4 +- src/dtls12/queue.rs | 10 ++-- src/dtls12/server.rs | 16 ++--- src/dtls13/client.rs | 8 +-- src/dtls13/engine.rs | 30 +++++----- src/dtls13/incoming.rs | 10 ++-- src/dtls13/message/record.rs | 4 +- src/dtls13/queue.rs | 10 ++-- src/dtls13/server.rs | 6 +- src/types.rs | 111 ++++++++++++++++++++++++----------- tests/dtls12/retransmit.rs | 2 +- 14 files changed, 149 insertions(+), 110 deletions(-) diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 7f962b6d..b1194cec 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -234,7 +234,7 @@ impl Client { // Use the engine's create_record to send application data // The encryption is now handled in the engine self.engine - .create_record(ContentType::ApplicationData, 1, false, |body| { + .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { body.extend_from_slice(data); })?; @@ -252,7 +252,7 @@ impl Client { return Ok(()); } self.engine - .create_record(ContentType::Alert, 1, false, |body| { + .create_record(ContentType::ALERT, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -456,7 +456,7 @@ impl State { } // Enforce Null compression only - if server_hello.compression_method != CompressionMethod::Null { + if server_hello.compression_method != CompressionMethod::NULL { return Err( Error::SecurityError(crate::SecurityError::UnsupportedServerCompression( server_hello.compression_method, @@ -950,7 +950,7 @@ impl State { trace!("Sending ChangeCipherSpec"); client .engine - .create_record(ContentType::ChangeCipherSpec, 0, true, |body| { + .create_record(ContentType::CHANGE_CIPHER_SPEC, 0, true, |body| { // Change cipher spec is just a single byte with value 1 body.push(1); })?; @@ -1046,7 +1046,7 @@ impl State { } fn await_change_cipher_spec(self, client: &mut Client) -> Result { - let maybe = client.engine.next_record(ContentType::ChangeCipherSpec); + let maybe = client.engine.next_record(ContentType::CHANGE_CIPHER_SPEC); let Some(_) = maybe else { // Stay in same state @@ -1190,7 +1190,7 @@ impl State { client.engine.discard_pending_writes(); client .engine - .create_record(ContentType::Alert, 1, false, |body| { + .create_record(ContentType::ALERT, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -1205,7 +1205,7 @@ impl State { for data in client.queued_data.drain(..) { client .engine - .create_record(ContentType::ApplicationData, 1, false, |body| { + .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { body.extend_from_slice(&data); })?; } @@ -1242,7 +1242,7 @@ fn handshake_create_client_hello( ); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::Null); + compression_methods.push(CompressionMethod::NULL); // Create ClientHello with all required extensions let client_hello = ClientHello::new( @@ -1409,7 +1409,7 @@ mod tests { fn epoch0_handshake_packet(msg_type: MessageType, message_seq: u16, body: &[u8]) -> Vec { let handshake_len = 12 + body.len(); let mut packet = Vec::new(); - packet.push(ContentType::Handshake.as_u8()); + packet.push(ContentType::HANDSHAKE.as_u8()); packet.extend_from_slice(&[0xfe, 0xfd]); packet.extend_from_slice(&0u16.to_be_bytes()); packet.extend_from_slice(&0u64.to_be_bytes()[2..]); diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 3bd86e99..80acf3d9 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -324,7 +324,7 @@ impl Engine { if self.peer_encryption_enabled && seq_current.epoch == 0 - && first.record().content_type == ContentType::Handshake + && first.record().content_type == ContentType::HANDSHAKE { return Ok(()); } @@ -332,7 +332,7 @@ impl Engine { if self.peer_encryption_enabled { for record in incoming.records().iter() { if record.record().sequence.epoch == 0 - && record.record().content_type == ContentType::Handshake + && record.record().content_type == ContentType::HANDSHAKE { if record.handshakes().is_empty() { record.set_handled(); @@ -443,7 +443,7 @@ impl Engine { .queue_rx .iter() .flat_map(|i| i.records().iter()) - .filter(|r| r.record().content_type == ContentType::ApplicationData) + .filter(|r| r.record().content_type == ContentType::APPLICATION_DATA) .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { @@ -697,7 +697,7 @@ impl Engine { pub fn drop_pending_ccs(&mut self) { for incoming in self.queue_rx.iter() { for record in incoming.records().iter() { - if record.record().content_type == ContentType::ChangeCipherSpec { + if record.record().content_type == ContentType::CHANGE_CIPHER_SPEC { record.set_handled(); } } @@ -983,7 +983,7 @@ impl Engine { }; // Emit the record; packing into current datagram happens inside create_record - self.create_record(ContentType::Handshake, epoch, true, |fragment| { + self.create_record(ContentType::HANDSHAKE, epoch, true, |fragment| { // Serialize with body_buffer as source frag_handshake.serialize(&body_buffer, fragment); })?; @@ -1237,7 +1237,7 @@ impl RecordHandler for Engine { fn classify_record(&mut self, record: Record) -> Result, Error> { let epoch = record.record().sequence.epoch; - if record.record().content_type == ContentType::ChangeCipherSpec + if record.record().content_type == ContentType::CHANGE_CIPHER_SPEC && epoch == 0 && self.peer_encryption_enabled { @@ -1250,7 +1250,7 @@ impl RecordHandler for Engine { return Ok(None); } - if record.record().content_type == ContentType::Handshake + if record.record().content_type == ContentType::HANDSHAKE && epoch == 0 && self.peer_encryption_enabled && record @@ -1266,7 +1266,7 @@ impl RecordHandler for Engine { return Ok(None); } - if record.record().content_type == ContentType::Alert { + if record.record().content_type == ContentType::ALERT { if epoch == 0 { if self.peer_encryption_enabled { // Post-handshake: epoch 0 alerts are unauthenticated, discard. @@ -1318,7 +1318,7 @@ impl RecordHandler for Engine { } if self.close_notify_received - && record.record().content_type == ContentType::ApplicationData + && record.record().content_type == ContentType::APPLICATION_DATA { self.push_buffer(record.into_buffer()); return Ok(None); @@ -1347,7 +1347,7 @@ impl RecordHandler for Engine { // that's known, a stale plaintext handshake (unauthenticated, replayable) // must no longer drive a courtesy flight retransmission. The client // confirms separately at its own completion (flight_stop_resend_timers). - if content_type == ContentType::ApplicationData { + if content_type == ContentType::APPLICATION_DATA { self.peer_handshake_confirmed = true; } } diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index ca21b6c0..0e2c4bc8 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -279,7 +279,7 @@ impl ParsedRecord { ) -> Result { let (_, record) = DTLSRecord::parse(input, 0, offset)?; - let handshakes = if record.content_type == ContentType::Handshake { + let handshakes = if record.content_type == ContentType::HANDSHAKE { // This will also return None on the encrypted Finished after ChangeCipherSpec. // However we will then decrypt and try again. let fragment_offset = record.fragment_range.start; @@ -409,7 +409,7 @@ mod tests { impl RecordHandler for TestHandler { fn classify_record(&mut self, record: Record) -> Result, Error> { self.classify_calls += 1; - if record.record().content_type == ContentType::Alert { + if record.record().content_type == ContentType::ALERT { self.dropped_alerts += 1; return Ok(None); } @@ -464,9 +464,9 @@ mod tests { #[test] fn parse_packet_filters_control_records_after_packet_validation() { let mut packet = Vec::new(); - packet.extend_from_slice(&build_record(ContentType::Alert, 0, 1, &[0x01, 0x00])); + packet.extend_from_slice(&build_record(ContentType::ALERT, 0, 1, &[0x01, 0x00])); packet.extend_from_slice(&build_record( - ContentType::ApplicationData, + ContentType::APPLICATION_DATA, 1, 2, &[0xAA, 0xBB], @@ -482,7 +482,7 @@ mod tests { assert_eq!(incoming.records().len(), 1); assert_eq!( incoming.first().record().content_type, - ContentType::ApplicationData + ContentType::APPLICATION_DATA ); assert_eq!(incoming.first().record().sequence.epoch, 1); } diff --git a/src/dtls12/message/record.rs b/src/dtls12/message/record.rs index 896e0af3..5e74c624 100644 --- a/src/dtls12/message/record.rs +++ b/src/dtls12/message/record.rs @@ -70,7 +70,7 @@ impl DTLSRecord { // the epoch-0 content types this implementation supports. if epoch == 0 { match content_type { - ContentType::ChangeCipherSpec | ContentType::Alert | ContentType::Handshake => {} + ContentType::CHANGE_CIPHER_SPEC | ContentType::ALERT | ContentType::HANDSHAKE => {} _ => { return Err(Err::Failure(nom::error::Error::new( input, @@ -156,7 +156,7 @@ mod tests { use crate::buffer::Buf; const RECORD: &[u8] = &[ - 0x16, // ContentType::Handshake + 0x16, // ContentType::HANDSHAKE 0xFE, 0xFD, // ProtocolVersion::DTLS1_2 0x00, 0x01, // epoch 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // sequence_number diff --git a/src/dtls12/queue.rs b/src/dtls12/queue.rs index 91994c01..8d1dddd2 100644 --- a/src/dtls12/queue.rs +++ b/src/dtls12/queue.rs @@ -50,11 +50,11 @@ impl fmt::Debug for QueueRx { for item in &self.0 { let record = item.first().record(); match record.content_type { - ContentType::Handshake => handshake += 1, - ContentType::ApplicationData => app_data += 1, - ContentType::Alert => alert += 1, - ContentType::ChangeCipherSpec => ccs += 1, - ContentType::Unknown(_) | ContentType::Ack => other += 1, + ContentType::HANDSHAKE => handshake += 1, + ContentType::APPLICATION_DATA => app_data += 1, + ContentType::ALERT => alert += 1, + ContentType::CHANGE_CIPHER_SPEC => ccs += 1, + _ => other += 1, } let seq = (record.sequence.epoch, record.sequence.sequence_number); diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 48e22f1a..276e4f57 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -232,7 +232,7 @@ impl Server { // Use the engine's create_record to send application data // The encryption is now handled in the engine self.engine - .create_record(ContentType::ApplicationData, 1, false, |body| { + .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { body.extend_from_slice(data); })?; @@ -250,7 +250,7 @@ impl Server { return Ok(()); } self.engine - .create_record(ContentType::Alert, 1, false, |body| { + .create_record(ContentType::ALERT, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -340,7 +340,7 @@ impl State { } // Enforce Null compression only (client must offer it) - let has_null = ch.compression_methods.contains(&CompressionMethod::Null); + let has_null = ch.compression_methods.contains(&CompressionMethod::NULL); if !has_null { return Err( Error::SecurityError(crate::SecurityError::UnsupportedClientCompression).into(), @@ -948,7 +948,7 @@ impl State { } fn await_change_cipher_spec(self, server: &mut Server) -> Result { - let maybe = server.engine.next_record(ContentType::ChangeCipherSpec); + let maybe = server.engine.next_record(ContentType::CHANGE_CIPHER_SPEC); let Some(_) = maybe else { // Stay in same state @@ -1048,7 +1048,7 @@ impl State { // Send ChangeCipherSpec server .engine - .create_record(ContentType::ChangeCipherSpec, 0, true, |body| { + .create_record(ContentType::CHANGE_CIPHER_SPEC, 0, true, |body| { body.push(1); })?; @@ -1118,7 +1118,7 @@ impl State { server.engine.discard_pending_writes(); server .engine - .create_record(ContentType::Alert, 1, false, |body| { + .create_record(ContentType::ALERT, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -1134,7 +1134,7 @@ impl State { for data in server.queued_data.drain(..) { server .engine - .create_record(ContentType::ApplicationData, 1, false, |body| { + .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { body.extend_from_slice(&data); })?; } @@ -1210,7 +1210,7 @@ fn handshake_create_server_hello( random, session_id, cs, - CompressionMethod::Null, + CompressionMethod::NULL, None, ) .with_extensions(extension_data, srtp_pid); diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 66bc40ad..49e8b733 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -267,7 +267,7 @@ impl Client { let epoch = self.engine.app_send_epoch(); self.engine.create_ciphertext_record( - ContentType::ApplicationData, + ContentType::APPLICATION_DATA, epoch, false, |body| { @@ -290,7 +290,7 @@ impl Client { } let epoch = self.engine.app_send_epoch(); self.engine - .create_ciphertext_record(ContentType::Alert, epoch, false, |body| { + .create_ciphertext_record(ContentType::ALERT, epoch, false, |body| { body.push(1); // level: legacy (ignored in DTLS 1.3) body.push(0); // description: close_notify })?; @@ -1090,7 +1090,7 @@ impl State { ); for data in client.queued_data.drain(..) { client.engine.create_ciphertext_record( - ContentType::ApplicationData, + ContentType::APPLICATION_DATA, epoch, false, |body| { @@ -1603,7 +1603,7 @@ mod tests { fn epoch0_handshake_packet(msg_type: MessageType, message_seq: u16, body: &[u8]) -> Vec { let handshake_len = 12 + body.len(); let mut packet = Vec::new(); - packet.push(ContentType::Handshake.as_u8()); + packet.push(ContentType::HANDSHAKE.as_u8()); packet.extend_from_slice(&[0xfe, 0xfd]); packet.extend_from_slice(&0u16.to_be_bytes()); packet.extend_from_slice(&0u64.to_be_bytes()[2..]); diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index d67ccc34..f1d36fa0 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -558,7 +558,7 @@ impl Engine { .queue_rx .iter() .flat_map(|i| i.records().iter()) - .filter(|r| r.record().content_type == ContentType::ApplicationData) + .filter(|r| r.record().content_type == ContentType::APPLICATION_DATA) .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { @@ -1066,7 +1066,7 @@ impl Engine { // Build the record for serialization let record = Dtls13Record { - content_type: ContentType::ApplicationData, + content_type: ContentType::APPLICATION_DATA, sequence: Sequence { epoch, sequence_number: seq, @@ -1216,11 +1216,11 @@ impl Engine { }; if epoch == 0 { - self.create_plaintext_record(ContentType::Handshake, true, |fragment| { + self.create_plaintext_record(ContentType::HANDSHAKE, true, |fragment| { frag_handshake.serialize(&body_buffer, fragment); })?; } else { - self.create_ciphertext_record(ContentType::Handshake, epoch, true, |fragment| { + self.create_ciphertext_record(ContentType::HANDSHAKE, epoch, true, |fragment| { frag_handshake.serialize(&body_buffer, fragment); })?; } @@ -1288,7 +1288,7 @@ impl Engine { 2 }; - self.create_ciphertext_record(ContentType::Ack, epoch, false, |fragment| { + self.create_ciphertext_record(ContentType::ACK, epoch, false, |fragment| { // record_numbers_length: 2 bytes, value = entries.len() * 16 let len = (entries.len() * 16) as u16; fragment.extend_from_slice(&len.to_be_bytes()); @@ -1515,7 +1515,7 @@ impl Engine { for incoming in self.queue_rx.iter() { for r in incoming.records().iter() { if r.record().sequence.epoch == 2 - && r.record().content_type == ContentType::Handshake + && r.record().content_type == ContentType::HANDSHAKE { let seq = r.record().sequence; let _ = record_numbers.try_push((seq.epoch as u64, seq.sequence_number)); @@ -1538,7 +1538,7 @@ impl Engine { return Ok(()); } - self.create_ciphertext_record(ContentType::Ack, 2, false, |fragment| { + self.create_ciphertext_record(ContentType::ACK, 2, false, |fragment| { let len = (record_numbers.len() * 16) as u16; fragment.extend_from_slice(&len.to_be_bytes()); for &(epoch, seq) in record_numbers { @@ -1890,7 +1890,7 @@ impl Engine { let epoch = self.app_send_epoch; // Build the handshake message manually (12-byte DTLS header + 1-byte body) - self.create_ciphertext_record(ContentType::Handshake, epoch, true, |fragment| { + self.create_ciphertext_record(ContentType::HANDSHAKE, epoch, true, |fragment| { // DTLS handshake header (12 bytes): // msg_type(1) + length(3) + message_seq(2) + fragment_offset(3) + fragment_length(3) fragment.push(MessageType::KeyUpdate.as_u8()); @@ -2294,7 +2294,7 @@ impl RecordHandler for Engine { && self.peer_encryption_enabled && matches!( record.record().content_type, - ContentType::Ack | ContentType::Alert + ContentType::ACK | ContentType::ALERT ) { // Plaintext ACKs and alerts after peer encryption is enabled are @@ -2304,13 +2304,13 @@ impl RecordHandler for Engine { } match record.record().content_type { - ContentType::Ack => { + ContentType::ACK => { let fragment = record.record().fragment(record.buffer()); self.process_ack(fragment); self.push_buffer(record.into_buffer()); Ok(None) } - ContentType::Alert => { + ContentType::ALERT => { // RFC 8446 §6: TLS 1.3 ignores the AlertLevel byte; severity is // implicit in the description (only close_notify and user_canceled // are non-fatal). @@ -2335,7 +2335,7 @@ impl RecordHandler for Engine { None => Ok(None), } } - ContentType::ChangeCipherSpec => { + ContentType::CHANGE_CIPHER_SPEC => { trace!("Discarding CCS record"); self.push_buffer(record.into_buffer()); Ok(None) @@ -2564,7 +2564,7 @@ mod tests { fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); fragment.push(KeyUpdateRequest::UpdateRequested.as_u8()); - fragment.push(ContentType::Handshake.as_u8()); + fragment.push(ContentType::HANDSHAKE.as_u8()); let mut packet = Vec::new(); packet.push( @@ -2602,7 +2602,7 @@ mod tests { // Set epoch-0 sequence to MAX — the next increment should be rejected engine.sequence_epoch_0.sequence_number = MAX_SEQUENCE_NUMBER; - let result = engine.create_plaintext_record(ContentType::Handshake, false, |buf| { + let result = engine.create_plaintext_record(ContentType::HANDSHAKE, false, |buf| { buf.extend_from_slice(b"test") }); assert!( @@ -2736,7 +2736,7 @@ mod tests { fn malformed_ack_record_number_vector_is_ignored() { let mut engine = test_engine(); engine.flight_saved_records.push(Entry { - content_type: ContentType::Handshake, + content_type: ContentType::HANDSHAKE, epoch: 2, send_seq: 7, fragment: Buf::new(), diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index f617a416..c1b42e1f 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -363,7 +363,7 @@ impl ParsedRecord { ) -> Result { let (_, record) = Dtls13Record::parse(input, 0)?; - let handshakes = if record.content_type == ContentType::Handshake { + let handshakes = if record.content_type == ContentType::HANDSHAKE { let fragment_offset = record.fragment_range.start; parse_handshakes(record.fragment(input), fragment_offset, cipher_suite) } else { @@ -383,7 +383,7 @@ impl ParsedRecord { input: &[u8], cipher_suite: Option, ) -> ParsedRecord { - let handshakes = if record.content_type == ContentType::Handshake { + let handshakes = if record.content_type == ContentType::HANDSHAKE { let fragment_offset = record.fragment_range.start; parse_handshakes(record.fragment(input), fragment_offset, cipher_suite) } else { @@ -541,7 +541,7 @@ mod tests { impl RecordHandler for TestHandler { fn classify_record(&mut self, record: Record) -> Result, Error> { self.classify_calls += 1; - if record.record().content_type == ContentType::Ack { + if record.record().content_type == ContentType::ACK { self.dropped_acks += 1; return Ok(None); } @@ -617,7 +617,7 @@ mod tests { #[test] fn parse_packet_filters_control_records_after_packet_validation() { let mut packet = Vec::new(); - packet.extend_from_slice(&build_plaintext_record(ContentType::Ack, 1, &[0xAA, 0xBB])); + packet.extend_from_slice(&build_plaintext_record(ContentType::ACK, 1, &[0xAA, 0xBB])); packet.extend_from_slice(&build_ciphertext_record(2, 2, &[0x11, 0x22, 0x33])); let mut handler = TestHandler::default(); @@ -630,7 +630,7 @@ mod tests { assert_eq!(incoming.records().len(), 1); assert_eq!( incoming.first().record().content_type, - ContentType::ApplicationData + ContentType::APPLICATION_DATA ); assert_eq!(incoming.first().record().sequence.epoch, 2); } diff --git a/src/dtls13/message/record.rs b/src/dtls13/message/record.rs index cbbaf734..1a7ccebb 100644 --- a/src/dtls13/message/record.rs +++ b/src/dtls13/message/record.rs @@ -75,7 +75,7 @@ impl Dtls13Record { // RFC 9147 §4.1: Only alert(21), handshake(22), and ack(26) are valid // plaintext content types in DTLS 1.3. Reject all others. match content_type { - ContentType::Alert | ContentType::Handshake | ContentType::Ack => {} + ContentType::ALERT | ContentType::HANDSHAKE | ContentType::ACK => {} _ => { return Err(Err::Failure(nom::error::Error::new( input, @@ -190,7 +190,7 @@ impl Dtls13Record { Ok(( rest, Dtls13Record { - content_type: ContentType::ApplicationData, + content_type: ContentType::APPLICATION_DATA, sequence, length, fragment_range: start..end, diff --git a/src/dtls13/queue.rs b/src/dtls13/queue.rs index eeeea1c8..cfc07241 100644 --- a/src/dtls13/queue.rs +++ b/src/dtls13/queue.rs @@ -49,12 +49,10 @@ impl fmt::Debug for QueueRx { for item in &self.0 { let record = item.first().record(); match record.content_type { - ContentType::Handshake => handshake += 1, - ContentType::ApplicationData => app_data += 1, - ContentType::Alert => alert += 1, - ContentType::Unknown(_) | ContentType::ChangeCipherSpec | ContentType::Ack => { - other += 1 - } + ContentType::HANDSHAKE => handshake += 1, + ContentType::APPLICATION_DATA => app_data += 1, + ContentType::ALERT => alert += 1, + _ => other += 1, } let seq = (record.sequence.epoch, record.sequence.sequence_number); diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 3b7040e2..35caecaa 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -309,7 +309,7 @@ impl Server { let epoch = self.engine.app_send_epoch(); self.engine.create_ciphertext_record( - ContentType::ApplicationData, + ContentType::APPLICATION_DATA, epoch, false, |body| { @@ -332,7 +332,7 @@ impl Server { } let epoch = self.engine.app_send_epoch(); self.engine - .create_ciphertext_record(ContentType::Alert, epoch, false, |body| { + .create_ciphertext_record(ContentType::ALERT, epoch, false, |body| { body.push(1); // level: legacy (ignored in DTLS 1.3) body.push(0); // description: close_notify })?; @@ -1172,7 +1172,7 @@ impl State { ); for data in server.queued_data.drain(..) { server.engine.create_ciphertext_record( - ContentType::ApplicationData, + ContentType::APPLICATION_DATA, epoch, false, |body| { diff --git a/src/types.rs b/src/types.rs index 0dd3277c..42f4097b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -440,51 +440,42 @@ impl fmt::Debug for SignatureAlgorithm { /// /// Identifies the type of data in a DTLS record. These values are the same /// for both DTLS 1.2 and DTLS 1.3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ContentType { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct ContentType(u8); + +impl ContentType { /// Change Cipher Spec (used in DTLS 1.2, compatibility-only in 1.3). - ChangeCipherSpec, + pub const CHANGE_CIPHER_SPEC: Self = Self(20); /// Alert message. - Alert, + pub const ALERT: Self = Self(21); /// Handshake message. - Handshake, + pub const HANDSHAKE: Self = Self(22); /// Application data. - ApplicationData, + pub const APPLICATION_DATA: Self = Self(23); /// ACK (DTLS 1.3 only, RFC 9147 Section 7). - Ack, - /// Unknown content type. - Unknown(u8), -} + pub const ACK: Self = Self(26); -impl Default for ContentType { - fn default() -> Self { - Self::Unknown(0) - } -} - -impl ContentType { /// Convert a u8 value to a `ContentType`. - pub fn from_u8(value: u8) -> Self { - match value { - 20 => ContentType::ChangeCipherSpec, - 21 => ContentType::Alert, - 22 => ContentType::Handshake, - 23 => ContentType::ApplicationData, - 26 => ContentType::Ack, - _ => ContentType::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `ContentType` to its u8 value. - pub fn as_u8(&self) -> u8 { - match self { - ContentType::ChangeCipherSpec => 20, - ContentType::Alert => 21, - ContentType::Handshake => 22, - ContentType::ApplicationData => 23, - ContentType::Ack => 26, - ContentType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known DTLS record content type. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + ContentType::CHANGE_CIPHER_SPEC + | ContentType::ALERT + | ContentType::HANDSHAKE + | ContentType::APPLICATION_DATA + | ContentType::ACK + ) } /// Parse a `ContentType` from wire format. @@ -494,6 +485,19 @@ impl ContentType { } } +impl fmt::Debug for ContentType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + ContentType::CHANGE_CIPHER_SPEC => f.write_str("ChangeCipherSpec"), + ContentType::ALERT => f.write_str("Alert"), + ContentType::HANDSHAKE => f.write_str("Handshake"), + ContentType::APPLICATION_DATA => f.write_str("ApplicationData"), + ContentType::ACK => f.write_str("Ack"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Sequence Number // ============================================================================ @@ -1065,6 +1069,43 @@ mod tests { ); } + #[test] + fn content_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert!(ContentType::default().is_unknown()); + } + + #[test] + fn content_type_wire_roundtrip() { + let known = [ + (20, ContentType::CHANGE_CIPHER_SPEC), + (21, ContentType::ALERT), + (22, ContentType::HANDSHAKE), + (23, ContentType::APPLICATION_DATA), + (26, ContentType::ACK), + ]; + + for (wire, content_type) in known { + assert_eq!(ContentType::from_u8(wire), content_type); + assert_eq!(content_type.as_u8(), wire); + assert!(!content_type.is_unknown()); + } + + let unknown = ContentType::from_u8(24); + assert_eq!(unknown.as_u8(), 24); + assert!(unknown.is_unknown()); + } + + #[test] + fn content_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ContentType::CHANGE_CIPHER_SPEC), + "ChangeCipherSpec" + ); + assert_eq!(format!("{:?}", ContentType::HANDSHAKE), "Handshake"); + assert_eq!(format!("{:?}", ContentType::from_u8(24)), "Unknown(24)"); + } + #[test] fn random_parse() { let data = [ diff --git a/tests/dtls12/retransmit.rs b/tests/dtls12/retransmit.rs index eb7a6445..a0a5b944 100644 --- a/tests/dtls12/retransmit.rs +++ b/tests/dtls12/retransmit.rs @@ -1130,7 +1130,7 @@ fn forged_epoch1_app_data() -> Vec { // garbage (undecryptable) body. The content type lives in the cleartext // header, so this looks like app data before anything is decrypted. let mut rec = Vec::new(); - rec.push(23); // ContentType::ApplicationData + rec.push(23); // ContentType::APPLICATION_DATA rec.extend_from_slice(&[0xFE, 0xFD]); // DTLS 1.2 rec.extend_from_slice(&1u16.to_be_bytes()); // epoch 1 rec.extend_from_slice(&[0, 0, 0, 0, 0, 1]); // 48-bit sequence number From 98ce7c7b74a87ea0e3a20939773299cd3d11f4a3 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:48:28 +0300 Subject: [PATCH 05/18] types: make named group a newtype --- src/config.rs | 8 +- src/crypto/aws_lc_rs/kx_group.rs | 16 +- src/crypto/aws_lc_rs/sign.rs | 14 +- src/crypto/provider.rs | 16 +- src/crypto/rust_crypto/kx_group.rs | 20 +- src/crypto/rust_crypto/sign.rs | 18 +- src/crypto/validation/mod.rs | 2 +- src/dtls12/message/client_key_exchange.rs | 2 +- .../message/extensions/supported_groups.rs | 6 +- src/dtls12/message/server_hello.rs | 6 +- src/dtls12/server.rs | 18 +- src/dtls13/client.rs | 4 +- src/dtls13/message/extensions/key_share.rs | 2 +- .../message/extensions/supported_groups.rs | 6 +- src/dtls13/message/server_hello.rs | 6 +- src/types.rs | 293 ++++++++++-------- tests/dtls12/handshake.rs | 2 +- tests/dtls13/handshake.rs | 16 +- 18 files changed, 246 insertions(+), 209 deletions(-) diff --git a/src/config.rs b/src/config.rs index b84dc122..b8a52a92 100644 --- a/src/config.rs +++ b/src/config.rs @@ -765,11 +765,11 @@ mod tests { #[test] fn filter_kx_groups() { let config = Config::builder() - .kx_groups(&[NamedGroup::Secp256r1]) + .kx_groups(&[NamedGroup::SECP256R1]) .build() .expect("should accept single kx group"); let groups: Vec<_> = config.kx_groups().map(|g| g.name()).collect(); - assert_eq!(groups, &[NamedGroup::Secp256r1]); + assert_eq!(groups, &[NamedGroup::SECP256R1]); } #[test] @@ -1003,7 +1003,7 @@ mod tests { .with_crypto_provider(aws_lc_rs::default_provider()) .dtls12_cipher_suites(&[Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384]) .dtls13_cipher_suites(&[Dtls13CipherSuite::AES_128_GCM_SHA256]) - .kx_groups(&[NamedGroup::X25519, NamedGroup::Secp256r1]) + .kx_groups(&[NamedGroup::X25519, NamedGroup::SECP256R1]) .build() .expect("should accept filtered config with explicit provider"); let suites12: Vec<_> = config.dtls12_cipher_suites().map(|cs| cs.suite()).collect(); @@ -1014,7 +1014,7 @@ mod tests { let suites13: Vec<_> = config.dtls13_cipher_suites().map(|cs| cs.suite()).collect(); assert_eq!(suites13, &[Dtls13CipherSuite::AES_128_GCM_SHA256]); let groups: Vec<_> = config.kx_groups().map(|g| g.name()).collect(); - assert_eq!(groups, &[NamedGroup::X25519, NamedGroup::Secp256r1]); + assert_eq!(groups, &[NamedGroup::X25519, NamedGroup::SECP256R1]); } } } diff --git a/src/crypto/aws_lc_rs/kx_group.rs b/src/crypto/aws_lc_rs/kx_group.rs index 1a79cce9..88276895 100644 --- a/src/crypto/aws_lc_rs/kx_group.rs +++ b/src/crypto/aws_lc_rs/kx_group.rs @@ -28,8 +28,8 @@ impl EcdhKeyExchange { fn new(group: NamedGroup, mut buf: Buf) -> Result { let algorithm = match group { NamedGroup::X25519 => &X25519, - NamedGroup::Secp256r1 => &ECDH_P256, - NamedGroup::Secp384r1 => &ECDH_P384, + NamedGroup::SECP256R1 => &ECDH_P256, + NamedGroup::SECP384R1 => &ECDH_P384, _ => return Err(CryptoError::UnsupportedKeyExchangeGroup(group)), }; @@ -54,8 +54,8 @@ impl EcdhKeyExchange { fn algorithm(&self) -> &'static aws_lc_rs::agreement::Algorithm { match self.group { NamedGroup::X25519 => &X25519, - NamedGroup::Secp256r1 => &ECDH_P256, - NamedGroup::Secp384r1 => &ECDH_P384, + NamedGroup::SECP256R1 => &ECDH_P256, + NamedGroup::SECP384R1 => &ECDH_P384, _ => unreachable!("Unsupported group"), } } @@ -109,11 +109,11 @@ struct P256; impl SupportedKxGroup for P256 { fn name(&self) -> NamedGroup { - NamedGroup::Secp256r1 + NamedGroup::SECP256R1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp256r1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP256R1, buf)?)) } } @@ -123,11 +123,11 @@ struct P384; impl SupportedKxGroup for P384 { fn name(&self) -> NamedGroup { - NamedGroup::Secp384r1 + NamedGroup::SECP384R1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp384r1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP384R1, buf)?)) } } diff --git a/src/crypto/aws_lc_rs/sign.rs b/src/crypto/aws_lc_rs/sign.rs index 614202d5..f6e27c5b 100644 --- a/src/crypto/aws_lc_rs/sign.rs +++ b/src/crypto/aws_lc_rs/sign.rs @@ -221,18 +221,18 @@ impl SignatureVerifier for AwsLcSignatureVerifier { .map_err(|_| CryptoError::InvalidEcCurveParameter)?; let group = match curve_oid { - OID_P256 => NamedGroup::Secp256r1, - OID_P384 => NamedGroup::Secp384r1, + OID_P256 => NamedGroup::SECP256R1, + OID_P384 => NamedGroup::SECP384R1, _ => return Err(CryptoError::UnsupportedEcCurve), }; check_verify_scheme(sig_alg, hash_alg, group)?; let algorithm: &EcdsaVerificationAlgorithm = match (group, hash_alg) { - (NamedGroup::Secp256r1, HashAlgorithm::SHA256) => &ECDSA_P256_SHA256_ASN1, - (NamedGroup::Secp256r1, HashAlgorithm::SHA384) => &ECDSA_P256_SHA384_ASN1, - (NamedGroup::Secp384r1, HashAlgorithm::SHA256) => &ECDSA_P384_SHA256_ASN1, - (NamedGroup::Secp384r1, HashAlgorithm::SHA384) => &ECDSA_P384_SHA384_ASN1, + (NamedGroup::SECP256R1, HashAlgorithm::SHA256) => &ECDSA_P256_SHA256_ASN1, + (NamedGroup::SECP256R1, HashAlgorithm::SHA384) => &ECDSA_P256_SHA384_ASN1, + (NamedGroup::SECP384R1, HashAlgorithm::SHA256) => &ECDSA_P384_SHA256_ASN1, + (NamedGroup::SECP384R1, HashAlgorithm::SHA384) => &ECDSA_P384_SHA384_ASN1, // unreachable: check_verify_scheme already validated _ => unreachable!(), }; @@ -288,7 +288,7 @@ mod tests { CryptoError::SignatureVerificationFailed { signature: SignatureAlgorithm::ECDSA, hash: HashAlgorithm::SHA256, - group: NamedGroup::Secp256r1, + group: NamedGroup::SECP256R1, } ); } diff --git a/src/crypto/provider.rs b/src/crypto/provider.rs index a4f0af08..bde0a0f5 100644 --- a/src/crypto/provider.rs +++ b/src/crypto/provider.rs @@ -319,22 +319,22 @@ const SUPPORTED_VERIFY_SCHEMES: &[(SignatureAlgorithm, HashAlgorithm, NamedGroup ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA256, - NamedGroup::Secp256r1, + NamedGroup::SECP256R1, ), ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA256, - NamedGroup::Secp384r1, + NamedGroup::SECP384R1, ), ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA384, - NamedGroup::Secp256r1, + NamedGroup::SECP256R1, ), ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA384, - NamedGroup::Secp384r1, + NamedGroup::SECP384R1, ), ]; @@ -380,8 +380,8 @@ pub fn cert_named_group(cert_der: &[u8]) -> Result .map_err(|_| CertificateError::InvalidEcCurveParameter)?; match curve_oid { - OID_P256 => Ok(NamedGroup::Secp256r1), - OID_P384 => Ok(NamedGroup::Secp384r1), + OID_P256 => Ok(NamedGroup::SECP256R1), + OID_P384 => Ok(NamedGroup::SECP384R1), _ => Err(CertificateError::UnsupportedEcCurve), } } @@ -640,7 +640,7 @@ mod tests { let cert = params.self_signed(&key_pair).unwrap(); let group = cert_named_group(cert.der()).unwrap(); - assert_eq!(group, NamedGroup::Secp256r1); + assert_eq!(group, NamedGroup::SECP256R1); } #[test] @@ -653,7 +653,7 @@ mod tests { let cert = params.self_signed(&key_pair).unwrap(); let group = cert_named_group(cert.der()).unwrap(); - assert_eq!(group, NamedGroup::Secp384r1); + assert_eq!(group, NamedGroup::SECP384R1); } #[test] diff --git a/src/crypto/rust_crypto/kx_group.rs b/src/crypto/rust_crypto/kx_group.rs index 6593b0dc..0d0d44d3 100644 --- a/src/crypto/rust_crypto/kx_group.rs +++ b/src/crypto/rust_crypto/kx_group.rs @@ -57,7 +57,7 @@ impl EcdhKeyExchange { public_key: buf, }) } - NamedGroup::Secp256r1 => { + NamedGroup::SECP256R1 => { use rand_core::OsRng; let secret = EphemeralSecret::random(&mut OsRng); let public_key_obj = P256PublicKey::from(&secret); @@ -69,7 +69,7 @@ impl EcdhKeyExchange { public_key: buf, }) } - NamedGroup::Secp384r1 => { + NamedGroup::SECP384R1 => { use rand_core::OsRng; let secret = P384EphemeralSecret::random(&mut OsRng); let public_key_obj = P384PublicKey::from(&secret); @@ -113,7 +113,7 @@ impl ActiveKeyExchange for EcdhKeyExchange { } EcdhKeyExchange::P256 { secret, .. } => { let peer_key = P256PublicKey::from_sec1_bytes(peer_pub) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp256r1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP256R1))?; let shared_secret = secret.diffie_hellman(&peer_key); out.clear(); out.extend_from_slice(shared_secret.raw_secret_bytes().as_slice()); @@ -121,7 +121,7 @@ impl ActiveKeyExchange for EcdhKeyExchange { } EcdhKeyExchange::P384 { secret, .. } => { let peer_key = P384PublicKey::from_sec1_bytes(peer_pub) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp384r1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP384R1))?; let shared_secret = secret.diffie_hellman(&peer_key); out.clear(); out.extend_from_slice(shared_secret.raw_secret_bytes().as_slice()); @@ -133,8 +133,8 @@ impl ActiveKeyExchange for EcdhKeyExchange { fn group(&self) -> NamedGroup { match self { EcdhKeyExchange::X25519 { .. } => NamedGroup::X25519, - EcdhKeyExchange::P256 { .. } => NamedGroup::Secp256r1, - EcdhKeyExchange::P384 { .. } => NamedGroup::Secp384r1, + EcdhKeyExchange::P256 { .. } => NamedGroup::SECP256R1, + EcdhKeyExchange::P384 { .. } => NamedGroup::SECP384R1, } } } @@ -159,11 +159,11 @@ struct P256; impl SupportedKxGroup for P256 { fn name(&self) -> NamedGroup { - NamedGroup::Secp256r1 + NamedGroup::SECP256R1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp256r1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP256R1, buf)?)) } } @@ -173,11 +173,11 @@ struct P384; impl SupportedKxGroup for P384 { fn name(&self) -> NamedGroup { - NamedGroup::Secp384r1 + NamedGroup::SECP384R1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp384r1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP384R1, buf)?)) } } diff --git a/src/crypto/rust_crypto/sign.rs b/src/crypto/rust_crypto/sign.rs index 1bf80084..fa56a723 100644 --- a/src/crypto/rust_crypto/sign.rs +++ b/src/crypto/rust_crypto/sign.rs @@ -55,7 +55,7 @@ impl SigningKeyTrait for EcdsaSigningKey { } _ => { return Err(CryptoError::SigningKeyUnsupportedHash { - group: NamedGroup::Secp256r1, + group: NamedGroup::SECP256R1, hash: hash_alg, }); } @@ -77,7 +77,7 @@ impl SigningKeyTrait for EcdsaSigningKey { } _ => { return Err(CryptoError::SigningKeyUnsupportedHash { - group: NamedGroup::Secp384r1, + group: NamedGroup::SECP384R1, hash: hash_alg, }); } @@ -231,8 +231,8 @@ impl SignatureVerifier for RustCryptoSignatureVerifier { .map_err(|_| CryptoError::InvalidEcCurveParameter)?; let group = match curve_oid { - OID_P256 => NamedGroup::Secp256r1, - OID_P384 => NamedGroup::Secp384r1, + OID_P256 => NamedGroup::SECP256R1, + OID_P384 => NamedGroup::SECP384R1, _ => return Err(CryptoError::UnsupportedEcCurve), }; @@ -250,9 +250,9 @@ impl SignatureVerifier for RustCryptoSignatureVerifier { }; match group { - NamedGroup::Secp256r1 => { + NamedGroup::SECP256R1 => { let verifying_key = VerifyingKey::::from_sec1_bytes(pubkey_bytes) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp256r1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP256R1))?; let sig = Signature::::from_der(signature) .map_err(|_| CryptoError::InvalidSignatureFormat)?; verifying_key.verify_prehash(&hash, &sig).map_err(|_| { @@ -263,9 +263,9 @@ impl SignatureVerifier for RustCryptoSignatureVerifier { } }) } - NamedGroup::Secp384r1 => { + NamedGroup::SECP384R1 => { let verifying_key = VerifyingKey::::from_sec1_bytes(pubkey_bytes) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp384r1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP384R1))?; let sig = Signature::::from_der(signature) .map_err(|_| CryptoError::InvalidSignatureFormat)?; verifying_key.verify_prehash(&hash, &sig).map_err(|_| { @@ -322,7 +322,7 @@ mod tests { CryptoError::SignatureVerificationFailed { signature: SignatureAlgorithm::ECDSA, hash: HashAlgorithm::SHA256, - group: NamedGroup::Secp256r1, + group: NamedGroup::SECP256R1, } ); } diff --git a/src/crypto/validation/mod.rs b/src/crypto/validation/mod.rs index 3b19acce..2ce2fc3b 100644 --- a/src/crypto/validation/mod.rs +++ b/src/crypto/validation/mod.rs @@ -40,7 +40,7 @@ impl CryptoProvider { self.kx_groups.iter().copied().filter(|kx| { matches!( kx.name(), - NamedGroup::X25519 | NamedGroup::Secp256r1 | NamedGroup::Secp384r1 + NamedGroup::X25519 | NamedGroup::SECP256R1 | NamedGroup::SECP384R1 ) }) } diff --git a/src/dtls12/message/client_key_exchange.rs b/src/dtls12/message/client_key_exchange.rs index 522a6d59..26f68907 100644 --- a/src/dtls12/message/client_key_exchange.rs +++ b/src/dtls12/message/client_key_exchange.rs @@ -43,7 +43,7 @@ impl ClientEcdhKeys { // In ClientKeyExchange, we don't include curve_type and named_group // since they're already established during ServerKeyExchange curve_type: CurveType::NamedCurve, // Default - named_group: NamedGroup::Secp256r1, // Default + named_group: NamedGroup::SECP256R1, // Default public_key_range: start..end, }, )) diff --git a/src/dtls12/message/extensions/supported_groups.rs b/src/dtls12/message/extensions/supported_groups.rs index 12b9fb44..09d2a984 100644 --- a/src/dtls12/message/extensions/supported_groups.rs +++ b/src/dtls12/message/extensions/supported_groups.rs @@ -61,7 +61,7 @@ mod tests { fn test_supported_groups_extension() { let mut groups = NamedGroupVec::new(); groups.push(NamedGroup::X25519); - groups.push(NamedGroup::Secp256r1); + groups.push(NamedGroup::SECP256R1); let ext = SupportedGroupsExtension { groups }; @@ -98,8 +98,8 @@ mod tests { parsed.groups.as_slice(), &[ NamedGroup::X25519, - NamedGroup::Secp256r1, - NamedGroup::Secp384r1 + NamedGroup::SECP256R1, + NamedGroup::SECP384R1 ] ); } diff --git a/src/dtls12/message/server_hello.rs b/src/dtls12/message/server_hello.rs index b6496547..9f1cd980 100644 --- a/src/dtls12/message/server_hello.rs +++ b/src/dtls12/message/server_hello.rs @@ -205,9 +205,9 @@ mod test { 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension data length (8 bytes) 0x00, 0x06, // Extension data - 0x00, 0x17, // NamedGroup::Secp256r1 - 0x00, 0x18, // NamedGroup::Secp384r1 - 0x00, 0x19, // NamedGroup::Secp521r1 + 0x00, 0x17, // NamedGroup::SECP256R1 + 0x00, 0x18, // NamedGroup::SECP384R1 + 0x00, 0x19, // NamedGroup::SECP521R1 ]; #[test] diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 276e4f57..15681bbc 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -1418,11 +1418,11 @@ mod tests { #[test] fn select_named_group_prefers_x25519_when_available() { let client = named_group_vec(&[ - NamedGroup::Secp256r1, + NamedGroup::SECP256R1, NamedGroup::X25519, - NamedGroup::Secp384r1, + NamedGroup::SECP384R1, ]); - let provider = [NamedGroup::X25519, NamedGroup::Secp256r1]; + let provider = [NamedGroup::X25519, NamedGroup::SECP256R1]; let selected = select_named_group(Some(&client), &provider); @@ -1431,27 +1431,27 @@ mod tests { #[test] fn select_named_group_respects_provider_capabilities() { - let client = named_group_vec(&[NamedGroup::X25519, NamedGroup::Secp256r1]); - let provider = [NamedGroup::Secp256r1]; + let client = named_group_vec(&[NamedGroup::X25519, NamedGroup::SECP256R1]); + let provider = [NamedGroup::SECP256R1]; let selected = select_named_group(Some(&client), &provider); - assert_eq!(selected, Some(NamedGroup::Secp256r1)); + assert_eq!(selected, Some(NamedGroup::SECP256R1)); } #[test] fn select_named_group_falls_back_to_provider_when_client_missing() { - let provider = [NamedGroup::Secp384r1]; + let provider = [NamedGroup::SECP384R1]; let selected = select_named_group(None, &provider); - assert_eq!(selected, Some(NamedGroup::Secp384r1)); + assert_eq!(selected, Some(NamedGroup::SECP384R1)); } #[test] fn select_named_group_rejects_when_client_has_no_overlap() { let client = named_group_vec(&[NamedGroup::X25519]); - let provider = [NamedGroup::Secp256r1]; + let provider = [NamedGroup::SECP256R1]; let selected = select_named_group(Some(&client), &provider); diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 49e8b733..cf30cdef 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -1679,7 +1679,7 @@ mod tests { .parse_packet(&epoch0_handshake_packet( MessageType::ServerHello, 0, - &server_hello_with_key_share(NamedGroup::Secp256r1), + &server_hello_with_key_share(NamedGroup::SECP256R1), )) .expect("queue mismatched ServerHello"); @@ -1692,7 +1692,7 @@ mod tests { crate::InternalError::Fatal(Error::SecurityError( SecurityError::ServerKeyShareGroupMismatch { expected: NamedGroup::X25519, - actual: NamedGroup::Secp256r1, + actual: NamedGroup::SECP256R1, } )) )); diff --git a/src/dtls13/message/extensions/key_share.rs b/src/dtls13/message/extensions/key_share.rs index 04256dce..3d8b8cf8 100644 --- a/src/dtls13/message/extensions/key_share.rs +++ b/src/dtls13/message/extensions/key_share.rs @@ -180,7 +180,7 @@ mod tests { #[test] fn key_share_hrr_roundtrip() { let message: &[u8] = &[ - 0x00, 0x17, // NamedGroup::Secp256r1 + 0x00, 0x17, // NamedGroup::SECP256R1 ]; let (rest, parsed) = KeyShareHelloRetryRequest::parse(message).unwrap(); diff --git a/src/dtls13/message/extensions/supported_groups.rs b/src/dtls13/message/extensions/supported_groups.rs index 2bcdff1c..7b4d712c 100644 --- a/src/dtls13/message/extensions/supported_groups.rs +++ b/src/dtls13/message/extensions/supported_groups.rs @@ -62,7 +62,7 @@ mod tests { fn test_supported_groups_extension() { let mut groups = ArrayVec::new(); groups.push(NamedGroup::X25519); - groups.push(NamedGroup::Secp256r1); + groups.push(NamedGroup::SECP256R1); let ext = SupportedGroupsExtension { groups }; @@ -94,8 +94,8 @@ mod tests { parsed.groups.as_slice(), &[ NamedGroup::X25519, - NamedGroup::Secp256r1, - NamedGroup::Secp384r1 + NamedGroup::SECP256R1, + NamedGroup::SECP384R1 ] ); } diff --git a/src/dtls13/message/server_hello.rs b/src/dtls13/message/server_hello.rs index 24d796dc..2886ca8d 100644 --- a/src/dtls13/message/server_hello.rs +++ b/src/dtls13/message/server_hello.rs @@ -161,9 +161,9 @@ mod test { 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension data length (8 bytes) 0x00, 0x06, // Extension data - 0x00, 0x17, // NamedGroup::Secp256r1 - 0x00, 0x18, // NamedGroup::Secp384r1 - 0x00, 0x19, // NamedGroup::Secp521r1 + 0x00, 0x17, // NamedGroup::SECP256R1 + 0x00, 0x18, // NamedGroup::SECP384R1 + 0x00, 0x19, // NamedGroup::SECP521R1 ]; #[test] diff --git a/src/types.rs b/src/types.rs index 42f4097b..ffad1604 100644 --- a/src/types.rs +++ b/src/types.rs @@ -81,134 +81,108 @@ impl Random { /// /// Used for Elliptic Curve Diffie-Hellman Ephemeral (ECDHE) key exchange. /// The same named groups are used in both DTLS 1.2 and DTLS 1.3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[non_exhaustive] -pub enum NamedGroup { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct NamedGroup(u16); + +impl NamedGroup { /// sect163k1 (deprecated). - Sect163k1, + pub const SECT163K1: Self = Self(1); /// sect163r1 (deprecated). - Sect163r1, + pub const SECT163R1: Self = Self(2); /// sect163r2 (deprecated). - Sect163r2, + pub const SECT163R2: Self = Self(3); /// sect193r1 (deprecated). - Sect193r1, + pub const SECT193R1: Self = Self(4); /// sect193r2 (deprecated). - Sect193r2, + pub const SECT193R2: Self = Self(5); /// sect233k1 (deprecated). - Sect233k1, + pub const SECT233K1: Self = Self(6); /// sect233r1 (deprecated). - Sect233r1, + pub const SECT233R1: Self = Self(7); /// sect239k1 (deprecated). - Sect239k1, + pub const SECT239K1: Self = Self(8); /// sect283k1 (deprecated). - Sect283k1, + pub const SECT283K1: Self = Self(9); /// sect283r1 (deprecated). - Sect283r1, + pub const SECT283R1: Self = Self(10); /// sect409k1 (deprecated). - Sect409k1, + pub const SECT409K1: Self = Self(11); /// sect409r1 (deprecated). - Sect409r1, + pub const SECT409R1: Self = Self(12); /// sect571k1 (deprecated). - Sect571k1, + pub const SECT571K1: Self = Self(13); /// sect571r1 (deprecated). - Sect571r1, + pub const SECT571R1: Self = Self(14); /// secp160k1 (deprecated). - Secp160k1, + pub const SECP160K1: Self = Self(15); /// secp160r1 (deprecated). - Secp160r1, + pub const SECP160R1: Self = Self(16); /// secp160r2 (deprecated). - Secp160r2, + pub const SECP160R2: Self = Self(17); /// secp192k1 (deprecated). - Secp192k1, + pub const SECP192K1: Self = Self(18); /// secp192r1 (deprecated). - Secp192r1, + pub const SECP192R1: Self = Self(19); /// secp224k1. - Secp224k1, + pub const SECP224K1: Self = Self(20); /// secp224r1. - Secp224r1, + pub const SECP224R1: Self = Self(21); /// secp256k1. - Secp256k1, + pub const SECP256K1: Self = Self(22); /// secp256r1 / P-256 (supported by dimpl). - Secp256r1, + pub const SECP256R1: Self = Self(23); /// secp384r1 / P-384 (supported by dimpl). - Secp384r1, + pub const SECP384R1: Self = Self(24); /// secp521r1 / P-521. - Secp521r1, + pub const SECP521R1: Self = Self(25); /// X25519 (Curve25519 for ECDHE). - X25519, + pub const X25519: Self = Self(29); /// X448 (Curve448 for ECDHE). - X448, - /// Unknown or unsupported group. - Unknown(u16), -} + pub const X448: Self = Self(30); -impl NamedGroup { /// Convert a wire format u16 value to a `NamedGroup`. - pub fn from_u16(value: u16) -> Self { - match value { - 1 => NamedGroup::Sect163k1, - 2 => NamedGroup::Sect163r1, - 3 => NamedGroup::Sect163r2, - 4 => NamedGroup::Sect193r1, - 5 => NamedGroup::Sect193r2, - 6 => NamedGroup::Sect233k1, - 7 => NamedGroup::Sect233r1, - 8 => NamedGroup::Sect239k1, - 9 => NamedGroup::Sect283k1, - 10 => NamedGroup::Sect283r1, - 11 => NamedGroup::Sect409k1, - 12 => NamedGroup::Sect409r1, - 13 => NamedGroup::Sect571k1, - 14 => NamedGroup::Sect571r1, - 15 => NamedGroup::Secp160k1, - 16 => NamedGroup::Secp160r1, - 17 => NamedGroup::Secp160r2, - 18 => NamedGroup::Secp192k1, - 19 => NamedGroup::Secp192r1, - 20 => NamedGroup::Secp224k1, - 21 => NamedGroup::Secp224r1, - 22 => NamedGroup::Secp256k1, - 23 => NamedGroup::Secp256r1, - 24 => NamedGroup::Secp384r1, - 25 => NamedGroup::Secp521r1, - 29 => NamedGroup::X25519, - 30 => NamedGroup::X448, - _ => NamedGroup::Unknown(value), - } + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Convert this `NamedGroup` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - NamedGroup::Sect163k1 => 1, - NamedGroup::Sect163r1 => 2, - NamedGroup::Sect163r2 => 3, - NamedGroup::Sect193r1 => 4, - NamedGroup::Sect193r2 => 5, - NamedGroup::Sect233k1 => 6, - NamedGroup::Sect233r1 => 7, - NamedGroup::Sect239k1 => 8, - NamedGroup::Sect283k1 => 9, - NamedGroup::Sect283r1 => 10, - NamedGroup::Sect409k1 => 11, - NamedGroup::Sect409r1 => 12, - NamedGroup::Sect571k1 => 13, - NamedGroup::Sect571r1 => 14, - NamedGroup::Secp160k1 => 15, - NamedGroup::Secp160r1 => 16, - NamedGroup::Secp160r2 => 17, - NamedGroup::Secp192k1 => 18, - NamedGroup::Secp192r1 => 19, - NamedGroup::Secp224k1 => 20, - NamedGroup::Secp224r1 => 21, - NamedGroup::Secp256k1 => 22, - NamedGroup::Secp256r1 => 23, - NamedGroup::Secp384r1 => 24, - NamedGroup::Secp521r1 => 25, - NamedGroup::X25519 => 29, - NamedGroup::X448 => 30, - NamedGroup::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known TLS named group wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + NamedGroup::SECT163K1 + | NamedGroup::SECT163R1 + | NamedGroup::SECT163R2 + | NamedGroup::SECT193R1 + | NamedGroup::SECT193R2 + | NamedGroup::SECT233K1 + | NamedGroup::SECT233R1 + | NamedGroup::SECT239K1 + | NamedGroup::SECT283K1 + | NamedGroup::SECT283R1 + | NamedGroup::SECT409K1 + | NamedGroup::SECT409R1 + | NamedGroup::SECT571K1 + | NamedGroup::SECT571R1 + | NamedGroup::SECP160K1 + | NamedGroup::SECP160R1 + | NamedGroup::SECP160R2 + | NamedGroup::SECP192K1 + | NamedGroup::SECP192R1 + | NamedGroup::SECP224K1 + | NamedGroup::SECP224R1 + | NamedGroup::SECP256K1 + | NamedGroup::SECP256R1 + | NamedGroup::SECP384R1 + | NamedGroup::SECP521R1 + | NamedGroup::X25519 + | NamedGroup::X448 + ) } /// Parse a `NamedGroup` from wire format. @@ -225,31 +199,31 @@ impl NamedGroup { /// All recognized named groups (every non-`Unknown` variant). pub const fn all() -> &'static [NamedGroup; 27] { &[ - NamedGroup::Sect163k1, - NamedGroup::Sect163r1, - NamedGroup::Sect163r2, - NamedGroup::Sect193r1, - NamedGroup::Sect193r2, - NamedGroup::Sect233k1, - NamedGroup::Sect233r1, - NamedGroup::Sect239k1, - NamedGroup::Sect283k1, - NamedGroup::Sect283r1, - NamedGroup::Sect409k1, - NamedGroup::Sect409r1, - NamedGroup::Sect571k1, - NamedGroup::Sect571r1, - NamedGroup::Secp160k1, - NamedGroup::Secp160r1, - NamedGroup::Secp160r2, - NamedGroup::Secp192k1, - NamedGroup::Secp192r1, - NamedGroup::Secp224k1, - NamedGroup::Secp224r1, - NamedGroup::Secp256k1, - NamedGroup::Secp256r1, - NamedGroup::Secp384r1, - NamedGroup::Secp521r1, + NamedGroup::SECT163K1, + NamedGroup::SECT163R1, + NamedGroup::SECT163R2, + NamedGroup::SECT193R1, + NamedGroup::SECT193R2, + NamedGroup::SECT233K1, + NamedGroup::SECT233R1, + NamedGroup::SECT239K1, + NamedGroup::SECT283K1, + NamedGroup::SECT283R1, + NamedGroup::SECT409K1, + NamedGroup::SECT409R1, + NamedGroup::SECT571K1, + NamedGroup::SECT571R1, + NamedGroup::SECP160K1, + NamedGroup::SECP160R1, + NamedGroup::SECP160R2, + NamedGroup::SECP192K1, + NamedGroup::SECP192R1, + NamedGroup::SECP224K1, + NamedGroup::SECP224R1, + NamedGroup::SECP256K1, + NamedGroup::SECP256R1, + NamedGroup::SECP384R1, + NamedGroup::SECP521R1, NamedGroup::X25519, NamedGroup::X448, ] @@ -259,13 +233,48 @@ impl NamedGroup { pub const fn supported() -> &'static [NamedGroup; 4] { &[ NamedGroup::X25519, - NamedGroup::Secp256r1, - NamedGroup::Secp384r1, - NamedGroup::Secp521r1, + NamedGroup::SECP256R1, + NamedGroup::SECP384R1, + NamedGroup::SECP521R1, ] } } +impl fmt::Debug for NamedGroup { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + NamedGroup::SECT163K1 => f.write_str("Sect163k1"), + NamedGroup::SECT163R1 => f.write_str("Sect163r1"), + NamedGroup::SECT163R2 => f.write_str("Sect163r2"), + NamedGroup::SECT193R1 => f.write_str("Sect193r1"), + NamedGroup::SECT193R2 => f.write_str("Sect193r2"), + NamedGroup::SECT233K1 => f.write_str("Sect233k1"), + NamedGroup::SECT233R1 => f.write_str("Sect233r1"), + NamedGroup::SECT239K1 => f.write_str("Sect239k1"), + NamedGroup::SECT283K1 => f.write_str("Sect283k1"), + NamedGroup::SECT283R1 => f.write_str("Sect283r1"), + NamedGroup::SECT409K1 => f.write_str("Sect409k1"), + NamedGroup::SECT409R1 => f.write_str("Sect409r1"), + NamedGroup::SECT571K1 => f.write_str("Sect571k1"), + NamedGroup::SECT571R1 => f.write_str("Sect571r1"), + NamedGroup::SECP160K1 => f.write_str("Secp160k1"), + NamedGroup::SECP160R1 => f.write_str("Secp160r1"), + NamedGroup::SECP160R2 => f.write_str("Secp160r2"), + NamedGroup::SECP192K1 => f.write_str("Secp192k1"), + NamedGroup::SECP192R1 => f.write_str("Secp192r1"), + NamedGroup::SECP224K1 => f.write_str("Secp224k1"), + NamedGroup::SECP224R1 => f.write_str("Secp224r1"), + NamedGroup::SECP256K1 => f.write_str("Secp256k1"), + NamedGroup::SECP256R1 => f.write_str("Secp256r1"), + NamedGroup::SECP384R1 => f.write_str("Secp384r1"), + NamedGroup::SECP521R1 => f.write_str("Secp521r1"), + NamedGroup::X25519 => f.write_str("X25519"), + NamedGroup::X448 => f.write_str("X448"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Hash Algorithms // ============================================================================ @@ -690,8 +699,8 @@ impl SignatureScheme { /// Returns `None` for non-ECDSA schemes. pub fn named_group(&self) -> Option { match self { - SignatureScheme::ECDSA_SECP256R1_SHA256 => Some(NamedGroup::Secp256r1), - SignatureScheme::ECDSA_SECP384R1_SHA384 => Some(NamedGroup::Secp384r1), + SignatureScheme::ECDSA_SECP256R1_SHA256 => Some(NamedGroup::SECP256R1), + SignatureScheme::ECDSA_SECP384R1_SHA384 => Some(NamedGroup::SECP384R1), _ => None, } } @@ -951,6 +960,34 @@ impl fmt::Debug for CompressionMethod { mod tests { use super::*; + #[test] + fn named_group_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert!(NamedGroup::default().is_unknown()); + } + + #[test] + fn named_group_wire_roundtrip() { + for group in NamedGroup::all() { + assert_eq!(NamedGroup::from_u16(group.as_u16()), *group); + assert!(!group.is_unknown()); + } + + let unknown = NamedGroup::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn named_group_debug_stays_enum_like() { + assert_eq!(format!("{:?}", NamedGroup::SECP256R1), "Secp256r1"); + assert_eq!(format!("{:?}", NamedGroup::X25519), "X25519"); + assert_eq!( + format!("{:?}", NamedGroup::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn hash_algorithm_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); @@ -1150,11 +1187,11 @@ mod tests { fn signature_scheme_named_group_ecdsa() { assert_eq!( SignatureScheme::ECDSA_SECP256R1_SHA256.named_group(), - Some(NamedGroup::Secp256r1) + Some(NamedGroup::SECP256R1) ); assert_eq!( SignatureScheme::ECDSA_SECP384R1_SHA384.named_group(), - Some(NamedGroup::Secp384r1) + Some(NamedGroup::SECP384R1) ); } diff --git a/tests/dtls12/handshake.rs b/tests/dtls12/handshake.rs index a6f19907..757f37a4 100644 --- a/tests/dtls12/handshake.rs +++ b/tests/dtls12/handshake.rs @@ -672,7 +672,7 @@ fn dtls12_handshake_secp384r1_key_exchange() { .kx_groups .iter() .copied() - .filter(|g| g.name() == dimpl::NamedGroup::Secp384r1) + .filter(|g| g.name() == dimpl::NamedGroup::SECP384R1) .collect(); // leak: intentional leak to produce a &'static slice for the provider field let p384_only: &'static [&'static dyn dimpl::crypto::SupportedKxGroup] = diff --git a/tests/dtls13/handshake.rs b/tests/dtls13/handshake.rs index 1799bf7f..a2610749 100644 --- a/tests/dtls13/handshake.rs +++ b/tests/dtls13/handshake.rs @@ -515,7 +515,7 @@ fn dtls13_handshake_secp256r1_key_exchange() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::Secp256r1) + .filter(|g| g.name() == NamedGroup::SECP256R1) .collect(); assert!(!p256_only.is_empty(), "Provider must have P-256"); @@ -806,12 +806,12 @@ fn dtls13_hrr_with_p256_then_x25519() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::X25519) + .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::X25519) .collect(); // Ensure P-256 is first let mut client_groups_sorted: Vec<_> = client_groups; client_groups_sorted.sort_by_key(|g| { - if g.name() == NamedGroup::Secp256r1 { + if g.name() == NamedGroup::SECP256R1 { 0 } else { 1 @@ -829,7 +829,7 @@ fn dtls13_hrr_with_p256_then_x25519() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::X25519) + .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::X25519) .collect(); let mut server_groups_sorted: Vec<_> = server_groups; server_groups_sorted.sort_by_key(|g| if g.name() == NamedGroup::X25519 { 0 } else { 1 }); @@ -935,11 +935,11 @@ fn dtls13_hrr_handshake_completes_after_packet_loss() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::Secp384r1) + .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::SECP384R1) .collect(); let mut client_groups_sorted: Vec<_> = client_groups; client_groups_sorted.sort_by_key(|g| { - if g.name() == NamedGroup::Secp256r1 { + if g.name() == NamedGroup::SECP256R1 { 0 } else { 1 @@ -957,11 +957,11 @@ fn dtls13_hrr_handshake_completes_after_packet_loss() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::Secp384r1) + .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::SECP384R1) .collect(); let mut server_groups_sorted: Vec<_> = server_groups; server_groups_sorted.sort_by_key(|g| { - if g.name() == NamedGroup::Secp384r1 { + if g.name() == NamedGroup::SECP384R1 { 0 } else { 1 From 9a393de4f7ec92ddb6cf7a648c42225edf5be3ac Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:50:34 +0300 Subject: [PATCH 06/18] types: make signature scheme a newtype --- src/types.rs | 164 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 63 deletions(-) diff --git a/src/types.rs b/src/types.rs index ffad1604..0384ae1e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -570,83 +570,69 @@ impl PartialOrd for Sequence { /// In TLS 1.3, signature schemes combine the signature algorithm with the /// hash algorithm into a single identifier, unlike TLS 1.2 where they were /// separate. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -#[non_exhaustive] -pub enum SignatureScheme { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct SignatureScheme(u16); + +impl SignatureScheme { /// ECDSA with P-256 and SHA-256. - ECDSA_SECP256R1_SHA256, + pub const ECDSA_SECP256R1_SHA256: Self = Self(0x0403); /// ECDSA with P-384 and SHA-384. - ECDSA_SECP384R1_SHA384, + pub const ECDSA_SECP384R1_SHA384: Self = Self(0x0503); /// ECDSA with P-521 and SHA-512. - ECDSA_SECP521R1_SHA512, + pub const ECDSA_SECP521R1_SHA512: Self = Self(0x0603); /// Ed25519. - ED25519, + pub const ED25519: Self = Self(0x0807); /// Ed448. - ED448, + pub const ED448: Self = Self(0x0808); /// RSA-PSS with SHA-256 (rsaEncryption OID). - RSA_PSS_RSAE_SHA256, + pub const RSA_PSS_RSAE_SHA256: Self = Self(0x0804); /// RSA-PSS with SHA-384 (rsaEncryption OID). - RSA_PSS_RSAE_SHA384, + pub const RSA_PSS_RSAE_SHA384: Self = Self(0x0805); /// RSA-PSS with SHA-512 (rsaEncryption OID). - RSA_PSS_RSAE_SHA512, + pub const RSA_PSS_RSAE_SHA512: Self = Self(0x0806); /// RSA-PSS with SHA-256 (id-rsassa-pss OID). - RSA_PSS_PSS_SHA256, + pub const RSA_PSS_PSS_SHA256: Self = Self(0x0809); /// RSA-PSS with SHA-384 (id-rsassa-pss OID). - RSA_PSS_PSS_SHA384, + pub const RSA_PSS_PSS_SHA384: Self = Self(0x080a); /// RSA-PSS with SHA-512 (id-rsassa-pss OID). - RSA_PSS_PSS_SHA512, + pub const RSA_PSS_PSS_SHA512: Self = Self(0x080b); /// RSA PKCS#1 v1.5 with SHA-256 (legacy). - RSA_PKCS1_SHA256, + pub const RSA_PKCS1_SHA256: Self = Self(0x0401); /// RSA PKCS#1 v1.5 with SHA-384 (legacy). - RSA_PKCS1_SHA384, + pub const RSA_PKCS1_SHA384: Self = Self(0x0501); /// RSA PKCS#1 v1.5 with SHA-512 (legacy). - RSA_PKCS1_SHA512, - /// Unknown or unsupported signature scheme. - Unknown(u16), -} + pub const RSA_PKCS1_SHA512: Self = Self(0x0601); -impl SignatureScheme { /// Convert a wire format u16 value to a `SignatureScheme`. - pub fn from_u16(value: u16) -> Self { - match value { - 0x0403 => SignatureScheme::ECDSA_SECP256R1_SHA256, - 0x0503 => SignatureScheme::ECDSA_SECP384R1_SHA384, - 0x0603 => SignatureScheme::ECDSA_SECP521R1_SHA512, - 0x0807 => SignatureScheme::ED25519, - 0x0808 => SignatureScheme::ED448, - 0x0804 => SignatureScheme::RSA_PSS_RSAE_SHA256, - 0x0805 => SignatureScheme::RSA_PSS_RSAE_SHA384, - 0x0806 => SignatureScheme::RSA_PSS_RSAE_SHA512, - 0x0809 => SignatureScheme::RSA_PSS_PSS_SHA256, - 0x080a => SignatureScheme::RSA_PSS_PSS_SHA384, - 0x080b => SignatureScheme::RSA_PSS_PSS_SHA512, - 0x0401 => SignatureScheme::RSA_PKCS1_SHA256, - 0x0501 => SignatureScheme::RSA_PKCS1_SHA384, - 0x0601 => SignatureScheme::RSA_PKCS1_SHA512, - _ => SignatureScheme::Unknown(value), - } + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Convert this `SignatureScheme` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - SignatureScheme::ECDSA_SECP256R1_SHA256 => 0x0403, - SignatureScheme::ECDSA_SECP384R1_SHA384 => 0x0503, - SignatureScheme::ECDSA_SECP521R1_SHA512 => 0x0603, - SignatureScheme::ED25519 => 0x0807, - SignatureScheme::ED448 => 0x0808, - SignatureScheme::RSA_PSS_RSAE_SHA256 => 0x0804, - SignatureScheme::RSA_PSS_RSAE_SHA384 => 0x0805, - SignatureScheme::RSA_PSS_RSAE_SHA512 => 0x0806, - SignatureScheme::RSA_PSS_PSS_SHA256 => 0x0809, - SignatureScheme::RSA_PSS_PSS_SHA384 => 0x080a, - SignatureScheme::RSA_PSS_PSS_SHA512 => 0x080b, - SignatureScheme::RSA_PKCS1_SHA256 => 0x0401, - SignatureScheme::RSA_PKCS1_SHA384 => 0x0501, - SignatureScheme::RSA_PKCS1_SHA512 => 0x0601, - SignatureScheme::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known TLS signature scheme wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + SignatureScheme::ECDSA_SECP256R1_SHA256 + | SignatureScheme::ECDSA_SECP384R1_SHA384 + | SignatureScheme::ECDSA_SECP521R1_SHA512 + | SignatureScheme::ED25519 + | SignatureScheme::ED448 + | SignatureScheme::RSA_PSS_RSAE_SHA256 + | SignatureScheme::RSA_PSS_RSAE_SHA384 + | SignatureScheme::RSA_PSS_RSAE_SHA512 + | SignatureScheme::RSA_PSS_PSS_SHA256 + | SignatureScheme::RSA_PSS_PSS_SHA384 + | SignatureScheme::RSA_PSS_PSS_SHA512 + | SignatureScheme::RSA_PKCS1_SHA256 + | SignatureScheme::RSA_PKCS1_SHA384 + | SignatureScheme::RSA_PKCS1_SHA512 + ) } /// Parse a `SignatureScheme` from wire format. @@ -661,7 +647,7 @@ impl SignatureScheme { } /// All recognized signature schemes (every non-`Unknown` variant). - pub fn all() -> &'static [SignatureScheme] { + pub const fn all() -> &'static [SignatureScheme] { &[ SignatureScheme::ECDSA_SECP256R1_SHA256, SignatureScheme::ECDSA_SECP384R1_SHA384, @@ -698,7 +684,7 @@ impl SignatureScheme { /// In DTLS 1.3, ECDSA signature schemes encode the expected curve. /// Returns `None` for non-ECDSA schemes. pub fn named_group(&self) -> Option { - match self { + match *self { SignatureScheme::ECDSA_SECP256R1_SHA256 => Some(NamedGroup::SECP256R1), SignatureScheme::ECDSA_SECP384R1_SHA384 => Some(NamedGroup::SECP384R1), _ => None, @@ -707,7 +693,7 @@ impl SignatureScheme { /// Returns the hash algorithm associated with this signature scheme. pub fn hash_algorithm(&self) -> HashAlgorithm { - match self { + match *self { SignatureScheme::ECDSA_SECP256R1_SHA256 | SignatureScheme::RSA_PSS_RSAE_SHA256 | SignatureScheme::RSA_PSS_PSS_SHA256 @@ -722,7 +708,29 @@ impl SignatureScheme { | SignatureScheme::RSA_PKCS1_SHA512 => HashAlgorithm::SHA512, // Ed25519 and Ed448 have intrinsic hash algorithms SignatureScheme::ED25519 | SignatureScheme::ED448 => HashAlgorithm::NONE, - SignatureScheme::Unknown(_) => HashAlgorithm::UNKNOWN_DERIVED, + _ => HashAlgorithm::UNKNOWN_DERIVED, + } + } +} + +impl fmt::Debug for SignatureScheme { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + SignatureScheme::ECDSA_SECP256R1_SHA256 => f.write_str("ECDSA_SECP256R1_SHA256"), + SignatureScheme::ECDSA_SECP384R1_SHA384 => f.write_str("ECDSA_SECP384R1_SHA384"), + SignatureScheme::ECDSA_SECP521R1_SHA512 => f.write_str("ECDSA_SECP521R1_SHA512"), + SignatureScheme::ED25519 => f.write_str("ED25519"), + SignatureScheme::ED448 => f.write_str("ED448"), + SignatureScheme::RSA_PSS_RSAE_SHA256 => f.write_str("RSA_PSS_RSAE_SHA256"), + SignatureScheme::RSA_PSS_RSAE_SHA384 => f.write_str("RSA_PSS_RSAE_SHA384"), + SignatureScheme::RSA_PSS_RSAE_SHA512 => f.write_str("RSA_PSS_RSAE_SHA512"), + SignatureScheme::RSA_PSS_PSS_SHA256 => f.write_str("RSA_PSS_PSS_SHA256"), + SignatureScheme::RSA_PSS_PSS_SHA384 => f.write_str("RSA_PSS_PSS_SHA384"), + SignatureScheme::RSA_PSS_PSS_SHA512 => f.write_str("RSA_PSS_PSS_SHA512"), + SignatureScheme::RSA_PKCS1_SHA256 => f.write_str("RSA_PKCS1_SHA256"), + SignatureScheme::RSA_PKCS1_SHA384 => f.write_str("RSA_PKCS1_SHA384"), + SignatureScheme::RSA_PKCS1_SHA512 => f.write_str("RSA_PKCS1_SHA512"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), } } } @@ -1143,6 +1151,36 @@ mod tests { assert_eq!(format!("{:?}", ContentType::from_u8(24)), "Unknown(24)"); } + #[test] + fn signature_scheme_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert!(SignatureScheme::default().is_unknown()); + } + + #[test] + fn signature_scheme_wire_roundtrip() { + for scheme in SignatureScheme::all() { + assert_eq!(SignatureScheme::from_u16(scheme.as_u16()), *scheme); + assert!(!scheme.is_unknown()); + } + + let unknown = SignatureScheme::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn signature_scheme_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", SignatureScheme::ECDSA_SECP256R1_SHA256), + "ECDSA_SECP256R1_SHA256" + ); + assert_eq!( + format!("{:?}", SignatureScheme::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn random_parse() { let data = [ @@ -1200,7 +1238,7 @@ mod tests { assert_eq!(SignatureScheme::RSA_PSS_RSAE_SHA256.named_group(), None); assert_eq!(SignatureScheme::ED25519.named_group(), None); assert_eq!(SignatureScheme::ECDSA_SECP521R1_SHA512.named_group(), None); - assert_eq!(SignatureScheme::Unknown(0xFFFF).named_group(), None); + assert_eq!(SignatureScheme::from_u16(0xFFFF).named_group(), None); } #[test] From ba7af655d033a1c9d74769d2a5c2c5d1a73f313e Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:52:03 +0300 Subject: [PATCH 07/18] types: make dtls13 cipher suite a newtype --- src/dtls13/client.rs | 2 +- src/types.rs | 108 +++++++++++++++++++++++++++++-------------- 2 files changed, 74 insertions(+), 36 deletions(-) diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index cf30cdef..4eb9134e 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -565,7 +565,7 @@ impl State { // Validate cipher suite let cs = server_hello.cipher_suite; - if matches!(cs, Dtls13CipherSuite::Unknown(_)) { + if cs.is_unknown() { return Err((Error::SecurityError( crate::SecurityError::ServerSelectedUnknownCipherSuite, )) diff --git a/src/types.rs b/src/types.rs index 0384ae1e..b0cfda3f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -743,47 +743,42 @@ impl fmt::Debug for SignatureScheme { /// /// Unlike DTLS 1.2, TLS 1.3 cipher suites only specify the AEAD algorithm /// and hash function. Key exchange is negotiated separately via key_share. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -#[non_exhaustive] -pub enum Dtls13CipherSuite { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct Dtls13CipherSuite(u16); + +impl Dtls13CipherSuite { /// TLS_AES_128_GCM_SHA256. - AES_128_GCM_SHA256, + pub const AES_128_GCM_SHA256: Self = Self(0x1301); /// TLS_AES_256_GCM_SHA384. - AES_256_GCM_SHA384, + pub const AES_256_GCM_SHA384: Self = Self(0x1302); /// TLS_CHACHA20_POLY1305_SHA256. - CHACHA20_POLY1305_SHA256, + pub const CHACHA20_POLY1305_SHA256: Self = Self(0x1303); /// TLS_AES_128_CCM_SHA256. - AES_128_CCM_SHA256, + pub const AES_128_CCM_SHA256: Self = Self(0x1304); /// TLS_AES_128_CCM_8_SHA256 (shorter tag, for constrained devices). - AES_128_CCM_8_SHA256, - /// Unknown or unsupported cipher suite. - Unknown(u16), -} + pub const AES_128_CCM_8_SHA256: Self = Self(0x1305); -impl Dtls13CipherSuite { /// Convert a wire format u16 value to a `Dtls13CipherSuite`. - pub fn from_u16(value: u16) -> Self { - match value { - 0x1301 => Dtls13CipherSuite::AES_128_GCM_SHA256, - 0x1302 => Dtls13CipherSuite::AES_256_GCM_SHA384, - 0x1303 => Dtls13CipherSuite::CHACHA20_POLY1305_SHA256, - 0x1304 => Dtls13CipherSuite::AES_128_CCM_SHA256, - 0x1305 => Dtls13CipherSuite::AES_128_CCM_8_SHA256, - _ => Dtls13CipherSuite::Unknown(value), - } + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Convert this `Dtls13CipherSuite` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - Dtls13CipherSuite::AES_128_GCM_SHA256 => 0x1301, - Dtls13CipherSuite::AES_256_GCM_SHA384 => 0x1302, - Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 => 0x1303, - Dtls13CipherSuite::AES_128_CCM_SHA256 => 0x1304, - Dtls13CipherSuite::AES_128_CCM_8_SHA256 => 0x1305, - Dtls13CipherSuite::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known DTLS 1.3 cipher suite wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + Dtls13CipherSuite::AES_128_GCM_SHA256 + | Dtls13CipherSuite::AES_256_GCM_SHA384 + | Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 + | Dtls13CipherSuite::AES_128_CCM_SHA256 + | Dtls13CipherSuite::AES_128_CCM_8_SHA256 + ) } /// Parse a `Dtls13CipherSuite` from wire format. @@ -794,13 +789,13 @@ impl Dtls13CipherSuite { /// Returns the hash algorithm used by this cipher suite. pub fn hash_algorithm(&self) -> HashAlgorithm { - match self { + match *self { Dtls13CipherSuite::AES_128_GCM_SHA256 | Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 | Dtls13CipherSuite::AES_128_CCM_SHA256 | Dtls13CipherSuite::AES_128_CCM_8_SHA256 => HashAlgorithm::SHA256, Dtls13CipherSuite::AES_256_GCM_SHA384 => HashAlgorithm::SHA384, - Dtls13CipherSuite::Unknown(_) => HashAlgorithm::UNKNOWN_DERIVED, + _ => HashAlgorithm::UNKNOWN_DERIVED, } } @@ -810,7 +805,7 @@ impl Dtls13CipherSuite { } /// All recognized DTLS 1.3 cipher suites (every non-`Unknown` variant). - pub fn all() -> &'static [Dtls13CipherSuite] { + pub const fn all() -> &'static [Dtls13CipherSuite] { &[ Dtls13CipherSuite::AES_128_GCM_SHA256, Dtls13CipherSuite::AES_256_GCM_SHA384, @@ -821,7 +816,7 @@ impl Dtls13CipherSuite { } /// Supported DTLS 1.3 cipher suites in preference order. - pub fn supported() -> &'static [Dtls13CipherSuite] { + pub const fn supported() -> &'static [Dtls13CipherSuite] { &[ Dtls13CipherSuite::AES_128_GCM_SHA256, Dtls13CipherSuite::AES_256_GCM_SHA384, @@ -835,6 +830,19 @@ impl Dtls13CipherSuite { } } +impl fmt::Debug for Dtls13CipherSuite { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Dtls13CipherSuite::AES_128_GCM_SHA256 => f.write_str("AES_128_GCM_SHA256"), + Dtls13CipherSuite::AES_256_GCM_SHA384 => f.write_str("AES_256_GCM_SHA384"), + Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 => f.write_str("CHACHA20_POLY1305_SHA256"), + Dtls13CipherSuite::AES_128_CCM_SHA256 => f.write_str("AES_128_CCM_SHA256"), + Dtls13CipherSuite::AES_128_CCM_8_SHA256 => f.write_str("AES_128_CCM_8_SHA256"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Protocol Version // ============================================================================ @@ -1181,6 +1189,36 @@ mod tests { ); } + #[test] + fn dtls13_cipher_suite_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert!(Dtls13CipherSuite::default().is_unknown()); + } + + #[test] + fn dtls13_cipher_suite_wire_roundtrip() { + for suite in Dtls13CipherSuite::all() { + assert_eq!(Dtls13CipherSuite::from_u16(suite.as_u16()), *suite); + assert!(!suite.is_unknown()); + } + + let unknown = Dtls13CipherSuite::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn dtls13_cipher_suite_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", Dtls13CipherSuite::AES_128_GCM_SHA256), + "AES_128_GCM_SHA256" + ); + assert_eq!( + format!("{:?}", Dtls13CipherSuite::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn random_parse() { let data = [ From 051ebb91329a37b7a0e1b145bee16f89200ef2de Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:53:13 +0300 Subject: [PATCH 08/18] types: make protocol version a newtype --- .../message/extensions/supported_versions.rs | 2 +- src/types.rs | 94 +++++++++++++------ 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/src/dtls13/message/extensions/supported_versions.rs b/src/dtls13/message/extensions/supported_versions.rs index 13e43c83..cb33a704 100644 --- a/src/dtls13/message/extensions/supported_versions.rs +++ b/src/dtls13/message/extensions/supported_versions.rs @@ -29,7 +29,7 @@ impl SupportedVersionsClientHello { let mut rest = versions_data; while !rest.is_empty() { let (r, version) = ProtocolVersion::parse(rest)?; - if !matches!(version, ProtocolVersion::Unknown(_)) { + if !version.is_unknown() { versions .try_push(version) .map_err(|_| Err::Failure(Error::new(rest, ErrorKind::LengthValue)))?; diff --git a/src/types.rs b/src/types.rs index b0cfda3f..1843adad 100644 --- a/src/types.rs +++ b/src/types.rs @@ -850,45 +850,40 @@ impl fmt::Debug for Dtls13CipherSuite { /// DTLS protocol version identifiers. /// /// Used in record headers and handshake messages for both DTLS 1.2 and 1.3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ProtocolVersion { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct ProtocolVersion(u16); + +impl ProtocolVersion { /// DTLS 1.0. - DTLS1_0, + pub const DTLS1_0: Self = Self(0xFEFF); /// DTLS 1.2. - DTLS1_2, + pub const DTLS1_2: Self = Self(0xFEFD); /// DTLS 1.3. - DTLS1_3, - /// Unknown protocol version. - Unknown(u16), -} + pub const DTLS1_3: Self = Self(0xFEFC); -impl Default for ProtocolVersion { - fn default() -> Self { - Self::Unknown(0) + /// Convert a wire format u16 value to a `ProtocolVersion`. + pub const fn from_u16(value: u16) -> Self { + Self(value) } -} -impl ProtocolVersion { /// Convert this `ProtocolVersion` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - ProtocolVersion::DTLS1_0 => 0xFEFF, - ProtocolVersion::DTLS1_2 => 0xFEFD, - ProtocolVersion::DTLS1_3 => 0xFEFC, - ProtocolVersion::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known DTLS protocol version wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + ProtocolVersion::DTLS1_0 | ProtocolVersion::DTLS1_2 | ProtocolVersion::DTLS1_3 + ) } /// Parse a `ProtocolVersion` from wire format. pub fn parse(input: &[u8]) -> IResult<&[u8], ProtocolVersion> { let (input, version) = be_u16(input)?; - let protocol_version = match version { - 0xFEFF => ProtocolVersion::DTLS1_0, - 0xFEFD => ProtocolVersion::DTLS1_2, - 0xFEFC => ProtocolVersion::DTLS1_3, - _ => ProtocolVersion::Unknown(version), - }; - Ok((input, protocol_version)) + Ok((input, ProtocolVersion::from_u16(version))) } /// Serialize this `ProtocolVersion` to wire format. @@ -897,6 +892,17 @@ impl ProtocolVersion { } } +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + ProtocolVersion::DTLS1_0 => f.write_str("DTLS1_0"), + ProtocolVersion::DTLS1_2 => f.write_str("DTLS1_2"), + ProtocolVersion::DTLS1_3 => f.write_str("DTLS1_3"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Compression Method // ============================================================================ @@ -1219,6 +1225,40 @@ mod tests { ); } + #[test] + fn protocol_version_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert!(ProtocolVersion::default().is_unknown()); + } + + #[test] + fn protocol_version_wire_roundtrip() { + let known = [ + (0xFEFF, ProtocolVersion::DTLS1_0), + (0xFEFD, ProtocolVersion::DTLS1_2), + (0xFEFC, ProtocolVersion::DTLS1_3), + ]; + + for (wire, version) in known { + assert_eq!(ProtocolVersion::from_u16(wire), version); + assert_eq!(version.as_u16(), wire); + assert!(!version.is_unknown()); + } + + let unknown = ProtocolVersion::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn protocol_version_debug_stays_enum_like() { + assert_eq!(format!("{:?}", ProtocolVersion::DTLS1_2), "DTLS1_2"); + assert_eq!( + format!("{:?}", ProtocolVersion::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn random_parse() { let data = [ From 66d7f0b41da2971c0561a797f9f1938bf086d67a Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 18:55:46 +0300 Subject: [PATCH 09/18] dtls12: make cipher suite a newtype --- src/dtls12/client.rs | 2 +- src/dtls12/message/mod.rs | 158 ++++++++++++++++++++++++-------------- tests/dtls12/edge.rs | 2 +- 3 files changed, 104 insertions(+), 58 deletions(-) diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index b1194cec..338ad55a 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -467,7 +467,7 @@ impl State { // Enforce cipher suite is known and allowed let cs = server_hello.cipher_suite; - if matches!(cs, Dtls12CipherSuite::Unknown(_)) { + if cs.is_unknown() { return Err((Error::SecurityError( crate::SecurityError::ServerSelectedUnknownCipherSuite, )) diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index 6e83119d..5c3d8036 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -22,6 +22,8 @@ mod server_hello; mod server_key_exchange; mod wrapped; +use std::fmt; + use arrayvec::ArrayVec; pub use certificate::Certificate; pub use certificate_request::CertificateRequest; @@ -55,60 +57,40 @@ use nom::number::complete::{be_u8, be_u16}; pub type CipherSuiteVec = ArrayVec; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] /// Supported TLS 1.2 cipher suites for DTLS. -pub enum Dtls12CipherSuite { - // ECDHE with AES-GCM - /// ECDHE with ECDSA authentication, AES-256-GCM, SHA-384 - ECDHE_ECDSA_AES256_GCM_SHA384, // 0xC02C - /// ECDHE with ECDSA authentication, AES-128-GCM, SHA-256 - ECDHE_ECDSA_AES128_GCM_SHA256, // 0xC02B - /// ECDHE with ECDSA authentication, ChaCha20-Poly1305, SHA-256 - ECDHE_ECDSA_CHACHA20_POLY1305_SHA256, // 0xCCA9 - - // PSK cipher suites (no certificate authentication) - /// PSK with AES-128-CCM-8 (8-byte tag), SHA-256 - PSK_AES128_CCM_8, // 0xC0A8 - - /// Unknown or unsupported cipher suite by its IANA value - Unknown(u16), -} - -impl Default for Dtls12CipherSuite { - fn default() -> Self { - Self::Unknown(0) - } -} +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct Dtls12CipherSuite(u16); impl Dtls12CipherSuite { - /// Convert the 16-bit IANA value to a `Dtls12CipherSuite`. - pub fn from_u16(value: u16) -> Self { - match value { - // ECDHE with AES-GCM - 0xC02C => Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384, - 0xC02B => Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256, - 0xCCA9 => Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256, + /// ECDHE with ECDSA authentication, AES-256-GCM, SHA-384. + pub const ECDHE_ECDSA_AES256_GCM_SHA384: Self = Self(0xC02C); + /// ECDHE with ECDSA authentication, AES-128-GCM, SHA-256. + pub const ECDHE_ECDSA_AES128_GCM_SHA256: Self = Self(0xC02B); + /// ECDHE with ECDSA authentication, ChaCha20-Poly1305, SHA-256. + pub const ECDHE_ECDSA_CHACHA20_POLY1305_SHA256: Self = Self(0xCCA9); + /// PSK with AES-128-CCM-8 (8-byte tag), SHA-256. + pub const PSK_AES128_CCM_8: Self = Self(0xC0A8); - // PSK - 0xC0A8 => Dtls12CipherSuite::PSK_AES128_CCM_8, - - _ => Dtls12CipherSuite::Unknown(value), - } + /// Convert the 16-bit IANA value to a `Dtls12CipherSuite`. + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Return the 16-bit IANA value for this cipher suite. - pub fn as_u16(&self) -> u16 { - match self { - // ECDHE with AES-GCM - Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 => 0xC02C, - Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 => 0xC02B, - Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => 0xCCA9, - - Dtls12CipherSuite::PSK_AES128_CCM_8 => 0xC0A8, + pub const fn as_u16(&self) -> u16 { + self.0 + } - Dtls12CipherSuite::Unknown(value) => *value, - } + /// Returns true if this is not a known DTLS 1.2 cipher suite wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 + | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 + | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 + | Dtls12CipherSuite::PSK_AES128_CCM_8 + ) } /// Parse a `Dtls12CipherSuite` from network byte order. @@ -119,20 +101,20 @@ impl Dtls12CipherSuite { /// Length in bytes of verify_data for Finished MACs. pub fn verify_data_length(&self) -> usize { - match self { + match *self { // AES-GCM suites Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 | Dtls12CipherSuite::PSK_AES128_CCM_8 => 12, - Dtls12CipherSuite::Unknown(_) => 12, // Default length for unknown cipher suites + _ => 12, // Default length for unknown cipher suites } } /// The key exchange algorithm family for this cipher suite. pub fn as_key_exchange_algorithm(&self) -> KeyExchangeAlgorithm { - match self { + match *self { // All ECDHE ciphers Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 @@ -142,14 +124,14 @@ impl Dtls12CipherSuite { Dtls12CipherSuite::PSK_AES128_CCM_8 => KeyExchangeAlgorithm::PSK, - Dtls12CipherSuite::Unknown(_) => KeyExchangeAlgorithm::Unknown, + _ => KeyExchangeAlgorithm::Unknown, } } /// Whether this cipher suite uses ECC-based key exchange. pub fn has_ecc(&self) -> bool { matches!( - self, + *self, Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 @@ -158,7 +140,7 @@ impl Dtls12CipherSuite { /// Whether this cipher suite uses PSK (Pre-Shared Key) key exchange. pub fn is_psk(&self) -> bool { - matches!(self, Dtls12CipherSuite::PSK_AES128_CCM_8) + matches!(*self, Dtls12CipherSuite::PSK_AES128_CCM_8) } /// All supported cipher suites in server preference order. @@ -195,12 +177,12 @@ impl Dtls12CipherSuite { /// The hash algorithm used by this cipher suite. pub fn hash_algorithm(&self) -> HashAlgorithm { - match self { + match *self { Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 => HashAlgorithm::SHA384, Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 | Dtls12CipherSuite::PSK_AES128_CCM_8 => HashAlgorithm::SHA256, - Dtls12CipherSuite::Unknown(_) => HashAlgorithm::UNKNOWN_DERIVED, + _ => HashAlgorithm::UNKNOWN_DERIVED, } } @@ -208,14 +190,14 @@ impl Dtls12CipherSuite { /// /// Returns `None` for PSK cipher suites (no signature authentication). pub fn signature_algorithm(&self) -> Option { - match self { + match *self { Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => { Some(SignatureAlgorithm::ECDSA) } Dtls12CipherSuite::PSK_AES128_CCM_8 => None, - Dtls12CipherSuite::Unknown(_) => Some(SignatureAlgorithm::UNKNOWN_DERIVED), + _ => Some(SignatureAlgorithm::UNKNOWN_DERIVED), } } @@ -230,6 +212,24 @@ impl Dtls12CipherSuite { } } +impl fmt::Debug for Dtls12CipherSuite { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 => { + f.write_str("ECDHE_ECDSA_AES256_GCM_SHA384") + } + Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 => { + f.write_str("ECDHE_ECDSA_AES128_GCM_SHA256") + } + Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => { + f.write_str("ECDHE_ECDSA_CHACHA20_POLY1305_SHA256") + } + Dtls12CipherSuite::PSK_AES128_CCM_8 => f.write_str("PSK_AES128_CCM_8"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + pub type CompressionMethodVec = ArrayVec; @@ -365,3 +365,49 @@ impl SignatureAndHashAlgorithm { Self::supported().contains(self) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dtls12_cipher_suite_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert!(Dtls12CipherSuite::default().is_unknown()); + } + + #[test] + fn dtls12_cipher_suite_wire_roundtrip() { + for suite in Dtls12CipherSuite::all() { + assert_eq!(Dtls12CipherSuite::from_u16(suite.as_u16()), *suite); + assert!(!suite.is_unknown()); + } + + let unknown = Dtls12CipherSuite::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn dtls12_cipher_suite_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256), + "ECDHE_ECDSA_AES128_GCM_SHA256" + ); + assert_eq!( + format!("{:?}", Dtls12CipherSuite::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + + #[test] + fn unknown_dtls12_cipher_suite_uses_internal_derived_markers() { + let unknown = Dtls12CipherSuite::from_u16(0xFFFF); + assert!(unknown.hash_algorithm().is_unknown()); + assert!( + unknown + .signature_algorithm() + .is_some_and(|s| s.is_unknown()) + ); + } +} diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 7624c612..6fd95d38 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -58,7 +58,7 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize { | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 => 24, Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => 16, Dtls12CipherSuite::PSK_AES128_CCM_8 => 16, - Dtls12CipherSuite::Unknown(_) => panic!("unknown cipher suite"), + _ => panic!("unknown cipher suite"), } } From 9d3cafaedc4df61ff93227af93ea1bb9a02c83e5 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:03:26 +0300 Subject: [PATCH 10/18] dtls12: make extension type a newtype --- src/dtls12/client.rs | 4 +- src/dtls12/message/client_hello.rs | 19 +- src/dtls12/message/extension.rs | 288 +++++++++++++++-------------- src/dtls12/message/server_hello.rs | 13 +- src/dtls12/server.rs | 10 +- 5 files changed, 175 insertions(+), 159 deletions(-) diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 338ad55a..e5199ac8 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -507,7 +507,7 @@ impl State { }; for extension in extensions { - if extension.extension_type == ExtensionType::UseSrtp { + if extension.extension_type == ExtensionType::USE_SRTP { // Parse the use_srtp extension to get the selected profile let extension_data = extension.extension_data(&client.defragment_buffer); let (_, use_srtp) = @@ -523,7 +523,7 @@ impl State { } // We are to use extended master secret - if extension.extension_type == ExtensionType::ExtendedMasterSecret { + if extension.extension_type == ExtensionType::EXTENDED_MASTER_SECRET { extended_master_secret = true; trace!("Server negotiated Extended Master Secret"); } diff --git a/src/dtls12/message/client_hello.rs b/src/dtls12/message/client_hello.rs index 0067dddc..f0178f5e 100644 --- a/src/dtls12/message/client_hello.rs +++ b/src/dtls12/message/client_hello.rs @@ -65,31 +65,31 @@ impl ClientHello { let supported_groups = SupportedGroupsExtension { groups }; let start_pos = buf.len(); supported_groups.serialize(buf); - ranges.push((ExtensionType::SupportedGroups, start_pos, buf.len())); + ranges.push((ExtensionType::SUPPORTED_GROUPS, start_pos, buf.len())); // Add EC point formats extension let ec_point_formats = ECPointFormatsExtension::default(); let start_pos = buf.len(); ec_point_formats.serialize(buf); - ranges.push((ExtensionType::EcPointFormats, start_pos, buf.len())); + ranges.push((ExtensionType::EC_POINT_FORMATS, start_pos, buf.len())); } // Add signature algorithms extension (required for TLS 1.2+) let signature_algorithms = SignatureAlgorithmsExtension::default(); let start_pos = buf.len(); signature_algorithms.serialize(buf); - ranges.push((ExtensionType::SignatureAlgorithms, start_pos, buf.len())); + ranges.push((ExtensionType::SIGNATURE_ALGORITHMS, start_pos, buf.len())); // Add use_srtp extension for DTLS-SRTP support let use_srtp = UseSrtpExtension::default(); let start_pos = buf.len(); use_srtp.serialize(buf); - ranges.push((ExtensionType::UseSrtp, start_pos, buf.len())); + ranges.push((ExtensionType::USE_SRTP, start_pos, buf.len())); // // Add session_ticket extension (empty) // let start_pos = buf.len(); // buf.extend_from_slice(&[0x00]); // Empty extension data - // ranges.push((ExtensionType::SessionTicket, start_pos, buf.len())); + // ranges.push((ExtensionType::SESSION_TICKET, start_pos, buf.len())); let need_etm = self .cipher_suites @@ -99,12 +99,12 @@ impl ClientHello { // Add encrypt_then_mac extension (empty) let start_pos = buf.len(); buf.extend_from_slice(&[0x00]); // Empty extension data - ranges.push((ExtensionType::EncryptThenMac, start_pos, buf.len())); + ranges.push((ExtensionType::ENCRYPT_THEN_MAC, start_pos, buf.len())); } let start_pos = buf.len(); ranges.push(( - ExtensionType::ExtendedMasterSecret, + ExtensionType::EXTENDED_MASTER_SECRET, start_pos, start_pos, // No data at all )); @@ -333,8 +333,9 @@ mod tests { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message - .extend_from_slice(&ExtensionType::ExtendedMasterSecret.as_u16().to_be_bytes()); + message.extend_from_slice( + &ExtensionType::EXTENDED_MASTER_SECRET.as_u16().to_be_bytes(), + ); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls12/message/extension.rs b/src/dtls12/message/extension.rs index 59da8920..57404543 100644 --- a/src/dtls12/message/extension.rs +++ b/src/dtls12/message/extension.rs @@ -1,7 +1,7 @@ use crate::buffer::Buf; use arrayvec::ArrayVec; use nom::{IResult, bytes::complete::take, number::complete::be_u16}; -use std::ops::Range; +use std::{fmt, ops::Range}; pub type ExtensionVec = ArrayVec; @@ -51,142 +51,69 @@ impl Extension { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ExtensionType { - ServerName, - MaxFragmentLength, - ClientCertificateUrl, - TrustedCaKeys, - TruncatedHmac, - StatusRequest, - UserMapping, - ClientAuthz, - ServerAuthz, - CertType, - SupportedGroups, - EcPointFormats, - Srp, - SignatureAlgorithms, - UseSrtp, - Heartbeat, - ApplicationLayerProtocolNegotiation, - StatusRequestV2, - SignedCertificateTimestamp, - ClientCertificateType, - ServerCertificateType, - Padding, - EncryptThenMac, - ExtendedMasterSecret, - TokenBinding, - CachedInfo, - SessionTicket, - PreSharedKey, - EarlyData, - SupportedVersions, - Cookie, - PskKeyExchangeModes, - CertificateAuthorities, - OidFilters, - PostHandshakeAuth, - SignatureAlgorithmsCert, - KeyShare, - RenegotiationInfo, - Unknown(u16), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct ExtensionType(u16); impl Default for ExtensionType { fn default() -> Self { - Self::Unknown(0) + Self(u16::MAX) } } impl ExtensionType { - pub fn from_u16(value: u16) -> Self { - match value { - 0x0000 => ExtensionType::ServerName, - 0x0001 => ExtensionType::MaxFragmentLength, - 0x0002 => ExtensionType::ClientCertificateUrl, - 0x0003 => ExtensionType::TrustedCaKeys, - 0x0004 => ExtensionType::TruncatedHmac, - 0x0005 => ExtensionType::StatusRequest, - 0x0006 => ExtensionType::UserMapping, - 0x0007 => ExtensionType::ClientAuthz, - 0x0008 => ExtensionType::ServerAuthz, - 0x0009 => ExtensionType::CertType, - 0x000A => ExtensionType::SupportedGroups, - 0x000B => ExtensionType::EcPointFormats, - 0x000C => ExtensionType::Srp, - 0x000D => ExtensionType::SignatureAlgorithms, - 0x000E => ExtensionType::UseSrtp, - 0x000F => ExtensionType::Heartbeat, - 0x0010 => ExtensionType::ApplicationLayerProtocolNegotiation, - 0x0011 => ExtensionType::StatusRequestV2, - 0x0012 => ExtensionType::SignedCertificateTimestamp, - 0x0013 => ExtensionType::ClientCertificateType, - 0x0014 => ExtensionType::ServerCertificateType, - 0x0015 => ExtensionType::Padding, - 0x0016 => ExtensionType::EncryptThenMac, - 0x0017 => ExtensionType::ExtendedMasterSecret, - 0x0018 => ExtensionType::TokenBinding, - 0x0019 => ExtensionType::CachedInfo, - 0x0023 => ExtensionType::SessionTicket, - 0x0029 => ExtensionType::PreSharedKey, - 0x002A => ExtensionType::EarlyData, - 0x002B => ExtensionType::SupportedVersions, - 0x002C => ExtensionType::Cookie, - 0x002D => ExtensionType::PskKeyExchangeModes, - 0x002F => ExtensionType::CertificateAuthorities, - 0x0030 => ExtensionType::OidFilters, - 0x0031 => ExtensionType::PostHandshakeAuth, - 0x0032 => ExtensionType::SignatureAlgorithmsCert, - 0x0033 => ExtensionType::KeyShare, - 0xFF01 => ExtensionType::RenegotiationInfo, - _ => ExtensionType::Unknown(value), - } + pub const SERVER_NAME: Self = Self(0x0000); + pub const MAX_FRAGMENT_LENGTH: Self = Self(0x0001); + pub const CLIENT_CERTIFICATE_URL: Self = Self(0x0002); + pub const TRUSTED_CA_KEYS: Self = Self(0x0003); + pub const TRUNCATED_HMAC: Self = Self(0x0004); + pub const STATUS_REQUEST: Self = Self(0x0005); + pub const USER_MAPPING: Self = Self(0x0006); + pub const CLIENT_AUTHZ: Self = Self(0x0007); + pub const SERVER_AUTHZ: Self = Self(0x0008); + pub const CERT_TYPE: Self = Self(0x0009); + pub const SUPPORTED_GROUPS: Self = Self(0x000A); + pub const EC_POINT_FORMATS: Self = Self(0x000B); + pub const SRP: Self = Self(0x000C); + pub const SIGNATURE_ALGORITHMS: Self = Self(0x000D); + pub const USE_SRTP: Self = Self(0x000E); + pub const HEARTBEAT: Self = Self(0x000F); + pub const APPLICATION_LAYER_PROTOCOL_NEGOTIATION: Self = Self(0x0010); + pub const STATUS_REQUEST_V2: Self = Self(0x0011); + pub const SIGNED_CERTIFICATE_TIMESTAMP: Self = Self(0x0012); + pub const CLIENT_CERTIFICATE_TYPE: Self = Self(0x0013); + pub const SERVER_CERTIFICATE_TYPE: Self = Self(0x0014); + pub const PADDING: Self = Self(0x0015); + pub const ENCRYPT_THEN_MAC: Self = Self(0x0016); + pub const EXTENDED_MASTER_SECRET: Self = Self(0x0017); + pub const TOKEN_BINDING: Self = Self(0x0018); + pub const CACHED_INFO: Self = Self(0x0019); + pub const SESSION_TICKET: Self = Self(0x0023); + pub const PRE_SHARED_KEY: Self = Self(0x0029); + pub const EARLY_DATA: Self = Self(0x002A); + pub const SUPPORTED_VERSIONS: Self = Self(0x002B); + pub const COOKIE: Self = Self(0x002C); + pub const PSK_KEY_EXCHANGE_MODES: Self = Self(0x002D); + pub const CERTIFICATE_AUTHORITIES: Self = Self(0x002F); + pub const OID_FILTERS: Self = Self(0x0030); + pub const POST_HANDSHAKE_AUTH: Self = Self(0x0031); + pub const SIGNATURE_ALGORITHMS_CERT: Self = Self(0x0032); + pub const KEY_SHARE: Self = Self(0x0033); + pub const RENEGOTIATION_INFO: Self = Self(0xFF01); + + pub const fn from_u16(value: u16) -> Self { + Self(value) } - pub fn as_u16(&self) -> u16 { - match self { - ExtensionType::ServerName => 0x0000, - ExtensionType::MaxFragmentLength => 0x0001, - ExtensionType::ClientCertificateUrl => 0x0002, - ExtensionType::TrustedCaKeys => 0x0003, - ExtensionType::TruncatedHmac => 0x0004, - ExtensionType::StatusRequest => 0x0005, - ExtensionType::UserMapping => 0x0006, - ExtensionType::ClientAuthz => 0x0007, - ExtensionType::ServerAuthz => 0x0008, - ExtensionType::CertType => 0x0009, - ExtensionType::SupportedGroups => 0x000A, - ExtensionType::EcPointFormats => 0x000B, - ExtensionType::Srp => 0x000C, - ExtensionType::SignatureAlgorithms => 0x000D, - ExtensionType::UseSrtp => 0x000E, - ExtensionType::Heartbeat => 0x000F, - ExtensionType::ApplicationLayerProtocolNegotiation => 0x0010, - ExtensionType::StatusRequestV2 => 0x0011, - ExtensionType::SignedCertificateTimestamp => 0x0012, - ExtensionType::ClientCertificateType => 0x0013, - ExtensionType::ServerCertificateType => 0x0014, - ExtensionType::Padding => 0x0015, - ExtensionType::EncryptThenMac => 0x0016, - ExtensionType::ExtendedMasterSecret => 0x0017, - ExtensionType::TokenBinding => 0x0018, - ExtensionType::CachedInfo => 0x0019, - ExtensionType::SessionTicket => 0x0023, - ExtensionType::PreSharedKey => 0x0029, - ExtensionType::EarlyData => 0x002A, - ExtensionType::SupportedVersions => 0x002B, - ExtensionType::Cookie => 0x002C, - ExtensionType::PskKeyExchangeModes => 0x002D, - ExtensionType::CertificateAuthorities => 0x002F, - ExtensionType::OidFilters => 0x0030, - ExtensionType::PostHandshakeAuth => 0x0031, - ExtensionType::SignatureAlgorithmsCert => 0x0032, - ExtensionType::KeyShare => 0x0033, - ExtensionType::RenegotiationInfo => 0xFF01, - ExtensionType::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!( + *self, + Self(0x0000..=0x0019 | 0x0023 | 0x0029..=0x002D | 0x002F..=0x0033 | 0xFF01) + ) } pub fn parse(input: &[u8]) -> IResult<&[u8], ExtensionType> { @@ -202,29 +129,116 @@ impl ExtensionType { /// Supported extension types that this implementation handles. pub const fn supported() -> &'static [ExtensionType; 8] { &[ - ExtensionType::SupportedGroups, - ExtensionType::EcPointFormats, - ExtensionType::SignatureAlgorithms, - ExtensionType::UseSrtp, - ExtensionType::EncryptThenMac, - ExtensionType::ExtendedMasterSecret, - ExtensionType::RenegotiationInfo, - ExtensionType::SessionTicket, + ExtensionType::SUPPORTED_GROUPS, + ExtensionType::EC_POINT_FORMATS, + ExtensionType::SIGNATURE_ALGORITHMS, + ExtensionType::USE_SRTP, + ExtensionType::ENCRYPT_THEN_MAC, + ExtensionType::EXTENDED_MASTER_SECRET, + ExtensionType::RENEGOTIATION_INFO, + ExtensionType::SESSION_TICKET, ] } } +impl fmt::Debug for ExtensionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + ExtensionType::SERVER_NAME => "ServerName", + ExtensionType::MAX_FRAGMENT_LENGTH => "MaxFragmentLength", + ExtensionType::CLIENT_CERTIFICATE_URL => "ClientCertificateUrl", + ExtensionType::TRUSTED_CA_KEYS => "TrustedCaKeys", + ExtensionType::TRUNCATED_HMAC => "TruncatedHmac", + ExtensionType::STATUS_REQUEST => "StatusRequest", + ExtensionType::USER_MAPPING => "UserMapping", + ExtensionType::CLIENT_AUTHZ => "ClientAuthz", + ExtensionType::SERVER_AUTHZ => "ServerAuthz", + ExtensionType::CERT_TYPE => "CertType", + ExtensionType::SUPPORTED_GROUPS => "SupportedGroups", + ExtensionType::EC_POINT_FORMATS => "EcPointFormats", + ExtensionType::SRP => "Srp", + ExtensionType::SIGNATURE_ALGORITHMS => "SignatureAlgorithms", + ExtensionType::USE_SRTP => "UseSrtp", + ExtensionType::HEARTBEAT => "Heartbeat", + ExtensionType::APPLICATION_LAYER_PROTOCOL_NEGOTIATION => { + "ApplicationLayerProtocolNegotiation" + } + ExtensionType::STATUS_REQUEST_V2 => "StatusRequestV2", + ExtensionType::SIGNED_CERTIFICATE_TIMESTAMP => "SignedCertificateTimestamp", + ExtensionType::CLIENT_CERTIFICATE_TYPE => "ClientCertificateType", + ExtensionType::SERVER_CERTIFICATE_TYPE => "ServerCertificateType", + ExtensionType::PADDING => "Padding", + ExtensionType::ENCRYPT_THEN_MAC => "EncryptThenMac", + ExtensionType::EXTENDED_MASTER_SECRET => "ExtendedMasterSecret", + ExtensionType::TOKEN_BINDING => "TokenBinding", + ExtensionType::CACHED_INFO => "CachedInfo", + ExtensionType::SESSION_TICKET => "SessionTicket", + ExtensionType::PRE_SHARED_KEY => "PreSharedKey", + ExtensionType::EARLY_DATA => "EarlyData", + ExtensionType::SUPPORTED_VERSIONS => "SupportedVersions", + ExtensionType::COOKIE => "Cookie", + ExtensionType::PSK_KEY_EXCHANGE_MODES => "PskKeyExchangeModes", + ExtensionType::CERTIFICATE_AUTHORITIES => "CertificateAuthorities", + ExtensionType::OID_FILTERS => "OidFilters", + ExtensionType::POST_HANDSHAKE_AUTH => "PostHandshakeAuth", + ExtensionType::SIGNATURE_ALGORITHMS_CERT => "SignatureAlgorithmsCert", + ExtensionType::KEY_SHARE => "KeyShare", + ExtensionType::RENEGOTIATION_INFO => "RenegotiationInfo", + _ => unreachable!("known DTLS 1.2 extension type missing Debug label"), + }; + + f.write_str(name) + } +} + #[cfg(test)] mod tests { use super::*; use crate::buffer::Buf; const MESSAGE: &[u8] = &[ - 0x00, 0x0A, // ExtensionType::SupportedGroups + 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS 0x00, 0x08, // Extension length 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; + #[test] + fn extension_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert!(ExtensionType::default().is_unknown()); + } + + #[test] + fn extension_type_wire_roundtrip() { + for extension_type in ExtensionType::supported() { + assert_eq!( + ExtensionType::from_u16(extension_type.as_u16()), + *extension_type + ); + assert!(!extension_type.is_unknown()); + } + + let unknown = ExtensionType::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn extension_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ExtensionType::SUPPORTED_GROUPS), + "SupportedGroups" + ); + assert_eq!( + format!("{:?}", ExtensionType::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn roundtrip() { // Parse the message with base_offset 0 diff --git a/src/dtls12/message/server_hello.rs b/src/dtls12/message/server_hello.rs index 9f1cd980..72fb249b 100644 --- a/src/dtls12/message/server_hello.rs +++ b/src/dtls12/message/server_hello.rs @@ -58,17 +58,17 @@ impl ServerHello { profiles.push(pid); let ext = UseSrtpExtension::new(profiles, ArrayVec::new()); ext.serialize(buf); - ranges.push((ExtensionType::UseSrtp, start, buf.len())); + ranges.push((ExtensionType::USE_SRTP, start, buf.len())); } // Extended Master Secret (mandatory) let start = buf.len(); - ranges.push((ExtensionType::ExtendedMasterSecret, start, start)); + ranges.push((ExtensionType::EXTENDED_MASTER_SECRET, start, start)); // Renegotiation Info (RFC 5746) - empty for initial handshake let start = buf.len(); buf.push(0); // renegotiated_connection length = 0 - ranges.push((ExtensionType::RenegotiationInfo, start, buf.len())); + ranges.push((ExtensionType::RENEGOTIATION_INFO, start, buf.len())); let mut extensions = ExtensionVec::new(); for (t, s, e) in ranges { @@ -202,7 +202,7 @@ mod test { 0xC0, 0x2B, // Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 0x00, // CompressionMethod::NULL 0x00, 0x0C, // Extensions length (12 bytes total: 2 type + 2 length + 8 data) - 0x00, 0x0A, // ExtensionType::SupportedGroups + 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS 0x00, 0x08, // Extension data length (8 bytes) 0x00, 0x06, // Extension data 0x00, 0x17, // NamedGroup::SECP256R1 @@ -237,8 +237,9 @@ mod test { let mut message = MESSAGE[..39].to_vec(); message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message - .extend_from_slice(&ExtensionType::ExtendedMasterSecret.as_u16().to_be_bytes()); + message.extend_from_slice( + &ExtensionType::EXTENDED_MASTER_SECRET.as_u16().to_be_bytes(), + ); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 15681bbc..2a41f4d1 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -421,27 +421,27 @@ impl State { let mut client_signature_algorithms: Option = None; for ext in ch.extensions { match ext.extension_type { - ExtensionType::UseSrtp => { + ExtensionType::USE_SRTP => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, use_srtp) = UseSrtpExtension::parse(ext_data).map_err(InternalError::from)?; client_srtp_profiles = Some(use_srtp.profiles); } - ExtensionType::ExtendedMasterSecret => { + ExtensionType::EXTENDED_MASTER_SECRET => { client_offers_ems = true; } - ExtensionType::SupportedGroups => { + ExtensionType::SUPPORTED_GROUPS => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, groups) = SupportedGroupsExtension::parse(ext_data).map_err(InternalError::from)?; client_supported_groups = Some(groups.groups); } - ExtensionType::EcPointFormats => { + ExtensionType::EC_POINT_FORMATS => { let ext_data = ext.extension_data(&server.defragment_buffer); let _ = ECPointFormatsExtension::parse(ext_data).map_err(InternalError::from)?; } - ExtensionType::SignatureAlgorithms => { + ExtensionType::SIGNATURE_ALGORITHMS => { let ext_data = ext.extension_data(&server.defragment_buffer); if let Ok((_, sigs)) = SignatureAlgorithmsExtension::parse(ext_data) { client_signature_algorithms = Some(sigs.supported_signature_algorithms); From 4fa3c8ad7c8d742e38f01a9476c98d69252f07ca Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:05:47 +0300 Subject: [PATCH 11/18] dtls13: make extension type a newtype --- src/dtls13/client.rs | 32 +-- src/dtls13/message/client_hello.rs | 6 +- src/dtls13/message/encrypted_extensions.rs | 4 +- src/dtls13/message/extension.rs | 284 +++++++++++---------- src/dtls13/message/server_hello.rs | 4 +- src/dtls13/server.rs | 28 +- 6 files changed, 186 insertions(+), 172 deletions(-) diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 4eb9134e..728fd5d1 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -478,13 +478,13 @@ impl State { if let Some(ref extensions) = server_hello.extensions { for ext in extensions { match ext.extension_type { - ExtensionType::KeyShare => { + ExtensionType::KEY_SHARE => { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, hrr_ks)) = KeyShareHelloRetryRequest::parse(ext_data) { client.hrr_selected_group = Some(hrr_ks.selected_group); } } - ExtensionType::Cookie => { + ExtensionType::COOKIE => { let ext_data = ext.extension_data(&client.defragment_buffer); parse_cookie_extension(ext_data).map_err(InternalError::from)?; let mut cookie = Buf::new(); @@ -512,7 +512,7 @@ impl State { let mut hrr_version_ok = false; if let Some(ref extensions) = server_hello.extensions { for ext in extensions { - if ext.extension_type == ExtensionType::SupportedVersions { + if ext.extension_type == ExtensionType::SUPPORTED_VERSIONS { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, sv)) = SupportedVersionsServerHello::parse(ext_data) { hrr_version_ok = sv.selected_version == ProtocolVersion::DTLS1_3; @@ -592,7 +592,7 @@ impl State { for ext in extensions { match ext.extension_type { - ExtensionType::SupportedVersions => { + ExtensionType::SUPPORTED_VERSIONS => { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, sv)) = SupportedVersionsServerHello::parse(ext_data) { if sv.selected_version == ProtocolVersion::DTLS1_3 { @@ -600,7 +600,7 @@ impl State { } } } - ExtensionType::KeyShare => { + ExtensionType::KEY_SHARE => { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, ks)) = KeyShareServerHello::parse(ext_data, 0) { // The key_exchange data is at offset 0 within ext_data, but @@ -695,7 +695,7 @@ impl State { // Process extensions for ext in &ee.extensions { - if ext.extension_type == ExtensionType::UseSrtp { + if ext.extension_type == ExtensionType::USE_SRTP { let ext_data = ext.extension_data(&client.defragment_buffer); let (_, use_srtp) = UseSrtpExtension::parse(ext_data).map_err(InternalError::from)?; @@ -1210,7 +1210,7 @@ fn handshake_create_client_hello( sv.serialize(&mut ext_buf); let sv_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SupportedVersions, + extension_type: ExtensionType::SUPPORTED_VERSIONS, extension_data_range: sv_start..sv_end, }); @@ -1221,7 +1221,7 @@ fn handshake_create_client_hello( sg.serialize(&mut ext_buf); let sg_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SupportedGroups, + extension_type: ExtensionType::SUPPORTED_GROUPS, extension_data_range: sg_start..sg_end, }); @@ -1236,7 +1236,7 @@ fn handshake_create_client_hello( ks.serialize(extension_data, &mut ext_buf); let ks_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::KeyShare, + extension_type: ExtensionType::KEY_SHARE, extension_data_range: ks_start..ks_end, }); @@ -1246,7 +1246,7 @@ fn handshake_create_client_hello( sa.serialize(&mut ext_buf); let sa_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SignatureAlgorithms, + extension_type: ExtensionType::SIGNATURE_ALGORITHMS, extension_data_range: sa_start..sa_end, }); @@ -1256,7 +1256,7 @@ fn handshake_create_client_hello( use_srtp.serialize(&mut ext_buf); let srtp_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::UseSrtp, + extension_type: ExtensionType::USE_SRTP, extension_data_range: srtp_start..srtp_end, }); @@ -1267,7 +1267,7 @@ fn handshake_create_client_hello( ext_buf.extend_from_slice(&extension_data[cookie_range]); let cookie_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::Cookie, + extension_type: ExtensionType::COOKIE, extension_data_range: cookie_start..cookie_end, }); } @@ -1296,7 +1296,7 @@ fn handshake_create_client_hello( } let pad_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::Padding, + extension_type: ExtensionType::PADDING, extension_data_range: pad_start..pad_end, }); } @@ -1515,7 +1515,7 @@ fn parse_certificate_request(cr_data: &[u8], base_offset: usize) -> Result let ca_data = &cr_data[pos..pos + ext_data_len]; if ca_data.len() >= 2 { @@ -1624,10 +1624,10 @@ mod tests { key_share.push(0); let mut extensions = Vec::new(); - extensions.extend_from_slice(&ExtensionType::SupportedVersions.as_u16().to_be_bytes()); + extensions.extend_from_slice(&ExtensionType::SUPPORTED_VERSIONS.as_u16().to_be_bytes()); extensions.extend_from_slice(&2u16.to_be_bytes()); extensions.extend_from_slice(&ProtocolVersion::DTLS1_3.as_u16().to_be_bytes()); - extensions.extend_from_slice(&ExtensionType::KeyShare.as_u16().to_be_bytes()); + extensions.extend_from_slice(&ExtensionType::KEY_SHARE.as_u16().to_be_bytes()); extensions.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); extensions.extend_from_slice(&key_share); diff --git a/src/dtls13/message/client_hello.rs b/src/dtls13/message/client_hello.rs index 56d37f07..e4ea3991 100644 --- a/src/dtls13/message/client_hello.rs +++ b/src/dtls13/message/client_hello.rs @@ -340,7 +340,7 @@ mod tests { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } @@ -360,7 +360,7 @@ mod tests { fn zero_length_extension_vector_rejects_trailing_bytes() { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&0u16.to_be_bytes()); - message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); assert!( @@ -373,7 +373,7 @@ mod tests { fn underdeclared_extension_vector_rejects_trailing_bytes() { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&4u16.to_be_bytes()); - message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); message.push(0); diff --git a/src/dtls13/message/encrypted_extensions.rs b/src/dtls13/message/encrypted_extensions.rs index 176f16b8..daa3ead5 100644 --- a/src/dtls13/message/encrypted_extensions.rs +++ b/src/dtls13/message/encrypted_extensions.rs @@ -71,7 +71,7 @@ mod tests { const MESSAGE: &[u8] = &[ 0x00, 0x0C, // Extensions length (12) - 0x00, 0x0A, // ExtensionType::SupportedGroups + 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS 0x00, 0x08, // Extension data length 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; @@ -92,7 +92,7 @@ mod tests { let count = 2; message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls13/message/extension.rs b/src/dtls13/message/extension.rs index 437927d5..ff799332 100644 --- a/src/dtls13/message/extension.rs +++ b/src/dtls13/message/extension.rs @@ -1,6 +1,6 @@ use crate::buffer::Buf; use nom::{IResult, bytes::complete::take, number::complete::be_u16}; -use std::ops::Range; +use std::{fmt, ops::Range}; #[derive(Debug, PartialEq, Eq, Default)] pub struct Extension { @@ -48,142 +48,69 @@ impl Extension { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ExtensionType { - ServerName, - MaxFragmentLength, - ClientCertificateUrl, - TrustedCaKeys, - TruncatedHmac, - StatusRequest, - UserMapping, - ClientAuthz, - ServerAuthz, - CertType, - SupportedGroups, - EcPointFormats, - Srp, - SignatureAlgorithms, - UseSrtp, - Heartbeat, - ApplicationLayerProtocolNegotiation, - StatusRequestV2, - SignedCertificateTimestamp, - ClientCertificateType, - ServerCertificateType, - Padding, - EncryptThenMac, - ExtendedMasterSecret, - TokenBinding, - CachedInfo, - SessionTicket, - PreSharedKey, - EarlyData, - SupportedVersions, - Cookie, - PskKeyExchangeModes, - CertificateAuthorities, - OidFilters, - PostHandshakeAuth, - SignatureAlgorithmsCert, - KeyShare, - RenegotiationInfo, - Unknown(u16), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct ExtensionType(u16); impl Default for ExtensionType { fn default() -> Self { - Self::Unknown(0) + Self(u16::MAX) } } impl ExtensionType { - pub fn from_u16(value: u16) -> Self { - match value { - 0x0000 => ExtensionType::ServerName, - 0x0001 => ExtensionType::MaxFragmentLength, - 0x0002 => ExtensionType::ClientCertificateUrl, - 0x0003 => ExtensionType::TrustedCaKeys, - 0x0004 => ExtensionType::TruncatedHmac, - 0x0005 => ExtensionType::StatusRequest, - 0x0006 => ExtensionType::UserMapping, - 0x0007 => ExtensionType::ClientAuthz, - 0x0008 => ExtensionType::ServerAuthz, - 0x0009 => ExtensionType::CertType, - 0x000A => ExtensionType::SupportedGroups, - 0x000B => ExtensionType::EcPointFormats, - 0x000C => ExtensionType::Srp, - 0x000D => ExtensionType::SignatureAlgorithms, - 0x000E => ExtensionType::UseSrtp, - 0x000F => ExtensionType::Heartbeat, - 0x0010 => ExtensionType::ApplicationLayerProtocolNegotiation, - 0x0011 => ExtensionType::StatusRequestV2, - 0x0012 => ExtensionType::SignedCertificateTimestamp, - 0x0013 => ExtensionType::ClientCertificateType, - 0x0014 => ExtensionType::ServerCertificateType, - 0x0015 => ExtensionType::Padding, - 0x0016 => ExtensionType::EncryptThenMac, - 0x0017 => ExtensionType::ExtendedMasterSecret, - 0x0018 => ExtensionType::TokenBinding, - 0x0019 => ExtensionType::CachedInfo, - 0x0023 => ExtensionType::SessionTicket, - 0x0029 => ExtensionType::PreSharedKey, - 0x002A => ExtensionType::EarlyData, - 0x002B => ExtensionType::SupportedVersions, - 0x002C => ExtensionType::Cookie, - 0x002D => ExtensionType::PskKeyExchangeModes, - 0x002F => ExtensionType::CertificateAuthorities, - 0x0030 => ExtensionType::OidFilters, - 0x0031 => ExtensionType::PostHandshakeAuth, - 0x0032 => ExtensionType::SignatureAlgorithmsCert, - 0x0033 => ExtensionType::KeyShare, - 0xFF01 => ExtensionType::RenegotiationInfo, - _ => ExtensionType::Unknown(value), - } + pub const SERVER_NAME: Self = Self(0x0000); + pub const MAX_FRAGMENT_LENGTH: Self = Self(0x0001); + pub const CLIENT_CERTIFICATE_URL: Self = Self(0x0002); + pub const TRUSTED_CA_KEYS: Self = Self(0x0003); + pub const TRUNCATED_HMAC: Self = Self(0x0004); + pub const STATUS_REQUEST: Self = Self(0x0005); + pub const USER_MAPPING: Self = Self(0x0006); + pub const CLIENT_AUTHZ: Self = Self(0x0007); + pub const SERVER_AUTHZ: Self = Self(0x0008); + pub const CERT_TYPE: Self = Self(0x0009); + pub const SUPPORTED_GROUPS: Self = Self(0x000A); + pub const EC_POINT_FORMATS: Self = Self(0x000B); + pub const SRP: Self = Self(0x000C); + pub const SIGNATURE_ALGORITHMS: Self = Self(0x000D); + pub const USE_SRTP: Self = Self(0x000E); + pub const HEARTBEAT: Self = Self(0x000F); + pub const APPLICATION_LAYER_PROTOCOL_NEGOTIATION: Self = Self(0x0010); + pub const STATUS_REQUEST_V2: Self = Self(0x0011); + pub const SIGNED_CERTIFICATE_TIMESTAMP: Self = Self(0x0012); + pub const CLIENT_CERTIFICATE_TYPE: Self = Self(0x0013); + pub const SERVER_CERTIFICATE_TYPE: Self = Self(0x0014); + pub const PADDING: Self = Self(0x0015); + pub const ENCRYPT_THEN_MAC: Self = Self(0x0016); + pub const EXTENDED_MASTER_SECRET: Self = Self(0x0017); + pub const TOKEN_BINDING: Self = Self(0x0018); + pub const CACHED_INFO: Self = Self(0x0019); + pub const SESSION_TICKET: Self = Self(0x0023); + pub const PRE_SHARED_KEY: Self = Self(0x0029); + pub const EARLY_DATA: Self = Self(0x002A); + pub const SUPPORTED_VERSIONS: Self = Self(0x002B); + pub const COOKIE: Self = Self(0x002C); + pub const PSK_KEY_EXCHANGE_MODES: Self = Self(0x002D); + pub const CERTIFICATE_AUTHORITIES: Self = Self(0x002F); + pub const OID_FILTERS: Self = Self(0x0030); + pub const POST_HANDSHAKE_AUTH: Self = Self(0x0031); + pub const SIGNATURE_ALGORITHMS_CERT: Self = Self(0x0032); + pub const KEY_SHARE: Self = Self(0x0033); + pub const RENEGOTIATION_INFO: Self = Self(0xFF01); + + pub const fn from_u16(value: u16) -> Self { + Self(value) } - pub fn as_u16(&self) -> u16 { - match self { - ExtensionType::ServerName => 0x0000, - ExtensionType::MaxFragmentLength => 0x0001, - ExtensionType::ClientCertificateUrl => 0x0002, - ExtensionType::TrustedCaKeys => 0x0003, - ExtensionType::TruncatedHmac => 0x0004, - ExtensionType::StatusRequest => 0x0005, - ExtensionType::UserMapping => 0x0006, - ExtensionType::ClientAuthz => 0x0007, - ExtensionType::ServerAuthz => 0x0008, - ExtensionType::CertType => 0x0009, - ExtensionType::SupportedGroups => 0x000A, - ExtensionType::EcPointFormats => 0x000B, - ExtensionType::Srp => 0x000C, - ExtensionType::SignatureAlgorithms => 0x000D, - ExtensionType::UseSrtp => 0x000E, - ExtensionType::Heartbeat => 0x000F, - ExtensionType::ApplicationLayerProtocolNegotiation => 0x0010, - ExtensionType::StatusRequestV2 => 0x0011, - ExtensionType::SignedCertificateTimestamp => 0x0012, - ExtensionType::ClientCertificateType => 0x0013, - ExtensionType::ServerCertificateType => 0x0014, - ExtensionType::Padding => 0x0015, - ExtensionType::EncryptThenMac => 0x0016, - ExtensionType::ExtendedMasterSecret => 0x0017, - ExtensionType::TokenBinding => 0x0018, - ExtensionType::CachedInfo => 0x0019, - ExtensionType::SessionTicket => 0x0023, - ExtensionType::PreSharedKey => 0x0029, - ExtensionType::EarlyData => 0x002A, - ExtensionType::SupportedVersions => 0x002B, - ExtensionType::Cookie => 0x002C, - ExtensionType::PskKeyExchangeModes => 0x002D, - ExtensionType::CertificateAuthorities => 0x002F, - ExtensionType::OidFilters => 0x0030, - ExtensionType::PostHandshakeAuth => 0x0031, - ExtensionType::SignatureAlgorithmsCert => 0x0032, - ExtensionType::KeyShare => 0x0033, - ExtensionType::RenegotiationInfo => 0xFF01, - ExtensionType::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!( + *self, + Self(0x0000..=0x0019 | 0x0023 | 0x0029..=0x002D | 0x002F..=0x0033 | 0xFF01) + ) } pub fn parse(input: &[u8]) -> IResult<&[u8], ExtensionType> { @@ -199,27 +126,114 @@ impl ExtensionType { /// Supported extension types that this DTLS 1.3 implementation handles. pub const fn supported() -> &'static [ExtensionType; 6] { &[ - ExtensionType::SupportedVersions, - ExtensionType::SupportedGroups, - ExtensionType::SignatureAlgorithms, - ExtensionType::KeyShare, - ExtensionType::UseSrtp, - ExtensionType::Cookie, + ExtensionType::SUPPORTED_VERSIONS, + ExtensionType::SUPPORTED_GROUPS, + ExtensionType::SIGNATURE_ALGORITHMS, + ExtensionType::KEY_SHARE, + ExtensionType::USE_SRTP, + ExtensionType::COOKIE, ] } } +impl fmt::Debug for ExtensionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + ExtensionType::SERVER_NAME => "ServerName", + ExtensionType::MAX_FRAGMENT_LENGTH => "MaxFragmentLength", + ExtensionType::CLIENT_CERTIFICATE_URL => "ClientCertificateUrl", + ExtensionType::TRUSTED_CA_KEYS => "TrustedCaKeys", + ExtensionType::TRUNCATED_HMAC => "TruncatedHmac", + ExtensionType::STATUS_REQUEST => "StatusRequest", + ExtensionType::USER_MAPPING => "UserMapping", + ExtensionType::CLIENT_AUTHZ => "ClientAuthz", + ExtensionType::SERVER_AUTHZ => "ServerAuthz", + ExtensionType::CERT_TYPE => "CertType", + ExtensionType::SUPPORTED_GROUPS => "SupportedGroups", + ExtensionType::EC_POINT_FORMATS => "EcPointFormats", + ExtensionType::SRP => "Srp", + ExtensionType::SIGNATURE_ALGORITHMS => "SignatureAlgorithms", + ExtensionType::USE_SRTP => "UseSrtp", + ExtensionType::HEARTBEAT => "Heartbeat", + ExtensionType::APPLICATION_LAYER_PROTOCOL_NEGOTIATION => { + "ApplicationLayerProtocolNegotiation" + } + ExtensionType::STATUS_REQUEST_V2 => "StatusRequestV2", + ExtensionType::SIGNED_CERTIFICATE_TIMESTAMP => "SignedCertificateTimestamp", + ExtensionType::CLIENT_CERTIFICATE_TYPE => "ClientCertificateType", + ExtensionType::SERVER_CERTIFICATE_TYPE => "ServerCertificateType", + ExtensionType::PADDING => "Padding", + ExtensionType::ENCRYPT_THEN_MAC => "EncryptThenMac", + ExtensionType::EXTENDED_MASTER_SECRET => "ExtendedMasterSecret", + ExtensionType::TOKEN_BINDING => "TokenBinding", + ExtensionType::CACHED_INFO => "CachedInfo", + ExtensionType::SESSION_TICKET => "SessionTicket", + ExtensionType::PRE_SHARED_KEY => "PreSharedKey", + ExtensionType::EARLY_DATA => "EarlyData", + ExtensionType::SUPPORTED_VERSIONS => "SupportedVersions", + ExtensionType::COOKIE => "Cookie", + ExtensionType::PSK_KEY_EXCHANGE_MODES => "PskKeyExchangeModes", + ExtensionType::CERTIFICATE_AUTHORITIES => "CertificateAuthorities", + ExtensionType::OID_FILTERS => "OidFilters", + ExtensionType::POST_HANDSHAKE_AUTH => "PostHandshakeAuth", + ExtensionType::SIGNATURE_ALGORITHMS_CERT => "SignatureAlgorithmsCert", + ExtensionType::KEY_SHARE => "KeyShare", + ExtensionType::RENEGOTIATION_INFO => "RenegotiationInfo", + _ => unreachable!("known DTLS 1.3 extension type missing Debug label"), + }; + + f.write_str(name) + } +} + #[cfg(test)] mod tests { use super::*; use crate::buffer::Buf; const MESSAGE: &[u8] = &[ - 0x00, 0x0A, // ExtensionType::SupportedGroups + 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS 0x00, 0x08, // Extension length 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; + #[test] + fn extension_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert!(ExtensionType::default().is_unknown()); + } + + #[test] + fn extension_type_wire_roundtrip() { + for extension_type in ExtensionType::supported() { + assert_eq!( + ExtensionType::from_u16(extension_type.as_u16()), + *extension_type + ); + assert!(!extension_type.is_unknown()); + } + + let unknown = ExtensionType::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn extension_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ExtensionType::SUPPORTED_GROUPS), + "SupportedGroups" + ); + assert_eq!( + format!("{:?}", ExtensionType::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn roundtrip() { // Parse the message with base_offset 0 diff --git a/src/dtls13/message/server_hello.rs b/src/dtls13/message/server_hello.rs index 2886ca8d..a89a2fae 100644 --- a/src/dtls13/message/server_hello.rs +++ b/src/dtls13/message/server_hello.rs @@ -158,7 +158,7 @@ mod test { 0x13, 0x01, // Dtls13CipherSuite::AES_128_GCM_SHA256 0x00, // CompressionMethod::NULL 0x00, 0x0C, // Extensions length (12 bytes) - 0x00, 0x0A, // ExtensionType::SupportedGroups + 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS 0x00, 0x08, // Extension data length (8 bytes) 0x00, 0x06, // Extension data 0x00, 0x17, // NamedGroup::SECP256R1 @@ -213,7 +213,7 @@ mod test { let count = 2; message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 35caecaa..1af45f55 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -448,7 +448,7 @@ impl State { for ext in &client_hello.extensions { match ext.extension_type { - ExtensionType::SupportedVersions => { + ExtensionType::SUPPORTED_VERSIONS => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, sv) = SupportedVersionsClientHello::parse(ext_data) .map_err(InternalError::from)?; @@ -458,7 +458,7 @@ impl State { } } } - ExtensionType::KeyShare => { + ExtensionType::KEY_SHARE => { let ext_data = ext.extension_data(&server.defragment_buffer); let ext_data_start = ext.extension_data_range.start; let (_, ks) = KeyShareClientHello::parse(ext_data, ext_data_start) @@ -473,24 +473,24 @@ impl State { } client_key_shares = Some(entries); } - ExtensionType::SupportedGroups => { + ExtensionType::SUPPORTED_GROUPS => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, sg) = SupportedGroupsExtension::parse(ext_data).map_err(InternalError::from)?; client_supported_groups = Some(sg.groups); } - ExtensionType::SignatureAlgorithms => { + ExtensionType::SIGNATURE_ALGORITHMS => { let ext_data = ext.extension_data(&server.defragment_buffer); // Parse but we don't currently filter by signature algorithms let _ = SignatureAlgorithmsExtension::parse(ext_data); } - ExtensionType::UseSrtp => { + ExtensionType::USE_SRTP => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, use_srtp) = UseSrtpExtension::parse(ext_data).map_err(InternalError::from)?; client_srtp_profiles = Some(use_srtp.profiles); } - ExtensionType::Cookie => { + ExtensionType::COOKIE => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, cookie) = parse_cookie_extension(ext_data).map_err(InternalError::from)?; @@ -1288,7 +1288,7 @@ fn send_hello_retry_request( sv.serialize(&mut ext_buf); let sv_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SupportedVersions, + extension_type: ExtensionType::SUPPORTED_VERSIONS, extension_data_range: sv_start..sv_end, }); @@ -1301,7 +1301,7 @@ fn send_hello_retry_request( hrr_ks.serialize(&mut ext_buf); let ks_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::KeyShare, + extension_type: ExtensionType::KEY_SHARE, extension_data_range: ks_start..ks_end, }); } @@ -1313,7 +1313,7 @@ fn send_hello_retry_request( ext_buf.extend_from_slice(cookie); let cookie_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::Cookie, + extension_type: ExtensionType::COOKIE, extension_data_range: cookie_start..cookie_end, }); @@ -1360,7 +1360,7 @@ fn handshake_create_server_hello( sv.serialize(&mut ext_buf); let sv_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SupportedVersions, + extension_type: ExtensionType::SUPPORTED_VERSIONS, extension_data_range: sv_start..sv_end, }); @@ -1375,7 +1375,7 @@ fn handshake_create_server_hello( ks.serialize(extension_data, &mut ext_buf); let ks_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::KeyShare, + extension_type: ExtensionType::KEY_SHARE, extension_data_range: ks_start..ks_end, }); @@ -1409,7 +1409,7 @@ fn handshake_create_encrypted_extensions( use_srtp.serialize(&mut ext_buf); let srtp_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::UseSrtp, + extension_type: ExtensionType::USE_SRTP, extension_data_range: srtp_start..srtp_end, }); } @@ -1445,7 +1445,7 @@ fn handshake_create_certificate_request(body: &mut Buf) -> Result<(), Error> { sa.serialize(&mut ext_buf); let sa_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SignatureAlgorithms, + extension_type: ExtensionType::SIGNATURE_ALGORITHMS, extension_data_range: sa_start..sa_end, }); @@ -1455,7 +1455,7 @@ fn handshake_create_certificate_request(body: &mut Buf) -> Result<(), Error> { serialize_certificate_authorities(&cas, &[], &mut ext_buf); let ca_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::CertificateAuthorities, + extension_type: ExtensionType::CERTIFICATE_AUTHORITIES, extension_data_range: ca_start..ca_end, }); From 8d74b9aa7bca82dd7b274bbfb95dec6d73535f71 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:07:58 +0300 Subject: [PATCH 12/18] dtls12: make message type a newtype --- src/dtls12/client.rs | 48 ++++---- src/dtls12/message/handshake.rs | 187 ++++++++++++++++++++------------ src/dtls12/server.rs | 40 +++---- 3 files changed, 161 insertions(+), 114 deletions(-) diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index e5199ac8..c16fcbee 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -354,7 +354,7 @@ impl State { client .engine - .create_handshake(MessageType::ClientHello, |body, engine| { + .create_handshake(MessageType::CLIENT_HELLO, |body, engine| { handshake_create_client_hello( body, engine, @@ -377,7 +377,7 @@ impl State { fn await_hello_verify_request(self, client: &mut Client) -> Result { let has_hello = client .engine - .has_complete_handshake(MessageType::ServerHello); + .has_complete_handshake(MessageType::SERVER_HELLO); // Got ServerHello, skip HelloVerifyRequest if has_hello { @@ -385,7 +385,7 @@ impl State { } let maybe = client.engine.next_handshake( - MessageType::HelloVerifyRequest, + MessageType::HELLO_VERIFY_REQUEST, &mut client.defragment_buffer, )?; @@ -429,7 +429,7 @@ impl State { fn await_server_hello(self, client: &mut Client) -> Result { let maybe = client .engine - .next_handshake(MessageType::ServerHello, &mut client.defragment_buffer)?; + .next_handshake(MessageType::SERVER_HELLO, &mut client.defragment_buffer)?; let Some(handshake) = maybe else { // Stay in same state @@ -554,7 +554,7 @@ impl State { fn await_certificate(self, client: &mut Client) -> Result { let maybe = client .engine - .next_handshake(MessageType::Certificate, &mut client.defragment_buffer)?; + .next_handshake(MessageType::CERTIFICATE, &mut client.defragment_buffer)?; let Some(ref handshake) = maybe else { // Stay in same state @@ -610,7 +610,7 @@ impl State { fn await_server_key_exchange_ecdhe(self, client: &mut Client) -> Result { let maybe = client.engine.next_handshake( - MessageType::ServerKeyExchange, + MessageType::SERVER_KEY_EXCHANGE, &mut client.defragment_buffer, )?; @@ -747,13 +747,13 @@ impl State { // If the server skipped ServerKeyExchange (no hint), go straight to ServerHelloDone let has_done = client .engine - .has_complete_handshake(MessageType::ServerHelloDone); + .has_complete_handshake(MessageType::SERVER_HELLO_DONE); if has_done { return Ok(Self::AwaitServerHelloDone); } let maybe = client.engine.next_handshake( - MessageType::ServerKeyExchange, + MessageType::SERVER_KEY_EXCHANGE, &mut client.defragment_buffer, )?; @@ -790,14 +790,14 @@ impl State { fn await_certificate_request(self, client: &mut Client) -> Result { let has_done = client .engine - .has_complete_handshake(MessageType::ServerHelloDone); + .has_complete_handshake(MessageType::SERVER_HELLO_DONE); if has_done { return Ok(Self::AwaitServerHelloDone); } let maybe = client.engine.next_handshake( - MessageType::CertificateRequest, + MessageType::CERTIFICATE_REQUEST, &mut client.defragment_buffer, )?; @@ -838,9 +838,10 @@ impl State { } fn await_server_hello_done(self, client: &mut Client) -> Result { - let maybe = client - .engine - .next_handshake(MessageType::ServerHelloDone, &mut client.defragment_buffer)?; + let maybe = client.engine.next_handshake( + MessageType::SERVER_HELLO_DONE, + &mut client.defragment_buffer, + )?; let Some(handshake) = maybe else { // stay in same state @@ -891,7 +892,7 @@ impl State { // Now use the engine with the stored data client .engine - .create_handshake(MessageType::Certificate, handshake_create_certificate)?; + .create_handshake(MessageType::CERTIFICATE, handshake_create_certificate)?; Ok(Self::SendClientKeyExchange) } @@ -906,7 +907,7 @@ impl State { // Send client key exchange message client.engine.create_handshake( - MessageType::ClientKeyExchange, + MessageType::CLIENT_KEY_EXCHANGE, handshake_create_client_key_exchange, )?; @@ -936,7 +937,7 @@ impl State { // Send the certificate verify message client.engine.create_handshake( - MessageType::CertificateVerify, + MessageType::CERTIFICATE_VERIFY, handshake_create_certificate_verify, )?; @@ -1031,7 +1032,7 @@ impl State { client .engine - .create_handshake(MessageType::Finished, |body, engine| { + .create_handshake(MessageType::FINISHED, |body, engine| { // Calculate verify data for Finished message using PRF let verify_data = engine.generate_verify_data(true)?; @@ -1065,15 +1066,16 @@ impl State { } fn await_new_session_ticket(self, client: &mut Client) -> Result { - let has_finished = client.engine.has_complete_handshake(MessageType::Finished); + let has_finished = client.engine.has_complete_handshake(MessageType::FINISHED); if has_finished { return Ok(Self::AwaitFinished); } - let maybe = client - .engine - .next_handshake(MessageType::NewSessionTicket, &mut client.defragment_buffer)?; + let maybe = client.engine.next_handshake( + MessageType::NEW_SESSION_TICKET, + &mut client.defragment_buffer, + )?; let Some(handshake) = maybe else { // Stay in same state @@ -1099,7 +1101,7 @@ impl State { let maybe = client .engine - .next_handshake(MessageType::Finished, &mut client.defragment_buffer)?; + .next_handshake(MessageType::FINISHED, &mut client.defragment_buffer)?; if maybe.is_none() { // stay in same state @@ -1429,7 +1431,7 @@ mod tests { client .engine .parse_packet(&epoch0_handshake_packet( - MessageType::Certificate, + MessageType::CERTIFICATE, 0, &[0, 0, 0], )) diff --git a/src/dtls12/message/handshake.rs b/src/dtls12/message/handshake.rs index 4b9f8f46..66ff43c1 100644 --- a/src/dtls12/message/handshake.rs +++ b/src/dtls12/message/handshake.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::ops::Range; use std::sync::atomic::{AtomicBool, Ordering}; @@ -188,7 +189,7 @@ impl Handshake { let (rest, body) = Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)?; - if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished { + if !rest.is_empty() && first_handshake.header.msg_type == MessageType::FINISHED { debug!("Defragmentation failed. Body::parse() did not consume the entire buffer"); return Err(crate::InternalError::parse_incomplete()); } @@ -268,10 +269,10 @@ impl Handshake { let qualifies = matches!( self.header.msg_type, - MessageType::ClientHello | // flight 1 and 3 - MessageType::HelloVerifyRequest | // flight 2 - MessageType::ServerHelloDone | // flight 4 - MessageType::ClientKeyExchange // flight 5 + MessageType::CLIENT_HELLO | // flight 1 and 3 + MessageType::HELLO_VERIFY_REQUEST | // flight 2 + MessageType::SERVER_HELLO_DONE | // flight 4 + MessageType::CLIENT_KEY_EXCHANGE // flight 5 ); qualifies.then_some(self.header.message_seq) @@ -286,64 +287,40 @@ impl Handshake { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageType { - HelloRequest, // empty - ClientHello, - HelloVerifyRequest, - ServerHello, - Certificate, - ServerKeyExchange, - CertificateRequest, - ServerHelloDone, // empty - CertificateVerify, - ClientKeyExchange, - NewSessionTicket, - Finished, - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct MessageType(u8); impl Default for MessageType { fn default() -> Self { - Self::Unknown(0) + Self(u8::MAX) } } impl MessageType { - pub fn from_u8(value: u8) -> Self { - match value { - 0 => MessageType::HelloRequest, // empty - 1 => MessageType::ClientHello, - 3 => MessageType::HelloVerifyRequest, - 2 => MessageType::ServerHello, - 11 => MessageType::Certificate, - 12 => MessageType::ServerKeyExchange, - 13 => MessageType::CertificateRequest, - 14 => MessageType::ServerHelloDone, // empty - 15 => MessageType::CertificateVerify, - 16 => MessageType::ClientKeyExchange, - 4 => MessageType::NewSessionTicket, - 20 => MessageType::Finished, - _ => MessageType::Unknown(value), - } + pub const HELLO_REQUEST: Self = Self(0); + pub const CLIENT_HELLO: Self = Self(1); + pub const SERVER_HELLO: Self = Self(2); + pub const HELLO_VERIFY_REQUEST: Self = Self(3); + pub const NEW_SESSION_TICKET: Self = Self(4); + pub const CERTIFICATE: Self = Self(11); + pub const SERVER_KEY_EXCHANGE: Self = Self(12); + pub const CERTIFICATE_REQUEST: Self = Self(13); + pub const SERVER_HELLO_DONE: Self = Self(14); + pub const CERTIFICATE_VERIFY: Self = Self(15); + pub const CLIENT_KEY_EXCHANGE: Self = Self(16); + pub const FINISHED: Self = Self(20); + + pub const fn from_u8(value: u8) -> Self { + Self(value) } - pub fn as_u8(&self) -> u8 { - match self { - MessageType::HelloRequest => 0, - MessageType::ClientHello => 1, - MessageType::HelloVerifyRequest => 3, - MessageType::ServerHello => 2, - MessageType::Certificate => 11, - MessageType::ServerKeyExchange => 12, - MessageType::CertificateRequest => 13, - MessageType::ServerHelloDone => 14, - MessageType::CertificateVerify => 15, - MessageType::ClientKeyExchange => 16, - MessageType::NewSessionTicket => 4, - MessageType::Finished => 20, - MessageType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(0..=4 | 11..=16 | 20)) } pub fn parse(input: &[u8]) -> IResult<&[u8], MessageType> { @@ -352,7 +329,10 @@ impl MessageType { } pub fn epoch(&self) -> u16 { - if matches!(self, MessageType::NewSessionTicket | MessageType::Finished) { + if matches!( + *self, + MessageType::NEW_SESSION_TICKET | MessageType::FINISHED + ) { 1 } else { 0 @@ -360,6 +340,32 @@ impl MessageType { } } +impl fmt::Debug for MessageType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + MessageType::HELLO_REQUEST => "HelloRequest", + MessageType::CLIENT_HELLO => "ClientHello", + MessageType::HELLO_VERIFY_REQUEST => "HelloVerifyRequest", + MessageType::SERVER_HELLO => "ServerHello", + MessageType::CERTIFICATE => "Certificate", + MessageType::SERVER_KEY_EXCHANGE => "ServerKeyExchange", + MessageType::CERTIFICATE_REQUEST => "CertificateRequest", + MessageType::SERVER_HELLO_DONE => "ServerHelloDone", + MessageType::CERTIFICATE_VERIFY => "CertificateVerify", + MessageType::CLIENT_KEY_EXCHANGE => "ClientKeyExchange", + MessageType::NEW_SESSION_TICKET => "NewSessionTicket", + MessageType::FINISHED => "Finished", + _ => unreachable!("known DTLS 1.2 handshake message type missing Debug label"), + }; + + f.write_str(name) + } +} + #[derive(Debug, PartialEq, Eq)] #[allow(clippy::large_enum_variant)] pub enum Body { @@ -393,24 +399,24 @@ impl Body { c: Option, ) -> IResult<&[u8], Body> { match m { - MessageType::HelloRequest => Ok((input, Body::HelloRequest)), - MessageType::ClientHello => { + MessageType::HELLO_REQUEST => Ok((input, Body::HelloRequest)), + MessageType::CLIENT_HELLO => { let (input, client_hello) = ClientHello::parse(input, base_offset)?; Ok((input, Body::ClientHello(client_hello))) } - MessageType::HelloVerifyRequest => { + MessageType::HELLO_VERIFY_REQUEST => { let (input, hello_verify_request) = HelloVerifyRequest::parse(input)?; Ok((input, Body::HelloVerifyRequest(hello_verify_request))) } - MessageType::ServerHello => { + MessageType::SERVER_HELLO => { let (input, server_hello) = ServerHello::parse(input, base_offset)?; Ok((input, Body::ServerHello(server_hello))) } - MessageType::Certificate => { + MessageType::CERTIFICATE => { let (input, certificate) = Certificate::parse(input, base_offset)?; Ok((input, Body::Certificate(certificate))) } - MessageType::ServerKeyExchange => { + MessageType::SERVER_KEY_EXCHANGE => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let algo = cipher_suite.as_key_exchange_algorithm(); @@ -418,16 +424,16 @@ impl Body { ServerKeyExchange::parse(input, base_offset, algo)?; Ok((input, Body::ServerKeyExchange(server_key_exchange))) } - MessageType::CertificateRequest => { + MessageType::CERTIFICATE_REQUEST => { let (input, certificate_request) = CertificateRequest::parse(input, base_offset)?; Ok((input, Body::CertificateRequest(certificate_request))) } - MessageType::ServerHelloDone => Ok((input, Body::ServerHelloDone)), - MessageType::CertificateVerify => { + MessageType::SERVER_HELLO_DONE => Ok((input, Body::ServerHelloDone)), + MessageType::CERTIFICATE_VERIFY => { let (input, certificate_verify) = CertificateVerify::parse(input, base_offset)?; Ok((input, Body::CertificateVerify(certificate_verify))) } - MessageType::ClientKeyExchange => { + MessageType::CLIENT_KEY_EXCHANGE => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let algo = cipher_suite.as_key_exchange_algorithm(); @@ -435,18 +441,18 @@ impl Body { ClientKeyExchange::parse(input, base_offset, algo)?; Ok((input, Body::ClientKeyExchange(client_key_exchange))) } - MessageType::NewSessionTicket => { + MessageType::NEW_SESSION_TICKET => { // Treat ticket as opaque per RFC 5077: lifetime_hint(4) + ticket (opaque vector) let range = base_offset..(base_offset + input.len()); Ok((&[], Body::NewSessionTicket(range))) } - MessageType::Finished => { + MessageType::FINISHED => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let (input, finished) = Finished::parse(input, cipher_suite)?; Ok((input, Body::Finished(finished))) } - MessageType::Unknown(value) => Ok((input, Body::Unknown(value))), + _ => Ok((input, Body::Unknown(m.as_u8()))), } } @@ -513,7 +519,7 @@ mod tests { use crate::dtls12::message::SessionId; const MESSAGE: &[u8] = &[ - 0x01, // MessageType::ClientHello + 0x01, // MessageType::CLIENT_HELLO 0x00, 0x00, 0x2E, // length 0x00, 0x00, // message_seq 0x00, 0x00, 0x00, // fragment_offset @@ -535,11 +541,48 @@ mod tests { 0x00, // CompressionMethod::NULL ]; + #[test] + fn message_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert!(MessageType::default().is_unknown()); + } + + #[test] + fn message_type_wire_roundtrip() { + for message_type in [ + MessageType::HELLO_REQUEST, + MessageType::CLIENT_HELLO, + MessageType::SERVER_HELLO, + MessageType::HELLO_VERIFY_REQUEST, + MessageType::NEW_SESSION_TICKET, + MessageType::CERTIFICATE, + MessageType::SERVER_KEY_EXCHANGE, + MessageType::CERTIFICATE_REQUEST, + MessageType::SERVER_HELLO_DONE, + MessageType::CERTIFICATE_VERIFY, + MessageType::CLIENT_KEY_EXCHANGE, + MessageType::FINISHED, + ] { + assert_eq!(MessageType::from_u8(message_type.as_u8()), message_type); + assert!(!message_type.is_unknown()); + } + + let unknown = MessageType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn message_type_debug_stays_enum_like() { + assert_eq!(format!("{:?}", MessageType::CLIENT_HELLO), "ClientHello"); + assert_eq!(format!("{:?}", MessageType::from_u8(0xFF)), "Unknown(255)"); + } + #[test] fn handshake_size() { let h = Handshake::new( // ServerHelloDone has a 0 sized body. - MessageType::ServerHelloDone, + MessageType::SERVER_HELLO_DONE, 0, 0, 0, @@ -576,7 +619,7 @@ mod tests { ); let handshake = Handshake::new( - MessageType::ClientHello, + MessageType::CLIENT_HELLO, 0x2E, 0, 0, @@ -619,7 +662,7 @@ mod tests { ); let handshake = Handshake::new( - MessageType::ClientHello, + MessageType::CLIENT_HELLO, 46, 0, 0, diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 2a41f4d1..19d615cb 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -318,7 +318,7 @@ impl State { fn await_client_hello(self, server: &mut Server) -> Result { let maybe = server .engine - .next_handshake(MessageType::ClientHello, &mut server.defragment_buffer)?; + .next_handshake(MessageType::CLIENT_HELLO, &mut server.defragment_buffer)?; let Some(handshake) = maybe else { // Stay in same state @@ -370,15 +370,16 @@ impl State { let cookie = compute_cookie(hmac_provider, &server.cookie_secret, client_random)?; // Start/restart flight timer for server Flight 2 (HelloVerifyRequest) server.engine.flight_begin(2); - server - .engine - .create_handshake(MessageType::HelloVerifyRequest, |body, _engine| { + server.engine.create_handshake( + MessageType::HELLO_VERIFY_REQUEST, + |body, _engine| { // RFC 6347 4.2.1: The server_version field in the HelloVerifyRequest // message MUST be set to DTLS 1.0 let hvr = HelloVerifyRequest::new(ProtocolVersion::DTLS1_0, cookie); hvr.serialize(body); Ok(()) - })?; + }, + )?; // The HelloVerifyRequest exchange is stateless per RFC 6347. // Reset all handshake state so the next ClientHello (with cookie) is processed fresh. @@ -505,7 +506,7 @@ impl State { // Send ServerHello server .engine - .create_handshake(MessageType::ServerHello, move |body, engine| { + .create_handshake(MessageType::SERVER_HELLO, move |body, engine| { handshake_create_server_hello( body, engine, @@ -533,7 +534,7 @@ impl State { server .engine - .create_handshake(MessageType::Certificate, handshake_create_certificate)?; + .create_handshake(MessageType::CERTIFICATE, handshake_create_certificate)?; Ok(Self::SendServerKeyExchange) } @@ -605,7 +606,7 @@ impl State { server .engine - .create_handshake(MessageType::ServerKeyExchange, |body, engine| { + .create_handshake(MessageType::SERVER_KEY_EXCHANGE, |body, engine| { handshake_create_server_key_exchange( body, engine, @@ -635,12 +636,13 @@ impl State { return Ok(Self::SendServerHelloDone); }; - server - .engine - .create_handshake(MessageType::ServerKeyExchange, move |body, _engine| { + server.engine.create_handshake( + MessageType::SERVER_KEY_EXCHANGE, + move |body, _engine| { PskParams::serialize_from_bytes(&hint, body); Ok(()) - })?; + }, + )?; // PSK never sends CertificateRequest Ok(Self::SendServerHelloDone) @@ -658,7 +660,7 @@ impl State { server .engine - .create_handshake(MessageType::CertificateRequest, move |body, _| { + .create_handshake(MessageType::CERTIFICATE_REQUEST, move |body, _| { handshake_serialize_certificate_request(body, &sig_algs) })?; @@ -670,7 +672,7 @@ impl State { server .engine - .create_handshake(MessageType::ServerHelloDone, |_, _| Ok(()))?; + .create_handshake(MessageType::SERVER_HELLO_DONE, |_, _| Ok(()))?; let cs = server.engine.cipher_suite().ok_or(Error::InvalidState( crate::InvalidStateError::NoCipherSuiteSelected, @@ -691,7 +693,7 @@ impl State { fn await_certificate(self, server: &mut Server) -> Result { let maybe = server .engine - .next_handshake(MessageType::Certificate, &mut server.defragment_buffer)?; + .next_handshake(MessageType::CERTIFICATE, &mut server.defragment_buffer)?; let Some(ref handshake) = maybe else { // Stay in same state @@ -737,7 +739,7 @@ impl State { fn await_client_key_exchange(self, server: &mut Server) -> Result { let maybe = server.engine.next_handshake( - MessageType::ClientKeyExchange, + MessageType::CLIENT_KEY_EXCHANGE, &mut server.defragment_buffer, )?; @@ -893,7 +895,7 @@ impl State { let data = server.engine.transcript().to_buf(); let maybe = server.engine.next_handshake( - MessageType::CertificateVerify, + MessageType::CERTIFICATE_VERIFY, &mut server.defragment_buffer, )?; @@ -974,7 +976,7 @@ impl State { let maybe = server .engine - .next_handshake(MessageType::Finished, &mut server.defragment_buffer)?; + .next_handshake(MessageType::FINISHED, &mut server.defragment_buffer)?; if maybe.is_none() { // stay in same state @@ -1060,7 +1062,7 @@ impl State { server .engine - .create_handshake(MessageType::Finished, |body, engine| { + .create_handshake(MessageType::FINISHED, |body, engine| { let verify_data = engine.generate_verify_data(false /* server */)?; trace!("Finished.verify_data length: {}", verify_data.len()); // Directly write the verify data without creating Finished struct From 8d87858927997af7a60e026361e6cd962e9aaa69 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:09:41 +0300 Subject: [PATCH 13/18] dtls13: make message type a newtype --- src/dtls13/client.rs | 42 ++++++---- src/dtls13/engine.rs | 10 +-- src/dtls13/message/handshake.rs | 144 ++++++++++++++++++++------------ src/dtls13/server.rs | 36 ++++---- 4 files changed, 142 insertions(+), 90 deletions(-) diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 728fd5d1..bba42df6 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -428,7 +428,7 @@ impl State { client .engine - .create_handshake(MessageType::ClientHello, |body, engine| { + .create_handshake(MessageType::CLIENT_HELLO, |body, engine| { handshake_create_client_hello( body, engine, @@ -450,7 +450,7 @@ impl State { let maybe = client .engine - .next_handshake(MessageType::ServerHello, &mut client.defragment_buffer)?; + .next_handshake(MessageType::SERVER_HELLO, &mut client.defragment_buffer)?; let Some(handshake) = maybe else { return Ok(self); @@ -681,7 +681,7 @@ impl State { fn await_encrypted_extensions(self, client: &mut Client) -> Result { let maybe = client.engine.next_handshake( - MessageType::EncryptedExtensions, + MessageType::ENCRYPTED_EXTENSIONS, &mut client.defragment_buffer, )?; @@ -717,14 +717,14 @@ impl State { // CertificateRequest is optional. Check if Certificate is available instead. let has_cert = client .engine - .has_complete_handshake(MessageType::Certificate); + .has_complete_handshake(MessageType::CERTIFICATE); if has_cert { return Ok(Self::AwaitCertificate); } let maybe = client.engine.next_handshake( - MessageType::CertificateRequest, + MessageType::CERTIFICATE_REQUEST, &mut client.defragment_buffer, )?; @@ -755,7 +755,7 @@ impl State { fn await_certificate(self, client: &mut Client) -> Result { let maybe = client .engine - .next_handshake(MessageType::Certificate, &mut client.defragment_buffer)?; + .next_handshake(MessageType::CERTIFICATE, &mut client.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -818,7 +818,7 @@ impl State { client.engine.transcript_hash(&mut transcript_hash); let maybe = client.engine.next_handshake( - MessageType::CertificateVerify, + MessageType::CERTIFICATE_VERIFY, &mut client.defragment_buffer, )?; @@ -898,7 +898,7 @@ impl State { let maybe = client .engine - .next_handshake(MessageType::Finished, &mut client.defragment_buffer)?; + .next_handshake(MessageType::FINISHED, &mut client.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -972,7 +972,7 @@ impl State { client .engine - .create_handshake(MessageType::Certificate, |body, engine| { + .create_handshake(MessageType::CERTIFICATE, |body, engine| { handshake_create_certificate(body, engine, &context_copy) })?; @@ -986,7 +986,7 @@ impl State { client .engine - .create_handshake(MessageType::Certificate, |body, _engine| { + .create_handshake(MessageType::CERTIFICATE, |body, _engine| { // certificate_request_context body.push(context_copy.len() as u8); body.extend_from_slice(&context_copy); @@ -1004,7 +1004,7 @@ impl State { client .engine - .create_handshake(MessageType::CertificateVerify, |body, engine| { + .create_handshake(MessageType::CERTIFICATE_VERIFY, |body, engine| { handshake_create_certificate_verify( body, engine, @@ -1039,7 +1039,7 @@ impl State { client .engine - .create_handshake(MessageType::Finished, |body, engine| { + .create_handshake(MessageType::FINISHED, |body, engine| { let verify_data = engine.compute_verify_data(&client_hs_secret)?; body.extend_from_slice(&verify_data); Ok(()) @@ -1109,9 +1109,12 @@ impl State { } // Check for incoming KeyUpdate - if client.engine.has_complete_handshake(MessageType::KeyUpdate) { + if client + .engine + .has_complete_handshake(MessageType::KEY_UPDATE) + { let maybe = client.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, + MessageType::KEY_UPDATE, &mut client.defragment_buffer, )?; @@ -1142,9 +1145,12 @@ impl State { fn half_closed_local(self, client: &mut Client) -> Result { // Write half is closed: drain incoming KeyUpdate to keep recv keys in sync, // but do not send our own KeyUpdate response. - if client.engine.has_complete_handshake(MessageType::KeyUpdate) { + if client + .engine + .has_complete_handshake(MessageType::KEY_UPDATE) + { let maybe = client.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, + MessageType::KEY_UPDATE, &mut client.defragment_buffer, )?; if let Some(handshake) = maybe { @@ -1651,7 +1657,7 @@ mod tests { client .engine .parse_packet(&epoch0_handshake_packet( - MessageType::Certificate, + MessageType::CERTIFICATE, 2, &[0, 0, 0, 0], )) @@ -1677,7 +1683,7 @@ mod tests { client .engine .parse_packet(&epoch0_handshake_packet( - MessageType::ServerHello, + MessageType::SERVER_HELLO, 0, &server_hello_with_key_share(NamedGroup::SECP256R1), )) diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index f1d36fa0..80ec80c1 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -389,7 +389,7 @@ impl Engine { // but allow KeyUpdate (a post-handshake message). if self.release_app_data && handshake.header.message_seq >= self.peer_handshake_seq_no - && handshake.header.msg_type != MessageType::KeyUpdate + && handshake.header.msg_type != MessageType::KEY_UPDATE { return Err(Error::RenegotiationAttempt); } @@ -789,7 +789,7 @@ impl Engine { &mut self, defragment_buffer: &mut Buf, ) -> Result, InternalError> { - self.next_handshake_with_options(MessageType::ClientHello, defragment_buffer, true) + self.next_handshake_with_options(MessageType::CLIENT_HELLO, defragment_buffer, true) } fn next_handshake_with_options( @@ -1893,7 +1893,7 @@ impl Engine { self.create_ciphertext_record(ContentType::HANDSHAKE, epoch, true, |fragment| { // DTLS handshake header (12 bytes): // msg_type(1) + length(3) + message_seq(2) + fragment_offset(3) + fragment_length(3) - fragment.push(MessageType::KeyUpdate.as_u8()); + fragment.push(MessageType::KEY_UPDATE.as_u8()); fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); // length = 1 fragment.extend_from_slice(&msg_seq.to_be_bytes()); // message_seq fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); // fragment_offset = 0 @@ -2246,7 +2246,7 @@ fn jittered_aead_threshold(limit: u64, rng: &mut SeededRng) -> u64 { /// All other handshake messages are encrypted (epoch 2). fn epoch_for_message(msg_type: MessageType) -> u16 { match msg_type { - MessageType::ClientHello | MessageType::ServerHello => 0, + MessageType::CLIENT_HELLO | MessageType::SERVER_HELLO => 0, _ => 2, } } @@ -2558,7 +2558,7 @@ mod tests { fn encrypted_key_update_record(seq: u16) -> Vec { let mut fragment = Vec::new(); - fragment.push(MessageType::KeyUpdate.as_u8()); + fragment.push(MessageType::KEY_UPDATE.as_u8()); fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); fragment.extend_from_slice(&0u16.to_be_bytes()); fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index 006fae64..ce85a050 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::ops::Range; use std::sync::atomic::{AtomicBool, Ordering}; @@ -219,7 +220,7 @@ impl Handshake { Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)? }; - if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished { + if !rest.is_empty() && first_handshake.header.msg_type == MessageType::FINISHED { debug!("Defragmentation failed. Body::parse() did not consume the entire buffer"); return Err(crate::InternalError::parse_incomplete()); } @@ -251,7 +252,7 @@ impl Handshake { let qualifies = matches!( self.header.msg_type, - MessageType::ClientHello // flight 1 + MessageType::CLIENT_HELLO // flight 1 ); qualifies.then_some(self.header.message_seq) @@ -266,52 +267,36 @@ impl Handshake { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageType { - ClientHello, - ServerHello, - EncryptedExtensions, - Certificate, - CertificateRequest, - CertificateVerify, - Finished, - KeyUpdate, - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct MessageType(u8); impl Default for MessageType { fn default() -> Self { - Self::Unknown(0) + Self(u8::MAX) } } impl MessageType { - pub fn from_u8(value: u8) -> Self { - match value { - 1 => MessageType::ClientHello, - 2 => MessageType::ServerHello, - 8 => MessageType::EncryptedExtensions, - 11 => MessageType::Certificate, - 13 => MessageType::CertificateRequest, - 15 => MessageType::CertificateVerify, - 20 => MessageType::Finished, - 24 => MessageType::KeyUpdate, - _ => MessageType::Unknown(value), - } + pub const CLIENT_HELLO: Self = Self(1); + pub const SERVER_HELLO: Self = Self(2); + pub const ENCRYPTED_EXTENSIONS: Self = Self(8); + pub const CERTIFICATE: Self = Self(11); + pub const CERTIFICATE_REQUEST: Self = Self(13); + pub const CERTIFICATE_VERIFY: Self = Self(15); + pub const FINISHED: Self = Self(20); + pub const KEY_UPDATE: Self = Self(24); + + pub const fn from_u8(value: u8) -> Self { + Self(value) } - pub fn as_u8(&self) -> u8 { - match self { - MessageType::ClientHello => 1, - MessageType::ServerHello => 2, - MessageType::EncryptedExtensions => 8, - MessageType::Certificate => 11, - MessageType::CertificateRequest => 13, - MessageType::CertificateVerify => 15, - MessageType::Finished => 20, - MessageType::KeyUpdate => 24, - MessageType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(1..=2 | 8 | 11 | 13 | 15 | 20 | 24)) } pub fn parse(input: &[u8]) -> IResult<&[u8], MessageType> { @@ -320,6 +305,28 @@ impl MessageType { } } +impl fmt::Debug for MessageType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + MessageType::CLIENT_HELLO => "ClientHello", + MessageType::SERVER_HELLO => "ServerHello", + MessageType::ENCRYPTED_EXTENSIONS => "EncryptedExtensions", + MessageType::CERTIFICATE => "Certificate", + MessageType::CERTIFICATE_REQUEST => "CertificateRequest", + MessageType::CERTIFICATE_VERIFY => "CertificateVerify", + MessageType::FINISHED => "Finished", + MessageType::KEY_UPDATE => "KeyUpdate", + _ => unreachable!("known DTLS 1.3 handshake message type missing Debug label"), + }; + + f.write_str(name) + } +} + #[allow(clippy::large_enum_variant)] #[derive(Debug, PartialEq, Eq)] pub enum Body { @@ -393,7 +400,7 @@ impl Body { allow_unknown_client_hello_suites: bool, ) -> IResult<&[u8], Body> { match m { - MessageType::ClientHello => { + MessageType::CLIENT_HELLO => { let (input, client_hello) = if allow_unknown_client_hello_suites { ClientHello::parse_allow_unknown_suites(input, base_offset)? } else { @@ -401,33 +408,33 @@ impl Body { }; Ok((input, Body::ClientHello(client_hello))) } - MessageType::ServerHello => { + MessageType::SERVER_HELLO => { let (input, server_hello) = ServerHello::parse(input, base_offset)?; Ok((input, Body::ServerHello(server_hello))) } - MessageType::EncryptedExtensions => { + MessageType::ENCRYPTED_EXTENSIONS => { let (input, ee) = EncryptedExtensions::parse(input, base_offset)?; Ok((input, Body::EncryptedExtensions(ee))) } - MessageType::Certificate => { + MessageType::CERTIFICATE => { let (input, certificate) = Certificate::parse(input, base_offset)?; Ok((input, Body::Certificate(certificate))) } - MessageType::CertificateRequest => { + MessageType::CERTIFICATE_REQUEST => { let range = base_offset..(base_offset + input.len()); Ok((&[], Body::CertificateRequest(range))) } - MessageType::CertificateVerify => { + MessageType::CERTIFICATE_VERIFY => { let (input, cv) = CertificateVerify::parse(input, base_offset)?; Ok((input, Body::CertificateVerify(cv))) } - MessageType::Finished => { + MessageType::FINISHED => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let (input, finished) = Finished::parse(input, cipher_suite)?; Ok((input, Body::Finished(finished))) } - MessageType::KeyUpdate => { + MessageType::KEY_UPDATE => { let (input, byte) = be_u8(input)?; if !input.is_empty() { return Err(Err::Failure(Error::new(input, ErrorKind::LengthValue))); @@ -436,7 +443,7 @@ impl Body { .ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; Ok((input, Body::KeyUpdate(request))) } - MessageType::Unknown(value) => Ok((input, Body::Unknown(value))), + _ => Ok((input, Body::Unknown(m.as_u8()))), } } @@ -486,7 +493,7 @@ mod tests { use crate::dtls13::message::{ProtocolVersion, Random, SessionId}; const MESSAGE: &[u8] = &[ - 0x01, // MessageType::ClientHello + 0x01, // MessageType::CLIENT_HELLO 0x00, 0x00, 0x2E, // length 0x00, 0x00, // message_seq 0x00, 0x00, 0x00, // fragment_offset @@ -508,10 +515,43 @@ mod tests { 0x00, // Null ]; + #[test] + fn message_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert!(MessageType::default().is_unknown()); + } + + #[test] + fn message_type_wire_roundtrip() { + for message_type in [ + MessageType::CLIENT_HELLO, + MessageType::SERVER_HELLO, + MessageType::ENCRYPTED_EXTENSIONS, + MessageType::CERTIFICATE, + MessageType::CERTIFICATE_REQUEST, + MessageType::CERTIFICATE_VERIFY, + MessageType::FINISHED, + MessageType::KEY_UPDATE, + ] { + assert_eq!(MessageType::from_u8(message_type.as_u8()), message_type); + assert!(!message_type.is_unknown()); + } + + let unknown = MessageType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn message_type_debug_stays_enum_like() { + assert_eq!(format!("{:?}", MessageType::CLIENT_HELLO), "ClientHello"); + assert_eq!(format!("{:?}", MessageType::from_u8(0xFF)), "Unknown(255)"); + } + #[test] fn handshake_size() { let h = Handshake::new( - MessageType::EncryptedExtensions, + MessageType::ENCRYPTED_EXTENSIONS, 2, 0, 0, @@ -551,7 +591,7 @@ mod tests { ); let handshake = Handshake::new( - MessageType::ClientHello, + MessageType::CLIENT_HELLO, 0x2E, 0, 0, @@ -574,7 +614,7 @@ mod tests { fn key_update_body_rejects_trailing_bytes() { let source = [KeyUpdateRequest::UpdateRequested.as_u8(), 0]; let handshake = Handshake::new( - MessageType::KeyUpdate, + MessageType::KEY_UPDATE, source.len() as u32, 0, 0, diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 1af45f55..b256c516 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -406,7 +406,7 @@ impl State { } else { server .engine - .next_handshake(MessageType::ClientHello, &mut server.defragment_buffer)? + .next_handshake(MessageType::CLIENT_HELLO, &mut server.defragment_buffer)? }; let Some(handshake) = maybe else { @@ -798,7 +798,7 @@ impl State { server .engine - .create_handshake(MessageType::ServerHello, |body, engine| { + .create_handshake(MessageType::SERVER_HELLO, |body, engine| { handshake_create_server_hello( body, engine, @@ -845,7 +845,7 @@ impl State { server .engine - .create_handshake(MessageType::EncryptedExtensions, |body, _engine| { + .create_handshake(MessageType::ENCRYPTED_EXTENSIONS, |body, _engine| { handshake_create_encrypted_extensions(body, negotiated_srtp) })?; @@ -861,7 +861,7 @@ impl State { server .engine - .create_handshake(MessageType::CertificateRequest, |body, _engine| { + .create_handshake(MessageType::CERTIFICATE_REQUEST, |body, _engine| { handshake_create_certificate_request(body) })?; @@ -875,7 +875,7 @@ impl State { server .engine - .create_handshake(MessageType::Certificate, |body, engine| { + .create_handshake(MessageType::CERTIFICATE, |body, engine| { handshake_create_certificate(body, engine, &[]) })?; @@ -887,7 +887,7 @@ impl State { server .engine - .create_handshake(MessageType::CertificateVerify, |body, engine| { + .create_handshake(MessageType::CERTIFICATE_VERIFY, |body, engine| { handshake_create_certificate_verify( body, engine, @@ -914,7 +914,7 @@ impl State { server .engine - .create_handshake(MessageType::Finished, |body, engine| { + .create_handshake(MessageType::FINISHED, |body, engine| { let verify_data = engine.compute_verify_data(&server_hs_secret)?; body.extend_from_slice(&verify_data); Ok(()) @@ -946,7 +946,7 @@ impl State { fn await_certificate(self, server: &mut Server) -> Result { let maybe = server .engine - .next_handshake(MessageType::Certificate, &mut server.defragment_buffer)?; + .next_handshake(MessageType::CERTIFICATE, &mut server.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -1012,7 +1012,7 @@ impl State { server.engine.transcript_hash(&mut transcript_hash); let maybe = server.engine.next_handshake( - MessageType::CertificateVerify, + MessageType::CERTIFICATE_VERIFY, &mut server.defragment_buffer, )?; @@ -1092,7 +1092,7 @@ impl State { let maybe = server .engine - .next_handshake(MessageType::Finished, &mut server.defragment_buffer)?; + .next_handshake(MessageType::FINISHED, &mut server.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -1191,9 +1191,12 @@ impl State { } // Check for incoming KeyUpdate - if server.engine.has_complete_handshake(MessageType::KeyUpdate) { + if server + .engine + .has_complete_handshake(MessageType::KEY_UPDATE) + { let maybe = server.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, + MessageType::KEY_UPDATE, &mut server.defragment_buffer, )?; @@ -1224,9 +1227,12 @@ impl State { fn half_closed_local(self, server: &mut Server) -> Result { // Write half is closed: drain incoming KeyUpdate to keep recv keys in sync, // but do not send our own KeyUpdate response. - if server.engine.has_complete_handshake(MessageType::KeyUpdate) { + if server + .engine + .has_complete_handshake(MessageType::KEY_UPDATE) + { let maybe = server.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, + MessageType::KEY_UPDATE, &mut server.defragment_buffer, )?; if let Some(handshake) = maybe { @@ -1328,7 +1334,7 @@ fn send_hello_retry_request( server .engine - .create_handshake(MessageType::ServerHello, |body, _engine| { + .create_handshake(MessageType::SERVER_HELLO, |body, _engine| { server_hello.serialize(&ext_buf, body); Ok(()) })?; From 77257ee8a8c73d556963a1172fc5c8d33038561e Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:11:14 +0300 Subject: [PATCH 14/18] dtls12: make curve type a newtype --- src/dtls12/context.rs | 4 +- src/dtls12/message/client_key_exchange.rs | 2 +- src/dtls12/message/named_group.rs | 89 +++++++++++++++++------ src/dtls12/server.rs | 2 +- 4 files changed, 70 insertions(+), 27 deletions(-) diff --git a/src/dtls12/context.rs b/src/dtls12/context.rs index 30502c8f..4831ddc0 100644 --- a/src/dtls12/context.rs +++ b/src/dtls12/context.rs @@ -558,14 +558,14 @@ impl CryptoContext { pub fn get_key_exchange_group_info(&self) -> Option<(CurveType, NamedGroup)> { // Use stored group if available (after key exchange is consumed) if let Some(group) = self.key_exchange_group { - return Some((CurveType::NamedCurve, group)); + return Some((CurveType::NAMED_CURVE, group)); } // Otherwise get it from the active key exchange let Some(ke) = &self.key_exchange else { return None; }; - Some((CurveType::NamedCurve, ke.group())) + Some((CurveType::NAMED_CURVE, ke.group())) } /// Check if the client's private key is compatible with a given cipher suite. diff --git a/src/dtls12/message/client_key_exchange.rs b/src/dtls12/message/client_key_exchange.rs index 26f68907..147a24bd 100644 --- a/src/dtls12/message/client_key_exchange.rs +++ b/src/dtls12/message/client_key_exchange.rs @@ -42,7 +42,7 @@ impl ClientEcdhKeys { ClientEcdhKeys { // In ClientKeyExchange, we don't include curve_type and named_group // since they're already established during ServerKeyExchange - curve_type: CurveType::NamedCurve, // Default + curve_type: CurveType::NAMED_CURVE, // Default named_group: NamedGroup::SECP256R1, // Default public_key_range: start..end, }, diff --git a/src/dtls12/message/named_group.rs b/src/dtls12/message/named_group.rs index 12b9edda..c4dad1ae 100644 --- a/src/dtls12/message/named_group.rs +++ b/src/dtls12/message/named_group.rs @@ -5,41 +5,35 @@ use nom::IResult; use nom::number::complete::be_u8; +use std::fmt; /// Curve type for ECDH parameters in DTLS 1.2. /// /// This is specific to DTLS 1.2's ServerKeyExchange message format. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CurveType { +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct CurveType(u8); + +impl CurveType { /// Explicit prime curve parameters. - ExplicitPrime, + pub const EXPLICIT_PRIME: Self = Self(1); /// Explicit characteristic-2 curve parameters. - ExplicitChar2, + pub const EXPLICIT_CHAR2: Self = Self(2); /// Named curve (the common case). - NamedCurve, - /// Unknown curve type. - Unknown(u8), -} + pub const NAMED_CURVE: Self = Self(3); -impl CurveType { /// Convert a u8 value to a `CurveType`. - pub fn from_u8(value: u8) -> Self { - match value { - 1 => CurveType::ExplicitPrime, - 2 => CurveType::ExplicitChar2, - 3 => CurveType::NamedCurve, - _ => CurveType::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `CurveType` to its u8 value. - pub fn as_u8(&self) -> u8 { - match self { - CurveType::ExplicitPrime => 1, - CurveType::ExplicitChar2 => 2, - CurveType::NamedCurve => 3, - CurveType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(1..=3)) } /// Parse a `CurveType` from wire format. @@ -48,3 +42,52 @@ impl CurveType { Ok((input, CurveType::from_u8(value))) } } + +impl fmt::Debug for CurveType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + CurveType::EXPLICIT_PRIME => "ExplicitPrime", + CurveType::EXPLICIT_CHAR2 => "ExplicitChar2", + CurveType::NAMED_CURVE => "NamedCurve", + _ => unreachable!("known DTLS 1.2 curve type missing Debug label"), + }; + + f.write_str(name) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn curve_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + } + + #[test] + fn curve_type_wire_roundtrip() { + for curve_type in [ + CurveType::EXPLICIT_PRIME, + CurveType::EXPLICIT_CHAR2, + CurveType::NAMED_CURVE, + ] { + assert_eq!(CurveType::from_u8(curve_type.as_u8()), curve_type); + assert!(!curve_type.is_unknown()); + } + + let unknown = CurveType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn curve_type_debug_stays_enum_like() { + assert_eq!(format!("{:?}", CurveType::NAMED_CURVE), "NamedCurve"); + assert_eq!(format!("{:?}", CurveType::from_u8(0xFF)), "Unknown(255)"); + } +} diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 19d615cb..e0dbf3be 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -1243,7 +1243,7 @@ fn handshake_create_server_key_exchange( match key_exchange_algorithm { KeyExchangeAlgorithm::EECDH => { - let (curve_type, named_group) = (CurveType::NamedCurve, named_group); + let (curve_type, named_group) = (CurveType::NAMED_CURVE, named_group); let mut kx_buf = engine.pop_buffer(); let pubkey = engine .crypto_context_mut() From e859287faa060359deda93f55580e4bcc1ff20dc Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:13:28 +0300 Subject: [PATCH 15/18] dtls12: make client certificate type a newtype --- src/dtls12/message/mod.rs | 123 ++++++++++++++++++++++++++------------ 1 file changed, 85 insertions(+), 38 deletions(-) diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index 5c3d8036..dd2b157c 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -245,39 +245,28 @@ pub enum KeyExchangeAlgorithm { pub type CertificateTypeVec = ArrayVec; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum ClientCertificateType { - RSA_SIGN, - DSS_SIGN, - RSA_FIXED_DH, - DSS_FIXED_DH, - RSA_EPHEMERAL_DH, - DSS_EPHEMERAL_DH, - FORTEZZA_DMS, - ECDSA_SIGN, - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct ClientCertificateType(u8); impl Default for ClientCertificateType { fn default() -> Self { - Self::Unknown(0) + Self(u8::MAX) } } impl ClientCertificateType { - pub fn from_u8(value: u8) -> Self { - match value { - 1 => ClientCertificateType::RSA_SIGN, - 2 => ClientCertificateType::DSS_SIGN, - 3 => ClientCertificateType::RSA_FIXED_DH, - 4 => ClientCertificateType::DSS_FIXED_DH, - 5 => ClientCertificateType::RSA_EPHEMERAL_DH, - 6 => ClientCertificateType::DSS_EPHEMERAL_DH, - 20 => ClientCertificateType::FORTEZZA_DMS, - 64 => ClientCertificateType::ECDSA_SIGN, - _ => ClientCertificateType::Unknown(value), - } + pub const RSA_SIGN: Self = Self(1); + pub const DSS_SIGN: Self = Self(2); + pub const RSA_FIXED_DH: Self = Self(3); + pub const DSS_FIXED_DH: Self = Self(4); + pub const RSA_EPHEMERAL_DH: Self = Self(5); + pub const DSS_EPHEMERAL_DH: Self = Self(6); + pub const FORTEZZA_DMS: Self = Self(20); + pub const ECDSA_SIGN: Self = Self(64); + + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Returns true if this certificate type is supported by this implementation. @@ -291,18 +280,12 @@ impl ClientCertificateType { &[ClientCertificateType::ECDSA_SIGN] } - pub fn as_u8(&self) -> u8 { - match self { - ClientCertificateType::RSA_SIGN => 1, - ClientCertificateType::DSS_SIGN => 2, - ClientCertificateType::RSA_FIXED_DH => 3, - ClientCertificateType::DSS_FIXED_DH => 4, - ClientCertificateType::RSA_EPHEMERAL_DH => 5, - ClientCertificateType::DSS_EPHEMERAL_DH => 6, - ClientCertificateType::FORTEZZA_DMS => 20, - ClientCertificateType::ECDSA_SIGN => 64, - ClientCertificateType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(1..=6 | 20 | 64)) } pub fn parse(input: &[u8]) -> IResult<&[u8], ClientCertificateType> { @@ -311,6 +294,28 @@ impl ClientCertificateType { } } +impl fmt::Debug for ClientCertificateType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + ClientCertificateType::RSA_SIGN => "RSA_SIGN", + ClientCertificateType::DSS_SIGN => "DSS_SIGN", + ClientCertificateType::RSA_FIXED_DH => "RSA_FIXED_DH", + ClientCertificateType::DSS_FIXED_DH => "DSS_FIXED_DH", + ClientCertificateType::RSA_EPHEMERAL_DH => "RSA_EPHEMERAL_DH", + ClientCertificateType::DSS_EPHEMERAL_DH => "DSS_EPHEMERAL_DH", + ClientCertificateType::FORTEZZA_DMS => "FORTEZZA_DMS", + ClientCertificateType::ECDSA_SIGN => "ECDSA_SIGN", + _ => unreachable!("known DTLS 1.2 client certificate type missing Debug label"), + }; + + f.write_str(name) + } +} + // SignatureAlgorithm and HashAlgorithm are now in crate::types pub type SignatureAndHashAlgorithmVec = @@ -400,6 +405,48 @@ mod tests { ); } + #[test] + fn client_certificate_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert!(ClientCertificateType::default().is_unknown()); + } + + #[test] + fn client_certificate_type_wire_roundtrip() { + for certificate_type in [ + ClientCertificateType::RSA_SIGN, + ClientCertificateType::DSS_SIGN, + ClientCertificateType::RSA_FIXED_DH, + ClientCertificateType::DSS_FIXED_DH, + ClientCertificateType::RSA_EPHEMERAL_DH, + ClientCertificateType::DSS_EPHEMERAL_DH, + ClientCertificateType::FORTEZZA_DMS, + ClientCertificateType::ECDSA_SIGN, + ] { + assert_eq!( + ClientCertificateType::from_u8(certificate_type.as_u8()), + certificate_type + ); + assert!(!certificate_type.is_unknown()); + } + + let unknown = ClientCertificateType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn client_certificate_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ClientCertificateType::ECDSA_SIGN), + "ECDSA_SIGN" + ); + assert_eq!( + format!("{:?}", ClientCertificateType::from_u8(0xFF)), + "Unknown(255)" + ); + } + #[test] fn unknown_dtls12_cipher_suite_uses_internal_derived_markers() { let unknown = Dtls12CipherSuite::from_u16(0xFFFF); From e126ea2f8e33349ebd33908a68ccca374b079b24 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:14:13 +0300 Subject: [PATCH 16/18] types: simplify known-code checks --- src/dtls12/message/mod.rs | 8 +--- src/types.rs | 90 ++++----------------------------------- 2 files changed, 9 insertions(+), 89 deletions(-) diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index dd2b157c..a70be8c9 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -84,13 +84,7 @@ impl Dtls12CipherSuite { /// Returns true if this is not a known DTLS 1.2 cipher suite wire value. pub const fn is_unknown(&self) -> bool { - !matches!( - *self, - Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 - | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 - | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 - | Dtls12CipherSuite::PSK_AES128_CCM_8 - ) + !matches!(*self, Self(0xC02B..=0xC02C | 0xC0A8 | 0xCCA9)) } /// Parse a `Dtls12CipherSuite` from network byte order. diff --git a/src/types.rs b/src/types.rs index 1843adad..115d9da3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -153,36 +153,7 @@ impl NamedGroup { /// Returns true if this is not a known TLS named group wire value. pub const fn is_unknown(&self) -> bool { - !matches!( - *self, - NamedGroup::SECT163K1 - | NamedGroup::SECT163R1 - | NamedGroup::SECT163R2 - | NamedGroup::SECT193R1 - | NamedGroup::SECT193R2 - | NamedGroup::SECT233K1 - | NamedGroup::SECT233R1 - | NamedGroup::SECT239K1 - | NamedGroup::SECT283K1 - | NamedGroup::SECT283R1 - | NamedGroup::SECT409K1 - | NamedGroup::SECT409R1 - | NamedGroup::SECT571K1 - | NamedGroup::SECT571R1 - | NamedGroup::SECP160K1 - | NamedGroup::SECP160R1 - | NamedGroup::SECP160R2 - | NamedGroup::SECP192K1 - | NamedGroup::SECP192R1 - | NamedGroup::SECP224K1 - | NamedGroup::SECP224R1 - | NamedGroup::SECP256K1 - | NamedGroup::SECP256R1 - | NamedGroup::SECP384R1 - | NamedGroup::SECP521R1 - | NamedGroup::X25519 - | NamedGroup::X448 - ) + !matches!(*self, Self(1..=25 | 29..=30)) } /// Parse a `NamedGroup` from wire format. @@ -323,16 +294,7 @@ impl HashAlgorithm { /// Returns true if this is not a known DTLS hash algorithm wire value. pub const fn is_unknown(&self) -> bool { - !matches!( - *self, - HashAlgorithm::NONE - | HashAlgorithm::MD5 - | HashAlgorithm::SHA1 - | HashAlgorithm::SHA224 - | HashAlgorithm::SHA256 - | HashAlgorithm::SHA384 - | HashAlgorithm::SHA512 - ) + self.0 > Self::SHA512.0 } /// Parse a `HashAlgorithm` from wire format. @@ -413,13 +375,7 @@ impl SignatureAlgorithm { /// Returns true if this is not a known DTLS signature algorithm wire value. pub const fn is_unknown(&self) -> bool { - !matches!( - *self, - SignatureAlgorithm::ANONYMOUS - | SignatureAlgorithm::RSA - | SignatureAlgorithm::DSA - | SignatureAlgorithm::ECDSA - ) + self.0 > Self::ECDSA.0 } /// Parse a `SignatureAlgorithm` from network bytes. @@ -477,14 +433,7 @@ impl ContentType { /// Returns true if this is not a known DTLS record content type. pub const fn is_unknown(&self) -> bool { - !matches!( - *self, - ContentType::CHANGE_CIPHER_SPEC - | ContentType::ALERT - | ContentType::HANDSHAKE - | ContentType::APPLICATION_DATA - | ContentType::ACK - ) + !matches!(*self, Self(20..=23 | 26)) } /// Parse a `ContentType` from wire format. @@ -618,20 +567,7 @@ impl SignatureScheme { pub const fn is_unknown(&self) -> bool { !matches!( *self, - SignatureScheme::ECDSA_SECP256R1_SHA256 - | SignatureScheme::ECDSA_SECP384R1_SHA384 - | SignatureScheme::ECDSA_SECP521R1_SHA512 - | SignatureScheme::ED25519 - | SignatureScheme::ED448 - | SignatureScheme::RSA_PSS_RSAE_SHA256 - | SignatureScheme::RSA_PSS_RSAE_SHA384 - | SignatureScheme::RSA_PSS_RSAE_SHA512 - | SignatureScheme::RSA_PSS_PSS_SHA256 - | SignatureScheme::RSA_PSS_PSS_SHA384 - | SignatureScheme::RSA_PSS_PSS_SHA512 - | SignatureScheme::RSA_PKCS1_SHA256 - | SignatureScheme::RSA_PKCS1_SHA384 - | SignatureScheme::RSA_PKCS1_SHA512 + Self(0x0401 | 0x0403 | 0x0501 | 0x0503 | 0x0601 | 0x0603 | 0x0804..=0x080b) ) } @@ -771,14 +707,7 @@ impl Dtls13CipherSuite { /// Returns true if this is not a known DTLS 1.3 cipher suite wire value. pub const fn is_unknown(&self) -> bool { - !matches!( - *self, - Dtls13CipherSuite::AES_128_GCM_SHA256 - | Dtls13CipherSuite::AES_256_GCM_SHA384 - | Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 - | Dtls13CipherSuite::AES_128_CCM_SHA256 - | Dtls13CipherSuite::AES_128_CCM_8_SHA256 - ) + !matches!(*self, Self(0x1301..=0x1305)) } /// Parse a `Dtls13CipherSuite` from wire format. @@ -874,10 +803,7 @@ impl ProtocolVersion { /// Returns true if this is not a known DTLS protocol version wire value. pub const fn is_unknown(&self) -> bool { - !matches!( - *self, - ProtocolVersion::DTLS1_0 | ProtocolVersion::DTLS1_2 | ProtocolVersion::DTLS1_3 - ) + !matches!(*self, Self(0xFEFF | 0xFEFD | 0xFEFC)) } /// Parse a `ProtocolVersion` from wire format. @@ -958,7 +884,7 @@ impl CompressionMethod { /// Returns true if this is not a known TLS compression method wire value. pub const fn is_unknown(&self) -> bool { - !matches!(*self, CompressionMethod::NULL | CompressionMethod::DEFLATE) + self.0 > Self::DEFLATE.0 } /// Parse a `CompressionMethod` from wire format. From 0fdd39580d473aa5324edd4515bff1676db9361f Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 19:25:42 +0300 Subject: [PATCH 17/18] types: preserve default wire-code values --- CHANGELOG.md | 1 + src/dtls12/message/extension.rs | 11 +++-------- src/dtls12/message/handshake.rs | 11 +++-------- src/dtls12/message/mod.rs | 10 +++------- src/dtls13/message/extension.rs | 11 +++-------- src/dtls13/message/handshake.rs | 9 ++------- src/types.rs | 8 ++++++++ 7 files changed, 23 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa3d3dee..0f861109 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Represent DTLS wire-code identifiers as compact newtypes (breaking) #TBD * Make public errors structured and fatal-only (breaking) #134 # 0.6.2 diff --git a/src/dtls12/message/extension.rs b/src/dtls12/message/extension.rs index 57404543..51c82b2f 100644 --- a/src/dtls12/message/extension.rs +++ b/src/dtls12/message/extension.rs @@ -52,15 +52,9 @@ impl Extension { } #[repr(transparent)] -#[derive(Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct ExtensionType(u16); -impl Default for ExtensionType { - fn default() -> Self { - Self(u16::MAX) - } -} - impl ExtensionType { pub const SERVER_NAME: Self = Self(0x0000); pub const MAX_FRAGMENT_LENGTH: Self = Self(0x0001); @@ -209,7 +203,8 @@ mod tests { #[test] fn extension_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); - assert!(ExtensionType::default().is_unknown()); + assert_eq!(ExtensionType::default().as_u16(), 0); + assert_eq!(ExtensionType::default(), ExtensionType::SERVER_NAME); } #[test] diff --git a/src/dtls12/message/handshake.rs b/src/dtls12/message/handshake.rs index 66ff43c1..742013f8 100644 --- a/src/dtls12/message/handshake.rs +++ b/src/dtls12/message/handshake.rs @@ -288,15 +288,9 @@ impl Handshake { } #[repr(transparent)] -#[derive(Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct MessageType(u8); -impl Default for MessageType { - fn default() -> Self { - Self(u8::MAX) - } -} - impl MessageType { pub const HELLO_REQUEST: Self = Self(0); pub const CLIENT_HELLO: Self = Self(1); @@ -544,7 +538,8 @@ mod tests { #[test] fn message_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); - assert!(MessageType::default().is_unknown()); + assert_eq!(MessageType::default().as_u8(), 0); + assert_eq!(MessageType::default(), MessageType::HELLO_REQUEST); } #[test] diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index a70be8c9..d4c3a513 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -240,15 +240,9 @@ pub type CertificateTypeVec = ArrayVec; #[repr(transparent)] -#[derive(Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct ClientCertificateType(u8); -impl Default for ClientCertificateType { - fn default() -> Self { - Self(u8::MAX) - } -} - impl ClientCertificateType { pub const RSA_SIGN: Self = Self(1); pub const DSS_SIGN: Self = Self(2); @@ -372,6 +366,7 @@ mod tests { #[test] fn dtls12_cipher_suite_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); + assert_eq!(Dtls12CipherSuite::default().as_u16(), 0); assert!(Dtls12CipherSuite::default().is_unknown()); } @@ -402,6 +397,7 @@ mod tests { #[test] fn client_certificate_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); + assert_eq!(ClientCertificateType::default().as_u8(), 0); assert!(ClientCertificateType::default().is_unknown()); } diff --git a/src/dtls13/message/extension.rs b/src/dtls13/message/extension.rs index ff799332..ec2b22aa 100644 --- a/src/dtls13/message/extension.rs +++ b/src/dtls13/message/extension.rs @@ -49,15 +49,9 @@ impl Extension { } #[repr(transparent)] -#[derive(Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct ExtensionType(u16); -impl Default for ExtensionType { - fn default() -> Self { - Self(u16::MAX) - } -} - impl ExtensionType { pub const SERVER_NAME: Self = Self(0x0000); pub const MAX_FRAGMENT_LENGTH: Self = Self(0x0001); @@ -204,7 +198,8 @@ mod tests { #[test] fn extension_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); - assert!(ExtensionType::default().is_unknown()); + assert_eq!(ExtensionType::default().as_u16(), 0); + assert_eq!(ExtensionType::default(), ExtensionType::SERVER_NAME); } #[test] diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index ce85a050..d6ca8c69 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -268,15 +268,9 @@ impl Handshake { } #[repr(transparent)] -#[derive(Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct MessageType(u8); -impl Default for MessageType { - fn default() -> Self { - Self(u8::MAX) - } -} - impl MessageType { pub const CLIENT_HELLO: Self = Self(1); pub const SERVER_HELLO: Self = Self(2); @@ -518,6 +512,7 @@ mod tests { #[test] fn message_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); + assert_eq!(MessageType::default().as_u8(), 0); assert!(MessageType::default().is_unknown()); } diff --git a/src/types.rs b/src/types.rs index 115d9da3..fadf6b69 100644 --- a/src/types.rs +++ b/src/types.rs @@ -911,6 +911,7 @@ mod tests { #[test] fn named_group_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); + assert_eq!(NamedGroup::default().as_u16(), 0); assert!(NamedGroup::default().is_unknown()); } @@ -939,6 +940,7 @@ mod tests { #[test] fn hash_algorithm_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); + assert_eq!(HashAlgorithm::default().as_u8(), 0); assert_eq!(HashAlgorithm::default(), HashAlgorithm::NONE); } @@ -987,6 +989,7 @@ mod tests { #[test] fn signature_algorithm_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); + assert_eq!(SignatureAlgorithm::default().as_u8(), 0); assert_eq!(SignatureAlgorithm::default(), SignatureAlgorithm::ANONYMOUS); } @@ -1023,6 +1026,7 @@ mod tests { #[test] fn compression_method_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); + assert_eq!(CompressionMethod::default().as_u8(), 0); assert_eq!(CompressionMethod::default(), CompressionMethod::NULL); } @@ -1057,6 +1061,7 @@ mod tests { #[test] fn content_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); + assert_eq!(ContentType::default().as_u8(), 0); assert!(ContentType::default().is_unknown()); } @@ -1094,6 +1099,7 @@ mod tests { #[test] fn signature_scheme_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); + assert_eq!(SignatureScheme::default().as_u16(), 0); assert!(SignatureScheme::default().is_unknown()); } @@ -1124,6 +1130,7 @@ mod tests { #[test] fn dtls13_cipher_suite_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); + assert_eq!(Dtls13CipherSuite::default().as_u16(), 0); assert!(Dtls13CipherSuite::default().is_unknown()); } @@ -1154,6 +1161,7 @@ mod tests { #[test] fn protocol_version_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); + assert_eq!(ProtocolVersion::default().as_u16(), 0); assert!(ProtocolVersion::default().is_unknown()); } From e8bec3634e6316eba6b87495d3103f787dffc5fc Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Fri, 29 May 2026 20:46:10 +0300 Subject: [PATCH 18/18] types: retain enum-style wire-code casing --- CHANGELOG.md | 2 +- src/config.rs | 8 +- src/crypto/aws_lc_rs/kx_group.rs | 16 +- src/crypto/aws_lc_rs/sign.rs | 14 +- src/crypto/provider.rs | 16 +- src/crypto/rust_crypto/kx_group.rs | 20 +- src/crypto/rust_crypto/sign.rs | 18 +- src/crypto/validation/mod.rs | 2 +- src/dtls12/client.rs | 70 +++-- src/dtls12/context.rs | 4 +- src/dtls12/engine.rs | 20 +- src/dtls12/incoming.rs | 10 +- src/dtls12/message/client_hello.rs | 23 +- src/dtls12/message/client_key_exchange.rs | 4 +- src/dtls12/message/extension.rs | 175 ++++++------ .../message/extensions/supported_groups.rs | 6 +- src/dtls12/message/handshake.rs | 130 +++++---- src/dtls12/message/named_group.rs | 21 +- src/dtls12/message/record.rs | 4 +- src/dtls12/message/server_hello.rs | 21 +- src/dtls12/queue.rs | 8 +- src/dtls12/server.rs | 86 +++--- src/dtls13/client.rs | 92 +++--- src/dtls13/engine.rs | 40 +-- src/dtls13/incoming.rs | 10 +- src/dtls13/message/client_hello.rs | 8 +- src/dtls13/message/encrypted_extensions.rs | 4 +- src/dtls13/message/extension.rs | 171 ++++++------ src/dtls13/message/extensions/key_share.rs | 2 +- .../message/extensions/supported_groups.rs | 6 +- src/dtls13/message/handshake.rs | 81 +++--- src/dtls13/message/record.rs | 4 +- src/dtls13/message/server_hello.rs | 14 +- src/dtls13/queue.rs | 6 +- src/dtls13/server.rs | 76 +++-- src/types.rs | 263 +++++++++--------- tests/dtls12/handshake.rs | 2 +- tests/dtls12/retransmit.rs | 2 +- tests/dtls13/handshake.rs | 16 +- 39 files changed, 732 insertions(+), 743 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f861109..e3bdafd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Unreleased - * Represent DTLS wire-code identifiers as compact newtypes (breaking) #TBD + * Represent DTLS wire-code identifiers as compact newtypes (breaking) #137 * Make public errors structured and fatal-only (breaking) #134 # 0.6.2 diff --git a/src/config.rs b/src/config.rs index b8a52a92..b84dc122 100644 --- a/src/config.rs +++ b/src/config.rs @@ -765,11 +765,11 @@ mod tests { #[test] fn filter_kx_groups() { let config = Config::builder() - .kx_groups(&[NamedGroup::SECP256R1]) + .kx_groups(&[NamedGroup::Secp256r1]) .build() .expect("should accept single kx group"); let groups: Vec<_> = config.kx_groups().map(|g| g.name()).collect(); - assert_eq!(groups, &[NamedGroup::SECP256R1]); + assert_eq!(groups, &[NamedGroup::Secp256r1]); } #[test] @@ -1003,7 +1003,7 @@ mod tests { .with_crypto_provider(aws_lc_rs::default_provider()) .dtls12_cipher_suites(&[Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384]) .dtls13_cipher_suites(&[Dtls13CipherSuite::AES_128_GCM_SHA256]) - .kx_groups(&[NamedGroup::X25519, NamedGroup::SECP256R1]) + .kx_groups(&[NamedGroup::X25519, NamedGroup::Secp256r1]) .build() .expect("should accept filtered config with explicit provider"); let suites12: Vec<_> = config.dtls12_cipher_suites().map(|cs| cs.suite()).collect(); @@ -1014,7 +1014,7 @@ mod tests { let suites13: Vec<_> = config.dtls13_cipher_suites().map(|cs| cs.suite()).collect(); assert_eq!(suites13, &[Dtls13CipherSuite::AES_128_GCM_SHA256]); let groups: Vec<_> = config.kx_groups().map(|g| g.name()).collect(); - assert_eq!(groups, &[NamedGroup::X25519, NamedGroup::SECP256R1]); + assert_eq!(groups, &[NamedGroup::X25519, NamedGroup::Secp256r1]); } } } diff --git a/src/crypto/aws_lc_rs/kx_group.rs b/src/crypto/aws_lc_rs/kx_group.rs index 88276895..1a79cce9 100644 --- a/src/crypto/aws_lc_rs/kx_group.rs +++ b/src/crypto/aws_lc_rs/kx_group.rs @@ -28,8 +28,8 @@ impl EcdhKeyExchange { fn new(group: NamedGroup, mut buf: Buf) -> Result { let algorithm = match group { NamedGroup::X25519 => &X25519, - NamedGroup::SECP256R1 => &ECDH_P256, - NamedGroup::SECP384R1 => &ECDH_P384, + NamedGroup::Secp256r1 => &ECDH_P256, + NamedGroup::Secp384r1 => &ECDH_P384, _ => return Err(CryptoError::UnsupportedKeyExchangeGroup(group)), }; @@ -54,8 +54,8 @@ impl EcdhKeyExchange { fn algorithm(&self) -> &'static aws_lc_rs::agreement::Algorithm { match self.group { NamedGroup::X25519 => &X25519, - NamedGroup::SECP256R1 => &ECDH_P256, - NamedGroup::SECP384R1 => &ECDH_P384, + NamedGroup::Secp256r1 => &ECDH_P256, + NamedGroup::Secp384r1 => &ECDH_P384, _ => unreachable!("Unsupported group"), } } @@ -109,11 +109,11 @@ struct P256; impl SupportedKxGroup for P256 { fn name(&self) -> NamedGroup { - NamedGroup::SECP256R1 + NamedGroup::Secp256r1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP256R1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp256r1, buf)?)) } } @@ -123,11 +123,11 @@ struct P384; impl SupportedKxGroup for P384 { fn name(&self) -> NamedGroup { - NamedGroup::SECP384R1 + NamedGroup::Secp384r1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP384R1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp384r1, buf)?)) } } diff --git a/src/crypto/aws_lc_rs/sign.rs b/src/crypto/aws_lc_rs/sign.rs index f6e27c5b..614202d5 100644 --- a/src/crypto/aws_lc_rs/sign.rs +++ b/src/crypto/aws_lc_rs/sign.rs @@ -221,18 +221,18 @@ impl SignatureVerifier for AwsLcSignatureVerifier { .map_err(|_| CryptoError::InvalidEcCurveParameter)?; let group = match curve_oid { - OID_P256 => NamedGroup::SECP256R1, - OID_P384 => NamedGroup::SECP384R1, + OID_P256 => NamedGroup::Secp256r1, + OID_P384 => NamedGroup::Secp384r1, _ => return Err(CryptoError::UnsupportedEcCurve), }; check_verify_scheme(sig_alg, hash_alg, group)?; let algorithm: &EcdsaVerificationAlgorithm = match (group, hash_alg) { - (NamedGroup::SECP256R1, HashAlgorithm::SHA256) => &ECDSA_P256_SHA256_ASN1, - (NamedGroup::SECP256R1, HashAlgorithm::SHA384) => &ECDSA_P256_SHA384_ASN1, - (NamedGroup::SECP384R1, HashAlgorithm::SHA256) => &ECDSA_P384_SHA256_ASN1, - (NamedGroup::SECP384R1, HashAlgorithm::SHA384) => &ECDSA_P384_SHA384_ASN1, + (NamedGroup::Secp256r1, HashAlgorithm::SHA256) => &ECDSA_P256_SHA256_ASN1, + (NamedGroup::Secp256r1, HashAlgorithm::SHA384) => &ECDSA_P256_SHA384_ASN1, + (NamedGroup::Secp384r1, HashAlgorithm::SHA256) => &ECDSA_P384_SHA256_ASN1, + (NamedGroup::Secp384r1, HashAlgorithm::SHA384) => &ECDSA_P384_SHA384_ASN1, // unreachable: check_verify_scheme already validated _ => unreachable!(), }; @@ -288,7 +288,7 @@ mod tests { CryptoError::SignatureVerificationFailed { signature: SignatureAlgorithm::ECDSA, hash: HashAlgorithm::SHA256, - group: NamedGroup::SECP256R1, + group: NamedGroup::Secp256r1, } ); } diff --git a/src/crypto/provider.rs b/src/crypto/provider.rs index bde0a0f5..a4f0af08 100644 --- a/src/crypto/provider.rs +++ b/src/crypto/provider.rs @@ -319,22 +319,22 @@ const SUPPORTED_VERIFY_SCHEMES: &[(SignatureAlgorithm, HashAlgorithm, NamedGroup ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA256, - NamedGroup::SECP256R1, + NamedGroup::Secp256r1, ), ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA256, - NamedGroup::SECP384R1, + NamedGroup::Secp384r1, ), ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA384, - NamedGroup::SECP256R1, + NamedGroup::Secp256r1, ), ( SignatureAlgorithm::ECDSA, HashAlgorithm::SHA384, - NamedGroup::SECP384R1, + NamedGroup::Secp384r1, ), ]; @@ -380,8 +380,8 @@ pub fn cert_named_group(cert_der: &[u8]) -> Result .map_err(|_| CertificateError::InvalidEcCurveParameter)?; match curve_oid { - OID_P256 => Ok(NamedGroup::SECP256R1), - OID_P384 => Ok(NamedGroup::SECP384R1), + OID_P256 => Ok(NamedGroup::Secp256r1), + OID_P384 => Ok(NamedGroup::Secp384r1), _ => Err(CertificateError::UnsupportedEcCurve), } } @@ -640,7 +640,7 @@ mod tests { let cert = params.self_signed(&key_pair).unwrap(); let group = cert_named_group(cert.der()).unwrap(); - assert_eq!(group, NamedGroup::SECP256R1); + assert_eq!(group, NamedGroup::Secp256r1); } #[test] @@ -653,7 +653,7 @@ mod tests { let cert = params.self_signed(&key_pair).unwrap(); let group = cert_named_group(cert.der()).unwrap(); - assert_eq!(group, NamedGroup::SECP384R1); + assert_eq!(group, NamedGroup::Secp384r1); } #[test] diff --git a/src/crypto/rust_crypto/kx_group.rs b/src/crypto/rust_crypto/kx_group.rs index 0d0d44d3..6593b0dc 100644 --- a/src/crypto/rust_crypto/kx_group.rs +++ b/src/crypto/rust_crypto/kx_group.rs @@ -57,7 +57,7 @@ impl EcdhKeyExchange { public_key: buf, }) } - NamedGroup::SECP256R1 => { + NamedGroup::Secp256r1 => { use rand_core::OsRng; let secret = EphemeralSecret::random(&mut OsRng); let public_key_obj = P256PublicKey::from(&secret); @@ -69,7 +69,7 @@ impl EcdhKeyExchange { public_key: buf, }) } - NamedGroup::SECP384R1 => { + NamedGroup::Secp384r1 => { use rand_core::OsRng; let secret = P384EphemeralSecret::random(&mut OsRng); let public_key_obj = P384PublicKey::from(&secret); @@ -113,7 +113,7 @@ impl ActiveKeyExchange for EcdhKeyExchange { } EcdhKeyExchange::P256 { secret, .. } => { let peer_key = P256PublicKey::from_sec1_bytes(peer_pub) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP256R1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp256r1))?; let shared_secret = secret.diffie_hellman(&peer_key); out.clear(); out.extend_from_slice(shared_secret.raw_secret_bytes().as_slice()); @@ -121,7 +121,7 @@ impl ActiveKeyExchange for EcdhKeyExchange { } EcdhKeyExchange::P384 { secret, .. } => { let peer_key = P384PublicKey::from_sec1_bytes(peer_pub) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP384R1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp384r1))?; let shared_secret = secret.diffie_hellman(&peer_key); out.clear(); out.extend_from_slice(shared_secret.raw_secret_bytes().as_slice()); @@ -133,8 +133,8 @@ impl ActiveKeyExchange for EcdhKeyExchange { fn group(&self) -> NamedGroup { match self { EcdhKeyExchange::X25519 { .. } => NamedGroup::X25519, - EcdhKeyExchange::P256 { .. } => NamedGroup::SECP256R1, - EcdhKeyExchange::P384 { .. } => NamedGroup::SECP384R1, + EcdhKeyExchange::P256 { .. } => NamedGroup::Secp256r1, + EcdhKeyExchange::P384 { .. } => NamedGroup::Secp384r1, } } } @@ -159,11 +159,11 @@ struct P256; impl SupportedKxGroup for P256 { fn name(&self) -> NamedGroup { - NamedGroup::SECP256R1 + NamedGroup::Secp256r1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP256R1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp256r1, buf)?)) } } @@ -173,11 +173,11 @@ struct P384; impl SupportedKxGroup for P384 { fn name(&self) -> NamedGroup { - NamedGroup::SECP384R1 + NamedGroup::Secp384r1 } fn start_exchange(&self, buf: Buf) -> Result, CryptoError> { - Ok(Box::new(EcdhKeyExchange::new(NamedGroup::SECP384R1, buf)?)) + Ok(Box::new(EcdhKeyExchange::new(NamedGroup::Secp384r1, buf)?)) } } diff --git a/src/crypto/rust_crypto/sign.rs b/src/crypto/rust_crypto/sign.rs index fa56a723..1bf80084 100644 --- a/src/crypto/rust_crypto/sign.rs +++ b/src/crypto/rust_crypto/sign.rs @@ -55,7 +55,7 @@ impl SigningKeyTrait for EcdsaSigningKey { } _ => { return Err(CryptoError::SigningKeyUnsupportedHash { - group: NamedGroup::SECP256R1, + group: NamedGroup::Secp256r1, hash: hash_alg, }); } @@ -77,7 +77,7 @@ impl SigningKeyTrait for EcdsaSigningKey { } _ => { return Err(CryptoError::SigningKeyUnsupportedHash { - group: NamedGroup::SECP384R1, + group: NamedGroup::Secp384r1, hash: hash_alg, }); } @@ -231,8 +231,8 @@ impl SignatureVerifier for RustCryptoSignatureVerifier { .map_err(|_| CryptoError::InvalidEcCurveParameter)?; let group = match curve_oid { - OID_P256 => NamedGroup::SECP256R1, - OID_P384 => NamedGroup::SECP384R1, + OID_P256 => NamedGroup::Secp256r1, + OID_P384 => NamedGroup::Secp384r1, _ => return Err(CryptoError::UnsupportedEcCurve), }; @@ -250,9 +250,9 @@ impl SignatureVerifier for RustCryptoSignatureVerifier { }; match group { - NamedGroup::SECP256R1 => { + NamedGroup::Secp256r1 => { let verifying_key = VerifyingKey::::from_sec1_bytes(pubkey_bytes) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP256R1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp256r1))?; let sig = Signature::::from_der(signature) .map_err(|_| CryptoError::InvalidSignatureFormat)?; verifying_key.verify_prehash(&hash, &sig).map_err(|_| { @@ -263,9 +263,9 @@ impl SignatureVerifier for RustCryptoSignatureVerifier { } }) } - NamedGroup::SECP384R1 => { + NamedGroup::Secp384r1 => { let verifying_key = VerifyingKey::::from_sec1_bytes(pubkey_bytes) - .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::SECP384R1))?; + .map_err(|_| CryptoError::InvalidPublicKey(NamedGroup::Secp384r1))?; let sig = Signature::::from_der(signature) .map_err(|_| CryptoError::InvalidSignatureFormat)?; verifying_key.verify_prehash(&hash, &sig).map_err(|_| { @@ -322,7 +322,7 @@ mod tests { CryptoError::SignatureVerificationFailed { signature: SignatureAlgorithm::ECDSA, hash: HashAlgorithm::SHA256, - group: NamedGroup::SECP256R1, + group: NamedGroup::Secp256r1, } ); } diff --git a/src/crypto/validation/mod.rs b/src/crypto/validation/mod.rs index 2ce2fc3b..3b19acce 100644 --- a/src/crypto/validation/mod.rs +++ b/src/crypto/validation/mod.rs @@ -40,7 +40,7 @@ impl CryptoProvider { self.kx_groups.iter().copied().filter(|kx| { matches!( kx.name(), - NamedGroup::X25519 | NamedGroup::SECP256R1 | NamedGroup::SECP384R1 + NamedGroup::X25519 | NamedGroup::Secp256r1 | NamedGroup::Secp384r1 ) }) } diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index c16fcbee..71ac7a70 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -234,7 +234,7 @@ impl Client { // Use the engine's create_record to send application data // The encryption is now handled in the engine self.engine - .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { + .create_record(ContentType::ApplicationData, 1, false, |body| { body.extend_from_slice(data); })?; @@ -252,7 +252,7 @@ impl Client { return Ok(()); } self.engine - .create_record(ContentType::ALERT, 1, false, |body| { + .create_record(ContentType::Alert, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -354,7 +354,7 @@ impl State { client .engine - .create_handshake(MessageType::CLIENT_HELLO, |body, engine| { + .create_handshake(MessageType::ClientHello, |body, engine| { handshake_create_client_hello( body, engine, @@ -377,7 +377,7 @@ impl State { fn await_hello_verify_request(self, client: &mut Client) -> Result { let has_hello = client .engine - .has_complete_handshake(MessageType::SERVER_HELLO); + .has_complete_handshake(MessageType::ServerHello); // Got ServerHello, skip HelloVerifyRequest if has_hello { @@ -385,7 +385,7 @@ impl State { } let maybe = client.engine.next_handshake( - MessageType::HELLO_VERIFY_REQUEST, + MessageType::HelloVerifyRequest, &mut client.defragment_buffer, )?; @@ -429,7 +429,7 @@ impl State { fn await_server_hello(self, client: &mut Client) -> Result { let maybe = client .engine - .next_handshake(MessageType::SERVER_HELLO, &mut client.defragment_buffer)?; + .next_handshake(MessageType::ServerHello, &mut client.defragment_buffer)?; let Some(handshake) = maybe else { // Stay in same state @@ -456,7 +456,7 @@ impl State { } // Enforce Null compression only - if server_hello.compression_method != CompressionMethod::NULL { + if server_hello.compression_method != CompressionMethod::Null { return Err( Error::SecurityError(crate::SecurityError::UnsupportedServerCompression( server_hello.compression_method, @@ -507,7 +507,7 @@ impl State { }; for extension in extensions { - if extension.extension_type == ExtensionType::USE_SRTP { + if extension.extension_type == ExtensionType::UseSrtp { // Parse the use_srtp extension to get the selected profile let extension_data = extension.extension_data(&client.defragment_buffer); let (_, use_srtp) = @@ -523,7 +523,7 @@ impl State { } // We are to use extended master secret - if extension.extension_type == ExtensionType::EXTENDED_MASTER_SECRET { + if extension.extension_type == ExtensionType::ExtendedMasterSecret { extended_master_secret = true; trace!("Server negotiated Extended Master Secret"); } @@ -554,7 +554,7 @@ impl State { fn await_certificate(self, client: &mut Client) -> Result { let maybe = client .engine - .next_handshake(MessageType::CERTIFICATE, &mut client.defragment_buffer)?; + .next_handshake(MessageType::Certificate, &mut client.defragment_buffer)?; let Some(ref handshake) = maybe else { // Stay in same state @@ -610,7 +610,7 @@ impl State { fn await_server_key_exchange_ecdhe(self, client: &mut Client) -> Result { let maybe = client.engine.next_handshake( - MessageType::SERVER_KEY_EXCHANGE, + MessageType::ServerKeyExchange, &mut client.defragment_buffer, )?; @@ -747,13 +747,13 @@ impl State { // If the server skipped ServerKeyExchange (no hint), go straight to ServerHelloDone let has_done = client .engine - .has_complete_handshake(MessageType::SERVER_HELLO_DONE); + .has_complete_handshake(MessageType::ServerHelloDone); if has_done { return Ok(Self::AwaitServerHelloDone); } let maybe = client.engine.next_handshake( - MessageType::SERVER_KEY_EXCHANGE, + MessageType::ServerKeyExchange, &mut client.defragment_buffer, )?; @@ -790,14 +790,14 @@ impl State { fn await_certificate_request(self, client: &mut Client) -> Result { let has_done = client .engine - .has_complete_handshake(MessageType::SERVER_HELLO_DONE); + .has_complete_handshake(MessageType::ServerHelloDone); if has_done { return Ok(Self::AwaitServerHelloDone); } let maybe = client.engine.next_handshake( - MessageType::CERTIFICATE_REQUEST, + MessageType::CertificateRequest, &mut client.defragment_buffer, )?; @@ -838,10 +838,9 @@ impl State { } fn await_server_hello_done(self, client: &mut Client) -> Result { - let maybe = client.engine.next_handshake( - MessageType::SERVER_HELLO_DONE, - &mut client.defragment_buffer, - )?; + let maybe = client + .engine + .next_handshake(MessageType::ServerHelloDone, &mut client.defragment_buffer)?; let Some(handshake) = maybe else { // stay in same state @@ -892,7 +891,7 @@ impl State { // Now use the engine with the stored data client .engine - .create_handshake(MessageType::CERTIFICATE, handshake_create_certificate)?; + .create_handshake(MessageType::Certificate, handshake_create_certificate)?; Ok(Self::SendClientKeyExchange) } @@ -907,7 +906,7 @@ impl State { // Send client key exchange message client.engine.create_handshake( - MessageType::CLIENT_KEY_EXCHANGE, + MessageType::ClientKeyExchange, handshake_create_client_key_exchange, )?; @@ -937,7 +936,7 @@ impl State { // Send the certificate verify message client.engine.create_handshake( - MessageType::CERTIFICATE_VERIFY, + MessageType::CertificateVerify, handshake_create_certificate_verify, )?; @@ -951,7 +950,7 @@ impl State { trace!("Sending ChangeCipherSpec"); client .engine - .create_record(ContentType::CHANGE_CIPHER_SPEC, 0, true, |body| { + .create_record(ContentType::ChangeCipherSpec, 0, true, |body| { // Change cipher spec is just a single byte with value 1 body.push(1); })?; @@ -1032,7 +1031,7 @@ impl State { client .engine - .create_handshake(MessageType::FINISHED, |body, engine| { + .create_handshake(MessageType::Finished, |body, engine| { // Calculate verify data for Finished message using PRF let verify_data = engine.generate_verify_data(true)?; @@ -1047,7 +1046,7 @@ impl State { } fn await_change_cipher_spec(self, client: &mut Client) -> Result { - let maybe = client.engine.next_record(ContentType::CHANGE_CIPHER_SPEC); + let maybe = client.engine.next_record(ContentType::ChangeCipherSpec); let Some(_) = maybe else { // Stay in same state @@ -1066,16 +1065,15 @@ impl State { } fn await_new_session_ticket(self, client: &mut Client) -> Result { - let has_finished = client.engine.has_complete_handshake(MessageType::FINISHED); + let has_finished = client.engine.has_complete_handshake(MessageType::Finished); if has_finished { return Ok(Self::AwaitFinished); } - let maybe = client.engine.next_handshake( - MessageType::NEW_SESSION_TICKET, - &mut client.defragment_buffer, - )?; + let maybe = client + .engine + .next_handshake(MessageType::NewSessionTicket, &mut client.defragment_buffer)?; let Some(handshake) = maybe else { // Stay in same state @@ -1101,7 +1099,7 @@ impl State { let maybe = client .engine - .next_handshake(MessageType::FINISHED, &mut client.defragment_buffer)?; + .next_handshake(MessageType::Finished, &mut client.defragment_buffer)?; if maybe.is_none() { // stay in same state @@ -1192,7 +1190,7 @@ impl State { client.engine.discard_pending_writes(); client .engine - .create_record(ContentType::ALERT, 1, false, |body| { + .create_record(ContentType::Alert, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -1207,7 +1205,7 @@ impl State { for data in client.queued_data.drain(..) { client .engine - .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { + .create_record(ContentType::ApplicationData, 1, false, |body| { body.extend_from_slice(&data); })?; } @@ -1244,7 +1242,7 @@ fn handshake_create_client_hello( ); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::NULL); + compression_methods.push(CompressionMethod::Null); // Create ClientHello with all required extensions let client_hello = ClientHello::new( @@ -1411,7 +1409,7 @@ mod tests { fn epoch0_handshake_packet(msg_type: MessageType, message_seq: u16, body: &[u8]) -> Vec { let handshake_len = 12 + body.len(); let mut packet = Vec::new(); - packet.push(ContentType::HANDSHAKE.as_u8()); + packet.push(ContentType::Handshake.as_u8()); packet.extend_from_slice(&[0xfe, 0xfd]); packet.extend_from_slice(&0u16.to_be_bytes()); packet.extend_from_slice(&0u64.to_be_bytes()[2..]); @@ -1431,7 +1429,7 @@ mod tests { client .engine .parse_packet(&epoch0_handshake_packet( - MessageType::CERTIFICATE, + MessageType::Certificate, 0, &[0, 0, 0], )) diff --git a/src/dtls12/context.rs b/src/dtls12/context.rs index 4831ddc0..30502c8f 100644 --- a/src/dtls12/context.rs +++ b/src/dtls12/context.rs @@ -558,14 +558,14 @@ impl CryptoContext { pub fn get_key_exchange_group_info(&self) -> Option<(CurveType, NamedGroup)> { // Use stored group if available (after key exchange is consumed) if let Some(group) = self.key_exchange_group { - return Some((CurveType::NAMED_CURVE, group)); + return Some((CurveType::NamedCurve, group)); } // Otherwise get it from the active key exchange let Some(ke) = &self.key_exchange else { return None; }; - Some((CurveType::NAMED_CURVE, ke.group())) + Some((CurveType::NamedCurve, ke.group())) } /// Check if the client's private key is compatible with a given cipher suite. diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 80acf3d9..3bd86e99 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -324,7 +324,7 @@ impl Engine { if self.peer_encryption_enabled && seq_current.epoch == 0 - && first.record().content_type == ContentType::HANDSHAKE + && first.record().content_type == ContentType::Handshake { return Ok(()); } @@ -332,7 +332,7 @@ impl Engine { if self.peer_encryption_enabled { for record in incoming.records().iter() { if record.record().sequence.epoch == 0 - && record.record().content_type == ContentType::HANDSHAKE + && record.record().content_type == ContentType::Handshake { if record.handshakes().is_empty() { record.set_handled(); @@ -443,7 +443,7 @@ impl Engine { .queue_rx .iter() .flat_map(|i| i.records().iter()) - .filter(|r| r.record().content_type == ContentType::APPLICATION_DATA) + .filter(|r| r.record().content_type == ContentType::ApplicationData) .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { @@ -697,7 +697,7 @@ impl Engine { pub fn drop_pending_ccs(&mut self) { for incoming in self.queue_rx.iter() { for record in incoming.records().iter() { - if record.record().content_type == ContentType::CHANGE_CIPHER_SPEC { + if record.record().content_type == ContentType::ChangeCipherSpec { record.set_handled(); } } @@ -983,7 +983,7 @@ impl Engine { }; // Emit the record; packing into current datagram happens inside create_record - self.create_record(ContentType::HANDSHAKE, epoch, true, |fragment| { + self.create_record(ContentType::Handshake, epoch, true, |fragment| { // Serialize with body_buffer as source frag_handshake.serialize(&body_buffer, fragment); })?; @@ -1237,7 +1237,7 @@ impl RecordHandler for Engine { fn classify_record(&mut self, record: Record) -> Result, Error> { let epoch = record.record().sequence.epoch; - if record.record().content_type == ContentType::CHANGE_CIPHER_SPEC + if record.record().content_type == ContentType::ChangeCipherSpec && epoch == 0 && self.peer_encryption_enabled { @@ -1250,7 +1250,7 @@ impl RecordHandler for Engine { return Ok(None); } - if record.record().content_type == ContentType::HANDSHAKE + if record.record().content_type == ContentType::Handshake && epoch == 0 && self.peer_encryption_enabled && record @@ -1266,7 +1266,7 @@ impl RecordHandler for Engine { return Ok(None); } - if record.record().content_type == ContentType::ALERT { + if record.record().content_type == ContentType::Alert { if epoch == 0 { if self.peer_encryption_enabled { // Post-handshake: epoch 0 alerts are unauthenticated, discard. @@ -1318,7 +1318,7 @@ impl RecordHandler for Engine { } if self.close_notify_received - && record.record().content_type == ContentType::APPLICATION_DATA + && record.record().content_type == ContentType::ApplicationData { self.push_buffer(record.into_buffer()); return Ok(None); @@ -1347,7 +1347,7 @@ impl RecordHandler for Engine { // that's known, a stale plaintext handshake (unauthenticated, replayable) // must no longer drive a courtesy flight retransmission. The client // confirms separately at its own completion (flight_stop_resend_timers). - if content_type == ContentType::APPLICATION_DATA { + if content_type == ContentType::ApplicationData { self.peer_handshake_confirmed = true; } } diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index 0e2c4bc8..ca21b6c0 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -279,7 +279,7 @@ impl ParsedRecord { ) -> Result { let (_, record) = DTLSRecord::parse(input, 0, offset)?; - let handshakes = if record.content_type == ContentType::HANDSHAKE { + let handshakes = if record.content_type == ContentType::Handshake { // This will also return None on the encrypted Finished after ChangeCipherSpec. // However we will then decrypt and try again. let fragment_offset = record.fragment_range.start; @@ -409,7 +409,7 @@ mod tests { impl RecordHandler for TestHandler { fn classify_record(&mut self, record: Record) -> Result, Error> { self.classify_calls += 1; - if record.record().content_type == ContentType::ALERT { + if record.record().content_type == ContentType::Alert { self.dropped_alerts += 1; return Ok(None); } @@ -464,9 +464,9 @@ mod tests { #[test] fn parse_packet_filters_control_records_after_packet_validation() { let mut packet = Vec::new(); - packet.extend_from_slice(&build_record(ContentType::ALERT, 0, 1, &[0x01, 0x00])); + packet.extend_from_slice(&build_record(ContentType::Alert, 0, 1, &[0x01, 0x00])); packet.extend_from_slice(&build_record( - ContentType::APPLICATION_DATA, + ContentType::ApplicationData, 1, 2, &[0xAA, 0xBB], @@ -482,7 +482,7 @@ mod tests { assert_eq!(incoming.records().len(), 1); assert_eq!( incoming.first().record().content_type, - ContentType::APPLICATION_DATA + ContentType::ApplicationData ); assert_eq!(incoming.first().record().sequence.epoch, 1); } diff --git a/src/dtls12/message/client_hello.rs b/src/dtls12/message/client_hello.rs index f0178f5e..2747a1fa 100644 --- a/src/dtls12/message/client_hello.rs +++ b/src/dtls12/message/client_hello.rs @@ -65,31 +65,31 @@ impl ClientHello { let supported_groups = SupportedGroupsExtension { groups }; let start_pos = buf.len(); supported_groups.serialize(buf); - ranges.push((ExtensionType::SUPPORTED_GROUPS, start_pos, buf.len())); + ranges.push((ExtensionType::SupportedGroups, start_pos, buf.len())); // Add EC point formats extension let ec_point_formats = ECPointFormatsExtension::default(); let start_pos = buf.len(); ec_point_formats.serialize(buf); - ranges.push((ExtensionType::EC_POINT_FORMATS, start_pos, buf.len())); + ranges.push((ExtensionType::EcPointFormats, start_pos, buf.len())); } // Add signature algorithms extension (required for TLS 1.2+) let signature_algorithms = SignatureAlgorithmsExtension::default(); let start_pos = buf.len(); signature_algorithms.serialize(buf); - ranges.push((ExtensionType::SIGNATURE_ALGORITHMS, start_pos, buf.len())); + ranges.push((ExtensionType::SignatureAlgorithms, start_pos, buf.len())); // Add use_srtp extension for DTLS-SRTP support let use_srtp = UseSrtpExtension::default(); let start_pos = buf.len(); use_srtp.serialize(buf); - ranges.push((ExtensionType::USE_SRTP, start_pos, buf.len())); + ranges.push((ExtensionType::UseSrtp, start_pos, buf.len())); // // Add session_ticket extension (empty) // let start_pos = buf.len(); // buf.extend_from_slice(&[0x00]); // Empty extension data - // ranges.push((ExtensionType::SESSION_TICKET, start_pos, buf.len())); + // ranges.push((ExtensionType::SessionTicket, start_pos, buf.len())); let need_etm = self .cipher_suites @@ -99,12 +99,12 @@ impl ClientHello { // Add encrypt_then_mac extension (empty) let start_pos = buf.len(); buf.extend_from_slice(&[0x00]); // Empty extension data - ranges.push((ExtensionType::ENCRYPT_THEN_MAC, start_pos, buf.len())); + ranges.push((ExtensionType::EncryptThenMac, start_pos, buf.len())); } let start_pos = buf.len(); ranges.push(( - ExtensionType::EXTENDED_MASTER_SECRET, + ExtensionType::ExtendedMasterSecret, start_pos, start_pos, // No data at all )); @@ -274,7 +274,7 @@ mod tests { 0xC0, 0x2B, // Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 0xC0, 0x2C, // Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 0x01, // CompressionMethods length - 0x00, // CompressionMethod::NULL + 0x00, // CompressionMethod::Null ]; #[test] @@ -286,7 +286,7 @@ mod tests { cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256); cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::NULL); + compression_methods.push(CompressionMethod::Null); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, @@ -333,9 +333,8 @@ mod tests { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice( - &ExtensionType::EXTENDED_MASTER_SECRET.as_u16().to_be_bytes(), - ); + message + .extend_from_slice(&ExtensionType::ExtendedMasterSecret.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls12/message/client_key_exchange.rs b/src/dtls12/message/client_key_exchange.rs index 147a24bd..522a6d59 100644 --- a/src/dtls12/message/client_key_exchange.rs +++ b/src/dtls12/message/client_key_exchange.rs @@ -42,8 +42,8 @@ impl ClientEcdhKeys { ClientEcdhKeys { // In ClientKeyExchange, we don't include curve_type and named_group // since they're already established during ServerKeyExchange - curve_type: CurveType::NAMED_CURVE, // Default - named_group: NamedGroup::SECP256R1, // Default + curve_type: CurveType::NamedCurve, // Default + named_group: NamedGroup::Secp256r1, // Default public_key_range: start..end, }, )) diff --git a/src/dtls12/message/extension.rs b/src/dtls12/message/extension.rs index 51c82b2f..3ea0c528 100644 --- a/src/dtls12/message/extension.rs +++ b/src/dtls12/message/extension.rs @@ -55,45 +55,46 @@ impl Extension { #[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct ExtensionType(u16); +#[allow(non_upper_case_globals)] impl ExtensionType { - pub const SERVER_NAME: Self = Self(0x0000); - pub const MAX_FRAGMENT_LENGTH: Self = Self(0x0001); - pub const CLIENT_CERTIFICATE_URL: Self = Self(0x0002); - pub const TRUSTED_CA_KEYS: Self = Self(0x0003); - pub const TRUNCATED_HMAC: Self = Self(0x0004); - pub const STATUS_REQUEST: Self = Self(0x0005); - pub const USER_MAPPING: Self = Self(0x0006); - pub const CLIENT_AUTHZ: Self = Self(0x0007); - pub const SERVER_AUTHZ: Self = Self(0x0008); - pub const CERT_TYPE: Self = Self(0x0009); - pub const SUPPORTED_GROUPS: Self = Self(0x000A); - pub const EC_POINT_FORMATS: Self = Self(0x000B); - pub const SRP: Self = Self(0x000C); - pub const SIGNATURE_ALGORITHMS: Self = Self(0x000D); - pub const USE_SRTP: Self = Self(0x000E); - pub const HEARTBEAT: Self = Self(0x000F); - pub const APPLICATION_LAYER_PROTOCOL_NEGOTIATION: Self = Self(0x0010); - pub const STATUS_REQUEST_V2: Self = Self(0x0011); - pub const SIGNED_CERTIFICATE_TIMESTAMP: Self = Self(0x0012); - pub const CLIENT_CERTIFICATE_TYPE: Self = Self(0x0013); - pub const SERVER_CERTIFICATE_TYPE: Self = Self(0x0014); - pub const PADDING: Self = Self(0x0015); - pub const ENCRYPT_THEN_MAC: Self = Self(0x0016); - pub const EXTENDED_MASTER_SECRET: Self = Self(0x0017); - pub const TOKEN_BINDING: Self = Self(0x0018); - pub const CACHED_INFO: Self = Self(0x0019); - pub const SESSION_TICKET: Self = Self(0x0023); - pub const PRE_SHARED_KEY: Self = Self(0x0029); - pub const EARLY_DATA: Self = Self(0x002A); - pub const SUPPORTED_VERSIONS: Self = Self(0x002B); - pub const COOKIE: Self = Self(0x002C); - pub const PSK_KEY_EXCHANGE_MODES: Self = Self(0x002D); - pub const CERTIFICATE_AUTHORITIES: Self = Self(0x002F); - pub const OID_FILTERS: Self = Self(0x0030); - pub const POST_HANDSHAKE_AUTH: Self = Self(0x0031); - pub const SIGNATURE_ALGORITHMS_CERT: Self = Self(0x0032); - pub const KEY_SHARE: Self = Self(0x0033); - pub const RENEGOTIATION_INFO: Self = Self(0xFF01); + pub const ServerName: Self = Self(0x0000); + pub const MaxFragmentLength: Self = Self(0x0001); + pub const ClientCertificateUrl: Self = Self(0x0002); + pub const TrustedCaKeys: Self = Self(0x0003); + pub const TruncatedHmac: Self = Self(0x0004); + pub const StatusRequest: Self = Self(0x0005); + pub const UserMapping: Self = Self(0x0006); + pub const ClientAuthz: Self = Self(0x0007); + pub const ServerAuthz: Self = Self(0x0008); + pub const CertType: Self = Self(0x0009); + pub const SupportedGroups: Self = Self(0x000A); + pub const EcPointFormats: Self = Self(0x000B); + pub const Srp: Self = Self(0x000C); + pub const SignatureAlgorithms: Self = Self(0x000D); + pub const UseSrtp: Self = Self(0x000E); + pub const Heartbeat: Self = Self(0x000F); + pub const ApplicationLayerProtocolNegotiation: Self = Self(0x0010); + pub const StatusRequestV2: Self = Self(0x0011); + pub const SignedCertificateTimestamp: Self = Self(0x0012); + pub const ClientCertificateType: Self = Self(0x0013); + pub const ServerCertificateType: Self = Self(0x0014); + pub const Padding: Self = Self(0x0015); + pub const EncryptThenMac: Self = Self(0x0016); + pub const ExtendedMasterSecret: Self = Self(0x0017); + pub const TokenBinding: Self = Self(0x0018); + pub const CachedInfo: Self = Self(0x0019); + pub const SessionTicket: Self = Self(0x0023); + pub const PreSharedKey: Self = Self(0x0029); + pub const EarlyData: Self = Self(0x002A); + pub const SupportedVersions: Self = Self(0x002B); + pub const Cookie: Self = Self(0x002C); + pub const PskKeyExchangeModes: Self = Self(0x002D); + pub const CertificateAuthorities: Self = Self(0x002F); + pub const OidFilters: Self = Self(0x0030); + pub const PostHandshakeAuth: Self = Self(0x0031); + pub const SignatureAlgorithmsCert: Self = Self(0x0032); + pub const KeyShare: Self = Self(0x0033); + pub const RenegotiationInfo: Self = Self(0xFF01); pub const fn from_u16(value: u16) -> Self { Self(value) @@ -123,14 +124,14 @@ impl ExtensionType { /// Supported extension types that this implementation handles. pub const fn supported() -> &'static [ExtensionType; 8] { &[ - ExtensionType::SUPPORTED_GROUPS, - ExtensionType::EC_POINT_FORMATS, - ExtensionType::SIGNATURE_ALGORITHMS, - ExtensionType::USE_SRTP, - ExtensionType::ENCRYPT_THEN_MAC, - ExtensionType::EXTENDED_MASTER_SECRET, - ExtensionType::RENEGOTIATION_INFO, - ExtensionType::SESSION_TICKET, + ExtensionType::SupportedGroups, + ExtensionType::EcPointFormats, + ExtensionType::SignatureAlgorithms, + ExtensionType::UseSrtp, + ExtensionType::EncryptThenMac, + ExtensionType::ExtendedMasterSecret, + ExtensionType::RenegotiationInfo, + ExtensionType::SessionTicket, ] } } @@ -142,46 +143,46 @@ impl fmt::Debug for ExtensionType { } let name = match *self { - ExtensionType::SERVER_NAME => "ServerName", - ExtensionType::MAX_FRAGMENT_LENGTH => "MaxFragmentLength", - ExtensionType::CLIENT_CERTIFICATE_URL => "ClientCertificateUrl", - ExtensionType::TRUSTED_CA_KEYS => "TrustedCaKeys", - ExtensionType::TRUNCATED_HMAC => "TruncatedHmac", - ExtensionType::STATUS_REQUEST => "StatusRequest", - ExtensionType::USER_MAPPING => "UserMapping", - ExtensionType::CLIENT_AUTHZ => "ClientAuthz", - ExtensionType::SERVER_AUTHZ => "ServerAuthz", - ExtensionType::CERT_TYPE => "CertType", - ExtensionType::SUPPORTED_GROUPS => "SupportedGroups", - ExtensionType::EC_POINT_FORMATS => "EcPointFormats", - ExtensionType::SRP => "Srp", - ExtensionType::SIGNATURE_ALGORITHMS => "SignatureAlgorithms", - ExtensionType::USE_SRTP => "UseSrtp", - ExtensionType::HEARTBEAT => "Heartbeat", - ExtensionType::APPLICATION_LAYER_PROTOCOL_NEGOTIATION => { + ExtensionType::ServerName => "ServerName", + ExtensionType::MaxFragmentLength => "MaxFragmentLength", + ExtensionType::ClientCertificateUrl => "ClientCertificateUrl", + ExtensionType::TrustedCaKeys => "TrustedCaKeys", + ExtensionType::TruncatedHmac => "TruncatedHmac", + ExtensionType::StatusRequest => "StatusRequest", + ExtensionType::UserMapping => "UserMapping", + ExtensionType::ClientAuthz => "ClientAuthz", + ExtensionType::ServerAuthz => "ServerAuthz", + ExtensionType::CertType => "CertType", + ExtensionType::SupportedGroups => "SupportedGroups", + ExtensionType::EcPointFormats => "EcPointFormats", + ExtensionType::Srp => "Srp", + ExtensionType::SignatureAlgorithms => "SignatureAlgorithms", + ExtensionType::UseSrtp => "UseSrtp", + ExtensionType::Heartbeat => "Heartbeat", + ExtensionType::ApplicationLayerProtocolNegotiation => { "ApplicationLayerProtocolNegotiation" } - ExtensionType::STATUS_REQUEST_V2 => "StatusRequestV2", - ExtensionType::SIGNED_CERTIFICATE_TIMESTAMP => "SignedCertificateTimestamp", - ExtensionType::CLIENT_CERTIFICATE_TYPE => "ClientCertificateType", - ExtensionType::SERVER_CERTIFICATE_TYPE => "ServerCertificateType", - ExtensionType::PADDING => "Padding", - ExtensionType::ENCRYPT_THEN_MAC => "EncryptThenMac", - ExtensionType::EXTENDED_MASTER_SECRET => "ExtendedMasterSecret", - ExtensionType::TOKEN_BINDING => "TokenBinding", - ExtensionType::CACHED_INFO => "CachedInfo", - ExtensionType::SESSION_TICKET => "SessionTicket", - ExtensionType::PRE_SHARED_KEY => "PreSharedKey", - ExtensionType::EARLY_DATA => "EarlyData", - ExtensionType::SUPPORTED_VERSIONS => "SupportedVersions", - ExtensionType::COOKIE => "Cookie", - ExtensionType::PSK_KEY_EXCHANGE_MODES => "PskKeyExchangeModes", - ExtensionType::CERTIFICATE_AUTHORITIES => "CertificateAuthorities", - ExtensionType::OID_FILTERS => "OidFilters", - ExtensionType::POST_HANDSHAKE_AUTH => "PostHandshakeAuth", - ExtensionType::SIGNATURE_ALGORITHMS_CERT => "SignatureAlgorithmsCert", - ExtensionType::KEY_SHARE => "KeyShare", - ExtensionType::RENEGOTIATION_INFO => "RenegotiationInfo", + ExtensionType::StatusRequestV2 => "StatusRequestV2", + ExtensionType::SignedCertificateTimestamp => "SignedCertificateTimestamp", + ExtensionType::ClientCertificateType => "ClientCertificateType", + ExtensionType::ServerCertificateType => "ServerCertificateType", + ExtensionType::Padding => "Padding", + ExtensionType::EncryptThenMac => "EncryptThenMac", + ExtensionType::ExtendedMasterSecret => "ExtendedMasterSecret", + ExtensionType::TokenBinding => "TokenBinding", + ExtensionType::CachedInfo => "CachedInfo", + ExtensionType::SessionTicket => "SessionTicket", + ExtensionType::PreSharedKey => "PreSharedKey", + ExtensionType::EarlyData => "EarlyData", + ExtensionType::SupportedVersions => "SupportedVersions", + ExtensionType::Cookie => "Cookie", + ExtensionType::PskKeyExchangeModes => "PskKeyExchangeModes", + ExtensionType::CertificateAuthorities => "CertificateAuthorities", + ExtensionType::OidFilters => "OidFilters", + ExtensionType::PostHandshakeAuth => "PostHandshakeAuth", + ExtensionType::SignatureAlgorithmsCert => "SignatureAlgorithmsCert", + ExtensionType::KeyShare => "KeyShare", + ExtensionType::RenegotiationInfo => "RenegotiationInfo", _ => unreachable!("known DTLS 1.2 extension type missing Debug label"), }; @@ -195,7 +196,7 @@ mod tests { use crate::buffer::Buf; const MESSAGE: &[u8] = &[ - 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS + 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension length 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; @@ -204,7 +205,7 @@ mod tests { fn extension_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); assert_eq!(ExtensionType::default().as_u16(), 0); - assert_eq!(ExtensionType::default(), ExtensionType::SERVER_NAME); + assert_eq!(ExtensionType::default(), ExtensionType::ServerName); } #[test] @@ -225,7 +226,7 @@ mod tests { #[test] fn extension_type_debug_stays_enum_like() { assert_eq!( - format!("{:?}", ExtensionType::SUPPORTED_GROUPS), + format!("{:?}", ExtensionType::SupportedGroups), "SupportedGroups" ); assert_eq!( diff --git a/src/dtls12/message/extensions/supported_groups.rs b/src/dtls12/message/extensions/supported_groups.rs index 09d2a984..12b9fb44 100644 --- a/src/dtls12/message/extensions/supported_groups.rs +++ b/src/dtls12/message/extensions/supported_groups.rs @@ -61,7 +61,7 @@ mod tests { fn test_supported_groups_extension() { let mut groups = NamedGroupVec::new(); groups.push(NamedGroup::X25519); - groups.push(NamedGroup::SECP256R1); + groups.push(NamedGroup::Secp256r1); let ext = SupportedGroupsExtension { groups }; @@ -98,8 +98,8 @@ mod tests { parsed.groups.as_slice(), &[ NamedGroup::X25519, - NamedGroup::SECP256R1, - NamedGroup::SECP384R1 + NamedGroup::Secp256r1, + NamedGroup::Secp384r1 ] ); } diff --git a/src/dtls12/message/handshake.rs b/src/dtls12/message/handshake.rs index 742013f8..ed555c3f 100644 --- a/src/dtls12/message/handshake.rs +++ b/src/dtls12/message/handshake.rs @@ -189,7 +189,7 @@ impl Handshake { let (rest, body) = Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)?; - if !rest.is_empty() && first_handshake.header.msg_type == MessageType::FINISHED { + if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished { debug!("Defragmentation failed. Body::parse() did not consume the entire buffer"); return Err(crate::InternalError::parse_incomplete()); } @@ -269,10 +269,10 @@ impl Handshake { let qualifies = matches!( self.header.msg_type, - MessageType::CLIENT_HELLO | // flight 1 and 3 - MessageType::HELLO_VERIFY_REQUEST | // flight 2 - MessageType::SERVER_HELLO_DONE | // flight 4 - MessageType::CLIENT_KEY_EXCHANGE // flight 5 + MessageType::ClientHello | // flight 1 and 3 + MessageType::HelloVerifyRequest | // flight 2 + MessageType::ServerHelloDone | // flight 4 + MessageType::ClientKeyExchange // flight 5 ); qualifies.then_some(self.header.message_seq) @@ -291,19 +291,20 @@ impl Handshake { #[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct MessageType(u8); +#[allow(non_upper_case_globals)] impl MessageType { - pub const HELLO_REQUEST: Self = Self(0); - pub const CLIENT_HELLO: Self = Self(1); - pub const SERVER_HELLO: Self = Self(2); - pub const HELLO_VERIFY_REQUEST: Self = Self(3); - pub const NEW_SESSION_TICKET: Self = Self(4); - pub const CERTIFICATE: Self = Self(11); - pub const SERVER_KEY_EXCHANGE: Self = Self(12); - pub const CERTIFICATE_REQUEST: Self = Self(13); - pub const SERVER_HELLO_DONE: Self = Self(14); - pub const CERTIFICATE_VERIFY: Self = Self(15); - pub const CLIENT_KEY_EXCHANGE: Self = Self(16); - pub const FINISHED: Self = Self(20); + pub const HelloRequest: Self = Self(0); + pub const ClientHello: Self = Self(1); + pub const ServerHello: Self = Self(2); + pub const HelloVerifyRequest: Self = Self(3); + pub const NewSessionTicket: Self = Self(4); + pub const Certificate: Self = Self(11); + pub const ServerKeyExchange: Self = Self(12); + pub const CertificateRequest: Self = Self(13); + pub const ServerHelloDone: Self = Self(14); + pub const CertificateVerify: Self = Self(15); + pub const ClientKeyExchange: Self = Self(16); + pub const Finished: Self = Self(20); pub const fn from_u8(value: u8) -> Self { Self(value) @@ -323,10 +324,7 @@ impl MessageType { } pub fn epoch(&self) -> u16 { - if matches!( - *self, - MessageType::NEW_SESSION_TICKET | MessageType::FINISHED - ) { + if matches!(*self, MessageType::NewSessionTicket | MessageType::Finished) { 1 } else { 0 @@ -341,18 +339,18 @@ impl fmt::Debug for MessageType { } let name = match *self { - MessageType::HELLO_REQUEST => "HelloRequest", - MessageType::CLIENT_HELLO => "ClientHello", - MessageType::HELLO_VERIFY_REQUEST => "HelloVerifyRequest", - MessageType::SERVER_HELLO => "ServerHello", - MessageType::CERTIFICATE => "Certificate", - MessageType::SERVER_KEY_EXCHANGE => "ServerKeyExchange", - MessageType::CERTIFICATE_REQUEST => "CertificateRequest", - MessageType::SERVER_HELLO_DONE => "ServerHelloDone", - MessageType::CERTIFICATE_VERIFY => "CertificateVerify", - MessageType::CLIENT_KEY_EXCHANGE => "ClientKeyExchange", - MessageType::NEW_SESSION_TICKET => "NewSessionTicket", - MessageType::FINISHED => "Finished", + MessageType::HelloRequest => "HelloRequest", + MessageType::ClientHello => "ClientHello", + MessageType::HelloVerifyRequest => "HelloVerifyRequest", + MessageType::ServerHello => "ServerHello", + MessageType::Certificate => "Certificate", + MessageType::ServerKeyExchange => "ServerKeyExchange", + MessageType::CertificateRequest => "CertificateRequest", + MessageType::ServerHelloDone => "ServerHelloDone", + MessageType::CertificateVerify => "CertificateVerify", + MessageType::ClientKeyExchange => "ClientKeyExchange", + MessageType::NewSessionTicket => "NewSessionTicket", + MessageType::Finished => "Finished", _ => unreachable!("known DTLS 1.2 handshake message type missing Debug label"), }; @@ -393,24 +391,24 @@ impl Body { c: Option, ) -> IResult<&[u8], Body> { match m { - MessageType::HELLO_REQUEST => Ok((input, Body::HelloRequest)), - MessageType::CLIENT_HELLO => { + MessageType::HelloRequest => Ok((input, Body::HelloRequest)), + MessageType::ClientHello => { let (input, client_hello) = ClientHello::parse(input, base_offset)?; Ok((input, Body::ClientHello(client_hello))) } - MessageType::HELLO_VERIFY_REQUEST => { + MessageType::HelloVerifyRequest => { let (input, hello_verify_request) = HelloVerifyRequest::parse(input)?; Ok((input, Body::HelloVerifyRequest(hello_verify_request))) } - MessageType::SERVER_HELLO => { + MessageType::ServerHello => { let (input, server_hello) = ServerHello::parse(input, base_offset)?; Ok((input, Body::ServerHello(server_hello))) } - MessageType::CERTIFICATE => { + MessageType::Certificate => { let (input, certificate) = Certificate::parse(input, base_offset)?; Ok((input, Body::Certificate(certificate))) } - MessageType::SERVER_KEY_EXCHANGE => { + MessageType::ServerKeyExchange => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let algo = cipher_suite.as_key_exchange_algorithm(); @@ -418,16 +416,16 @@ impl Body { ServerKeyExchange::parse(input, base_offset, algo)?; Ok((input, Body::ServerKeyExchange(server_key_exchange))) } - MessageType::CERTIFICATE_REQUEST => { + MessageType::CertificateRequest => { let (input, certificate_request) = CertificateRequest::parse(input, base_offset)?; Ok((input, Body::CertificateRequest(certificate_request))) } - MessageType::SERVER_HELLO_DONE => Ok((input, Body::ServerHelloDone)), - MessageType::CERTIFICATE_VERIFY => { + MessageType::ServerHelloDone => Ok((input, Body::ServerHelloDone)), + MessageType::CertificateVerify => { let (input, certificate_verify) = CertificateVerify::parse(input, base_offset)?; Ok((input, Body::CertificateVerify(certificate_verify))) } - MessageType::CLIENT_KEY_EXCHANGE => { + MessageType::ClientKeyExchange => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let algo = cipher_suite.as_key_exchange_algorithm(); @@ -435,12 +433,12 @@ impl Body { ClientKeyExchange::parse(input, base_offset, algo)?; Ok((input, Body::ClientKeyExchange(client_key_exchange))) } - MessageType::NEW_SESSION_TICKET => { + MessageType::NewSessionTicket => { // Treat ticket as opaque per RFC 5077: lifetime_hint(4) + ticket (opaque vector) let range = base_offset..(base_offset + input.len()); Ok((&[], Body::NewSessionTicket(range))) } - MessageType::FINISHED => { + MessageType::Finished => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let (input, finished) = Finished::parse(input, cipher_suite)?; @@ -513,7 +511,7 @@ mod tests { use crate::dtls12::message::SessionId; const MESSAGE: &[u8] = &[ - 0x01, // MessageType::CLIENT_HELLO + 0x01, // MessageType::ClientHello 0x00, 0x00, 0x2E, // length 0x00, 0x00, // message_seq 0x00, 0x00, 0x00, // fragment_offset @@ -532,31 +530,31 @@ mod tests { 0xC0, 0x2B, // Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 0xC0, 0x2C, // Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 0x01, // CompressionMethods length - 0x00, // CompressionMethod::NULL + 0x00, // CompressionMethod::Null ]; #[test] fn message_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); assert_eq!(MessageType::default().as_u8(), 0); - assert_eq!(MessageType::default(), MessageType::HELLO_REQUEST); + assert_eq!(MessageType::default(), MessageType::HelloRequest); } #[test] fn message_type_wire_roundtrip() { for message_type in [ - MessageType::HELLO_REQUEST, - MessageType::CLIENT_HELLO, - MessageType::SERVER_HELLO, - MessageType::HELLO_VERIFY_REQUEST, - MessageType::NEW_SESSION_TICKET, - MessageType::CERTIFICATE, - MessageType::SERVER_KEY_EXCHANGE, - MessageType::CERTIFICATE_REQUEST, - MessageType::SERVER_HELLO_DONE, - MessageType::CERTIFICATE_VERIFY, - MessageType::CLIENT_KEY_EXCHANGE, - MessageType::FINISHED, + MessageType::HelloRequest, + MessageType::ClientHello, + MessageType::ServerHello, + MessageType::HelloVerifyRequest, + MessageType::NewSessionTicket, + MessageType::Certificate, + MessageType::ServerKeyExchange, + MessageType::CertificateRequest, + MessageType::ServerHelloDone, + MessageType::CertificateVerify, + MessageType::ClientKeyExchange, + MessageType::Finished, ] { assert_eq!(MessageType::from_u8(message_type.as_u8()), message_type); assert!(!message_type.is_unknown()); @@ -569,7 +567,7 @@ mod tests { #[test] fn message_type_debug_stays_enum_like() { - assert_eq!(format!("{:?}", MessageType::CLIENT_HELLO), "ClientHello"); + assert_eq!(format!("{:?}", MessageType::ClientHello), "ClientHello"); assert_eq!(format!("{:?}", MessageType::from_u8(0xFF)), "Unknown(255)"); } @@ -577,7 +575,7 @@ mod tests { fn handshake_size() { let h = Handshake::new( // ServerHelloDone has a 0 sized body. - MessageType::SERVER_HELLO_DONE, + MessageType::ServerHelloDone, 0, 0, 0, @@ -602,7 +600,7 @@ mod tests { cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256); cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::NULL); + compression_methods.push(CompressionMethod::Null); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, @@ -614,7 +612,7 @@ mod tests { ); let handshake = Handshake::new( - MessageType::CLIENT_HELLO, + MessageType::ClientHello, 0x2E, 0, 0, @@ -645,7 +643,7 @@ mod tests { cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256); cipher_suites.push(Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::NULL); + compression_methods.push(CompressionMethod::Null); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, @@ -657,7 +655,7 @@ mod tests { ); let handshake = Handshake::new( - MessageType::CLIENT_HELLO, + MessageType::ClientHello, 46, 0, 0, diff --git a/src/dtls12/message/named_group.rs b/src/dtls12/message/named_group.rs index c4dad1ae..5e715dd1 100644 --- a/src/dtls12/message/named_group.rs +++ b/src/dtls12/message/named_group.rs @@ -14,13 +14,14 @@ use std::fmt; #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct CurveType(u8); +#[allow(non_upper_case_globals)] impl CurveType { /// Explicit prime curve parameters. - pub const EXPLICIT_PRIME: Self = Self(1); + pub const ExplicitPrime: Self = Self(1); /// Explicit characteristic-2 curve parameters. - pub const EXPLICIT_CHAR2: Self = Self(2); + pub const ExplicitChar2: Self = Self(2); /// Named curve (the common case). - pub const NAMED_CURVE: Self = Self(3); + pub const NamedCurve: Self = Self(3); /// Convert a u8 value to a `CurveType`. pub const fn from_u8(value: u8) -> Self { @@ -50,9 +51,9 @@ impl fmt::Debug for CurveType { } let name = match *self { - CurveType::EXPLICIT_PRIME => "ExplicitPrime", - CurveType::EXPLICIT_CHAR2 => "ExplicitChar2", - CurveType::NAMED_CURVE => "NamedCurve", + CurveType::ExplicitPrime => "ExplicitPrime", + CurveType::ExplicitChar2 => "ExplicitChar2", + CurveType::NamedCurve => "NamedCurve", _ => unreachable!("known DTLS 1.2 curve type missing Debug label"), }; @@ -72,9 +73,9 @@ mod tests { #[test] fn curve_type_wire_roundtrip() { for curve_type in [ - CurveType::EXPLICIT_PRIME, - CurveType::EXPLICIT_CHAR2, - CurveType::NAMED_CURVE, + CurveType::ExplicitPrime, + CurveType::ExplicitChar2, + CurveType::NamedCurve, ] { assert_eq!(CurveType::from_u8(curve_type.as_u8()), curve_type); assert!(!curve_type.is_unknown()); @@ -87,7 +88,7 @@ mod tests { #[test] fn curve_type_debug_stays_enum_like() { - assert_eq!(format!("{:?}", CurveType::NAMED_CURVE), "NamedCurve"); + assert_eq!(format!("{:?}", CurveType::NamedCurve), "NamedCurve"); assert_eq!(format!("{:?}", CurveType::from_u8(0xFF)), "Unknown(255)"); } } diff --git a/src/dtls12/message/record.rs b/src/dtls12/message/record.rs index 5e74c624..896e0af3 100644 --- a/src/dtls12/message/record.rs +++ b/src/dtls12/message/record.rs @@ -70,7 +70,7 @@ impl DTLSRecord { // the epoch-0 content types this implementation supports. if epoch == 0 { match content_type { - ContentType::CHANGE_CIPHER_SPEC | ContentType::ALERT | ContentType::HANDSHAKE => {} + ContentType::ChangeCipherSpec | ContentType::Alert | ContentType::Handshake => {} _ => { return Err(Err::Failure(nom::error::Error::new( input, @@ -156,7 +156,7 @@ mod tests { use crate::buffer::Buf; const RECORD: &[u8] = &[ - 0x16, // ContentType::HANDSHAKE + 0x16, // ContentType::Handshake 0xFE, 0xFD, // ProtocolVersion::DTLS1_2 0x00, 0x01, // epoch 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // sequence_number diff --git a/src/dtls12/message/server_hello.rs b/src/dtls12/message/server_hello.rs index 72fb249b..5ae0c6cf 100644 --- a/src/dtls12/message/server_hello.rs +++ b/src/dtls12/message/server_hello.rs @@ -58,17 +58,17 @@ impl ServerHello { profiles.push(pid); let ext = UseSrtpExtension::new(profiles, ArrayVec::new()); ext.serialize(buf); - ranges.push((ExtensionType::USE_SRTP, start, buf.len())); + ranges.push((ExtensionType::UseSrtp, start, buf.len())); } // Extended Master Secret (mandatory) let start = buf.len(); - ranges.push((ExtensionType::EXTENDED_MASTER_SECRET, start, start)); + ranges.push((ExtensionType::ExtendedMasterSecret, start, start)); // Renegotiation Info (RFC 5746) - empty for initial handshake let start = buf.len(); buf.push(0); // renegotiated_connection length = 0 - ranges.push((ExtensionType::RENEGOTIATION_INFO, start, buf.len())); + ranges.push((ExtensionType::RenegotiationInfo, start, buf.len())); let mut extensions = ExtensionVec::new(); for (t, s, e) in ranges { @@ -200,14 +200,14 @@ mod test { 0x01, // SessionId length 0xAA, // SessionId 0xC0, 0x2B, // Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 - 0x00, // CompressionMethod::NULL + 0x00, // CompressionMethod::Null 0x00, 0x0C, // Extensions length (12 bytes total: 2 type + 2 length + 8 data) - 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS + 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension data length (8 bytes) 0x00, 0x06, // Extension data - 0x00, 0x17, // NamedGroup::SECP256R1 - 0x00, 0x18, // NamedGroup::SECP384R1 - 0x00, 0x19, // NamedGroup::SECP521R1 + 0x00, 0x17, // NamedGroup::Secp256r1 + 0x00, 0x18, // NamedGroup::Secp384r1 + 0x00, 0x19, // NamedGroup::Secp521r1 ]; #[test] @@ -237,9 +237,8 @@ mod test { let mut message = MESSAGE[..39].to_vec(); message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice( - &ExtensionType::EXTENDED_MASTER_SECRET.as_u16().to_be_bytes(), - ); + message + .extend_from_slice(&ExtensionType::ExtendedMasterSecret.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls12/queue.rs b/src/dtls12/queue.rs index 8d1dddd2..b67f4dbd 100644 --- a/src/dtls12/queue.rs +++ b/src/dtls12/queue.rs @@ -50,10 +50,10 @@ impl fmt::Debug for QueueRx { for item in &self.0 { let record = item.first().record(); match record.content_type { - ContentType::HANDSHAKE => handshake += 1, - ContentType::APPLICATION_DATA => app_data += 1, - ContentType::ALERT => alert += 1, - ContentType::CHANGE_CIPHER_SPEC => ccs += 1, + ContentType::Handshake => handshake += 1, + ContentType::ApplicationData => app_data += 1, + ContentType::Alert => alert += 1, + ContentType::ChangeCipherSpec => ccs += 1, _ => other += 1, } diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index e0dbf3be..48e22f1a 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -232,7 +232,7 @@ impl Server { // Use the engine's create_record to send application data // The encryption is now handled in the engine self.engine - .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { + .create_record(ContentType::ApplicationData, 1, false, |body| { body.extend_from_slice(data); })?; @@ -250,7 +250,7 @@ impl Server { return Ok(()); } self.engine - .create_record(ContentType::ALERT, 1, false, |body| { + .create_record(ContentType::Alert, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -318,7 +318,7 @@ impl State { fn await_client_hello(self, server: &mut Server) -> Result { let maybe = server .engine - .next_handshake(MessageType::CLIENT_HELLO, &mut server.defragment_buffer)?; + .next_handshake(MessageType::ClientHello, &mut server.defragment_buffer)?; let Some(handshake) = maybe else { // Stay in same state @@ -340,7 +340,7 @@ impl State { } // Enforce Null compression only (client must offer it) - let has_null = ch.compression_methods.contains(&CompressionMethod::NULL); + let has_null = ch.compression_methods.contains(&CompressionMethod::Null); if !has_null { return Err( Error::SecurityError(crate::SecurityError::UnsupportedClientCompression).into(), @@ -370,16 +370,15 @@ impl State { let cookie = compute_cookie(hmac_provider, &server.cookie_secret, client_random)?; // Start/restart flight timer for server Flight 2 (HelloVerifyRequest) server.engine.flight_begin(2); - server.engine.create_handshake( - MessageType::HELLO_VERIFY_REQUEST, - |body, _engine| { + server + .engine + .create_handshake(MessageType::HelloVerifyRequest, |body, _engine| { // RFC 6347 4.2.1: The server_version field in the HelloVerifyRequest // message MUST be set to DTLS 1.0 let hvr = HelloVerifyRequest::new(ProtocolVersion::DTLS1_0, cookie); hvr.serialize(body); Ok(()) - }, - )?; + })?; // The HelloVerifyRequest exchange is stateless per RFC 6347. // Reset all handshake state so the next ClientHello (with cookie) is processed fresh. @@ -422,27 +421,27 @@ impl State { let mut client_signature_algorithms: Option = None; for ext in ch.extensions { match ext.extension_type { - ExtensionType::USE_SRTP => { + ExtensionType::UseSrtp => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, use_srtp) = UseSrtpExtension::parse(ext_data).map_err(InternalError::from)?; client_srtp_profiles = Some(use_srtp.profiles); } - ExtensionType::EXTENDED_MASTER_SECRET => { + ExtensionType::ExtendedMasterSecret => { client_offers_ems = true; } - ExtensionType::SUPPORTED_GROUPS => { + ExtensionType::SupportedGroups => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, groups) = SupportedGroupsExtension::parse(ext_data).map_err(InternalError::from)?; client_supported_groups = Some(groups.groups); } - ExtensionType::EC_POINT_FORMATS => { + ExtensionType::EcPointFormats => { let ext_data = ext.extension_data(&server.defragment_buffer); let _ = ECPointFormatsExtension::parse(ext_data).map_err(InternalError::from)?; } - ExtensionType::SIGNATURE_ALGORITHMS => { + ExtensionType::SignatureAlgorithms => { let ext_data = ext.extension_data(&server.defragment_buffer); if let Ok((_, sigs)) = SignatureAlgorithmsExtension::parse(ext_data) { client_signature_algorithms = Some(sigs.supported_signature_algorithms); @@ -506,7 +505,7 @@ impl State { // Send ServerHello server .engine - .create_handshake(MessageType::SERVER_HELLO, move |body, engine| { + .create_handshake(MessageType::ServerHello, move |body, engine| { handshake_create_server_hello( body, engine, @@ -534,7 +533,7 @@ impl State { server .engine - .create_handshake(MessageType::CERTIFICATE, handshake_create_certificate)?; + .create_handshake(MessageType::Certificate, handshake_create_certificate)?; Ok(Self::SendServerKeyExchange) } @@ -606,7 +605,7 @@ impl State { server .engine - .create_handshake(MessageType::SERVER_KEY_EXCHANGE, |body, engine| { + .create_handshake(MessageType::ServerKeyExchange, |body, engine| { handshake_create_server_key_exchange( body, engine, @@ -636,13 +635,12 @@ impl State { return Ok(Self::SendServerHelloDone); }; - server.engine.create_handshake( - MessageType::SERVER_KEY_EXCHANGE, - move |body, _engine| { + server + .engine + .create_handshake(MessageType::ServerKeyExchange, move |body, _engine| { PskParams::serialize_from_bytes(&hint, body); Ok(()) - }, - )?; + })?; // PSK never sends CertificateRequest Ok(Self::SendServerHelloDone) @@ -660,7 +658,7 @@ impl State { server .engine - .create_handshake(MessageType::CERTIFICATE_REQUEST, move |body, _| { + .create_handshake(MessageType::CertificateRequest, move |body, _| { handshake_serialize_certificate_request(body, &sig_algs) })?; @@ -672,7 +670,7 @@ impl State { server .engine - .create_handshake(MessageType::SERVER_HELLO_DONE, |_, _| Ok(()))?; + .create_handshake(MessageType::ServerHelloDone, |_, _| Ok(()))?; let cs = server.engine.cipher_suite().ok_or(Error::InvalidState( crate::InvalidStateError::NoCipherSuiteSelected, @@ -693,7 +691,7 @@ impl State { fn await_certificate(self, server: &mut Server) -> Result { let maybe = server .engine - .next_handshake(MessageType::CERTIFICATE, &mut server.defragment_buffer)?; + .next_handshake(MessageType::Certificate, &mut server.defragment_buffer)?; let Some(ref handshake) = maybe else { // Stay in same state @@ -739,7 +737,7 @@ impl State { fn await_client_key_exchange(self, server: &mut Server) -> Result { let maybe = server.engine.next_handshake( - MessageType::CLIENT_KEY_EXCHANGE, + MessageType::ClientKeyExchange, &mut server.defragment_buffer, )?; @@ -895,7 +893,7 @@ impl State { let data = server.engine.transcript().to_buf(); let maybe = server.engine.next_handshake( - MessageType::CERTIFICATE_VERIFY, + MessageType::CertificateVerify, &mut server.defragment_buffer, )?; @@ -950,7 +948,7 @@ impl State { } fn await_change_cipher_spec(self, server: &mut Server) -> Result { - let maybe = server.engine.next_record(ContentType::CHANGE_CIPHER_SPEC); + let maybe = server.engine.next_record(ContentType::ChangeCipherSpec); let Some(_) = maybe else { // Stay in same state @@ -976,7 +974,7 @@ impl State { let maybe = server .engine - .next_handshake(MessageType::FINISHED, &mut server.defragment_buffer)?; + .next_handshake(MessageType::Finished, &mut server.defragment_buffer)?; if maybe.is_none() { // stay in same state @@ -1050,7 +1048,7 @@ impl State { // Send ChangeCipherSpec server .engine - .create_record(ContentType::CHANGE_CIPHER_SPEC, 0, true, |body| { + .create_record(ContentType::ChangeCipherSpec, 0, true, |body| { body.push(1); })?; @@ -1062,7 +1060,7 @@ impl State { server .engine - .create_handshake(MessageType::FINISHED, |body, engine| { + .create_handshake(MessageType::Finished, |body, engine| { let verify_data = engine.generate_verify_data(false /* server */)?; trace!("Finished.verify_data length: {}", verify_data.len()); // Directly write the verify data without creating Finished struct @@ -1120,7 +1118,7 @@ impl State { server.engine.discard_pending_writes(); server .engine - .create_record(ContentType::ALERT, 1, false, |body| { + .create_record(ContentType::Alert, 1, false, |body| { body.push(1); // level: warning body.push(0); // description: close_notify })?; @@ -1136,7 +1134,7 @@ impl State { for data in server.queued_data.drain(..) { server .engine - .create_record(ContentType::APPLICATION_DATA, 1, false, |body| { + .create_record(ContentType::ApplicationData, 1, false, |body| { body.extend_from_slice(&data); })?; } @@ -1212,7 +1210,7 @@ fn handshake_create_server_hello( random, session_id, cs, - CompressionMethod::NULL, + CompressionMethod::Null, None, ) .with_extensions(extension_data, srtp_pid); @@ -1243,7 +1241,7 @@ fn handshake_create_server_key_exchange( match key_exchange_algorithm { KeyExchangeAlgorithm::EECDH => { - let (curve_type, named_group) = (CurveType::NAMED_CURVE, named_group); + let (curve_type, named_group) = (CurveType::NamedCurve, named_group); let mut kx_buf = engine.pop_buffer(); let pubkey = engine .crypto_context_mut() @@ -1420,11 +1418,11 @@ mod tests { #[test] fn select_named_group_prefers_x25519_when_available() { let client = named_group_vec(&[ - NamedGroup::SECP256R1, + NamedGroup::Secp256r1, NamedGroup::X25519, - NamedGroup::SECP384R1, + NamedGroup::Secp384r1, ]); - let provider = [NamedGroup::X25519, NamedGroup::SECP256R1]; + let provider = [NamedGroup::X25519, NamedGroup::Secp256r1]; let selected = select_named_group(Some(&client), &provider); @@ -1433,27 +1431,27 @@ mod tests { #[test] fn select_named_group_respects_provider_capabilities() { - let client = named_group_vec(&[NamedGroup::X25519, NamedGroup::SECP256R1]); - let provider = [NamedGroup::SECP256R1]; + let client = named_group_vec(&[NamedGroup::X25519, NamedGroup::Secp256r1]); + let provider = [NamedGroup::Secp256r1]; let selected = select_named_group(Some(&client), &provider); - assert_eq!(selected, Some(NamedGroup::SECP256R1)); + assert_eq!(selected, Some(NamedGroup::Secp256r1)); } #[test] fn select_named_group_falls_back_to_provider_when_client_missing() { - let provider = [NamedGroup::SECP384R1]; + let provider = [NamedGroup::Secp384r1]; let selected = select_named_group(None, &provider); - assert_eq!(selected, Some(NamedGroup::SECP384R1)); + assert_eq!(selected, Some(NamedGroup::Secp384r1)); } #[test] fn select_named_group_rejects_when_client_has_no_overlap() { let client = named_group_vec(&[NamedGroup::X25519]); - let provider = [NamedGroup::SECP256R1]; + let provider = [NamedGroup::Secp256r1]; let selected = select_named_group(Some(&client), &provider); diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index bba42df6..a6c60639 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -267,7 +267,7 @@ impl Client { let epoch = self.engine.app_send_epoch(); self.engine.create_ciphertext_record( - ContentType::APPLICATION_DATA, + ContentType::ApplicationData, epoch, false, |body| { @@ -290,7 +290,7 @@ impl Client { } let epoch = self.engine.app_send_epoch(); self.engine - .create_ciphertext_record(ContentType::ALERT, epoch, false, |body| { + .create_ciphertext_record(ContentType::Alert, epoch, false, |body| { body.push(1); // level: legacy (ignored in DTLS 1.3) body.push(0); // description: close_notify })?; @@ -428,7 +428,7 @@ impl State { client .engine - .create_handshake(MessageType::CLIENT_HELLO, |body, engine| { + .create_handshake(MessageType::ClientHello, |body, engine| { handshake_create_client_hello( body, engine, @@ -450,7 +450,7 @@ impl State { let maybe = client .engine - .next_handshake(MessageType::SERVER_HELLO, &mut client.defragment_buffer)?; + .next_handshake(MessageType::ServerHello, &mut client.defragment_buffer)?; let Some(handshake) = maybe else { return Ok(self); @@ -478,13 +478,13 @@ impl State { if let Some(ref extensions) = server_hello.extensions { for ext in extensions { match ext.extension_type { - ExtensionType::KEY_SHARE => { + ExtensionType::KeyShare => { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, hrr_ks)) = KeyShareHelloRetryRequest::parse(ext_data) { client.hrr_selected_group = Some(hrr_ks.selected_group); } } - ExtensionType::COOKIE => { + ExtensionType::Cookie => { let ext_data = ext.extension_data(&client.defragment_buffer); parse_cookie_extension(ext_data).map_err(InternalError::from)?; let mut cookie = Buf::new(); @@ -512,7 +512,7 @@ impl State { let mut hrr_version_ok = false; if let Some(ref extensions) = server_hello.extensions { for ext in extensions { - if ext.extension_type == ExtensionType::SUPPORTED_VERSIONS { + if ext.extension_type == ExtensionType::SupportedVersions { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, sv)) = SupportedVersionsServerHello::parse(ext_data) { hrr_version_ok = sv.selected_version == ProtocolVersion::DTLS1_3; @@ -551,7 +551,7 @@ impl State { } // Validate legacy_compression_method (must be null) - if server_hello.legacy_compression_method != CompressionMethod::NULL { + if server_hello.legacy_compression_method != CompressionMethod::Null { return Err((Error::SecurityError( crate::SecurityError::ServerHelloCompressionMustBeNull, )) @@ -592,7 +592,7 @@ impl State { for ext in extensions { match ext.extension_type { - ExtensionType::SUPPORTED_VERSIONS => { + ExtensionType::SupportedVersions => { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, sv)) = SupportedVersionsServerHello::parse(ext_data) { if sv.selected_version == ProtocolVersion::DTLS1_3 { @@ -600,7 +600,7 @@ impl State { } } } - ExtensionType::KEY_SHARE => { + ExtensionType::KeyShare => { let ext_data = ext.extension_data(&client.defragment_buffer); if let Ok((_, ks)) = KeyShareServerHello::parse(ext_data, 0) { // The key_exchange data is at offset 0 within ext_data, but @@ -681,7 +681,7 @@ impl State { fn await_encrypted_extensions(self, client: &mut Client) -> Result { let maybe = client.engine.next_handshake( - MessageType::ENCRYPTED_EXTENSIONS, + MessageType::EncryptedExtensions, &mut client.defragment_buffer, )?; @@ -695,7 +695,7 @@ impl State { // Process extensions for ext in &ee.extensions { - if ext.extension_type == ExtensionType::USE_SRTP { + if ext.extension_type == ExtensionType::UseSrtp { let ext_data = ext.extension_data(&client.defragment_buffer); let (_, use_srtp) = UseSrtpExtension::parse(ext_data).map_err(InternalError::from)?; @@ -717,14 +717,14 @@ impl State { // CertificateRequest is optional. Check if Certificate is available instead. let has_cert = client .engine - .has_complete_handshake(MessageType::CERTIFICATE); + .has_complete_handshake(MessageType::Certificate); if has_cert { return Ok(Self::AwaitCertificate); } let maybe = client.engine.next_handshake( - MessageType::CERTIFICATE_REQUEST, + MessageType::CertificateRequest, &mut client.defragment_buffer, )?; @@ -755,7 +755,7 @@ impl State { fn await_certificate(self, client: &mut Client) -> Result { let maybe = client .engine - .next_handshake(MessageType::CERTIFICATE, &mut client.defragment_buffer)?; + .next_handshake(MessageType::Certificate, &mut client.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -818,7 +818,7 @@ impl State { client.engine.transcript_hash(&mut transcript_hash); let maybe = client.engine.next_handshake( - MessageType::CERTIFICATE_VERIFY, + MessageType::CertificateVerify, &mut client.defragment_buffer, )?; @@ -898,7 +898,7 @@ impl State { let maybe = client .engine - .next_handshake(MessageType::FINISHED, &mut client.defragment_buffer)?; + .next_handshake(MessageType::Finished, &mut client.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -972,7 +972,7 @@ impl State { client .engine - .create_handshake(MessageType::CERTIFICATE, |body, engine| { + .create_handshake(MessageType::Certificate, |body, engine| { handshake_create_certificate(body, engine, &context_copy) })?; @@ -986,7 +986,7 @@ impl State { client .engine - .create_handshake(MessageType::CERTIFICATE, |body, _engine| { + .create_handshake(MessageType::Certificate, |body, _engine| { // certificate_request_context body.push(context_copy.len() as u8); body.extend_from_slice(&context_copy); @@ -1004,7 +1004,7 @@ impl State { client .engine - .create_handshake(MessageType::CERTIFICATE_VERIFY, |body, engine| { + .create_handshake(MessageType::CertificateVerify, |body, engine| { handshake_create_certificate_verify( body, engine, @@ -1039,7 +1039,7 @@ impl State { client .engine - .create_handshake(MessageType::FINISHED, |body, engine| { + .create_handshake(MessageType::Finished, |body, engine| { let verify_data = engine.compute_verify_data(&client_hs_secret)?; body.extend_from_slice(&verify_data); Ok(()) @@ -1090,7 +1090,7 @@ impl State { ); for data in client.queued_data.drain(..) { client.engine.create_ciphertext_record( - ContentType::APPLICATION_DATA, + ContentType::ApplicationData, epoch, false, |body| { @@ -1109,12 +1109,9 @@ impl State { } // Check for incoming KeyUpdate - if client - .engine - .has_complete_handshake(MessageType::KEY_UPDATE) - { + if client.engine.has_complete_handshake(MessageType::KeyUpdate) { let maybe = client.engine.next_handshake_no_transcript( - MessageType::KEY_UPDATE, + MessageType::KeyUpdate, &mut client.defragment_buffer, )?; @@ -1145,12 +1142,9 @@ impl State { fn half_closed_local(self, client: &mut Client) -> Result { // Write half is closed: drain incoming KeyUpdate to keep recv keys in sync, // but do not send our own KeyUpdate response. - if client - .engine - .has_complete_handshake(MessageType::KEY_UPDATE) - { + if client.engine.has_complete_handshake(MessageType::KeyUpdate) { let maybe = client.engine.next_handshake_no_transcript( - MessageType::KEY_UPDATE, + MessageType::KeyUpdate, &mut client.defragment_buffer, )?; if let Some(handshake) = maybe { @@ -1202,7 +1196,7 @@ fn handshake_create_client_hello( ); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::NULL); + compression_methods.push(CompressionMethod::Null); // Build extensions let mut extensions: ArrayVec = ArrayVec::new(); @@ -1216,7 +1210,7 @@ fn handshake_create_client_hello( sv.serialize(&mut ext_buf); let sv_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SUPPORTED_VERSIONS, + extension_type: ExtensionType::SupportedVersions, extension_data_range: sv_start..sv_end, }); @@ -1227,7 +1221,7 @@ fn handshake_create_client_hello( sg.serialize(&mut ext_buf); let sg_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SUPPORTED_GROUPS, + extension_type: ExtensionType::SupportedGroups, extension_data_range: sg_start..sg_end, }); @@ -1242,7 +1236,7 @@ fn handshake_create_client_hello( ks.serialize(extension_data, &mut ext_buf); let ks_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::KEY_SHARE, + extension_type: ExtensionType::KeyShare, extension_data_range: ks_start..ks_end, }); @@ -1252,7 +1246,7 @@ fn handshake_create_client_hello( sa.serialize(&mut ext_buf); let sa_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SIGNATURE_ALGORITHMS, + extension_type: ExtensionType::SignatureAlgorithms, extension_data_range: sa_start..sa_end, }); @@ -1262,7 +1256,7 @@ fn handshake_create_client_hello( use_srtp.serialize(&mut ext_buf); let srtp_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::USE_SRTP, + extension_type: ExtensionType::UseSrtp, extension_data_range: srtp_start..srtp_end, }); @@ -1273,7 +1267,7 @@ fn handshake_create_client_hello( ext_buf.extend_from_slice(&extension_data[cookie_range]); let cookie_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::COOKIE, + extension_type: ExtensionType::Cookie, extension_data_range: cookie_start..cookie_end, }); } @@ -1302,7 +1296,7 @@ fn handshake_create_client_hello( } let pad_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::PADDING, + extension_type: ExtensionType::Padding, extension_data_range: pad_start..pad_end, }); } @@ -1521,7 +1515,7 @@ fn parse_certificate_request(cr_data: &[u8], base_offset: usize) -> Result let ca_data = &cr_data[pos..pos + ext_data_len]; if ca_data.len() >= 2 { @@ -1609,7 +1603,7 @@ mod tests { fn epoch0_handshake_packet(msg_type: MessageType, message_seq: u16, body: &[u8]) -> Vec { let handshake_len = 12 + body.len(); let mut packet = Vec::new(); - packet.push(ContentType::HANDSHAKE.as_u8()); + packet.push(ContentType::Handshake.as_u8()); packet.extend_from_slice(&[0xfe, 0xfd]); packet.extend_from_slice(&0u16.to_be_bytes()); packet.extend_from_slice(&0u64.to_be_bytes()[2..]); @@ -1630,10 +1624,10 @@ mod tests { key_share.push(0); let mut extensions = Vec::new(); - extensions.extend_from_slice(&ExtensionType::SUPPORTED_VERSIONS.as_u16().to_be_bytes()); + extensions.extend_from_slice(&ExtensionType::SupportedVersions.as_u16().to_be_bytes()); extensions.extend_from_slice(&2u16.to_be_bytes()); extensions.extend_from_slice(&ProtocolVersion::DTLS1_3.as_u16().to_be_bytes()); - extensions.extend_from_slice(&ExtensionType::KEY_SHARE.as_u16().to_be_bytes()); + extensions.extend_from_slice(&ExtensionType::KeyShare.as_u16().to_be_bytes()); extensions.extend_from_slice(&(key_share.len() as u16).to_be_bytes()); extensions.extend_from_slice(&key_share); @@ -1642,7 +1636,7 @@ mod tests { body.extend_from_slice(&[7; 32]); body.push(0); // legacy_session_id body.extend_from_slice(&Dtls13CipherSuite::AES_128_GCM_SHA256.as_u16().to_be_bytes()); - body.push(CompressionMethod::NULL.as_u8()); + body.push(CompressionMethod::Null.as_u8()); body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); body.extend_from_slice(&extensions); body @@ -1657,7 +1651,7 @@ mod tests { client .engine .parse_packet(&epoch0_handshake_packet( - MessageType::CERTIFICATE, + MessageType::Certificate, 2, &[0, 0, 0, 0], )) @@ -1683,9 +1677,9 @@ mod tests { client .engine .parse_packet(&epoch0_handshake_packet( - MessageType::SERVER_HELLO, + MessageType::ServerHello, 0, - &server_hello_with_key_share(NamedGroup::SECP256R1), + &server_hello_with_key_share(NamedGroup::Secp256r1), )) .expect("queue mismatched ServerHello"); @@ -1698,7 +1692,7 @@ mod tests { crate::InternalError::Fatal(Error::SecurityError( SecurityError::ServerKeyShareGroupMismatch { expected: NamedGroup::X25519, - actual: NamedGroup::SECP256R1, + actual: NamedGroup::Secp256r1, } )) )); diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 80ec80c1..d67ccc34 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -389,7 +389,7 @@ impl Engine { // but allow KeyUpdate (a post-handshake message). if self.release_app_data && handshake.header.message_seq >= self.peer_handshake_seq_no - && handshake.header.msg_type != MessageType::KEY_UPDATE + && handshake.header.msg_type != MessageType::KeyUpdate { return Err(Error::RenegotiationAttempt); } @@ -558,7 +558,7 @@ impl Engine { .queue_rx .iter() .flat_map(|i| i.records().iter()) - .filter(|r| r.record().content_type == ContentType::APPLICATION_DATA) + .filter(|r| r.record().content_type == ContentType::ApplicationData) .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { @@ -789,7 +789,7 @@ impl Engine { &mut self, defragment_buffer: &mut Buf, ) -> Result, InternalError> { - self.next_handshake_with_options(MessageType::CLIENT_HELLO, defragment_buffer, true) + self.next_handshake_with_options(MessageType::ClientHello, defragment_buffer, true) } fn next_handshake_with_options( @@ -1066,7 +1066,7 @@ impl Engine { // Build the record for serialization let record = Dtls13Record { - content_type: ContentType::APPLICATION_DATA, + content_type: ContentType::ApplicationData, sequence: Sequence { epoch, sequence_number: seq, @@ -1216,11 +1216,11 @@ impl Engine { }; if epoch == 0 { - self.create_plaintext_record(ContentType::HANDSHAKE, true, |fragment| { + self.create_plaintext_record(ContentType::Handshake, true, |fragment| { frag_handshake.serialize(&body_buffer, fragment); })?; } else { - self.create_ciphertext_record(ContentType::HANDSHAKE, epoch, true, |fragment| { + self.create_ciphertext_record(ContentType::Handshake, epoch, true, |fragment| { frag_handshake.serialize(&body_buffer, fragment); })?; } @@ -1288,7 +1288,7 @@ impl Engine { 2 }; - self.create_ciphertext_record(ContentType::ACK, epoch, false, |fragment| { + self.create_ciphertext_record(ContentType::Ack, epoch, false, |fragment| { // record_numbers_length: 2 bytes, value = entries.len() * 16 let len = (entries.len() * 16) as u16; fragment.extend_from_slice(&len.to_be_bytes()); @@ -1515,7 +1515,7 @@ impl Engine { for incoming in self.queue_rx.iter() { for r in incoming.records().iter() { if r.record().sequence.epoch == 2 - && r.record().content_type == ContentType::HANDSHAKE + && r.record().content_type == ContentType::Handshake { let seq = r.record().sequence; let _ = record_numbers.try_push((seq.epoch as u64, seq.sequence_number)); @@ -1538,7 +1538,7 @@ impl Engine { return Ok(()); } - self.create_ciphertext_record(ContentType::ACK, 2, false, |fragment| { + self.create_ciphertext_record(ContentType::Ack, 2, false, |fragment| { let len = (record_numbers.len() * 16) as u16; fragment.extend_from_slice(&len.to_be_bytes()); for &(epoch, seq) in record_numbers { @@ -1890,10 +1890,10 @@ impl Engine { let epoch = self.app_send_epoch; // Build the handshake message manually (12-byte DTLS header + 1-byte body) - self.create_ciphertext_record(ContentType::HANDSHAKE, epoch, true, |fragment| { + self.create_ciphertext_record(ContentType::Handshake, epoch, true, |fragment| { // DTLS handshake header (12 bytes): // msg_type(1) + length(3) + message_seq(2) + fragment_offset(3) + fragment_length(3) - fragment.push(MessageType::KEY_UPDATE.as_u8()); + fragment.push(MessageType::KeyUpdate.as_u8()); fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); // length = 1 fragment.extend_from_slice(&msg_seq.to_be_bytes()); // message_seq fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); // fragment_offset = 0 @@ -2246,7 +2246,7 @@ fn jittered_aead_threshold(limit: u64, rng: &mut SeededRng) -> u64 { /// All other handshake messages are encrypted (epoch 2). fn epoch_for_message(msg_type: MessageType) -> u16 { match msg_type { - MessageType::CLIENT_HELLO | MessageType::SERVER_HELLO => 0, + MessageType::ClientHello | MessageType::ServerHello => 0, _ => 2, } } @@ -2294,7 +2294,7 @@ impl RecordHandler for Engine { && self.peer_encryption_enabled && matches!( record.record().content_type, - ContentType::ACK | ContentType::ALERT + ContentType::Ack | ContentType::Alert ) { // Plaintext ACKs and alerts after peer encryption is enabled are @@ -2304,13 +2304,13 @@ impl RecordHandler for Engine { } match record.record().content_type { - ContentType::ACK => { + ContentType::Ack => { let fragment = record.record().fragment(record.buffer()); self.process_ack(fragment); self.push_buffer(record.into_buffer()); Ok(None) } - ContentType::ALERT => { + ContentType::Alert => { // RFC 8446 §6: TLS 1.3 ignores the AlertLevel byte; severity is // implicit in the description (only close_notify and user_canceled // are non-fatal). @@ -2335,7 +2335,7 @@ impl RecordHandler for Engine { None => Ok(None), } } - ContentType::CHANGE_CIPHER_SPEC => { + ContentType::ChangeCipherSpec => { trace!("Discarding CCS record"); self.push_buffer(record.into_buffer()); Ok(None) @@ -2558,13 +2558,13 @@ mod tests { fn encrypted_key_update_record(seq: u16) -> Vec { let mut fragment = Vec::new(); - fragment.push(MessageType::KEY_UPDATE.as_u8()); + fragment.push(MessageType::KeyUpdate.as_u8()); fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); fragment.extend_from_slice(&0u16.to_be_bytes()); fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); fragment.push(KeyUpdateRequest::UpdateRequested.as_u8()); - fragment.push(ContentType::HANDSHAKE.as_u8()); + fragment.push(ContentType::Handshake.as_u8()); let mut packet = Vec::new(); packet.push( @@ -2602,7 +2602,7 @@ mod tests { // Set epoch-0 sequence to MAX — the next increment should be rejected engine.sequence_epoch_0.sequence_number = MAX_SEQUENCE_NUMBER; - let result = engine.create_plaintext_record(ContentType::HANDSHAKE, false, |buf| { + let result = engine.create_plaintext_record(ContentType::Handshake, false, |buf| { buf.extend_from_slice(b"test") }); assert!( @@ -2736,7 +2736,7 @@ mod tests { fn malformed_ack_record_number_vector_is_ignored() { let mut engine = test_engine(); engine.flight_saved_records.push(Entry { - content_type: ContentType::HANDSHAKE, + content_type: ContentType::Handshake, epoch: 2, send_seq: 7, fragment: Buf::new(), diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index c1b42e1f..f617a416 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -363,7 +363,7 @@ impl ParsedRecord { ) -> Result { let (_, record) = Dtls13Record::parse(input, 0)?; - let handshakes = if record.content_type == ContentType::HANDSHAKE { + let handshakes = if record.content_type == ContentType::Handshake { let fragment_offset = record.fragment_range.start; parse_handshakes(record.fragment(input), fragment_offset, cipher_suite) } else { @@ -383,7 +383,7 @@ impl ParsedRecord { input: &[u8], cipher_suite: Option, ) -> ParsedRecord { - let handshakes = if record.content_type == ContentType::HANDSHAKE { + let handshakes = if record.content_type == ContentType::Handshake { let fragment_offset = record.fragment_range.start; parse_handshakes(record.fragment(input), fragment_offset, cipher_suite) } else { @@ -541,7 +541,7 @@ mod tests { impl RecordHandler for TestHandler { fn classify_record(&mut self, record: Record) -> Result, Error> { self.classify_calls += 1; - if record.record().content_type == ContentType::ACK { + if record.record().content_type == ContentType::Ack { self.dropped_acks += 1; return Ok(None); } @@ -617,7 +617,7 @@ mod tests { #[test] fn parse_packet_filters_control_records_after_packet_validation() { let mut packet = Vec::new(); - packet.extend_from_slice(&build_plaintext_record(ContentType::ACK, 1, &[0xAA, 0xBB])); + packet.extend_from_slice(&build_plaintext_record(ContentType::Ack, 1, &[0xAA, 0xBB])); packet.extend_from_slice(&build_ciphertext_record(2, 2, &[0x11, 0x22, 0x33])); let mut handler = TestHandler::default(); @@ -630,7 +630,7 @@ mod tests { assert_eq!(incoming.records().len(), 1); assert_eq!( incoming.first().record().content_type, - ContentType::APPLICATION_DATA + ContentType::ApplicationData ); assert_eq!(incoming.first().record().sequence.epoch, 2); } diff --git a/src/dtls13/message/client_hello.rs b/src/dtls13/message/client_hello.rs index e4ea3991..b6ceedad 100644 --- a/src/dtls13/message/client_hello.rs +++ b/src/dtls13/message/client_hello.rs @@ -214,7 +214,7 @@ mod tests { cipher_suites.push(Dtls13CipherSuite::AES_128_GCM_SHA256); cipher_suites.push(Dtls13CipherSuite::AES_256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::NULL); + compression_methods.push(CompressionMethod::Null); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, @@ -340,7 +340,7 @@ mod tests { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } @@ -360,7 +360,7 @@ mod tests { fn zero_length_extension_vector_rejects_trailing_bytes() { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&0u16.to_be_bytes()); - message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); assert!( @@ -373,7 +373,7 @@ mod tests { fn underdeclared_extension_vector_rejects_trailing_bytes() { let mut message = MESSAGE.to_vec(); message.extend_from_slice(&4u16.to_be_bytes()); - message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); message.push(0); diff --git a/src/dtls13/message/encrypted_extensions.rs b/src/dtls13/message/encrypted_extensions.rs index daa3ead5..176f16b8 100644 --- a/src/dtls13/message/encrypted_extensions.rs +++ b/src/dtls13/message/encrypted_extensions.rs @@ -71,7 +71,7 @@ mod tests { const MESSAGE: &[u8] = &[ 0x00, 0x0C, // Extensions length (12) - 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS + 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension data length 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; @@ -92,7 +92,7 @@ mod tests { let count = 2; message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls13/message/extension.rs b/src/dtls13/message/extension.rs index ec2b22aa..4c0d4571 100644 --- a/src/dtls13/message/extension.rs +++ b/src/dtls13/message/extension.rs @@ -52,45 +52,46 @@ impl Extension { #[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct ExtensionType(u16); +#[allow(non_upper_case_globals)] impl ExtensionType { - pub const SERVER_NAME: Self = Self(0x0000); - pub const MAX_FRAGMENT_LENGTH: Self = Self(0x0001); - pub const CLIENT_CERTIFICATE_URL: Self = Self(0x0002); - pub const TRUSTED_CA_KEYS: Self = Self(0x0003); - pub const TRUNCATED_HMAC: Self = Self(0x0004); - pub const STATUS_REQUEST: Self = Self(0x0005); - pub const USER_MAPPING: Self = Self(0x0006); - pub const CLIENT_AUTHZ: Self = Self(0x0007); - pub const SERVER_AUTHZ: Self = Self(0x0008); - pub const CERT_TYPE: Self = Self(0x0009); - pub const SUPPORTED_GROUPS: Self = Self(0x000A); - pub const EC_POINT_FORMATS: Self = Self(0x000B); - pub const SRP: Self = Self(0x000C); - pub const SIGNATURE_ALGORITHMS: Self = Self(0x000D); - pub const USE_SRTP: Self = Self(0x000E); - pub const HEARTBEAT: Self = Self(0x000F); - pub const APPLICATION_LAYER_PROTOCOL_NEGOTIATION: Self = Self(0x0010); - pub const STATUS_REQUEST_V2: Self = Self(0x0011); - pub const SIGNED_CERTIFICATE_TIMESTAMP: Self = Self(0x0012); - pub const CLIENT_CERTIFICATE_TYPE: Self = Self(0x0013); - pub const SERVER_CERTIFICATE_TYPE: Self = Self(0x0014); - pub const PADDING: Self = Self(0x0015); - pub const ENCRYPT_THEN_MAC: Self = Self(0x0016); - pub const EXTENDED_MASTER_SECRET: Self = Self(0x0017); - pub const TOKEN_BINDING: Self = Self(0x0018); - pub const CACHED_INFO: Self = Self(0x0019); - pub const SESSION_TICKET: Self = Self(0x0023); - pub const PRE_SHARED_KEY: Self = Self(0x0029); - pub const EARLY_DATA: Self = Self(0x002A); - pub const SUPPORTED_VERSIONS: Self = Self(0x002B); - pub const COOKIE: Self = Self(0x002C); - pub const PSK_KEY_EXCHANGE_MODES: Self = Self(0x002D); - pub const CERTIFICATE_AUTHORITIES: Self = Self(0x002F); - pub const OID_FILTERS: Self = Self(0x0030); - pub const POST_HANDSHAKE_AUTH: Self = Self(0x0031); - pub const SIGNATURE_ALGORITHMS_CERT: Self = Self(0x0032); - pub const KEY_SHARE: Self = Self(0x0033); - pub const RENEGOTIATION_INFO: Self = Self(0xFF01); + pub const ServerName: Self = Self(0x0000); + pub const MaxFragmentLength: Self = Self(0x0001); + pub const ClientCertificateUrl: Self = Self(0x0002); + pub const TrustedCaKeys: Self = Self(0x0003); + pub const TruncatedHmac: Self = Self(0x0004); + pub const StatusRequest: Self = Self(0x0005); + pub const UserMapping: Self = Self(0x0006); + pub const ClientAuthz: Self = Self(0x0007); + pub const ServerAuthz: Self = Self(0x0008); + pub const CertType: Self = Self(0x0009); + pub const SupportedGroups: Self = Self(0x000A); + pub const EcPointFormats: Self = Self(0x000B); + pub const Srp: Self = Self(0x000C); + pub const SignatureAlgorithms: Self = Self(0x000D); + pub const UseSrtp: Self = Self(0x000E); + pub const Heartbeat: Self = Self(0x000F); + pub const ApplicationLayerProtocolNegotiation: Self = Self(0x0010); + pub const StatusRequestV2: Self = Self(0x0011); + pub const SignedCertificateTimestamp: Self = Self(0x0012); + pub const ClientCertificateType: Self = Self(0x0013); + pub const ServerCertificateType: Self = Self(0x0014); + pub const Padding: Self = Self(0x0015); + pub const EncryptThenMac: Self = Self(0x0016); + pub const ExtendedMasterSecret: Self = Self(0x0017); + pub const TokenBinding: Self = Self(0x0018); + pub const CachedInfo: Self = Self(0x0019); + pub const SessionTicket: Self = Self(0x0023); + pub const PreSharedKey: Self = Self(0x0029); + pub const EarlyData: Self = Self(0x002A); + pub const SupportedVersions: Self = Self(0x002B); + pub const Cookie: Self = Self(0x002C); + pub const PskKeyExchangeModes: Self = Self(0x002D); + pub const CertificateAuthorities: Self = Self(0x002F); + pub const OidFilters: Self = Self(0x0030); + pub const PostHandshakeAuth: Self = Self(0x0031); + pub const SignatureAlgorithmsCert: Self = Self(0x0032); + pub const KeyShare: Self = Self(0x0033); + pub const RenegotiationInfo: Self = Self(0xFF01); pub const fn from_u16(value: u16) -> Self { Self(value) @@ -120,12 +121,12 @@ impl ExtensionType { /// Supported extension types that this DTLS 1.3 implementation handles. pub const fn supported() -> &'static [ExtensionType; 6] { &[ - ExtensionType::SUPPORTED_VERSIONS, - ExtensionType::SUPPORTED_GROUPS, - ExtensionType::SIGNATURE_ALGORITHMS, - ExtensionType::KEY_SHARE, - ExtensionType::USE_SRTP, - ExtensionType::COOKIE, + ExtensionType::SupportedVersions, + ExtensionType::SupportedGroups, + ExtensionType::SignatureAlgorithms, + ExtensionType::KeyShare, + ExtensionType::UseSrtp, + ExtensionType::Cookie, ] } } @@ -137,46 +138,46 @@ impl fmt::Debug for ExtensionType { } let name = match *self { - ExtensionType::SERVER_NAME => "ServerName", - ExtensionType::MAX_FRAGMENT_LENGTH => "MaxFragmentLength", - ExtensionType::CLIENT_CERTIFICATE_URL => "ClientCertificateUrl", - ExtensionType::TRUSTED_CA_KEYS => "TrustedCaKeys", - ExtensionType::TRUNCATED_HMAC => "TruncatedHmac", - ExtensionType::STATUS_REQUEST => "StatusRequest", - ExtensionType::USER_MAPPING => "UserMapping", - ExtensionType::CLIENT_AUTHZ => "ClientAuthz", - ExtensionType::SERVER_AUTHZ => "ServerAuthz", - ExtensionType::CERT_TYPE => "CertType", - ExtensionType::SUPPORTED_GROUPS => "SupportedGroups", - ExtensionType::EC_POINT_FORMATS => "EcPointFormats", - ExtensionType::SRP => "Srp", - ExtensionType::SIGNATURE_ALGORITHMS => "SignatureAlgorithms", - ExtensionType::USE_SRTP => "UseSrtp", - ExtensionType::HEARTBEAT => "Heartbeat", - ExtensionType::APPLICATION_LAYER_PROTOCOL_NEGOTIATION => { + ExtensionType::ServerName => "ServerName", + ExtensionType::MaxFragmentLength => "MaxFragmentLength", + ExtensionType::ClientCertificateUrl => "ClientCertificateUrl", + ExtensionType::TrustedCaKeys => "TrustedCaKeys", + ExtensionType::TruncatedHmac => "TruncatedHmac", + ExtensionType::StatusRequest => "StatusRequest", + ExtensionType::UserMapping => "UserMapping", + ExtensionType::ClientAuthz => "ClientAuthz", + ExtensionType::ServerAuthz => "ServerAuthz", + ExtensionType::CertType => "CertType", + ExtensionType::SupportedGroups => "SupportedGroups", + ExtensionType::EcPointFormats => "EcPointFormats", + ExtensionType::Srp => "Srp", + ExtensionType::SignatureAlgorithms => "SignatureAlgorithms", + ExtensionType::UseSrtp => "UseSrtp", + ExtensionType::Heartbeat => "Heartbeat", + ExtensionType::ApplicationLayerProtocolNegotiation => { "ApplicationLayerProtocolNegotiation" } - ExtensionType::STATUS_REQUEST_V2 => "StatusRequestV2", - ExtensionType::SIGNED_CERTIFICATE_TIMESTAMP => "SignedCertificateTimestamp", - ExtensionType::CLIENT_CERTIFICATE_TYPE => "ClientCertificateType", - ExtensionType::SERVER_CERTIFICATE_TYPE => "ServerCertificateType", - ExtensionType::PADDING => "Padding", - ExtensionType::ENCRYPT_THEN_MAC => "EncryptThenMac", - ExtensionType::EXTENDED_MASTER_SECRET => "ExtendedMasterSecret", - ExtensionType::TOKEN_BINDING => "TokenBinding", - ExtensionType::CACHED_INFO => "CachedInfo", - ExtensionType::SESSION_TICKET => "SessionTicket", - ExtensionType::PRE_SHARED_KEY => "PreSharedKey", - ExtensionType::EARLY_DATA => "EarlyData", - ExtensionType::SUPPORTED_VERSIONS => "SupportedVersions", - ExtensionType::COOKIE => "Cookie", - ExtensionType::PSK_KEY_EXCHANGE_MODES => "PskKeyExchangeModes", - ExtensionType::CERTIFICATE_AUTHORITIES => "CertificateAuthorities", - ExtensionType::OID_FILTERS => "OidFilters", - ExtensionType::POST_HANDSHAKE_AUTH => "PostHandshakeAuth", - ExtensionType::SIGNATURE_ALGORITHMS_CERT => "SignatureAlgorithmsCert", - ExtensionType::KEY_SHARE => "KeyShare", - ExtensionType::RENEGOTIATION_INFO => "RenegotiationInfo", + ExtensionType::StatusRequestV2 => "StatusRequestV2", + ExtensionType::SignedCertificateTimestamp => "SignedCertificateTimestamp", + ExtensionType::ClientCertificateType => "ClientCertificateType", + ExtensionType::ServerCertificateType => "ServerCertificateType", + ExtensionType::Padding => "Padding", + ExtensionType::EncryptThenMac => "EncryptThenMac", + ExtensionType::ExtendedMasterSecret => "ExtendedMasterSecret", + ExtensionType::TokenBinding => "TokenBinding", + ExtensionType::CachedInfo => "CachedInfo", + ExtensionType::SessionTicket => "SessionTicket", + ExtensionType::PreSharedKey => "PreSharedKey", + ExtensionType::EarlyData => "EarlyData", + ExtensionType::SupportedVersions => "SupportedVersions", + ExtensionType::Cookie => "Cookie", + ExtensionType::PskKeyExchangeModes => "PskKeyExchangeModes", + ExtensionType::CertificateAuthorities => "CertificateAuthorities", + ExtensionType::OidFilters => "OidFilters", + ExtensionType::PostHandshakeAuth => "PostHandshakeAuth", + ExtensionType::SignatureAlgorithmsCert => "SignatureAlgorithmsCert", + ExtensionType::KeyShare => "KeyShare", + ExtensionType::RenegotiationInfo => "RenegotiationInfo", _ => unreachable!("known DTLS 1.3 extension type missing Debug label"), }; @@ -190,7 +191,7 @@ mod tests { use crate::buffer::Buf; const MESSAGE: &[u8] = &[ - 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS + 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension length 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; @@ -199,7 +200,7 @@ mod tests { fn extension_type_newtype_shape() { assert_eq!(std::mem::size_of::(), 2); assert_eq!(ExtensionType::default().as_u16(), 0); - assert_eq!(ExtensionType::default(), ExtensionType::SERVER_NAME); + assert_eq!(ExtensionType::default(), ExtensionType::ServerName); } #[test] @@ -220,7 +221,7 @@ mod tests { #[test] fn extension_type_debug_stays_enum_like() { assert_eq!( - format!("{:?}", ExtensionType::SUPPORTED_GROUPS), + format!("{:?}", ExtensionType::SupportedGroups), "SupportedGroups" ); assert_eq!( diff --git a/src/dtls13/message/extensions/key_share.rs b/src/dtls13/message/extensions/key_share.rs index 3d8b8cf8..04256dce 100644 --- a/src/dtls13/message/extensions/key_share.rs +++ b/src/dtls13/message/extensions/key_share.rs @@ -180,7 +180,7 @@ mod tests { #[test] fn key_share_hrr_roundtrip() { let message: &[u8] = &[ - 0x00, 0x17, // NamedGroup::SECP256R1 + 0x00, 0x17, // NamedGroup::Secp256r1 ]; let (rest, parsed) = KeyShareHelloRetryRequest::parse(message).unwrap(); diff --git a/src/dtls13/message/extensions/supported_groups.rs b/src/dtls13/message/extensions/supported_groups.rs index 7b4d712c..2bcdff1c 100644 --- a/src/dtls13/message/extensions/supported_groups.rs +++ b/src/dtls13/message/extensions/supported_groups.rs @@ -62,7 +62,7 @@ mod tests { fn test_supported_groups_extension() { let mut groups = ArrayVec::new(); groups.push(NamedGroup::X25519); - groups.push(NamedGroup::SECP256R1); + groups.push(NamedGroup::Secp256r1); let ext = SupportedGroupsExtension { groups }; @@ -94,8 +94,8 @@ mod tests { parsed.groups.as_slice(), &[ NamedGroup::X25519, - NamedGroup::SECP256R1, - NamedGroup::SECP384R1 + NamedGroup::Secp256r1, + NamedGroup::Secp384r1 ] ); } diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index d6ca8c69..7896bcf0 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -220,7 +220,7 @@ impl Handshake { Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)? }; - if !rest.is_empty() && first_handshake.header.msg_type == MessageType::FINISHED { + if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished { debug!("Defragmentation failed. Body::parse() did not consume the entire buffer"); return Err(crate::InternalError::parse_incomplete()); } @@ -252,7 +252,7 @@ impl Handshake { let qualifies = matches!( self.header.msg_type, - MessageType::CLIENT_HELLO // flight 1 + MessageType::ClientHello // flight 1 ); qualifies.then_some(self.header.message_seq) @@ -271,15 +271,16 @@ impl Handshake { #[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct MessageType(u8); +#[allow(non_upper_case_globals)] impl MessageType { - pub const CLIENT_HELLO: Self = Self(1); - pub const SERVER_HELLO: Self = Self(2); - pub const ENCRYPTED_EXTENSIONS: Self = Self(8); - pub const CERTIFICATE: Self = Self(11); - pub const CERTIFICATE_REQUEST: Self = Self(13); - pub const CERTIFICATE_VERIFY: Self = Self(15); - pub const FINISHED: Self = Self(20); - pub const KEY_UPDATE: Self = Self(24); + pub const ClientHello: Self = Self(1); + pub const ServerHello: Self = Self(2); + pub const EncryptedExtensions: Self = Self(8); + pub const Certificate: Self = Self(11); + pub const CertificateRequest: Self = Self(13); + pub const CertificateVerify: Self = Self(15); + pub const Finished: Self = Self(20); + pub const KeyUpdate: Self = Self(24); pub const fn from_u8(value: u8) -> Self { Self(value) @@ -306,14 +307,14 @@ impl fmt::Debug for MessageType { } let name = match *self { - MessageType::CLIENT_HELLO => "ClientHello", - MessageType::SERVER_HELLO => "ServerHello", - MessageType::ENCRYPTED_EXTENSIONS => "EncryptedExtensions", - MessageType::CERTIFICATE => "Certificate", - MessageType::CERTIFICATE_REQUEST => "CertificateRequest", - MessageType::CERTIFICATE_VERIFY => "CertificateVerify", - MessageType::FINISHED => "Finished", - MessageType::KEY_UPDATE => "KeyUpdate", + MessageType::ClientHello => "ClientHello", + MessageType::ServerHello => "ServerHello", + MessageType::EncryptedExtensions => "EncryptedExtensions", + MessageType::Certificate => "Certificate", + MessageType::CertificateRequest => "CertificateRequest", + MessageType::CertificateVerify => "CertificateVerify", + MessageType::Finished => "Finished", + MessageType::KeyUpdate => "KeyUpdate", _ => unreachable!("known DTLS 1.3 handshake message type missing Debug label"), }; @@ -394,7 +395,7 @@ impl Body { allow_unknown_client_hello_suites: bool, ) -> IResult<&[u8], Body> { match m { - MessageType::CLIENT_HELLO => { + MessageType::ClientHello => { let (input, client_hello) = if allow_unknown_client_hello_suites { ClientHello::parse_allow_unknown_suites(input, base_offset)? } else { @@ -402,33 +403,33 @@ impl Body { }; Ok((input, Body::ClientHello(client_hello))) } - MessageType::SERVER_HELLO => { + MessageType::ServerHello => { let (input, server_hello) = ServerHello::parse(input, base_offset)?; Ok((input, Body::ServerHello(server_hello))) } - MessageType::ENCRYPTED_EXTENSIONS => { + MessageType::EncryptedExtensions => { let (input, ee) = EncryptedExtensions::parse(input, base_offset)?; Ok((input, Body::EncryptedExtensions(ee))) } - MessageType::CERTIFICATE => { + MessageType::Certificate => { let (input, certificate) = Certificate::parse(input, base_offset)?; Ok((input, Body::Certificate(certificate))) } - MessageType::CERTIFICATE_REQUEST => { + MessageType::CertificateRequest => { let range = base_offset..(base_offset + input.len()); Ok((&[], Body::CertificateRequest(range))) } - MessageType::CERTIFICATE_VERIFY => { + MessageType::CertificateVerify => { let (input, cv) = CertificateVerify::parse(input, base_offset)?; Ok((input, Body::CertificateVerify(cv))) } - MessageType::FINISHED => { + MessageType::Finished => { let cipher_suite = c.ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; let (input, finished) = Finished::parse(input, cipher_suite)?; Ok((input, Body::Finished(finished))) } - MessageType::KEY_UPDATE => { + MessageType::KeyUpdate => { let (input, byte) = be_u8(input)?; if !input.is_empty() { return Err(Err::Failure(Error::new(input, ErrorKind::LengthValue))); @@ -487,7 +488,7 @@ mod tests { use crate::dtls13::message::{ProtocolVersion, Random, SessionId}; const MESSAGE: &[u8] = &[ - 0x01, // MessageType::CLIENT_HELLO + 0x01, // MessageType::ClientHello 0x00, 0x00, 0x2E, // length 0x00, 0x00, // message_seq 0x00, 0x00, 0x00, // fragment_offset @@ -519,14 +520,14 @@ mod tests { #[test] fn message_type_wire_roundtrip() { for message_type in [ - MessageType::CLIENT_HELLO, - MessageType::SERVER_HELLO, - MessageType::ENCRYPTED_EXTENSIONS, - MessageType::CERTIFICATE, - MessageType::CERTIFICATE_REQUEST, - MessageType::CERTIFICATE_VERIFY, - MessageType::FINISHED, - MessageType::KEY_UPDATE, + MessageType::ClientHello, + MessageType::ServerHello, + MessageType::EncryptedExtensions, + MessageType::Certificate, + MessageType::CertificateRequest, + MessageType::CertificateVerify, + MessageType::Finished, + MessageType::KeyUpdate, ] { assert_eq!(MessageType::from_u8(message_type.as_u8()), message_type); assert!(!message_type.is_unknown()); @@ -539,14 +540,14 @@ mod tests { #[test] fn message_type_debug_stays_enum_like() { - assert_eq!(format!("{:?}", MessageType::CLIENT_HELLO), "ClientHello"); + assert_eq!(format!("{:?}", MessageType::ClientHello), "ClientHello"); assert_eq!(format!("{:?}", MessageType::from_u8(0xFF)), "Unknown(255)"); } #[test] fn handshake_size() { let h = Handshake::new( - MessageType::ENCRYPTED_EXTENSIONS, + MessageType::EncryptedExtensions, 2, 0, 0, @@ -574,7 +575,7 @@ mod tests { cipher_suites.push(Dtls13CipherSuite::AES_128_GCM_SHA256); cipher_suites.push(Dtls13CipherSuite::AES_256_GCM_SHA384); let mut compression_methods = ArrayVec::new(); - compression_methods.push(CompressionMethod::NULL); + compression_methods.push(CompressionMethod::Null); let client_hello = ClientHello::new( ProtocolVersion::DTLS1_2, @@ -586,7 +587,7 @@ mod tests { ); let handshake = Handshake::new( - MessageType::CLIENT_HELLO, + MessageType::ClientHello, 0x2E, 0, 0, @@ -609,7 +610,7 @@ mod tests { fn key_update_body_rejects_trailing_bytes() { let source = [KeyUpdateRequest::UpdateRequested.as_u8(), 0]; let handshake = Handshake::new( - MessageType::KEY_UPDATE, + MessageType::KeyUpdate, source.len() as u32, 0, 0, diff --git a/src/dtls13/message/record.rs b/src/dtls13/message/record.rs index 1a7ccebb..cbbaf734 100644 --- a/src/dtls13/message/record.rs +++ b/src/dtls13/message/record.rs @@ -75,7 +75,7 @@ impl Dtls13Record { // RFC 9147 §4.1: Only alert(21), handshake(22), and ack(26) are valid // plaintext content types in DTLS 1.3. Reject all others. match content_type { - ContentType::ALERT | ContentType::HANDSHAKE | ContentType::ACK => {} + ContentType::Alert | ContentType::Handshake | ContentType::Ack => {} _ => { return Err(Err::Failure(nom::error::Error::new( input, @@ -190,7 +190,7 @@ impl Dtls13Record { Ok(( rest, Dtls13Record { - content_type: ContentType::APPLICATION_DATA, + content_type: ContentType::ApplicationData, sequence, length, fragment_range: start..end, diff --git a/src/dtls13/message/server_hello.rs b/src/dtls13/message/server_hello.rs index a89a2fae..de705117 100644 --- a/src/dtls13/message/server_hello.rs +++ b/src/dtls13/message/server_hello.rs @@ -156,14 +156,14 @@ mod test { 0x01, // SessionId length 0xAA, // SessionId 0x13, 0x01, // Dtls13CipherSuite::AES_128_GCM_SHA256 - 0x00, // CompressionMethod::NULL + 0x00, // CompressionMethod::Null 0x00, 0x0C, // Extensions length (12 bytes) - 0x00, 0x0A, // ExtensionType::SUPPORTED_GROUPS + 0x00, 0x0A, // ExtensionType::SupportedGroups 0x00, 0x08, // Extension data length (8 bytes) 0x00, 0x06, // Extension data - 0x00, 0x17, // NamedGroup::SECP256R1 - 0x00, 0x18, // NamedGroup::SECP384R1 - 0x00, 0x19, // NamedGroup::SECP521R1 + 0x00, 0x17, // NamedGroup::Secp256r1 + 0x00, 0x18, // NamedGroup::Secp384r1 + 0x00, 0x19, // NamedGroup::Secp521r1 ]; #[test] @@ -194,7 +194,7 @@ mod test { hrr_random, SessionId::empty(), Dtls13CipherSuite::AES_128_GCM_SHA256, - CompressionMethod::NULL, + CompressionMethod::Null, None, ); @@ -213,7 +213,7 @@ mod test { let count = 2; message.extend_from_slice(&(count as u16 * 4).to_be_bytes()); for _ in 0..count { - message.extend_from_slice(&ExtensionType::COOKIE.as_u16().to_be_bytes()); + message.extend_from_slice(&ExtensionType::Cookie.as_u16().to_be_bytes()); message.extend_from_slice(&0u16.to_be_bytes()); } diff --git a/src/dtls13/queue.rs b/src/dtls13/queue.rs index cfc07241..8fd1ff98 100644 --- a/src/dtls13/queue.rs +++ b/src/dtls13/queue.rs @@ -49,9 +49,9 @@ impl fmt::Debug for QueueRx { for item in &self.0 { let record = item.first().record(); match record.content_type { - ContentType::HANDSHAKE => handshake += 1, - ContentType::APPLICATION_DATA => app_data += 1, - ContentType::ALERT => alert += 1, + ContentType::Handshake => handshake += 1, + ContentType::ApplicationData => app_data += 1, + ContentType::Alert => alert += 1, _ => other += 1, } diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index b256c516..4711c8d4 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -309,7 +309,7 @@ impl Server { let epoch = self.engine.app_send_epoch(); self.engine.create_ciphertext_record( - ContentType::APPLICATION_DATA, + ContentType::ApplicationData, epoch, false, |body| { @@ -332,7 +332,7 @@ impl Server { } let epoch = self.engine.app_send_epoch(); self.engine - .create_ciphertext_record(ContentType::ALERT, epoch, false, |body| { + .create_ciphertext_record(ContentType::Alert, epoch, false, |body| { body.push(1); // level: legacy (ignored in DTLS 1.3) body.push(0); // description: close_notify })?; @@ -406,7 +406,7 @@ impl State { } else { server .engine - .next_handshake(MessageType::CLIENT_HELLO, &mut server.defragment_buffer)? + .next_handshake(MessageType::ClientHello, &mut server.defragment_buffer)? }; let Some(handshake) = maybe else { @@ -428,7 +428,7 @@ impl State { // Validate null compression is offered let has_null_compression = client_hello .legacy_compression_methods - .contains(&CompressionMethod::NULL); + .contains(&CompressionMethod::Null); if !has_null_compression { return Err(Error::SecurityError( crate::SecurityError::ClientHelloMustOfferNullCompression, @@ -448,7 +448,7 @@ impl State { for ext in &client_hello.extensions { match ext.extension_type { - ExtensionType::SUPPORTED_VERSIONS => { + ExtensionType::SupportedVersions => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, sv) = SupportedVersionsClientHello::parse(ext_data) .map_err(InternalError::from)?; @@ -458,7 +458,7 @@ impl State { } } } - ExtensionType::KEY_SHARE => { + ExtensionType::KeyShare => { let ext_data = ext.extension_data(&server.defragment_buffer); let ext_data_start = ext.extension_data_range.start; let (_, ks) = KeyShareClientHello::parse(ext_data, ext_data_start) @@ -473,24 +473,24 @@ impl State { } client_key_shares = Some(entries); } - ExtensionType::SUPPORTED_GROUPS => { + ExtensionType::SupportedGroups => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, sg) = SupportedGroupsExtension::parse(ext_data).map_err(InternalError::from)?; client_supported_groups = Some(sg.groups); } - ExtensionType::SIGNATURE_ALGORITHMS => { + ExtensionType::SignatureAlgorithms => { let ext_data = ext.extension_data(&server.defragment_buffer); // Parse but we don't currently filter by signature algorithms let _ = SignatureAlgorithmsExtension::parse(ext_data); } - ExtensionType::USE_SRTP => { + ExtensionType::UseSrtp => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, use_srtp) = UseSrtpExtension::parse(ext_data).map_err(InternalError::from)?; client_srtp_profiles = Some(use_srtp.profiles); } - ExtensionType::COOKIE => { + ExtensionType::Cookie => { let ext_data = ext.extension_data(&server.defragment_buffer); let (_, cookie) = parse_cookie_extension(ext_data).map_err(InternalError::from)?; @@ -798,7 +798,7 @@ impl State { server .engine - .create_handshake(MessageType::SERVER_HELLO, |body, engine| { + .create_handshake(MessageType::ServerHello, |body, engine| { handshake_create_server_hello( body, engine, @@ -845,7 +845,7 @@ impl State { server .engine - .create_handshake(MessageType::ENCRYPTED_EXTENSIONS, |body, _engine| { + .create_handshake(MessageType::EncryptedExtensions, |body, _engine| { handshake_create_encrypted_extensions(body, negotiated_srtp) })?; @@ -861,7 +861,7 @@ impl State { server .engine - .create_handshake(MessageType::CERTIFICATE_REQUEST, |body, _engine| { + .create_handshake(MessageType::CertificateRequest, |body, _engine| { handshake_create_certificate_request(body) })?; @@ -875,7 +875,7 @@ impl State { server .engine - .create_handshake(MessageType::CERTIFICATE, |body, engine| { + .create_handshake(MessageType::Certificate, |body, engine| { handshake_create_certificate(body, engine, &[]) })?; @@ -887,7 +887,7 @@ impl State { server .engine - .create_handshake(MessageType::CERTIFICATE_VERIFY, |body, engine| { + .create_handshake(MessageType::CertificateVerify, |body, engine| { handshake_create_certificate_verify( body, engine, @@ -914,7 +914,7 @@ impl State { server .engine - .create_handshake(MessageType::FINISHED, |body, engine| { + .create_handshake(MessageType::Finished, |body, engine| { let verify_data = engine.compute_verify_data(&server_hs_secret)?; body.extend_from_slice(&verify_data); Ok(()) @@ -946,7 +946,7 @@ impl State { fn await_certificate(self, server: &mut Server) -> Result { let maybe = server .engine - .next_handshake(MessageType::CERTIFICATE, &mut server.defragment_buffer)?; + .next_handshake(MessageType::Certificate, &mut server.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -1012,7 +1012,7 @@ impl State { server.engine.transcript_hash(&mut transcript_hash); let maybe = server.engine.next_handshake( - MessageType::CERTIFICATE_VERIFY, + MessageType::CertificateVerify, &mut server.defragment_buffer, )?; @@ -1092,7 +1092,7 @@ impl State { let maybe = server .engine - .next_handshake(MessageType::FINISHED, &mut server.defragment_buffer)?; + .next_handshake(MessageType::Finished, &mut server.defragment_buffer)?; let Some(ref handshake) = maybe else { return Ok(self); @@ -1172,7 +1172,7 @@ impl State { ); for data in server.queued_data.drain(..) { server.engine.create_ciphertext_record( - ContentType::APPLICATION_DATA, + ContentType::ApplicationData, epoch, false, |body| { @@ -1191,12 +1191,9 @@ impl State { } // Check for incoming KeyUpdate - if server - .engine - .has_complete_handshake(MessageType::KEY_UPDATE) - { + if server.engine.has_complete_handshake(MessageType::KeyUpdate) { let maybe = server.engine.next_handshake_no_transcript( - MessageType::KEY_UPDATE, + MessageType::KeyUpdate, &mut server.defragment_buffer, )?; @@ -1227,12 +1224,9 @@ impl State { fn half_closed_local(self, server: &mut Server) -> Result { // Write half is closed: drain incoming KeyUpdate to keep recv keys in sync, // but do not send our own KeyUpdate response. - if server - .engine - .has_complete_handshake(MessageType::KEY_UPDATE) - { + if server.engine.has_complete_handshake(MessageType::KeyUpdate) { let maybe = server.engine.next_handshake_no_transcript( - MessageType::KEY_UPDATE, + MessageType::KeyUpdate, &mut server.defragment_buffer, )?; if let Some(handshake) = maybe { @@ -1294,7 +1288,7 @@ fn send_hello_retry_request( sv.serialize(&mut ext_buf); let sv_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SUPPORTED_VERSIONS, + extension_type: ExtensionType::SupportedVersions, extension_data_range: sv_start..sv_end, }); @@ -1307,7 +1301,7 @@ fn send_hello_retry_request( hrr_ks.serialize(&mut ext_buf); let ks_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::KEY_SHARE, + extension_type: ExtensionType::KeyShare, extension_data_range: ks_start..ks_end, }); } @@ -1319,7 +1313,7 @@ fn send_hello_retry_request( ext_buf.extend_from_slice(cookie); let cookie_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::COOKIE, + extension_type: ExtensionType::Cookie, extension_data_range: cookie_start..cookie_end, }); @@ -1328,13 +1322,13 @@ fn send_hello_retry_request( hrr_random, client_session_id, cipher_suite, - CompressionMethod::NULL, + CompressionMethod::Null, Some(extensions), ); server .engine - .create_handshake(MessageType::SERVER_HELLO, |body, _engine| { + .create_handshake(MessageType::ServerHello, |body, _engine| { server_hello.serialize(&ext_buf, body); Ok(()) })?; @@ -1366,7 +1360,7 @@ fn handshake_create_server_hello( sv.serialize(&mut ext_buf); let sv_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SUPPORTED_VERSIONS, + extension_type: ExtensionType::SupportedVersions, extension_data_range: sv_start..sv_end, }); @@ -1381,7 +1375,7 @@ fn handshake_create_server_hello( ks.serialize(extension_data, &mut ext_buf); let ks_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::KEY_SHARE, + extension_type: ExtensionType::KeyShare, extension_data_range: ks_start..ks_end, }); @@ -1390,7 +1384,7 @@ fn handshake_create_server_hello( random, client_session_id, cipher_suite, - CompressionMethod::NULL, + CompressionMethod::Null, Some(extensions), ); @@ -1415,7 +1409,7 @@ fn handshake_create_encrypted_extensions( use_srtp.serialize(&mut ext_buf); let srtp_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::USE_SRTP, + extension_type: ExtensionType::UseSrtp, extension_data_range: srtp_start..srtp_end, }); } @@ -1451,7 +1445,7 @@ fn handshake_create_certificate_request(body: &mut Buf) -> Result<(), Error> { sa.serialize(&mut ext_buf); let sa_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::SIGNATURE_ALGORITHMS, + extension_type: ExtensionType::SignatureAlgorithms, extension_data_range: sa_start..sa_end, }); @@ -1461,7 +1455,7 @@ fn handshake_create_certificate_request(body: &mut Buf) -> Result<(), Error> { serialize_certificate_authorities(&cas, &[], &mut ext_buf); let ca_end = ext_buf.len(); extensions.push(Extension { - extension_type: ExtensionType::CERTIFICATE_AUTHORITIES, + extension_type: ExtensionType::CertificateAuthorities, extension_data_range: ca_start..ca_end, }); diff --git a/src/types.rs b/src/types.rs index fadf6b69..08a1a32b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -85,57 +85,58 @@ impl Random { #[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct NamedGroup(u16); +#[allow(non_upper_case_globals)] impl NamedGroup { /// sect163k1 (deprecated). - pub const SECT163K1: Self = Self(1); + pub const Sect163k1: Self = Self(1); /// sect163r1 (deprecated). - pub const SECT163R1: Self = Self(2); + pub const Sect163r1: Self = Self(2); /// sect163r2 (deprecated). - pub const SECT163R2: Self = Self(3); + pub const Sect163r2: Self = Self(3); /// sect193r1 (deprecated). - pub const SECT193R1: Self = Self(4); + pub const Sect193r1: Self = Self(4); /// sect193r2 (deprecated). - pub const SECT193R2: Self = Self(5); + pub const Sect193r2: Self = Self(5); /// sect233k1 (deprecated). - pub const SECT233K1: Self = Self(6); + pub const Sect233k1: Self = Self(6); /// sect233r1 (deprecated). - pub const SECT233R1: Self = Self(7); + pub const Sect233r1: Self = Self(7); /// sect239k1 (deprecated). - pub const SECT239K1: Self = Self(8); + pub const Sect239k1: Self = Self(8); /// sect283k1 (deprecated). - pub const SECT283K1: Self = Self(9); + pub const Sect283k1: Self = Self(9); /// sect283r1 (deprecated). - pub const SECT283R1: Self = Self(10); + pub const Sect283r1: Self = Self(10); /// sect409k1 (deprecated). - pub const SECT409K1: Self = Self(11); + pub const Sect409k1: Self = Self(11); /// sect409r1 (deprecated). - pub const SECT409R1: Self = Self(12); + pub const Sect409r1: Self = Self(12); /// sect571k1 (deprecated). - pub const SECT571K1: Self = Self(13); + pub const Sect571k1: Self = Self(13); /// sect571r1 (deprecated). - pub const SECT571R1: Self = Self(14); + pub const Sect571r1: Self = Self(14); /// secp160k1 (deprecated). - pub const SECP160K1: Self = Self(15); + pub const Secp160k1: Self = Self(15); /// secp160r1 (deprecated). - pub const SECP160R1: Self = Self(16); + pub const Secp160r1: Self = Self(16); /// secp160r2 (deprecated). - pub const SECP160R2: Self = Self(17); + pub const Secp160r2: Self = Self(17); /// secp192k1 (deprecated). - pub const SECP192K1: Self = Self(18); + pub const Secp192k1: Self = Self(18); /// secp192r1 (deprecated). - pub const SECP192R1: Self = Self(19); + pub const Secp192r1: Self = Self(19); /// secp224k1. - pub const SECP224K1: Self = Self(20); + pub const Secp224k1: Self = Self(20); /// secp224r1. - pub const SECP224R1: Self = Self(21); + pub const Secp224r1: Self = Self(21); /// secp256k1. - pub const SECP256K1: Self = Self(22); + pub const Secp256k1: Self = Self(22); /// secp256r1 / P-256 (supported by dimpl). - pub const SECP256R1: Self = Self(23); + pub const Secp256r1: Self = Self(23); /// secp384r1 / P-384 (supported by dimpl). - pub const SECP384R1: Self = Self(24); + pub const Secp384r1: Self = Self(24); /// secp521r1 / P-521. - pub const SECP521R1: Self = Self(25); + pub const Secp521r1: Self = Self(25); /// X25519 (Curve25519 for ECDHE). pub const X25519: Self = Self(29); /// X448 (Curve448 for ECDHE). @@ -170,31 +171,31 @@ impl NamedGroup { /// All recognized named groups (every non-`Unknown` variant). pub const fn all() -> &'static [NamedGroup; 27] { &[ - NamedGroup::SECT163K1, - NamedGroup::SECT163R1, - NamedGroup::SECT163R2, - NamedGroup::SECT193R1, - NamedGroup::SECT193R2, - NamedGroup::SECT233K1, - NamedGroup::SECT233R1, - NamedGroup::SECT239K1, - NamedGroup::SECT283K1, - NamedGroup::SECT283R1, - NamedGroup::SECT409K1, - NamedGroup::SECT409R1, - NamedGroup::SECT571K1, - NamedGroup::SECT571R1, - NamedGroup::SECP160K1, - NamedGroup::SECP160R1, - NamedGroup::SECP160R2, - NamedGroup::SECP192K1, - NamedGroup::SECP192R1, - NamedGroup::SECP224K1, - NamedGroup::SECP224R1, - NamedGroup::SECP256K1, - NamedGroup::SECP256R1, - NamedGroup::SECP384R1, - NamedGroup::SECP521R1, + NamedGroup::Sect163k1, + NamedGroup::Sect163r1, + NamedGroup::Sect163r2, + NamedGroup::Sect193r1, + NamedGroup::Sect193r2, + NamedGroup::Sect233k1, + NamedGroup::Sect233r1, + NamedGroup::Sect239k1, + NamedGroup::Sect283k1, + NamedGroup::Sect283r1, + NamedGroup::Sect409k1, + NamedGroup::Sect409r1, + NamedGroup::Sect571k1, + NamedGroup::Sect571r1, + NamedGroup::Secp160k1, + NamedGroup::Secp160r1, + NamedGroup::Secp160r2, + NamedGroup::Secp192k1, + NamedGroup::Secp192r1, + NamedGroup::Secp224k1, + NamedGroup::Secp224r1, + NamedGroup::Secp256k1, + NamedGroup::Secp256r1, + NamedGroup::Secp384r1, + NamedGroup::Secp521r1, NamedGroup::X25519, NamedGroup::X448, ] @@ -204,9 +205,9 @@ impl NamedGroup { pub const fn supported() -> &'static [NamedGroup; 4] { &[ NamedGroup::X25519, - NamedGroup::SECP256R1, - NamedGroup::SECP384R1, - NamedGroup::SECP521R1, + NamedGroup::Secp256r1, + NamedGroup::Secp384r1, + NamedGroup::Secp521r1, ] } } @@ -214,31 +215,31 @@ impl NamedGroup { impl fmt::Debug for NamedGroup { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - NamedGroup::SECT163K1 => f.write_str("Sect163k1"), - NamedGroup::SECT163R1 => f.write_str("Sect163r1"), - NamedGroup::SECT163R2 => f.write_str("Sect163r2"), - NamedGroup::SECT193R1 => f.write_str("Sect193r1"), - NamedGroup::SECT193R2 => f.write_str("Sect193r2"), - NamedGroup::SECT233K1 => f.write_str("Sect233k1"), - NamedGroup::SECT233R1 => f.write_str("Sect233r1"), - NamedGroup::SECT239K1 => f.write_str("Sect239k1"), - NamedGroup::SECT283K1 => f.write_str("Sect283k1"), - NamedGroup::SECT283R1 => f.write_str("Sect283r1"), - NamedGroup::SECT409K1 => f.write_str("Sect409k1"), - NamedGroup::SECT409R1 => f.write_str("Sect409r1"), - NamedGroup::SECT571K1 => f.write_str("Sect571k1"), - NamedGroup::SECT571R1 => f.write_str("Sect571r1"), - NamedGroup::SECP160K1 => f.write_str("Secp160k1"), - NamedGroup::SECP160R1 => f.write_str("Secp160r1"), - NamedGroup::SECP160R2 => f.write_str("Secp160r2"), - NamedGroup::SECP192K1 => f.write_str("Secp192k1"), - NamedGroup::SECP192R1 => f.write_str("Secp192r1"), - NamedGroup::SECP224K1 => f.write_str("Secp224k1"), - NamedGroup::SECP224R1 => f.write_str("Secp224r1"), - NamedGroup::SECP256K1 => f.write_str("Secp256k1"), - NamedGroup::SECP256R1 => f.write_str("Secp256r1"), - NamedGroup::SECP384R1 => f.write_str("Secp384r1"), - NamedGroup::SECP521R1 => f.write_str("Secp521r1"), + NamedGroup::Sect163k1 => f.write_str("Sect163k1"), + NamedGroup::Sect163r1 => f.write_str("Sect163r1"), + NamedGroup::Sect163r2 => f.write_str("Sect163r2"), + NamedGroup::Sect193r1 => f.write_str("Sect193r1"), + NamedGroup::Sect193r2 => f.write_str("Sect193r2"), + NamedGroup::Sect233k1 => f.write_str("Sect233k1"), + NamedGroup::Sect233r1 => f.write_str("Sect233r1"), + NamedGroup::Sect239k1 => f.write_str("Sect239k1"), + NamedGroup::Sect283k1 => f.write_str("Sect283k1"), + NamedGroup::Sect283r1 => f.write_str("Sect283r1"), + NamedGroup::Sect409k1 => f.write_str("Sect409k1"), + NamedGroup::Sect409r1 => f.write_str("Sect409r1"), + NamedGroup::Sect571k1 => f.write_str("Sect571k1"), + NamedGroup::Sect571r1 => f.write_str("Sect571r1"), + NamedGroup::Secp160k1 => f.write_str("Secp160k1"), + NamedGroup::Secp160r1 => f.write_str("Secp160r1"), + NamedGroup::Secp160r2 => f.write_str("Secp160r2"), + NamedGroup::Secp192k1 => f.write_str("Secp192k1"), + NamedGroup::Secp192r1 => f.write_str("Secp192r1"), + NamedGroup::Secp224k1 => f.write_str("Secp224k1"), + NamedGroup::Secp224r1 => f.write_str("Secp224r1"), + NamedGroup::Secp256k1 => f.write_str("Secp256k1"), + NamedGroup::Secp256r1 => f.write_str("Secp256r1"), + NamedGroup::Secp384r1 => f.write_str("Secp384r1"), + NamedGroup::Secp521r1 => f.write_str("Secp521r1"), NamedGroup::X25519 => f.write_str("X25519"), NamedGroup::X448 => f.write_str("X448"), _ => f.debug_tuple("Unknown").field(&self.0).finish(), @@ -260,13 +261,14 @@ pub struct HashAlgorithm(u8); impl Default for HashAlgorithm { fn default() -> Self { - Self::NONE + Self::None } } +#[allow(non_upper_case_globals)] impl HashAlgorithm { /// No hash (not typically used). - pub const NONE: Self = Self(0); + pub const None: Self = Self(0); /// MD5 hash (deprecated, not supported). pub const MD5: Self = Self(1); /// SHA-1 hash (deprecated, not supported). @@ -306,7 +308,7 @@ impl HashAlgorithm { /// Returns the output length in bytes for this hash algorithm. pub const fn output_len(&self) -> usize { match *self { - HashAlgorithm::NONE => 0, + HashAlgorithm::None => 0, HashAlgorithm::MD5 => 16, HashAlgorithm::SHA1 => 20, HashAlgorithm::SHA224 => 28, @@ -321,7 +323,7 @@ impl HashAlgorithm { impl fmt::Debug for HashAlgorithm { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - HashAlgorithm::NONE => f.write_str("None"), + HashAlgorithm::None => f.write_str("None"), HashAlgorithm::MD5 => f.write_str("MD5"), HashAlgorithm::SHA1 => f.write_str("SHA1"), HashAlgorithm::SHA224 => f.write_str("SHA224"), @@ -347,13 +349,14 @@ pub struct SignatureAlgorithm(u8); impl Default for SignatureAlgorithm { fn default() -> Self { - Self::ANONYMOUS + Self::Anonymous } } +#[allow(non_upper_case_globals)] impl SignatureAlgorithm { /// Anonymous (no certificate). - pub const ANONYMOUS: Self = Self(0); + pub const Anonymous: Self = Self(0); /// RSA signatures. pub const RSA: Self = Self(1); /// DSA signatures. @@ -388,7 +391,7 @@ impl SignatureAlgorithm { impl fmt::Debug for SignatureAlgorithm { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - SignatureAlgorithm::ANONYMOUS => f.write_str("Anonymous"), + SignatureAlgorithm::Anonymous => f.write_str("Anonymous"), SignatureAlgorithm::RSA => f.write_str("RSA"), SignatureAlgorithm::DSA => f.write_str("DSA"), SignatureAlgorithm::ECDSA => f.write_str("ECDSA"), @@ -409,17 +412,18 @@ impl fmt::Debug for SignatureAlgorithm { #[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct ContentType(u8); +#[allow(non_upper_case_globals)] impl ContentType { /// Change Cipher Spec (used in DTLS 1.2, compatibility-only in 1.3). - pub const CHANGE_CIPHER_SPEC: Self = Self(20); + pub const ChangeCipherSpec: Self = Self(20); /// Alert message. - pub const ALERT: Self = Self(21); + pub const Alert: Self = Self(21); /// Handshake message. - pub const HANDSHAKE: Self = Self(22); + pub const Handshake: Self = Self(22); /// Application data. - pub const APPLICATION_DATA: Self = Self(23); + pub const ApplicationData: Self = Self(23); /// ACK (DTLS 1.3 only, RFC 9147 Section 7). - pub const ACK: Self = Self(26); + pub const Ack: Self = Self(26); /// Convert a u8 value to a `ContentType`. pub const fn from_u8(value: u8) -> Self { @@ -446,11 +450,11 @@ impl ContentType { impl fmt::Debug for ContentType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - ContentType::CHANGE_CIPHER_SPEC => f.write_str("ChangeCipherSpec"), - ContentType::ALERT => f.write_str("Alert"), - ContentType::HANDSHAKE => f.write_str("Handshake"), - ContentType::APPLICATION_DATA => f.write_str("ApplicationData"), - ContentType::ACK => f.write_str("Ack"), + ContentType::ChangeCipherSpec => f.write_str("ChangeCipherSpec"), + ContentType::Alert => f.write_str("Alert"), + ContentType::Handshake => f.write_str("Handshake"), + ContentType::ApplicationData => f.write_str("ApplicationData"), + ContentType::Ack => f.write_str("Ack"), _ => f.debug_tuple("Unknown").field(&self.0).finish(), } } @@ -621,8 +625,8 @@ impl SignatureScheme { /// Returns `None` for non-ECDSA schemes. pub fn named_group(&self) -> Option { match *self { - SignatureScheme::ECDSA_SECP256R1_SHA256 => Some(NamedGroup::SECP256R1), - SignatureScheme::ECDSA_SECP384R1_SHA384 => Some(NamedGroup::SECP384R1), + SignatureScheme::ECDSA_SECP256R1_SHA256 => Some(NamedGroup::Secp256r1), + SignatureScheme::ECDSA_SECP384R1_SHA384 => Some(NamedGroup::Secp384r1), _ => None, } } @@ -643,7 +647,7 @@ impl SignatureScheme { | SignatureScheme::RSA_PSS_PSS_SHA512 | SignatureScheme::RSA_PKCS1_SHA512 => HashAlgorithm::SHA512, // Ed25519 and Ed448 have intrinsic hash algorithms - SignatureScheme::ED25519 | SignatureScheme::ED448 => HashAlgorithm::NONE, + SignatureScheme::ED25519 | SignatureScheme::ED448 => HashAlgorithm::None, _ => HashAlgorithm::UNKNOWN_DERIVED, } } @@ -843,15 +847,16 @@ pub struct CompressionMethod(u8); impl Default for CompressionMethod { fn default() -> Self { - Self::NULL + Self::Null } } +#[allow(non_upper_case_globals)] impl CompressionMethod { /// No compression. - pub const NULL: Self = Self(0x00); + pub const Null: Self = Self(0x00); /// DEFLATE compression. - pub const DEFLATE: Self = Self(0x01); + pub const Deflate: Self = Self(0x01); /// Convert a u8 value to a `CompressionMethod`. pub const fn from_u8(value: u8) -> Self { @@ -865,7 +870,7 @@ impl CompressionMethod { /// All recognized compression methods (every non-`Unknown` variant). pub const fn all() -> &'static [CompressionMethod; 2] { - &[CompressionMethod::NULL, CompressionMethod::DEFLATE] + &[CompressionMethod::Null, CompressionMethod::Deflate] } /// Supported compression methods. @@ -874,7 +879,7 @@ impl CompressionMethod { /// §4.1.2) mandates exactly one compression method (null). DEFLATE /// is recognized by parsing but not accepted. pub const fn supported() -> &'static [CompressionMethod; 1] { - &[CompressionMethod::NULL] + &[CompressionMethod::Null] } /// Convert this `CompressionMethod` to its u8 value. @@ -884,7 +889,7 @@ impl CompressionMethod { /// Returns true if this is not a known TLS compression method wire value. pub const fn is_unknown(&self) -> bool { - self.0 > Self::DEFLATE.0 + self.0 > Self::Deflate.0 } /// Parse a `CompressionMethod` from wire format. @@ -897,8 +902,8 @@ impl CompressionMethod { impl fmt::Debug for CompressionMethod { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - CompressionMethod::NULL => f.write_str("Null"), - CompressionMethod::DEFLATE => f.write_str("Deflate"), + CompressionMethod::Null => f.write_str("Null"), + CompressionMethod::Deflate => f.write_str("Deflate"), _ => f.debug_tuple("Unknown").field(&self.0).finish(), } } @@ -929,7 +934,7 @@ mod tests { #[test] fn named_group_debug_stays_enum_like() { - assert_eq!(format!("{:?}", NamedGroup::SECP256R1), "Secp256r1"); + assert_eq!(format!("{:?}", NamedGroup::Secp256r1), "Secp256r1"); assert_eq!(format!("{:?}", NamedGroup::X25519), "X25519"); assert_eq!( format!("{:?}", NamedGroup::from_u16(0xFFFF)), @@ -941,13 +946,13 @@ mod tests { fn hash_algorithm_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); assert_eq!(HashAlgorithm::default().as_u8(), 0); - assert_eq!(HashAlgorithm::default(), HashAlgorithm::NONE); + assert_eq!(HashAlgorithm::default(), HashAlgorithm::None); } #[test] fn hash_algorithm_wire_roundtrip() { let known = [ - (0, HashAlgorithm::NONE), + (0, HashAlgorithm::None), (1, HashAlgorithm::MD5), (2, HashAlgorithm::SHA1), (3, HashAlgorithm::SHA224), @@ -969,7 +974,7 @@ mod tests { #[test] fn hash_algorithm_output_len() { - assert_eq!(HashAlgorithm::NONE.output_len(), 0); + assert_eq!(HashAlgorithm::None.output_len(), 0); assert_eq!(HashAlgorithm::MD5.output_len(), 16); assert_eq!(HashAlgorithm::SHA1.output_len(), 20); assert_eq!(HashAlgorithm::SHA224.output_len(), 28); @@ -981,7 +986,7 @@ mod tests { #[test] fn hash_algorithm_debug_stays_enum_like() { - assert_eq!(format!("{:?}", HashAlgorithm::NONE), "None"); + assert_eq!(format!("{:?}", HashAlgorithm::None), "None"); assert_eq!(format!("{:?}", HashAlgorithm::SHA256), "SHA256"); assert_eq!(format!("{:?}", HashAlgorithm::from_u8(7)), "Unknown(7)"); } @@ -990,13 +995,13 @@ mod tests { fn signature_algorithm_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); assert_eq!(SignatureAlgorithm::default().as_u8(), 0); - assert_eq!(SignatureAlgorithm::default(), SignatureAlgorithm::ANONYMOUS); + assert_eq!(SignatureAlgorithm::default(), SignatureAlgorithm::Anonymous); } #[test] fn signature_algorithm_wire_roundtrip() { let known = [ - (0, SignatureAlgorithm::ANONYMOUS), + (0, SignatureAlgorithm::Anonymous), (1, SignatureAlgorithm::RSA), (2, SignatureAlgorithm::DSA), (3, SignatureAlgorithm::ECDSA), @@ -1015,7 +1020,7 @@ mod tests { #[test] fn signature_algorithm_debug_stays_enum_like() { - assert_eq!(format!("{:?}", SignatureAlgorithm::ANONYMOUS), "Anonymous"); + assert_eq!(format!("{:?}", SignatureAlgorithm::Anonymous), "Anonymous"); assert_eq!(format!("{:?}", SignatureAlgorithm::ECDSA), "ECDSA"); assert_eq!( format!("{:?}", SignatureAlgorithm::from_u8(4)), @@ -1027,14 +1032,14 @@ mod tests { fn compression_method_newtype_shape() { assert_eq!(std::mem::size_of::(), 1); assert_eq!(CompressionMethod::default().as_u8(), 0); - assert_eq!(CompressionMethod::default(), CompressionMethod::NULL); + assert_eq!(CompressionMethod::default(), CompressionMethod::Null); } #[test] fn compression_method_wire_roundtrip() { let known = [ - (0x00, CompressionMethod::NULL), - (0x01, CompressionMethod::DEFLATE), + (0x00, CompressionMethod::Null), + (0x01, CompressionMethod::Deflate), ]; for (wire, method) in known { @@ -1050,8 +1055,8 @@ mod tests { #[test] fn compression_method_debug_stays_enum_like() { - assert_eq!(format!("{:?}", CompressionMethod::NULL), "Null"); - assert_eq!(format!("{:?}", CompressionMethod::DEFLATE), "Deflate"); + assert_eq!(format!("{:?}", CompressionMethod::Null), "Null"); + assert_eq!(format!("{:?}", CompressionMethod::Deflate), "Deflate"); assert_eq!( format!("{:?}", CompressionMethod::from_u8(0x02)), "Unknown(2)" @@ -1068,11 +1073,11 @@ mod tests { #[test] fn content_type_wire_roundtrip() { let known = [ - (20, ContentType::CHANGE_CIPHER_SPEC), - (21, ContentType::ALERT), - (22, ContentType::HANDSHAKE), - (23, ContentType::APPLICATION_DATA), - (26, ContentType::ACK), + (20, ContentType::ChangeCipherSpec), + (21, ContentType::Alert), + (22, ContentType::Handshake), + (23, ContentType::ApplicationData), + (26, ContentType::Ack), ]; for (wire, content_type) in known { @@ -1089,10 +1094,10 @@ mod tests { #[test] fn content_type_debug_stays_enum_like() { assert_eq!( - format!("{:?}", ContentType::CHANGE_CIPHER_SPEC), + format!("{:?}", ContentType::ChangeCipherSpec), "ChangeCipherSpec" ); - assert_eq!(format!("{:?}", ContentType::HANDSHAKE), "Handshake"); + assert_eq!(format!("{:?}", ContentType::Handshake), "Handshake"); assert_eq!(format!("{:?}", ContentType::from_u8(24)), "Unknown(24)"); } @@ -1228,7 +1233,7 @@ mod tests { let supported = CompressionMethod::supported(); assert_eq!( supported, - &[CompressionMethod::NULL], + &[CompressionMethod::Null], "Only Null compression should be supported" ); } @@ -1237,11 +1242,11 @@ mod tests { fn signature_scheme_named_group_ecdsa() { assert_eq!( SignatureScheme::ECDSA_SECP256R1_SHA256.named_group(), - Some(NamedGroup::SECP256R1) + Some(NamedGroup::Secp256r1) ); assert_eq!( SignatureScheme::ECDSA_SECP384R1_SHA384.named_group(), - Some(NamedGroup::SECP384R1) + Some(NamedGroup::Secp384r1) ); } diff --git a/tests/dtls12/handshake.rs b/tests/dtls12/handshake.rs index 757f37a4..a6f19907 100644 --- a/tests/dtls12/handshake.rs +++ b/tests/dtls12/handshake.rs @@ -672,7 +672,7 @@ fn dtls12_handshake_secp384r1_key_exchange() { .kx_groups .iter() .copied() - .filter(|g| g.name() == dimpl::NamedGroup::SECP384R1) + .filter(|g| g.name() == dimpl::NamedGroup::Secp384r1) .collect(); // leak: intentional leak to produce a &'static slice for the provider field let p384_only: &'static [&'static dyn dimpl::crypto::SupportedKxGroup] = diff --git a/tests/dtls12/retransmit.rs b/tests/dtls12/retransmit.rs index a0a5b944..eb7a6445 100644 --- a/tests/dtls12/retransmit.rs +++ b/tests/dtls12/retransmit.rs @@ -1130,7 +1130,7 @@ fn forged_epoch1_app_data() -> Vec { // garbage (undecryptable) body. The content type lives in the cleartext // header, so this looks like app data before anything is decrypted. let mut rec = Vec::new(); - rec.push(23); // ContentType::APPLICATION_DATA + rec.push(23); // ContentType::ApplicationData rec.extend_from_slice(&[0xFE, 0xFD]); // DTLS 1.2 rec.extend_from_slice(&1u16.to_be_bytes()); // epoch 1 rec.extend_from_slice(&[0, 0, 0, 0, 0, 1]); // 48-bit sequence number diff --git a/tests/dtls13/handshake.rs b/tests/dtls13/handshake.rs index a2610749..1799bf7f 100644 --- a/tests/dtls13/handshake.rs +++ b/tests/dtls13/handshake.rs @@ -515,7 +515,7 @@ fn dtls13_handshake_secp256r1_key_exchange() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::SECP256R1) + .filter(|g| g.name() == NamedGroup::Secp256r1) .collect(); assert!(!p256_only.is_empty(), "Provider must have P-256"); @@ -806,12 +806,12 @@ fn dtls13_hrr_with_p256_then_x25519() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::X25519) + .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::X25519) .collect(); // Ensure P-256 is first let mut client_groups_sorted: Vec<_> = client_groups; client_groups_sorted.sort_by_key(|g| { - if g.name() == NamedGroup::SECP256R1 { + if g.name() == NamedGroup::Secp256r1 { 0 } else { 1 @@ -829,7 +829,7 @@ fn dtls13_hrr_with_p256_then_x25519() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::X25519) + .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::X25519) .collect(); let mut server_groups_sorted: Vec<_> = server_groups; server_groups_sorted.sort_by_key(|g| if g.name() == NamedGroup::X25519 { 0 } else { 1 }); @@ -935,11 +935,11 @@ fn dtls13_hrr_handshake_completes_after_packet_loss() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::SECP384R1) + .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::Secp384r1) .collect(); let mut client_groups_sorted: Vec<_> = client_groups; client_groups_sorted.sort_by_key(|g| { - if g.name() == NamedGroup::SECP256R1 { + if g.name() == NamedGroup::Secp256r1 { 0 } else { 1 @@ -957,11 +957,11 @@ fn dtls13_hrr_handshake_completes_after_packet_loss() { .kx_groups .iter() .copied() - .filter(|g| g.name() == NamedGroup::SECP256R1 || g.name() == NamedGroup::SECP384R1) + .filter(|g| g.name() == NamedGroup::Secp256r1 || g.name() == NamedGroup::Secp384r1) .collect(); let mut server_groups_sorted: Vec<_> = server_groups; server_groups_sorted.sort_by_key(|g| { - if g.name() == NamedGroup::SECP384R1 { + if g.name() == NamedGroup::Secp384r1 { 0 } else { 1