diff --git a/CHANGELOG.md b/CHANGELOG.md index aa3d3dee..e3bdafd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Represent DTLS wire-code identifiers as compact newtypes (breaking) #137 * Make public errors structured and fatal-only (breaking) #134 # 0.6.2 diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 7f962b6d..71ac7a70 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -467,7 +467,7 @@ impl State { // Enforce cipher suite is known and allowed let cs = server_hello.cipher_suite; - if matches!(cs, Dtls12CipherSuite::Unknown(_)) { + if cs.is_unknown() { return Err((Error::SecurityError( crate::SecurityError::ServerSelectedUnknownCipherSuite, )) diff --git a/src/dtls12/message/extension.rs b/src/dtls12/message/extension.rs index 59da8920..3ea0c528 100644 --- a/src/dtls12/message/extension.rs +++ b/src/dtls12/message/extension.rs @@ -1,7 +1,7 @@ use crate::buffer::Buf; use arrayvec::ArrayVec; use nom::{IResult, bytes::complete::take, number::complete::be_u16}; -use std::ops::Range; +use std::{fmt, ops::Range}; pub type ExtensionVec = ArrayVec; @@ -51,142 +51,64 @@ impl Extension { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ExtensionType { - ServerName, - MaxFragmentLength, - ClientCertificateUrl, - TrustedCaKeys, - TruncatedHmac, - StatusRequest, - UserMapping, - ClientAuthz, - ServerAuthz, - CertType, - SupportedGroups, - EcPointFormats, - Srp, - SignatureAlgorithms, - UseSrtp, - Heartbeat, - ApplicationLayerProtocolNegotiation, - StatusRequestV2, - SignedCertificateTimestamp, - ClientCertificateType, - ServerCertificateType, - Padding, - EncryptThenMac, - ExtendedMasterSecret, - TokenBinding, - CachedInfo, - SessionTicket, - PreSharedKey, - EarlyData, - SupportedVersions, - Cookie, - PskKeyExchangeModes, - CertificateAuthorities, - OidFilters, - PostHandshakeAuth, - SignatureAlgorithmsCert, - KeyShare, - RenegotiationInfo, - Unknown(u16), -} +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct ExtensionType(u16); -impl Default for ExtensionType { - fn default() -> Self { - Self::Unknown(0) +#[allow(non_upper_case_globals)] +impl ExtensionType { + pub const ServerName: Self = Self(0x0000); + pub const MaxFragmentLength: Self = Self(0x0001); + pub const ClientCertificateUrl: Self = Self(0x0002); + pub const TrustedCaKeys: Self = Self(0x0003); + pub const TruncatedHmac: Self = Self(0x0004); + pub const StatusRequest: Self = Self(0x0005); + pub const UserMapping: Self = Self(0x0006); + pub const ClientAuthz: Self = Self(0x0007); + pub const ServerAuthz: Self = Self(0x0008); + pub const CertType: Self = Self(0x0009); + pub const SupportedGroups: Self = Self(0x000A); + pub const EcPointFormats: Self = Self(0x000B); + pub const Srp: Self = Self(0x000C); + pub const SignatureAlgorithms: Self = Self(0x000D); + pub const UseSrtp: Self = Self(0x000E); + pub const Heartbeat: Self = Self(0x000F); + pub const ApplicationLayerProtocolNegotiation: Self = Self(0x0010); + pub const StatusRequestV2: Self = Self(0x0011); + pub const SignedCertificateTimestamp: Self = Self(0x0012); + pub const ClientCertificateType: Self = Self(0x0013); + pub const ServerCertificateType: Self = Self(0x0014); + pub const Padding: Self = Self(0x0015); + pub const EncryptThenMac: Self = Self(0x0016); + pub const ExtendedMasterSecret: Self = Self(0x0017); + pub const TokenBinding: Self = Self(0x0018); + pub const CachedInfo: Self = Self(0x0019); + pub const SessionTicket: Self = Self(0x0023); + pub const PreSharedKey: Self = Self(0x0029); + pub const EarlyData: Self = Self(0x002A); + pub const SupportedVersions: Self = Self(0x002B); + pub const Cookie: Self = Self(0x002C); + pub const PskKeyExchangeModes: Self = Self(0x002D); + pub const CertificateAuthorities: Self = Self(0x002F); + pub const OidFilters: Self = Self(0x0030); + pub const PostHandshakeAuth: Self = Self(0x0031); + pub const SignatureAlgorithmsCert: Self = Self(0x0032); + pub const KeyShare: Self = Self(0x0033); + pub const RenegotiationInfo: Self = Self(0xFF01); + + pub const fn from_u16(value: u16) -> Self { + Self(value) } -} -impl ExtensionType { - pub fn from_u16(value: u16) -> Self { - match value { - 0x0000 => ExtensionType::ServerName, - 0x0001 => ExtensionType::MaxFragmentLength, - 0x0002 => ExtensionType::ClientCertificateUrl, - 0x0003 => ExtensionType::TrustedCaKeys, - 0x0004 => ExtensionType::TruncatedHmac, - 0x0005 => ExtensionType::StatusRequest, - 0x0006 => ExtensionType::UserMapping, - 0x0007 => ExtensionType::ClientAuthz, - 0x0008 => ExtensionType::ServerAuthz, - 0x0009 => ExtensionType::CertType, - 0x000A => ExtensionType::SupportedGroups, - 0x000B => ExtensionType::EcPointFormats, - 0x000C => ExtensionType::Srp, - 0x000D => ExtensionType::SignatureAlgorithms, - 0x000E => ExtensionType::UseSrtp, - 0x000F => ExtensionType::Heartbeat, - 0x0010 => ExtensionType::ApplicationLayerProtocolNegotiation, - 0x0011 => ExtensionType::StatusRequestV2, - 0x0012 => ExtensionType::SignedCertificateTimestamp, - 0x0013 => ExtensionType::ClientCertificateType, - 0x0014 => ExtensionType::ServerCertificateType, - 0x0015 => ExtensionType::Padding, - 0x0016 => ExtensionType::EncryptThenMac, - 0x0017 => ExtensionType::ExtendedMasterSecret, - 0x0018 => ExtensionType::TokenBinding, - 0x0019 => ExtensionType::CachedInfo, - 0x0023 => ExtensionType::SessionTicket, - 0x0029 => ExtensionType::PreSharedKey, - 0x002A => ExtensionType::EarlyData, - 0x002B => ExtensionType::SupportedVersions, - 0x002C => ExtensionType::Cookie, - 0x002D => ExtensionType::PskKeyExchangeModes, - 0x002F => ExtensionType::CertificateAuthorities, - 0x0030 => ExtensionType::OidFilters, - 0x0031 => ExtensionType::PostHandshakeAuth, - 0x0032 => ExtensionType::SignatureAlgorithmsCert, - 0x0033 => ExtensionType::KeyShare, - 0xFF01 => ExtensionType::RenegotiationInfo, - _ => ExtensionType::Unknown(value), - } + pub const fn as_u16(&self) -> u16 { + self.0 } - pub fn as_u16(&self) -> u16 { - match self { - ExtensionType::ServerName => 0x0000, - ExtensionType::MaxFragmentLength => 0x0001, - ExtensionType::ClientCertificateUrl => 0x0002, - ExtensionType::TrustedCaKeys => 0x0003, - ExtensionType::TruncatedHmac => 0x0004, - ExtensionType::StatusRequest => 0x0005, - ExtensionType::UserMapping => 0x0006, - ExtensionType::ClientAuthz => 0x0007, - ExtensionType::ServerAuthz => 0x0008, - ExtensionType::CertType => 0x0009, - ExtensionType::SupportedGroups => 0x000A, - ExtensionType::EcPointFormats => 0x000B, - ExtensionType::Srp => 0x000C, - ExtensionType::SignatureAlgorithms => 0x000D, - ExtensionType::UseSrtp => 0x000E, - ExtensionType::Heartbeat => 0x000F, - ExtensionType::ApplicationLayerProtocolNegotiation => 0x0010, - ExtensionType::StatusRequestV2 => 0x0011, - ExtensionType::SignedCertificateTimestamp => 0x0012, - ExtensionType::ClientCertificateType => 0x0013, - ExtensionType::ServerCertificateType => 0x0014, - ExtensionType::Padding => 0x0015, - ExtensionType::EncryptThenMac => 0x0016, - ExtensionType::ExtendedMasterSecret => 0x0017, - ExtensionType::TokenBinding => 0x0018, - ExtensionType::CachedInfo => 0x0019, - ExtensionType::SessionTicket => 0x0023, - ExtensionType::PreSharedKey => 0x0029, - ExtensionType::EarlyData => 0x002A, - ExtensionType::SupportedVersions => 0x002B, - ExtensionType::Cookie => 0x002C, - ExtensionType::PskKeyExchangeModes => 0x002D, - ExtensionType::CertificateAuthorities => 0x002F, - ExtensionType::OidFilters => 0x0030, - ExtensionType::PostHandshakeAuth => 0x0031, - ExtensionType::SignatureAlgorithmsCert => 0x0032, - ExtensionType::KeyShare => 0x0033, - ExtensionType::RenegotiationInfo => 0xFF01, - ExtensionType::Unknown(value) => *value, - } + const fn is_unknown(&self) -> bool { + !matches!( + *self, + Self(0x0000..=0x0019 | 0x0023 | 0x0029..=0x002D | 0x002F..=0x0033 | 0xFF01) + ) } pub fn parse(input: &[u8]) -> IResult<&[u8], ExtensionType> { @@ -214,6 +136,60 @@ impl ExtensionType { } } +impl fmt::Debug for ExtensionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + ExtensionType::ServerName => "ServerName", + ExtensionType::MaxFragmentLength => "MaxFragmentLength", + ExtensionType::ClientCertificateUrl => "ClientCertificateUrl", + ExtensionType::TrustedCaKeys => "TrustedCaKeys", + ExtensionType::TruncatedHmac => "TruncatedHmac", + ExtensionType::StatusRequest => "StatusRequest", + ExtensionType::UserMapping => "UserMapping", + ExtensionType::ClientAuthz => "ClientAuthz", + ExtensionType::ServerAuthz => "ServerAuthz", + ExtensionType::CertType => "CertType", + ExtensionType::SupportedGroups => "SupportedGroups", + ExtensionType::EcPointFormats => "EcPointFormats", + ExtensionType::Srp => "Srp", + ExtensionType::SignatureAlgorithms => "SignatureAlgorithms", + ExtensionType::UseSrtp => "UseSrtp", + ExtensionType::Heartbeat => "Heartbeat", + ExtensionType::ApplicationLayerProtocolNegotiation => { + "ApplicationLayerProtocolNegotiation" + } + ExtensionType::StatusRequestV2 => "StatusRequestV2", + ExtensionType::SignedCertificateTimestamp => "SignedCertificateTimestamp", + ExtensionType::ClientCertificateType => "ClientCertificateType", + ExtensionType::ServerCertificateType => "ServerCertificateType", + ExtensionType::Padding => "Padding", + ExtensionType::EncryptThenMac => "EncryptThenMac", + ExtensionType::ExtendedMasterSecret => "ExtendedMasterSecret", + ExtensionType::TokenBinding => "TokenBinding", + ExtensionType::CachedInfo => "CachedInfo", + ExtensionType::SessionTicket => "SessionTicket", + ExtensionType::PreSharedKey => "PreSharedKey", + ExtensionType::EarlyData => "EarlyData", + ExtensionType::SupportedVersions => "SupportedVersions", + ExtensionType::Cookie => "Cookie", + ExtensionType::PskKeyExchangeModes => "PskKeyExchangeModes", + ExtensionType::CertificateAuthorities => "CertificateAuthorities", + ExtensionType::OidFilters => "OidFilters", + ExtensionType::PostHandshakeAuth => "PostHandshakeAuth", + ExtensionType::SignatureAlgorithmsCert => "SignatureAlgorithmsCert", + ExtensionType::KeyShare => "KeyShare", + ExtensionType::RenegotiationInfo => "RenegotiationInfo", + _ => unreachable!("known DTLS 1.2 extension type missing Debug label"), + }; + + f.write_str(name) + } +} + #[cfg(test)] mod tests { use super::*; @@ -225,6 +201,40 @@ mod tests { 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; + #[test] + fn extension_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(ExtensionType::default().as_u16(), 0); + assert_eq!(ExtensionType::default(), ExtensionType::ServerName); + } + + #[test] + fn extension_type_wire_roundtrip() { + for extension_type in ExtensionType::supported() { + assert_eq!( + ExtensionType::from_u16(extension_type.as_u16()), + *extension_type + ); + assert!(!extension_type.is_unknown()); + } + + let unknown = ExtensionType::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn extension_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ExtensionType::SupportedGroups), + "SupportedGroups" + ); + assert_eq!( + format!("{:?}", ExtensionType::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn roundtrip() { // Parse the message with base_offset 0 diff --git a/src/dtls12/message/handshake.rs b/src/dtls12/message/handshake.rs index 65521159..ed555c3f 100644 --- a/src/dtls12/message/handshake.rs +++ b/src/dtls12/message/handshake.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::ops::Range; use std::sync::atomic::{AtomicBool, Ordering}; @@ -286,64 +287,35 @@ impl Handshake { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageType { - HelloRequest, // empty - ClientHello, - HelloVerifyRequest, - ServerHello, - Certificate, - ServerKeyExchange, - CertificateRequest, - ServerHelloDone, // empty - CertificateVerify, - ClientKeyExchange, - NewSessionTicket, - Finished, - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct MessageType(u8); -impl Default for MessageType { - fn default() -> Self { - Self::Unknown(0) +#[allow(non_upper_case_globals)] +impl MessageType { + pub const HelloRequest: Self = Self(0); + pub const ClientHello: Self = Self(1); + pub const ServerHello: Self = Self(2); + pub const HelloVerifyRequest: Self = Self(3); + pub const NewSessionTicket: Self = Self(4); + pub const Certificate: Self = Self(11); + pub const ServerKeyExchange: Self = Self(12); + pub const CertificateRequest: Self = Self(13); + pub const ServerHelloDone: Self = Self(14); + pub const CertificateVerify: Self = Self(15); + pub const ClientKeyExchange: Self = Self(16); + pub const Finished: Self = Self(20); + + pub const fn from_u8(value: u8) -> Self { + Self(value) } -} -impl MessageType { - pub fn from_u8(value: u8) -> Self { - match value { - 0 => MessageType::HelloRequest, // empty - 1 => MessageType::ClientHello, - 3 => MessageType::HelloVerifyRequest, - 2 => MessageType::ServerHello, - 11 => MessageType::Certificate, - 12 => MessageType::ServerKeyExchange, - 13 => MessageType::CertificateRequest, - 14 => MessageType::ServerHelloDone, // empty - 15 => MessageType::CertificateVerify, - 16 => MessageType::ClientKeyExchange, - 4 => MessageType::NewSessionTicket, - 20 => MessageType::Finished, - _ => MessageType::Unknown(value), - } + pub const fn as_u8(&self) -> u8 { + self.0 } - pub fn as_u8(&self) -> u8 { - match self { - MessageType::HelloRequest => 0, - MessageType::ClientHello => 1, - MessageType::HelloVerifyRequest => 3, - MessageType::ServerHello => 2, - MessageType::Certificate => 11, - MessageType::ServerKeyExchange => 12, - MessageType::CertificateRequest => 13, - MessageType::ServerHelloDone => 14, - MessageType::CertificateVerify => 15, - MessageType::ClientKeyExchange => 16, - MessageType::NewSessionTicket => 4, - MessageType::Finished => 20, - MessageType::Unknown(value) => *value, - } + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(0..=4 | 11..=16 | 20)) } pub fn parse(input: &[u8]) -> IResult<&[u8], MessageType> { @@ -352,7 +324,7 @@ impl MessageType { } pub fn epoch(&self) -> u16 { - if matches!(self, MessageType::NewSessionTicket | MessageType::Finished) { + if matches!(*self, MessageType::NewSessionTicket | MessageType::Finished) { 1 } else { 0 @@ -360,6 +332,32 @@ impl MessageType { } } +impl fmt::Debug for MessageType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + MessageType::HelloRequest => "HelloRequest", + MessageType::ClientHello => "ClientHello", + MessageType::HelloVerifyRequest => "HelloVerifyRequest", + MessageType::ServerHello => "ServerHello", + MessageType::Certificate => "Certificate", + MessageType::ServerKeyExchange => "ServerKeyExchange", + MessageType::CertificateRequest => "CertificateRequest", + MessageType::ServerHelloDone => "ServerHelloDone", + MessageType::CertificateVerify => "CertificateVerify", + MessageType::ClientKeyExchange => "ClientKeyExchange", + MessageType::NewSessionTicket => "NewSessionTicket", + MessageType::Finished => "Finished", + _ => unreachable!("known DTLS 1.2 handshake message type missing Debug label"), + }; + + f.write_str(name) + } +} + #[derive(Debug, PartialEq, Eq)] #[allow(clippy::large_enum_variant)] pub enum Body { @@ -446,7 +444,7 @@ impl Body { let (input, finished) = Finished::parse(input, cipher_suite)?; Ok((input, Body::Finished(finished))) } - MessageType::Unknown(value) => Ok((input, Body::Unknown(value))), + _ => Ok((input, Body::Unknown(m.as_u8()))), } } @@ -535,6 +533,44 @@ mod tests { 0x00, // CompressionMethod::Null ]; + #[test] + fn message_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(MessageType::default().as_u8(), 0); + assert_eq!(MessageType::default(), MessageType::HelloRequest); + } + + #[test] + fn message_type_wire_roundtrip() { + for message_type in [ + MessageType::HelloRequest, + MessageType::ClientHello, + MessageType::ServerHello, + MessageType::HelloVerifyRequest, + MessageType::NewSessionTicket, + MessageType::Certificate, + MessageType::ServerKeyExchange, + MessageType::CertificateRequest, + MessageType::ServerHelloDone, + MessageType::CertificateVerify, + MessageType::ClientKeyExchange, + MessageType::Finished, + ] { + assert_eq!(MessageType::from_u8(message_type.as_u8()), message_type); + assert!(!message_type.is_unknown()); + } + + let unknown = MessageType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn message_type_debug_stays_enum_like() { + assert_eq!(format!("{:?}", MessageType::ClientHello), "ClientHello"); + assert_eq!(format!("{:?}", MessageType::from_u8(0xFF)), "Unknown(255)"); + } + #[test] fn handshake_size() { let h = Handshake::new( diff --git a/src/dtls12/message/mod.rs b/src/dtls12/message/mod.rs index 6eb8ca24..d4c3a513 100644 --- a/src/dtls12/message/mod.rs +++ b/src/dtls12/message/mod.rs @@ -22,6 +22,8 @@ mod server_hello; mod server_key_exchange; mod wrapped; +use std::fmt; + use arrayvec::ArrayVec; pub use certificate::Certificate; pub use certificate_request::CertificateRequest; @@ -55,60 +57,34 @@ use nom::number::complete::{be_u8, be_u16}; pub type CipherSuiteVec = ArrayVec; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] /// Supported TLS 1.2 cipher suites for DTLS. -pub enum Dtls12CipherSuite { - // ECDHE with AES-GCM - /// ECDHE with ECDSA authentication, AES-256-GCM, SHA-384 - ECDHE_ECDSA_AES256_GCM_SHA384, // 0xC02C - /// ECDHE with ECDSA authentication, AES-128-GCM, SHA-256 - ECDHE_ECDSA_AES128_GCM_SHA256, // 0xC02B - /// ECDHE with ECDSA authentication, ChaCha20-Poly1305, SHA-256 - ECDHE_ECDSA_CHACHA20_POLY1305_SHA256, // 0xCCA9 - - // PSK cipher suites (no certificate authentication) - /// PSK with AES-128-CCM-8 (8-byte tag), SHA-256 - PSK_AES128_CCM_8, // 0xC0A8 - - /// Unknown or unsupported cipher suite by its IANA value - Unknown(u16), -} - -impl Default for Dtls12CipherSuite { - fn default() -> Self { - Self::Unknown(0) - } -} +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct Dtls12CipherSuite(u16); impl Dtls12CipherSuite { - /// Convert the 16-bit IANA value to a `Dtls12CipherSuite`. - pub fn from_u16(value: u16) -> Self { - match value { - // ECDHE with AES-GCM - 0xC02C => Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384, - 0xC02B => Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256, - 0xCCA9 => Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256, - - // PSK - 0xC0A8 => Dtls12CipherSuite::PSK_AES128_CCM_8, + /// ECDHE with ECDSA authentication, AES-256-GCM, SHA-384. + pub const ECDHE_ECDSA_AES256_GCM_SHA384: Self = Self(0xC02C); + /// ECDHE with ECDSA authentication, AES-128-GCM, SHA-256. + pub const ECDHE_ECDSA_AES128_GCM_SHA256: Self = Self(0xC02B); + /// ECDHE with ECDSA authentication, ChaCha20-Poly1305, SHA-256. + pub const ECDHE_ECDSA_CHACHA20_POLY1305_SHA256: Self = Self(0xCCA9); + /// PSK with AES-128-CCM-8 (8-byte tag), SHA-256. + pub const PSK_AES128_CCM_8: Self = Self(0xC0A8); - _ => Dtls12CipherSuite::Unknown(value), - } + /// Convert the 16-bit IANA value to a `Dtls12CipherSuite`. + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Return the 16-bit IANA value for this cipher suite. - pub fn as_u16(&self) -> u16 { - match self { - // ECDHE with AES-GCM - Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 => 0xC02C, - Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 => 0xC02B, - Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => 0xCCA9, - - Dtls12CipherSuite::PSK_AES128_CCM_8 => 0xC0A8, + pub const fn as_u16(&self) -> u16 { + self.0 + } - Dtls12CipherSuite::Unknown(value) => *value, - } + /// Returns true if this is not a known DTLS 1.2 cipher suite wire value. + pub const fn is_unknown(&self) -> bool { + !matches!(*self, Self(0xC02B..=0xC02C | 0xC0A8 | 0xCCA9)) } /// Parse a `Dtls12CipherSuite` from network byte order. @@ -119,20 +95,20 @@ impl Dtls12CipherSuite { /// Length in bytes of verify_data for Finished MACs. pub fn verify_data_length(&self) -> usize { - match self { + match *self { // AES-GCM suites Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 | Dtls12CipherSuite::PSK_AES128_CCM_8 => 12, - Dtls12CipherSuite::Unknown(_) => 12, // Default length for unknown cipher suites + _ => 12, // Default length for unknown cipher suites } } /// The key exchange algorithm family for this cipher suite. pub fn as_key_exchange_algorithm(&self) -> KeyExchangeAlgorithm { - match self { + match *self { // All ECDHE ciphers Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 @@ -142,14 +118,14 @@ impl Dtls12CipherSuite { Dtls12CipherSuite::PSK_AES128_CCM_8 => KeyExchangeAlgorithm::PSK, - Dtls12CipherSuite::Unknown(_) => KeyExchangeAlgorithm::Unknown, + _ => KeyExchangeAlgorithm::Unknown, } } /// Whether this cipher suite uses ECC-based key exchange. pub fn has_ecc(&self) -> bool { matches!( - self, + *self, Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 @@ -158,7 +134,7 @@ impl Dtls12CipherSuite { /// Whether this cipher suite uses PSK (Pre-Shared Key) key exchange. pub fn is_psk(&self) -> bool { - matches!(self, Dtls12CipherSuite::PSK_AES128_CCM_8) + matches!(*self, Dtls12CipherSuite::PSK_AES128_CCM_8) } /// All supported cipher suites in server preference order. @@ -195,12 +171,12 @@ impl Dtls12CipherSuite { /// The hash algorithm used by this cipher suite. pub fn hash_algorithm(&self) -> HashAlgorithm { - match self { + match *self { Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 => HashAlgorithm::SHA384, Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 | Dtls12CipherSuite::PSK_AES128_CCM_8 => HashAlgorithm::SHA256, - Dtls12CipherSuite::Unknown(_) => HashAlgorithm::Unknown(0), + _ => HashAlgorithm::UNKNOWN_DERIVED, } } @@ -208,14 +184,14 @@ impl Dtls12CipherSuite { /// /// Returns `None` for PSK cipher suites (no signature authentication). pub fn signature_algorithm(&self) -> Option { - match self { + match *self { Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 | Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => { Some(SignatureAlgorithm::ECDSA) } Dtls12CipherSuite::PSK_AES128_CCM_8 => None, - Dtls12CipherSuite::Unknown(_) => Some(SignatureAlgorithm::Unknown(0)), + _ => Some(SignatureAlgorithm::UNKNOWN_DERIVED), } } @@ -230,6 +206,24 @@ impl Dtls12CipherSuite { } } +impl fmt::Debug for Dtls12CipherSuite { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Dtls12CipherSuite::ECDHE_ECDSA_AES256_GCM_SHA384 => { + f.write_str("ECDHE_ECDSA_AES256_GCM_SHA384") + } + Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 => { + f.write_str("ECDHE_ECDSA_AES128_GCM_SHA256") + } + Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => { + f.write_str("ECDHE_ECDSA_CHACHA20_POLY1305_SHA256") + } + Dtls12CipherSuite::PSK_AES128_CCM_8 => f.write_str("PSK_AES128_CCM_8"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + pub type CompressionMethodVec = ArrayVec; @@ -245,39 +239,22 @@ pub enum KeyExchangeAlgorithm { pub type CertificateTypeVec = ArrayVec; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum ClientCertificateType { - RSA_SIGN, - DSS_SIGN, - RSA_FIXED_DH, - DSS_FIXED_DH, - RSA_EPHEMERAL_DH, - DSS_EPHEMERAL_DH, - FORTEZZA_DMS, - ECDSA_SIGN, - Unknown(u8), -} - -impl Default for ClientCertificateType { - fn default() -> Self { - Self::Unknown(0) - } -} +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct ClientCertificateType(u8); impl ClientCertificateType { - pub fn from_u8(value: u8) -> Self { - match value { - 1 => ClientCertificateType::RSA_SIGN, - 2 => ClientCertificateType::DSS_SIGN, - 3 => ClientCertificateType::RSA_FIXED_DH, - 4 => ClientCertificateType::DSS_FIXED_DH, - 5 => ClientCertificateType::RSA_EPHEMERAL_DH, - 6 => ClientCertificateType::DSS_EPHEMERAL_DH, - 20 => ClientCertificateType::FORTEZZA_DMS, - 64 => ClientCertificateType::ECDSA_SIGN, - _ => ClientCertificateType::Unknown(value), - } + pub const RSA_SIGN: Self = Self(1); + pub const DSS_SIGN: Self = Self(2); + pub const RSA_FIXED_DH: Self = Self(3); + pub const DSS_FIXED_DH: Self = Self(4); + pub const RSA_EPHEMERAL_DH: Self = Self(5); + pub const DSS_EPHEMERAL_DH: Self = Self(6); + pub const FORTEZZA_DMS: Self = Self(20); + pub const ECDSA_SIGN: Self = Self(64); + + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Returns true if this certificate type is supported by this implementation. @@ -291,18 +268,12 @@ impl ClientCertificateType { &[ClientCertificateType::ECDSA_SIGN] } - pub fn as_u8(&self) -> u8 { - match self { - ClientCertificateType::RSA_SIGN => 1, - ClientCertificateType::DSS_SIGN => 2, - ClientCertificateType::RSA_FIXED_DH => 3, - ClientCertificateType::DSS_FIXED_DH => 4, - ClientCertificateType::RSA_EPHEMERAL_DH => 5, - ClientCertificateType::DSS_EPHEMERAL_DH => 6, - ClientCertificateType::FORTEZZA_DMS => 20, - ClientCertificateType::ECDSA_SIGN => 64, - ClientCertificateType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(1..=6 | 20 | 64)) } pub fn parse(input: &[u8]) -> IResult<&[u8], ClientCertificateType> { @@ -311,6 +282,28 @@ impl ClientCertificateType { } } +impl fmt::Debug for ClientCertificateType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + ClientCertificateType::RSA_SIGN => "RSA_SIGN", + ClientCertificateType::DSS_SIGN => "DSS_SIGN", + ClientCertificateType::RSA_FIXED_DH => "RSA_FIXED_DH", + ClientCertificateType::DSS_FIXED_DH => "DSS_FIXED_DH", + ClientCertificateType::RSA_EPHEMERAL_DH => "RSA_EPHEMERAL_DH", + ClientCertificateType::DSS_EPHEMERAL_DH => "DSS_EPHEMERAL_DH", + ClientCertificateType::FORTEZZA_DMS => "FORTEZZA_DMS", + ClientCertificateType::ECDSA_SIGN => "ECDSA_SIGN", + _ => unreachable!("known DTLS 1.2 client certificate type missing Debug label"), + }; + + f.write_str(name) + } +} + // SignatureAlgorithm and HashAlgorithm are now in crate::types pub type SignatureAndHashAlgorithmVec = @@ -365,3 +358,93 @@ impl SignatureAndHashAlgorithm { Self::supported().contains(self) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dtls12_cipher_suite_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(Dtls12CipherSuite::default().as_u16(), 0); + assert!(Dtls12CipherSuite::default().is_unknown()); + } + + #[test] + fn dtls12_cipher_suite_wire_roundtrip() { + for suite in Dtls12CipherSuite::all() { + assert_eq!(Dtls12CipherSuite::from_u16(suite.as_u16()), *suite); + assert!(!suite.is_unknown()); + } + + let unknown = Dtls12CipherSuite::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn dtls12_cipher_suite_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256), + "ECDHE_ECDSA_AES128_GCM_SHA256" + ); + assert_eq!( + format!("{:?}", Dtls12CipherSuite::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + + #[test] + fn client_certificate_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(ClientCertificateType::default().as_u8(), 0); + assert!(ClientCertificateType::default().is_unknown()); + } + + #[test] + fn client_certificate_type_wire_roundtrip() { + for certificate_type in [ + ClientCertificateType::RSA_SIGN, + ClientCertificateType::DSS_SIGN, + ClientCertificateType::RSA_FIXED_DH, + ClientCertificateType::DSS_FIXED_DH, + ClientCertificateType::RSA_EPHEMERAL_DH, + ClientCertificateType::DSS_EPHEMERAL_DH, + ClientCertificateType::FORTEZZA_DMS, + ClientCertificateType::ECDSA_SIGN, + ] { + assert_eq!( + ClientCertificateType::from_u8(certificate_type.as_u8()), + certificate_type + ); + assert!(!certificate_type.is_unknown()); + } + + let unknown = ClientCertificateType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn client_certificate_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ClientCertificateType::ECDSA_SIGN), + "ECDSA_SIGN" + ); + assert_eq!( + format!("{:?}", ClientCertificateType::from_u8(0xFF)), + "Unknown(255)" + ); + } + + #[test] + fn unknown_dtls12_cipher_suite_uses_internal_derived_markers() { + let unknown = Dtls12CipherSuite::from_u16(0xFFFF); + assert!(unknown.hash_algorithm().is_unknown()); + assert!( + unknown + .signature_algorithm() + .is_some_and(|s| s.is_unknown()) + ); + } +} diff --git a/src/dtls12/message/named_group.rs b/src/dtls12/message/named_group.rs index 12b9edda..5e715dd1 100644 --- a/src/dtls12/message/named_group.rs +++ b/src/dtls12/message/named_group.rs @@ -5,41 +5,36 @@ use nom::IResult; use nom::number::complete::be_u8; +use std::fmt; /// Curve type for ECDH parameters in DTLS 1.2. /// /// This is specific to DTLS 1.2's ServerKeyExchange message format. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CurveType { +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct CurveType(u8); + +#[allow(non_upper_case_globals)] +impl CurveType { /// Explicit prime curve parameters. - ExplicitPrime, + pub const ExplicitPrime: Self = Self(1); /// Explicit characteristic-2 curve parameters. - ExplicitChar2, + pub const ExplicitChar2: Self = Self(2); /// Named curve (the common case). - NamedCurve, - /// Unknown curve type. - Unknown(u8), -} + pub const NamedCurve: Self = Self(3); -impl CurveType { /// Convert a u8 value to a `CurveType`. - pub fn from_u8(value: u8) -> Self { - match value { - 1 => CurveType::ExplicitPrime, - 2 => CurveType::ExplicitChar2, - 3 => CurveType::NamedCurve, - _ => CurveType::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `CurveType` to its u8 value. - pub fn as_u8(&self) -> u8 { - match self { - CurveType::ExplicitPrime => 1, - CurveType::ExplicitChar2 => 2, - CurveType::NamedCurve => 3, - CurveType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(1..=3)) } /// Parse a `CurveType` from wire format. @@ -48,3 +43,52 @@ impl CurveType { Ok((input, CurveType::from_u8(value))) } } + +impl fmt::Debug for CurveType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + CurveType::ExplicitPrime => "ExplicitPrime", + CurveType::ExplicitChar2 => "ExplicitChar2", + CurveType::NamedCurve => "NamedCurve", + _ => unreachable!("known DTLS 1.2 curve type missing Debug label"), + }; + + f.write_str(name) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn curve_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + } + + #[test] + fn curve_type_wire_roundtrip() { + for curve_type in [ + CurveType::ExplicitPrime, + CurveType::ExplicitChar2, + CurveType::NamedCurve, + ] { + assert_eq!(CurveType::from_u8(curve_type.as_u8()), curve_type); + assert!(!curve_type.is_unknown()); + } + + let unknown = CurveType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn curve_type_debug_stays_enum_like() { + assert_eq!(format!("{:?}", CurveType::NamedCurve), "NamedCurve"); + assert_eq!(format!("{:?}", CurveType::from_u8(0xFF)), "Unknown(255)"); + } +} diff --git a/src/dtls12/queue.rs b/src/dtls12/queue.rs index 91994c01..b67f4dbd 100644 --- a/src/dtls12/queue.rs +++ b/src/dtls12/queue.rs @@ -54,7 +54,7 @@ impl fmt::Debug for QueueRx { ContentType::ApplicationData => app_data += 1, ContentType::Alert => alert += 1, ContentType::ChangeCipherSpec => ccs += 1, - ContentType::Unknown(_) | ContentType::Ack => other += 1, + _ => other += 1, } let seq = (record.sequence.epoch, record.sequence.sequence_number); diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 092be3bb..a6c60639 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -565,7 +565,7 @@ impl State { // Validate cipher suite let cs = server_hello.cipher_suite; - if matches!(cs, Dtls13CipherSuite::Unknown(_)) { + if cs.is_unknown() { return Err((Error::SecurityError( crate::SecurityError::ServerSelectedUnknownCipherSuite, )) diff --git a/src/dtls13/message/extension.rs b/src/dtls13/message/extension.rs index 437927d5..4c0d4571 100644 --- a/src/dtls13/message/extension.rs +++ b/src/dtls13/message/extension.rs @@ -1,6 +1,6 @@ use crate::buffer::Buf; use nom::{IResult, bytes::complete::take, number::complete::be_u16}; -use std::ops::Range; +use std::{fmt, ops::Range}; #[derive(Debug, PartialEq, Eq, Default)] pub struct Extension { @@ -48,142 +48,64 @@ impl Extension { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ExtensionType { - ServerName, - MaxFragmentLength, - ClientCertificateUrl, - TrustedCaKeys, - TruncatedHmac, - StatusRequest, - UserMapping, - ClientAuthz, - ServerAuthz, - CertType, - SupportedGroups, - EcPointFormats, - Srp, - SignatureAlgorithms, - UseSrtp, - Heartbeat, - ApplicationLayerProtocolNegotiation, - StatusRequestV2, - SignedCertificateTimestamp, - ClientCertificateType, - ServerCertificateType, - Padding, - EncryptThenMac, - ExtendedMasterSecret, - TokenBinding, - CachedInfo, - SessionTicket, - PreSharedKey, - EarlyData, - SupportedVersions, - Cookie, - PskKeyExchangeModes, - CertificateAuthorities, - OidFilters, - PostHandshakeAuth, - SignatureAlgorithmsCert, - KeyShare, - RenegotiationInfo, - Unknown(u16), -} +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct ExtensionType(u16); -impl Default for ExtensionType { - fn default() -> Self { - Self::Unknown(0) +#[allow(non_upper_case_globals)] +impl ExtensionType { + pub const ServerName: Self = Self(0x0000); + pub const MaxFragmentLength: Self = Self(0x0001); + pub const ClientCertificateUrl: Self = Self(0x0002); + pub const TrustedCaKeys: Self = Self(0x0003); + pub const TruncatedHmac: Self = Self(0x0004); + pub const StatusRequest: Self = Self(0x0005); + pub const UserMapping: Self = Self(0x0006); + pub const ClientAuthz: Self = Self(0x0007); + pub const ServerAuthz: Self = Self(0x0008); + pub const CertType: Self = Self(0x0009); + pub const SupportedGroups: Self = Self(0x000A); + pub const EcPointFormats: Self = Self(0x000B); + pub const Srp: Self = Self(0x000C); + pub const SignatureAlgorithms: Self = Self(0x000D); + pub const UseSrtp: Self = Self(0x000E); + pub const Heartbeat: Self = Self(0x000F); + pub const ApplicationLayerProtocolNegotiation: Self = Self(0x0010); + pub const StatusRequestV2: Self = Self(0x0011); + pub const SignedCertificateTimestamp: Self = Self(0x0012); + pub const ClientCertificateType: Self = Self(0x0013); + pub const ServerCertificateType: Self = Self(0x0014); + pub const Padding: Self = Self(0x0015); + pub const EncryptThenMac: Self = Self(0x0016); + pub const ExtendedMasterSecret: Self = Self(0x0017); + pub const TokenBinding: Self = Self(0x0018); + pub const CachedInfo: Self = Self(0x0019); + pub const SessionTicket: Self = Self(0x0023); + pub const PreSharedKey: Self = Self(0x0029); + pub const EarlyData: Self = Self(0x002A); + pub const SupportedVersions: Self = Self(0x002B); + pub const Cookie: Self = Self(0x002C); + pub const PskKeyExchangeModes: Self = Self(0x002D); + pub const CertificateAuthorities: Self = Self(0x002F); + pub const OidFilters: Self = Self(0x0030); + pub const PostHandshakeAuth: Self = Self(0x0031); + pub const SignatureAlgorithmsCert: Self = Self(0x0032); + pub const KeyShare: Self = Self(0x0033); + pub const RenegotiationInfo: Self = Self(0xFF01); + + pub const fn from_u16(value: u16) -> Self { + Self(value) } -} -impl ExtensionType { - pub fn from_u16(value: u16) -> Self { - match value { - 0x0000 => ExtensionType::ServerName, - 0x0001 => ExtensionType::MaxFragmentLength, - 0x0002 => ExtensionType::ClientCertificateUrl, - 0x0003 => ExtensionType::TrustedCaKeys, - 0x0004 => ExtensionType::TruncatedHmac, - 0x0005 => ExtensionType::StatusRequest, - 0x0006 => ExtensionType::UserMapping, - 0x0007 => ExtensionType::ClientAuthz, - 0x0008 => ExtensionType::ServerAuthz, - 0x0009 => ExtensionType::CertType, - 0x000A => ExtensionType::SupportedGroups, - 0x000B => ExtensionType::EcPointFormats, - 0x000C => ExtensionType::Srp, - 0x000D => ExtensionType::SignatureAlgorithms, - 0x000E => ExtensionType::UseSrtp, - 0x000F => ExtensionType::Heartbeat, - 0x0010 => ExtensionType::ApplicationLayerProtocolNegotiation, - 0x0011 => ExtensionType::StatusRequestV2, - 0x0012 => ExtensionType::SignedCertificateTimestamp, - 0x0013 => ExtensionType::ClientCertificateType, - 0x0014 => ExtensionType::ServerCertificateType, - 0x0015 => ExtensionType::Padding, - 0x0016 => ExtensionType::EncryptThenMac, - 0x0017 => ExtensionType::ExtendedMasterSecret, - 0x0018 => ExtensionType::TokenBinding, - 0x0019 => ExtensionType::CachedInfo, - 0x0023 => ExtensionType::SessionTicket, - 0x0029 => ExtensionType::PreSharedKey, - 0x002A => ExtensionType::EarlyData, - 0x002B => ExtensionType::SupportedVersions, - 0x002C => ExtensionType::Cookie, - 0x002D => ExtensionType::PskKeyExchangeModes, - 0x002F => ExtensionType::CertificateAuthorities, - 0x0030 => ExtensionType::OidFilters, - 0x0031 => ExtensionType::PostHandshakeAuth, - 0x0032 => ExtensionType::SignatureAlgorithmsCert, - 0x0033 => ExtensionType::KeyShare, - 0xFF01 => ExtensionType::RenegotiationInfo, - _ => ExtensionType::Unknown(value), - } + pub const fn as_u16(&self) -> u16 { + self.0 } - pub fn as_u16(&self) -> u16 { - match self { - ExtensionType::ServerName => 0x0000, - ExtensionType::MaxFragmentLength => 0x0001, - ExtensionType::ClientCertificateUrl => 0x0002, - ExtensionType::TrustedCaKeys => 0x0003, - ExtensionType::TruncatedHmac => 0x0004, - ExtensionType::StatusRequest => 0x0005, - ExtensionType::UserMapping => 0x0006, - ExtensionType::ClientAuthz => 0x0007, - ExtensionType::ServerAuthz => 0x0008, - ExtensionType::CertType => 0x0009, - ExtensionType::SupportedGroups => 0x000A, - ExtensionType::EcPointFormats => 0x000B, - ExtensionType::Srp => 0x000C, - ExtensionType::SignatureAlgorithms => 0x000D, - ExtensionType::UseSrtp => 0x000E, - ExtensionType::Heartbeat => 0x000F, - ExtensionType::ApplicationLayerProtocolNegotiation => 0x0010, - ExtensionType::StatusRequestV2 => 0x0011, - ExtensionType::SignedCertificateTimestamp => 0x0012, - ExtensionType::ClientCertificateType => 0x0013, - ExtensionType::ServerCertificateType => 0x0014, - ExtensionType::Padding => 0x0015, - ExtensionType::EncryptThenMac => 0x0016, - ExtensionType::ExtendedMasterSecret => 0x0017, - ExtensionType::TokenBinding => 0x0018, - ExtensionType::CachedInfo => 0x0019, - ExtensionType::SessionTicket => 0x0023, - ExtensionType::PreSharedKey => 0x0029, - ExtensionType::EarlyData => 0x002A, - ExtensionType::SupportedVersions => 0x002B, - ExtensionType::Cookie => 0x002C, - ExtensionType::PskKeyExchangeModes => 0x002D, - ExtensionType::CertificateAuthorities => 0x002F, - ExtensionType::OidFilters => 0x0030, - ExtensionType::PostHandshakeAuth => 0x0031, - ExtensionType::SignatureAlgorithmsCert => 0x0032, - ExtensionType::KeyShare => 0x0033, - ExtensionType::RenegotiationInfo => 0xFF01, - ExtensionType::Unknown(value) => *value, - } + const fn is_unknown(&self) -> bool { + !matches!( + *self, + Self(0x0000..=0x0019 | 0x0023 | 0x0029..=0x002D | 0x002F..=0x0033 | 0xFF01) + ) } pub fn parse(input: &[u8]) -> IResult<&[u8], ExtensionType> { @@ -209,6 +131,60 @@ impl ExtensionType { } } +impl fmt::Debug for ExtensionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + ExtensionType::ServerName => "ServerName", + ExtensionType::MaxFragmentLength => "MaxFragmentLength", + ExtensionType::ClientCertificateUrl => "ClientCertificateUrl", + ExtensionType::TrustedCaKeys => "TrustedCaKeys", + ExtensionType::TruncatedHmac => "TruncatedHmac", + ExtensionType::StatusRequest => "StatusRequest", + ExtensionType::UserMapping => "UserMapping", + ExtensionType::ClientAuthz => "ClientAuthz", + ExtensionType::ServerAuthz => "ServerAuthz", + ExtensionType::CertType => "CertType", + ExtensionType::SupportedGroups => "SupportedGroups", + ExtensionType::EcPointFormats => "EcPointFormats", + ExtensionType::Srp => "Srp", + ExtensionType::SignatureAlgorithms => "SignatureAlgorithms", + ExtensionType::UseSrtp => "UseSrtp", + ExtensionType::Heartbeat => "Heartbeat", + ExtensionType::ApplicationLayerProtocolNegotiation => { + "ApplicationLayerProtocolNegotiation" + } + ExtensionType::StatusRequestV2 => "StatusRequestV2", + ExtensionType::SignedCertificateTimestamp => "SignedCertificateTimestamp", + ExtensionType::ClientCertificateType => "ClientCertificateType", + ExtensionType::ServerCertificateType => "ServerCertificateType", + ExtensionType::Padding => "Padding", + ExtensionType::EncryptThenMac => "EncryptThenMac", + ExtensionType::ExtendedMasterSecret => "ExtendedMasterSecret", + ExtensionType::TokenBinding => "TokenBinding", + ExtensionType::CachedInfo => "CachedInfo", + ExtensionType::SessionTicket => "SessionTicket", + ExtensionType::PreSharedKey => "PreSharedKey", + ExtensionType::EarlyData => "EarlyData", + ExtensionType::SupportedVersions => "SupportedVersions", + ExtensionType::Cookie => "Cookie", + ExtensionType::PskKeyExchangeModes => "PskKeyExchangeModes", + ExtensionType::CertificateAuthorities => "CertificateAuthorities", + ExtensionType::OidFilters => "OidFilters", + ExtensionType::PostHandshakeAuth => "PostHandshakeAuth", + ExtensionType::SignatureAlgorithmsCert => "SignatureAlgorithmsCert", + ExtensionType::KeyShare => "KeyShare", + ExtensionType::RenegotiationInfo => "RenegotiationInfo", + _ => unreachable!("known DTLS 1.3 extension type missing Debug label"), + }; + + f.write_str(name) + } +} + #[cfg(test)] mod tests { use super::*; @@ -220,6 +196,40 @@ mod tests { 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, // Extension data ]; + #[test] + fn extension_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(ExtensionType::default().as_u16(), 0); + assert_eq!(ExtensionType::default(), ExtensionType::ServerName); + } + + #[test] + fn extension_type_wire_roundtrip() { + for extension_type in ExtensionType::supported() { + assert_eq!( + ExtensionType::from_u16(extension_type.as_u16()), + *extension_type + ); + assert!(!extension_type.is_unknown()); + } + + let unknown = ExtensionType::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn extension_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ExtensionType::SupportedGroups), + "SupportedGroups" + ); + assert_eq!( + format!("{:?}", ExtensionType::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn roundtrip() { // Parse the message with base_offset 0 diff --git a/src/dtls13/message/extensions/supported_versions.rs b/src/dtls13/message/extensions/supported_versions.rs index 13e43c83..cb33a704 100644 --- a/src/dtls13/message/extensions/supported_versions.rs +++ b/src/dtls13/message/extensions/supported_versions.rs @@ -29,7 +29,7 @@ impl SupportedVersionsClientHello { let mut rest = versions_data; while !rest.is_empty() { let (r, version) = ProtocolVersion::parse(rest)?; - if !matches!(version, ProtocolVersion::Unknown(_)) { + if !version.is_unknown() { versions .try_push(version) .map_err(|_| Err::Failure(Error::new(rest, ErrorKind::LengthValue)))?; diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index 612e195d..7896bcf0 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::ops::Range; use std::sync::atomic::{AtomicBool, Ordering}; @@ -266,52 +267,31 @@ impl Handshake { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MessageType { - ClientHello, - ServerHello, - EncryptedExtensions, - Certificate, - CertificateRequest, - CertificateVerify, - Finished, - KeyUpdate, - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct MessageType(u8); -impl Default for MessageType { - fn default() -> Self { - Self::Unknown(0) +#[allow(non_upper_case_globals)] +impl MessageType { + pub const ClientHello: Self = Self(1); + pub const ServerHello: Self = Self(2); + pub const EncryptedExtensions: Self = Self(8); + pub const Certificate: Self = Self(11); + pub const CertificateRequest: Self = Self(13); + pub const CertificateVerify: Self = Self(15); + pub const Finished: Self = Self(20); + pub const KeyUpdate: Self = Self(24); + + pub const fn from_u8(value: u8) -> Self { + Self(value) } -} -impl MessageType { - pub fn from_u8(value: u8) -> Self { - match value { - 1 => MessageType::ClientHello, - 2 => MessageType::ServerHello, - 8 => MessageType::EncryptedExtensions, - 11 => MessageType::Certificate, - 13 => MessageType::CertificateRequest, - 15 => MessageType::CertificateVerify, - 20 => MessageType::Finished, - 24 => MessageType::KeyUpdate, - _ => MessageType::Unknown(value), - } + pub const fn as_u8(&self) -> u8 { + self.0 } - pub fn as_u8(&self) -> u8 { - match self { - MessageType::ClientHello => 1, - MessageType::ServerHello => 2, - MessageType::EncryptedExtensions => 8, - MessageType::Certificate => 11, - MessageType::CertificateRequest => 13, - MessageType::CertificateVerify => 15, - MessageType::Finished => 20, - MessageType::KeyUpdate => 24, - MessageType::Unknown(value) => *value, - } + const fn is_unknown(&self) -> bool { + !matches!(*self, Self(1..=2 | 8 | 11 | 13 | 15 | 20 | 24)) } pub fn parse(input: &[u8]) -> IResult<&[u8], MessageType> { @@ -320,6 +300,28 @@ impl MessageType { } } +impl fmt::Debug for MessageType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_unknown() { + return f.debug_tuple("Unknown").field(&self.0).finish(); + } + + let name = match *self { + MessageType::ClientHello => "ClientHello", + MessageType::ServerHello => "ServerHello", + MessageType::EncryptedExtensions => "EncryptedExtensions", + MessageType::Certificate => "Certificate", + MessageType::CertificateRequest => "CertificateRequest", + MessageType::CertificateVerify => "CertificateVerify", + MessageType::Finished => "Finished", + MessageType::KeyUpdate => "KeyUpdate", + _ => unreachable!("known DTLS 1.3 handshake message type missing Debug label"), + }; + + f.write_str(name) + } +} + #[allow(clippy::large_enum_variant)] #[derive(Debug, PartialEq, Eq)] pub enum Body { @@ -436,7 +438,7 @@ impl Body { .ok_or_else(|| Err::Failure(Error::new(input, ErrorKind::Fail)))?; Ok((input, Body::KeyUpdate(request))) } - MessageType::Unknown(value) => Ok((input, Body::Unknown(value))), + _ => Ok((input, Body::Unknown(m.as_u8()))), } } @@ -508,6 +510,40 @@ mod tests { 0x00, // Null ]; + #[test] + fn message_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(MessageType::default().as_u8(), 0); + assert!(MessageType::default().is_unknown()); + } + + #[test] + fn message_type_wire_roundtrip() { + for message_type in [ + MessageType::ClientHello, + MessageType::ServerHello, + MessageType::EncryptedExtensions, + MessageType::Certificate, + MessageType::CertificateRequest, + MessageType::CertificateVerify, + MessageType::Finished, + MessageType::KeyUpdate, + ] { + assert_eq!(MessageType::from_u8(message_type.as_u8()), message_type); + assert!(!message_type.is_unknown()); + } + + let unknown = MessageType::from_u8(0xFF); + assert_eq!(unknown.as_u8(), 0xFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn message_type_debug_stays_enum_like() { + assert_eq!(format!("{:?}", MessageType::ClientHello), "ClientHello"); + assert_eq!(format!("{:?}", MessageType::from_u8(0xFF)), "Unknown(255)"); + } + #[test] fn handshake_size() { let h = Handshake::new( diff --git a/src/dtls13/queue.rs b/src/dtls13/queue.rs index eeeea1c8..8fd1ff98 100644 --- a/src/dtls13/queue.rs +++ b/src/dtls13/queue.rs @@ -52,9 +52,7 @@ impl fmt::Debug for QueueRx { ContentType::Handshake => handshake += 1, ContentType::ApplicationData => app_data += 1, ContentType::Alert => alert += 1, - ContentType::Unknown(_) | ContentType::ChangeCipherSpec | ContentType::Ack => { - other += 1 - } + _ => other += 1, } let seq = (record.sequence.epoch, record.sequence.sequence_number); diff --git a/src/types.rs b/src/types.rs index 48a9c2f1..08a1a32b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -81,134 +81,80 @@ impl Random { /// /// Used for Elliptic Curve Diffie-Hellman Ephemeral (ECDHE) key exchange. /// The same named groups are used in both DTLS 1.2 and DTLS 1.3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[non_exhaustive] -pub enum NamedGroup { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct NamedGroup(u16); + +#[allow(non_upper_case_globals)] +impl NamedGroup { /// sect163k1 (deprecated). - Sect163k1, + pub const Sect163k1: Self = Self(1); /// sect163r1 (deprecated). - Sect163r1, + pub const Sect163r1: Self = Self(2); /// sect163r2 (deprecated). - Sect163r2, + pub const Sect163r2: Self = Self(3); /// sect193r1 (deprecated). - Sect193r1, + pub const Sect193r1: Self = Self(4); /// sect193r2 (deprecated). - Sect193r2, + pub const Sect193r2: Self = Self(5); /// sect233k1 (deprecated). - Sect233k1, + pub const Sect233k1: Self = Self(6); /// sect233r1 (deprecated). - Sect233r1, + pub const Sect233r1: Self = Self(7); /// sect239k1 (deprecated). - Sect239k1, + pub const Sect239k1: Self = Self(8); /// sect283k1 (deprecated). - Sect283k1, + pub const Sect283k1: Self = Self(9); /// sect283r1 (deprecated). - Sect283r1, + pub const Sect283r1: Self = Self(10); /// sect409k1 (deprecated). - Sect409k1, + pub const Sect409k1: Self = Self(11); /// sect409r1 (deprecated). - Sect409r1, + pub const Sect409r1: Self = Self(12); /// sect571k1 (deprecated). - Sect571k1, + pub const Sect571k1: Self = Self(13); /// sect571r1 (deprecated). - Sect571r1, + pub const Sect571r1: Self = Self(14); /// secp160k1 (deprecated). - Secp160k1, + pub const Secp160k1: Self = Self(15); /// secp160r1 (deprecated). - Secp160r1, + pub const Secp160r1: Self = Self(16); /// secp160r2 (deprecated). - Secp160r2, + pub const Secp160r2: Self = Self(17); /// secp192k1 (deprecated). - Secp192k1, + pub const Secp192k1: Self = Self(18); /// secp192r1 (deprecated). - Secp192r1, + pub const Secp192r1: Self = Self(19); /// secp224k1. - Secp224k1, + pub const Secp224k1: Self = Self(20); /// secp224r1. - Secp224r1, + pub const Secp224r1: Self = Self(21); /// secp256k1. - Secp256k1, + pub const Secp256k1: Self = Self(22); /// secp256r1 / P-256 (supported by dimpl). - Secp256r1, + pub const Secp256r1: Self = Self(23); /// secp384r1 / P-384 (supported by dimpl). - Secp384r1, + pub const Secp384r1: Self = Self(24); /// secp521r1 / P-521. - Secp521r1, + pub const Secp521r1: Self = Self(25); /// X25519 (Curve25519 for ECDHE). - X25519, + pub const X25519: Self = Self(29); /// X448 (Curve448 for ECDHE). - X448, - /// Unknown or unsupported group. - Unknown(u16), -} + pub const X448: Self = Self(30); -impl NamedGroup { /// Convert a wire format u16 value to a `NamedGroup`. - pub fn from_u16(value: u16) -> Self { - match value { - 1 => NamedGroup::Sect163k1, - 2 => NamedGroup::Sect163r1, - 3 => NamedGroup::Sect163r2, - 4 => NamedGroup::Sect193r1, - 5 => NamedGroup::Sect193r2, - 6 => NamedGroup::Sect233k1, - 7 => NamedGroup::Sect233r1, - 8 => NamedGroup::Sect239k1, - 9 => NamedGroup::Sect283k1, - 10 => NamedGroup::Sect283r1, - 11 => NamedGroup::Sect409k1, - 12 => NamedGroup::Sect409r1, - 13 => NamedGroup::Sect571k1, - 14 => NamedGroup::Sect571r1, - 15 => NamedGroup::Secp160k1, - 16 => NamedGroup::Secp160r1, - 17 => NamedGroup::Secp160r2, - 18 => NamedGroup::Secp192k1, - 19 => NamedGroup::Secp192r1, - 20 => NamedGroup::Secp224k1, - 21 => NamedGroup::Secp224r1, - 22 => NamedGroup::Secp256k1, - 23 => NamedGroup::Secp256r1, - 24 => NamedGroup::Secp384r1, - 25 => NamedGroup::Secp521r1, - 29 => NamedGroup::X25519, - 30 => NamedGroup::X448, - _ => NamedGroup::Unknown(value), - } + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Convert this `NamedGroup` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - NamedGroup::Sect163k1 => 1, - NamedGroup::Sect163r1 => 2, - NamedGroup::Sect163r2 => 3, - NamedGroup::Sect193r1 => 4, - NamedGroup::Sect193r2 => 5, - NamedGroup::Sect233k1 => 6, - NamedGroup::Sect233r1 => 7, - NamedGroup::Sect239k1 => 8, - NamedGroup::Sect283k1 => 9, - NamedGroup::Sect283r1 => 10, - NamedGroup::Sect409k1 => 11, - NamedGroup::Sect409r1 => 12, - NamedGroup::Sect571k1 => 13, - NamedGroup::Sect571r1 => 14, - NamedGroup::Secp160k1 => 15, - NamedGroup::Secp160r1 => 16, - NamedGroup::Secp160r2 => 17, - NamedGroup::Secp192k1 => 18, - NamedGroup::Secp192r1 => 19, - NamedGroup::Secp224k1 => 20, - NamedGroup::Secp224r1 => 21, - NamedGroup::Secp256k1 => 22, - NamedGroup::Secp256r1 => 23, - NamedGroup::Secp384r1 => 24, - NamedGroup::Secp521r1 => 25, - NamedGroup::X25519 => 29, - NamedGroup::X448 => 30, - NamedGroup::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known TLS named group wire value. + pub const fn is_unknown(&self) -> bool { + !matches!(*self, Self(1..=25 | 29..=30)) } /// Parse a `NamedGroup` from wire format. @@ -266,6 +212,41 @@ impl NamedGroup { } } +impl fmt::Debug for NamedGroup { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + NamedGroup::Sect163k1 => f.write_str("Sect163k1"), + NamedGroup::Sect163r1 => f.write_str("Sect163r1"), + NamedGroup::Sect163r2 => f.write_str("Sect163r2"), + NamedGroup::Sect193r1 => f.write_str("Sect193r1"), + NamedGroup::Sect193r2 => f.write_str("Sect193r2"), + NamedGroup::Sect233k1 => f.write_str("Sect233k1"), + NamedGroup::Sect233r1 => f.write_str("Sect233r1"), + NamedGroup::Sect239k1 => f.write_str("Sect239k1"), + NamedGroup::Sect283k1 => f.write_str("Sect283k1"), + NamedGroup::Sect283r1 => f.write_str("Sect283r1"), + NamedGroup::Sect409k1 => f.write_str("Sect409k1"), + NamedGroup::Sect409r1 => f.write_str("Sect409r1"), + NamedGroup::Sect571k1 => f.write_str("Sect571k1"), + NamedGroup::Sect571r1 => f.write_str("Sect571r1"), + NamedGroup::Secp160k1 => f.write_str("Secp160k1"), + NamedGroup::Secp160r1 => f.write_str("Secp160r1"), + NamedGroup::Secp160r2 => f.write_str("Secp160r2"), + NamedGroup::Secp192k1 => f.write_str("Secp192k1"), + NamedGroup::Secp192r1 => f.write_str("Secp192r1"), + NamedGroup::Secp224k1 => f.write_str("Secp224k1"), + NamedGroup::Secp224r1 => f.write_str("Secp224r1"), + NamedGroup::Secp256k1 => f.write_str("Secp256k1"), + NamedGroup::Secp256r1 => f.write_str("Secp256r1"), + NamedGroup::Secp384r1 => f.write_str("Secp384r1"), + NamedGroup::Secp521r1 => f.write_str("Secp521r1"), + NamedGroup::X25519 => f.write_str("X25519"), + NamedGroup::X448 => f.write_str("X448"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Hash Algorithms // ============================================================================ @@ -274,60 +255,48 @@ impl NamedGroup { /// /// Specifies the hash algorithm to be used in digital signatures, /// PRF/HKDF operations, and transcript hashing. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum HashAlgorithm { +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct HashAlgorithm(u8); + +impl Default for HashAlgorithm { + fn default() -> Self { + Self::None + } +} + +#[allow(non_upper_case_globals)] +impl HashAlgorithm { /// No hash (not typically used). - None, + pub const None: Self = Self(0); /// MD5 hash (deprecated, not supported). - MD5, + pub const MD5: Self = Self(1); /// SHA-1 hash (deprecated, not supported). - SHA1, + pub const SHA1: Self = Self(2); /// SHA-224 hash. - SHA224, + pub const SHA224: Self = Self(3); /// SHA-256 hash (supported by dimpl). - SHA256, + pub const SHA256: Self = Self(4); /// SHA-384 hash (supported by dimpl). - SHA384, + pub const SHA384: Self = Self(5); /// SHA-512 hash. - SHA512, - /// Unknown or unsupported hash algorithm. - Unknown(u8), -} + pub const SHA512: Self = Self(6); -impl Default for HashAlgorithm { - fn default() -> Self { - Self::Unknown(0) - } -} + pub(crate) const UNKNOWN_DERIVED: Self = Self(u8::MAX); -impl HashAlgorithm { /// Convert a wire format u8 value to a `HashAlgorithm`. - pub fn from_u8(value: u8) -> Self { - match value { - 0 => HashAlgorithm::None, - 1 => HashAlgorithm::MD5, - 2 => HashAlgorithm::SHA1, - 3 => HashAlgorithm::SHA224, - 4 => HashAlgorithm::SHA256, - 5 => HashAlgorithm::SHA384, - 6 => HashAlgorithm::SHA512, - _ => HashAlgorithm::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `HashAlgorithm` to its wire format u8 value. - pub fn as_u8(&self) -> u8 { - match self { - HashAlgorithm::None => 0, - HashAlgorithm::MD5 => 1, - HashAlgorithm::SHA1 => 2, - HashAlgorithm::SHA224 => 3, - HashAlgorithm::SHA256 => 4, - HashAlgorithm::SHA384 => 5, - HashAlgorithm::SHA512 => 6, - HashAlgorithm::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known DTLS hash algorithm wire value. + pub const fn is_unknown(&self) -> bool { + self.0 > Self::SHA512.0 } /// Parse a `HashAlgorithm` from wire format. @@ -337,8 +306,8 @@ impl HashAlgorithm { } /// Returns the output length in bytes for this hash algorithm. - pub fn output_len(&self) -> usize { - match self { + pub const fn output_len(&self) -> usize { + match *self { HashAlgorithm::None => 0, HashAlgorithm::MD5 => 16, HashAlgorithm::SHA1 => 20, @@ -346,7 +315,22 @@ impl HashAlgorithm { HashAlgorithm::SHA256 => 32, HashAlgorithm::SHA384 => 48, HashAlgorithm::SHA512 => 64, - HashAlgorithm::Unknown(_) => 0, + _ => 0, + } + } +} + +impl fmt::Debug for HashAlgorithm { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + HashAlgorithm::None => f.write_str("None"), + HashAlgorithm::MD5 => f.write_str("MD5"), + HashAlgorithm::SHA1 => f.write_str("SHA1"), + HashAlgorithm::SHA224 => f.write_str("SHA224"), + HashAlgorithm::SHA256 => f.write_str("SHA256"), + HashAlgorithm::SHA384 => f.write_str("SHA384"), + HashAlgorithm::SHA512 => f.write_str("SHA512"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), } } } @@ -359,48 +343,42 @@ impl HashAlgorithm { /// /// Represents the underlying signature primitive (RSA, ECDSA, etc.). /// Used internally for signing operations across both DTLS versions. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum SignatureAlgorithm { - /// Anonymous (no certificate). - Anonymous, - /// RSA signatures. - RSA, - /// DSA signatures. - DSA, - /// ECDSA signatures. - ECDSA, - /// Unknown or unsupported signature algorithm. - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct SignatureAlgorithm(u8); impl Default for SignatureAlgorithm { fn default() -> Self { - Self::Unknown(0) + Self::Anonymous } } +#[allow(non_upper_case_globals)] impl SignatureAlgorithm { + /// Anonymous (no certificate). + pub const Anonymous: Self = Self(0); + /// RSA signatures. + pub const RSA: Self = Self(1); + /// DSA signatures. + pub const DSA: Self = Self(2); + /// ECDSA signatures. + pub const ECDSA: Self = Self(3); + + pub(crate) const UNKNOWN_DERIVED: Self = Self(u8::MAX); + /// Convert an 8-bit value into a `SignatureAlgorithm`. - pub fn from_u8(value: u8) -> Self { - match value { - 0 => SignatureAlgorithm::Anonymous, - 1 => SignatureAlgorithm::RSA, - 2 => SignatureAlgorithm::DSA, - 3 => SignatureAlgorithm::ECDSA, - _ => SignatureAlgorithm::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `SignatureAlgorithm` into its 8-bit representation. - pub fn as_u8(&self) -> u8 { - match self { - SignatureAlgorithm::Anonymous => 0, - SignatureAlgorithm::RSA => 1, - SignatureAlgorithm::DSA => 2, - SignatureAlgorithm::ECDSA => 3, - SignatureAlgorithm::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known DTLS signature algorithm wire value. + pub const fn is_unknown(&self) -> bool { + self.0 > Self::ECDSA.0 } /// Parse a `SignatureAlgorithm` from network bytes. @@ -410,6 +388,18 @@ impl SignatureAlgorithm { } } +impl fmt::Debug for SignatureAlgorithm { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + SignatureAlgorithm::Anonymous => f.write_str("Anonymous"), + SignatureAlgorithm::RSA => f.write_str("RSA"), + SignatureAlgorithm::DSA => f.write_str("DSA"), + SignatureAlgorithm::ECDSA => f.write_str("ECDSA"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Content Type // ============================================================================ @@ -418,51 +408,36 @@ impl SignatureAlgorithm { /// /// Identifies the type of data in a DTLS record. These values are the same /// for both DTLS 1.2 and DTLS 1.3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ContentType { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct ContentType(u8); + +#[allow(non_upper_case_globals)] +impl ContentType { /// Change Cipher Spec (used in DTLS 1.2, compatibility-only in 1.3). - ChangeCipherSpec, + pub const ChangeCipherSpec: Self = Self(20); /// Alert message. - Alert, + pub const Alert: Self = Self(21); /// Handshake message. - Handshake, + pub const Handshake: Self = Self(22); /// Application data. - ApplicationData, + pub const ApplicationData: Self = Self(23); /// ACK (DTLS 1.3 only, RFC 9147 Section 7). - Ack, - /// Unknown content type. - Unknown(u8), -} + pub const Ack: Self = Self(26); -impl Default for ContentType { - fn default() -> Self { - Self::Unknown(0) - } -} - -impl ContentType { /// Convert a u8 value to a `ContentType`. - pub fn from_u8(value: u8) -> Self { - match value { - 20 => ContentType::ChangeCipherSpec, - 21 => ContentType::Alert, - 22 => ContentType::Handshake, - 23 => ContentType::ApplicationData, - 26 => ContentType::Ack, - _ => ContentType::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Convert this `ContentType` to its u8 value. - pub fn as_u8(&self) -> u8 { - match self { - ContentType::ChangeCipherSpec => 20, - ContentType::Alert => 21, - ContentType::Handshake => 22, - ContentType::ApplicationData => 23, - ContentType::Ack => 26, - ContentType::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known DTLS record content type. + pub const fn is_unknown(&self) -> bool { + !matches!(*self, Self(20..=23 | 26)) } /// Parse a `ContentType` from wire format. @@ -472,6 +447,19 @@ impl ContentType { } } +impl fmt::Debug for ContentType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + ContentType::ChangeCipherSpec => f.write_str("ChangeCipherSpec"), + ContentType::Alert => f.write_str("Alert"), + ContentType::Handshake => f.write_str("Handshake"), + ContentType::ApplicationData => f.write_str("ApplicationData"), + ContentType::Ack => f.write_str("Ack"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Sequence Number // ============================================================================ @@ -535,83 +523,56 @@ impl PartialOrd for Sequence { /// In TLS 1.3, signature schemes combine the signature algorithm with the /// hash algorithm into a single identifier, unlike TLS 1.2 where they were /// separate. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -#[non_exhaustive] -pub enum SignatureScheme { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct SignatureScheme(u16); + +impl SignatureScheme { /// ECDSA with P-256 and SHA-256. - ECDSA_SECP256R1_SHA256, + pub const ECDSA_SECP256R1_SHA256: Self = Self(0x0403); /// ECDSA with P-384 and SHA-384. - ECDSA_SECP384R1_SHA384, + pub const ECDSA_SECP384R1_SHA384: Self = Self(0x0503); /// ECDSA with P-521 and SHA-512. - ECDSA_SECP521R1_SHA512, + pub const ECDSA_SECP521R1_SHA512: Self = Self(0x0603); /// Ed25519. - ED25519, + pub const ED25519: Self = Self(0x0807); /// Ed448. - ED448, + pub const ED448: Self = Self(0x0808); /// RSA-PSS with SHA-256 (rsaEncryption OID). - RSA_PSS_RSAE_SHA256, + pub const RSA_PSS_RSAE_SHA256: Self = Self(0x0804); /// RSA-PSS with SHA-384 (rsaEncryption OID). - RSA_PSS_RSAE_SHA384, + pub const RSA_PSS_RSAE_SHA384: Self = Self(0x0805); /// RSA-PSS with SHA-512 (rsaEncryption OID). - RSA_PSS_RSAE_SHA512, + pub const RSA_PSS_RSAE_SHA512: Self = Self(0x0806); /// RSA-PSS with SHA-256 (id-rsassa-pss OID). - RSA_PSS_PSS_SHA256, + pub const RSA_PSS_PSS_SHA256: Self = Self(0x0809); /// RSA-PSS with SHA-384 (id-rsassa-pss OID). - RSA_PSS_PSS_SHA384, + pub const RSA_PSS_PSS_SHA384: Self = Self(0x080a); /// RSA-PSS with SHA-512 (id-rsassa-pss OID). - RSA_PSS_PSS_SHA512, + pub const RSA_PSS_PSS_SHA512: Self = Self(0x080b); /// RSA PKCS#1 v1.5 with SHA-256 (legacy). - RSA_PKCS1_SHA256, + pub const RSA_PKCS1_SHA256: Self = Self(0x0401); /// RSA PKCS#1 v1.5 with SHA-384 (legacy). - RSA_PKCS1_SHA384, + pub const RSA_PKCS1_SHA384: Self = Self(0x0501); /// RSA PKCS#1 v1.5 with SHA-512 (legacy). - RSA_PKCS1_SHA512, - /// Unknown or unsupported signature scheme. - Unknown(u16), -} + pub const RSA_PKCS1_SHA512: Self = Self(0x0601); -impl SignatureScheme { /// Convert a wire format u16 value to a `SignatureScheme`. - pub fn from_u16(value: u16) -> Self { - match value { - 0x0403 => SignatureScheme::ECDSA_SECP256R1_SHA256, - 0x0503 => SignatureScheme::ECDSA_SECP384R1_SHA384, - 0x0603 => SignatureScheme::ECDSA_SECP521R1_SHA512, - 0x0807 => SignatureScheme::ED25519, - 0x0808 => SignatureScheme::ED448, - 0x0804 => SignatureScheme::RSA_PSS_RSAE_SHA256, - 0x0805 => SignatureScheme::RSA_PSS_RSAE_SHA384, - 0x0806 => SignatureScheme::RSA_PSS_RSAE_SHA512, - 0x0809 => SignatureScheme::RSA_PSS_PSS_SHA256, - 0x080a => SignatureScheme::RSA_PSS_PSS_SHA384, - 0x080b => SignatureScheme::RSA_PSS_PSS_SHA512, - 0x0401 => SignatureScheme::RSA_PKCS1_SHA256, - 0x0501 => SignatureScheme::RSA_PKCS1_SHA384, - 0x0601 => SignatureScheme::RSA_PKCS1_SHA512, - _ => SignatureScheme::Unknown(value), - } + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Convert this `SignatureScheme` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - SignatureScheme::ECDSA_SECP256R1_SHA256 => 0x0403, - SignatureScheme::ECDSA_SECP384R1_SHA384 => 0x0503, - SignatureScheme::ECDSA_SECP521R1_SHA512 => 0x0603, - SignatureScheme::ED25519 => 0x0807, - SignatureScheme::ED448 => 0x0808, - SignatureScheme::RSA_PSS_RSAE_SHA256 => 0x0804, - SignatureScheme::RSA_PSS_RSAE_SHA384 => 0x0805, - SignatureScheme::RSA_PSS_RSAE_SHA512 => 0x0806, - SignatureScheme::RSA_PSS_PSS_SHA256 => 0x0809, - SignatureScheme::RSA_PSS_PSS_SHA384 => 0x080a, - SignatureScheme::RSA_PSS_PSS_SHA512 => 0x080b, - SignatureScheme::RSA_PKCS1_SHA256 => 0x0401, - SignatureScheme::RSA_PKCS1_SHA384 => 0x0501, - SignatureScheme::RSA_PKCS1_SHA512 => 0x0601, - SignatureScheme::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known TLS signature scheme wire value. + pub const fn is_unknown(&self) -> bool { + !matches!( + *self, + Self(0x0401 | 0x0403 | 0x0501 | 0x0503 | 0x0601 | 0x0603 | 0x0804..=0x080b) + ) } /// Parse a `SignatureScheme` from wire format. @@ -626,7 +587,7 @@ impl SignatureScheme { } /// All recognized signature schemes (every non-`Unknown` variant). - pub fn all() -> &'static [SignatureScheme] { + pub const fn all() -> &'static [SignatureScheme] { &[ SignatureScheme::ECDSA_SECP256R1_SHA256, SignatureScheme::ECDSA_SECP384R1_SHA384, @@ -663,7 +624,7 @@ impl SignatureScheme { /// In DTLS 1.3, ECDSA signature schemes encode the expected curve. /// Returns `None` for non-ECDSA schemes. pub fn named_group(&self) -> Option { - match self { + match *self { SignatureScheme::ECDSA_SECP256R1_SHA256 => Some(NamedGroup::Secp256r1), SignatureScheme::ECDSA_SECP384R1_SHA384 => Some(NamedGroup::Secp384r1), _ => None, @@ -672,7 +633,7 @@ impl SignatureScheme { /// Returns the hash algorithm associated with this signature scheme. pub fn hash_algorithm(&self) -> HashAlgorithm { - match self { + match *self { SignatureScheme::ECDSA_SECP256R1_SHA256 | SignatureScheme::RSA_PSS_RSAE_SHA256 | SignatureScheme::RSA_PSS_PSS_SHA256 @@ -687,7 +648,29 @@ impl SignatureScheme { | SignatureScheme::RSA_PKCS1_SHA512 => HashAlgorithm::SHA512, // Ed25519 and Ed448 have intrinsic hash algorithms SignatureScheme::ED25519 | SignatureScheme::ED448 => HashAlgorithm::None, - SignatureScheme::Unknown(_) => HashAlgorithm::Unknown(0), + _ => HashAlgorithm::UNKNOWN_DERIVED, + } + } +} + +impl fmt::Debug for SignatureScheme { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + SignatureScheme::ECDSA_SECP256R1_SHA256 => f.write_str("ECDSA_SECP256R1_SHA256"), + SignatureScheme::ECDSA_SECP384R1_SHA384 => f.write_str("ECDSA_SECP384R1_SHA384"), + SignatureScheme::ECDSA_SECP521R1_SHA512 => f.write_str("ECDSA_SECP521R1_SHA512"), + SignatureScheme::ED25519 => f.write_str("ED25519"), + SignatureScheme::ED448 => f.write_str("ED448"), + SignatureScheme::RSA_PSS_RSAE_SHA256 => f.write_str("RSA_PSS_RSAE_SHA256"), + SignatureScheme::RSA_PSS_RSAE_SHA384 => f.write_str("RSA_PSS_RSAE_SHA384"), + SignatureScheme::RSA_PSS_RSAE_SHA512 => f.write_str("RSA_PSS_RSAE_SHA512"), + SignatureScheme::RSA_PSS_PSS_SHA256 => f.write_str("RSA_PSS_PSS_SHA256"), + SignatureScheme::RSA_PSS_PSS_SHA384 => f.write_str("RSA_PSS_PSS_SHA384"), + SignatureScheme::RSA_PSS_PSS_SHA512 => f.write_str("RSA_PSS_PSS_SHA512"), + SignatureScheme::RSA_PKCS1_SHA256 => f.write_str("RSA_PKCS1_SHA256"), + SignatureScheme::RSA_PKCS1_SHA384 => f.write_str("RSA_PKCS1_SHA384"), + SignatureScheme::RSA_PKCS1_SHA512 => f.write_str("RSA_PKCS1_SHA512"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), } } } @@ -700,47 +683,35 @@ impl SignatureScheme { /// /// Unlike DTLS 1.2, TLS 1.3 cipher suites only specify the AEAD algorithm /// and hash function. Key exchange is negotiated separately via key_share. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(non_camel_case_types)] -#[non_exhaustive] -pub enum Dtls13CipherSuite { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct Dtls13CipherSuite(u16); + +impl Dtls13CipherSuite { /// TLS_AES_128_GCM_SHA256. - AES_128_GCM_SHA256, + pub const AES_128_GCM_SHA256: Self = Self(0x1301); /// TLS_AES_256_GCM_SHA384. - AES_256_GCM_SHA384, + pub const AES_256_GCM_SHA384: Self = Self(0x1302); /// TLS_CHACHA20_POLY1305_SHA256. - CHACHA20_POLY1305_SHA256, + pub const CHACHA20_POLY1305_SHA256: Self = Self(0x1303); /// TLS_AES_128_CCM_SHA256. - AES_128_CCM_SHA256, + pub const AES_128_CCM_SHA256: Self = Self(0x1304); /// TLS_AES_128_CCM_8_SHA256 (shorter tag, for constrained devices). - AES_128_CCM_8_SHA256, - /// Unknown or unsupported cipher suite. - Unknown(u16), -} + pub const AES_128_CCM_8_SHA256: Self = Self(0x1305); -impl Dtls13CipherSuite { /// Convert a wire format u16 value to a `Dtls13CipherSuite`. - pub fn from_u16(value: u16) -> Self { - match value { - 0x1301 => Dtls13CipherSuite::AES_128_GCM_SHA256, - 0x1302 => Dtls13CipherSuite::AES_256_GCM_SHA384, - 0x1303 => Dtls13CipherSuite::CHACHA20_POLY1305_SHA256, - 0x1304 => Dtls13CipherSuite::AES_128_CCM_SHA256, - 0x1305 => Dtls13CipherSuite::AES_128_CCM_8_SHA256, - _ => Dtls13CipherSuite::Unknown(value), - } + pub const fn from_u16(value: u16) -> Self { + Self(value) } /// Convert this `Dtls13CipherSuite` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - Dtls13CipherSuite::AES_128_GCM_SHA256 => 0x1301, - Dtls13CipherSuite::AES_256_GCM_SHA384 => 0x1302, - Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 => 0x1303, - Dtls13CipherSuite::AES_128_CCM_SHA256 => 0x1304, - Dtls13CipherSuite::AES_128_CCM_8_SHA256 => 0x1305, - Dtls13CipherSuite::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known DTLS 1.3 cipher suite wire value. + pub const fn is_unknown(&self) -> bool { + !matches!(*self, Self(0x1301..=0x1305)) } /// Parse a `Dtls13CipherSuite` from wire format. @@ -751,13 +722,13 @@ impl Dtls13CipherSuite { /// Returns the hash algorithm used by this cipher suite. pub fn hash_algorithm(&self) -> HashAlgorithm { - match self { + match *self { Dtls13CipherSuite::AES_128_GCM_SHA256 | Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 | Dtls13CipherSuite::AES_128_CCM_SHA256 | Dtls13CipherSuite::AES_128_CCM_8_SHA256 => HashAlgorithm::SHA256, Dtls13CipherSuite::AES_256_GCM_SHA384 => HashAlgorithm::SHA384, - Dtls13CipherSuite::Unknown(_) => HashAlgorithm::Unknown(0), + _ => HashAlgorithm::UNKNOWN_DERIVED, } } @@ -767,7 +738,7 @@ impl Dtls13CipherSuite { } /// All recognized DTLS 1.3 cipher suites (every non-`Unknown` variant). - pub fn all() -> &'static [Dtls13CipherSuite] { + pub const fn all() -> &'static [Dtls13CipherSuite] { &[ Dtls13CipherSuite::AES_128_GCM_SHA256, Dtls13CipherSuite::AES_256_GCM_SHA384, @@ -778,7 +749,7 @@ impl Dtls13CipherSuite { } /// Supported DTLS 1.3 cipher suites in preference order. - pub fn supported() -> &'static [Dtls13CipherSuite] { + pub const fn supported() -> &'static [Dtls13CipherSuite] { &[ Dtls13CipherSuite::AES_128_GCM_SHA256, Dtls13CipherSuite::AES_256_GCM_SHA384, @@ -792,6 +763,19 @@ impl Dtls13CipherSuite { } } +impl fmt::Debug for Dtls13CipherSuite { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Dtls13CipherSuite::AES_128_GCM_SHA256 => f.write_str("AES_128_GCM_SHA256"), + Dtls13CipherSuite::AES_256_GCM_SHA384 => f.write_str("AES_256_GCM_SHA384"), + Dtls13CipherSuite::CHACHA20_POLY1305_SHA256 => f.write_str("CHACHA20_POLY1305_SHA256"), + Dtls13CipherSuite::AES_128_CCM_SHA256 => f.write_str("AES_128_CCM_SHA256"), + Dtls13CipherSuite::AES_128_CCM_8_SHA256 => f.write_str("AES_128_CCM_8_SHA256"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Protocol Version // ============================================================================ @@ -799,45 +783,37 @@ impl Dtls13CipherSuite { /// DTLS protocol version identifiers. /// /// Used in record headers and handshake messages for both DTLS 1.2 and 1.3. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ProtocolVersion { +#[repr(transparent)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct ProtocolVersion(u16); + +impl ProtocolVersion { /// DTLS 1.0. - DTLS1_0, + pub const DTLS1_0: Self = Self(0xFEFF); /// DTLS 1.2. - DTLS1_2, + pub const DTLS1_2: Self = Self(0xFEFD); /// DTLS 1.3. - DTLS1_3, - /// Unknown protocol version. - Unknown(u16), -} + pub const DTLS1_3: Self = Self(0xFEFC); -impl Default for ProtocolVersion { - fn default() -> Self { - Self::Unknown(0) + /// Convert a wire format u16 value to a `ProtocolVersion`. + pub const fn from_u16(value: u16) -> Self { + Self(value) } -} -impl ProtocolVersion { /// Convert this `ProtocolVersion` to its wire format u16 value. - pub fn as_u16(&self) -> u16 { - match self { - ProtocolVersion::DTLS1_0 => 0xFEFF, - ProtocolVersion::DTLS1_2 => 0xFEFD, - ProtocolVersion::DTLS1_3 => 0xFEFC, - ProtocolVersion::Unknown(value) => *value, - } + pub const fn as_u16(&self) -> u16 { + self.0 + } + + /// Returns true if this is not a known DTLS protocol version wire value. + pub const fn is_unknown(&self) -> bool { + !matches!(*self, Self(0xFEFF | 0xFEFD | 0xFEFC)) } /// Parse a `ProtocolVersion` from wire format. pub fn parse(input: &[u8]) -> IResult<&[u8], ProtocolVersion> { let (input, version) = be_u16(input)?; - let protocol_version = match version { - 0xFEFF => ProtocolVersion::DTLS1_0, - 0xFEFD => ProtocolVersion::DTLS1_2, - 0xFEFC => ProtocolVersion::DTLS1_3, - _ => ProtocolVersion::Unknown(version), - }; - Ok((input, protocol_version)) + Ok((input, ProtocolVersion::from_u16(version))) } /// Serialize this `ProtocolVersion` to wire format. @@ -846,6 +822,17 @@ impl ProtocolVersion { } } +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + ProtocolVersion::DTLS1_0 => f.write_str("DTLS1_0"), + ProtocolVersion::DTLS1_2 => f.write_str("DTLS1_2"), + ProtocolVersion::DTLS1_3 => f.write_str("DTLS1_3"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + // ============================================================================ // Compression Method // ============================================================================ @@ -854,30 +841,26 @@ impl ProtocolVersion { /// /// Used in ClientHello/ServerHello for both DTLS 1.2 and 1.3. /// TLS 1.3 only uses Null compression but includes it for compatibility. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CompressionMethod { - /// No compression. - Null, - /// DEFLATE compression. - Deflate, - /// Unknown compression method. - Unknown(u8), -} +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct CompressionMethod(u8); impl Default for CompressionMethod { fn default() -> Self { - Self::Unknown(0) + Self::Null } } +#[allow(non_upper_case_globals)] impl CompressionMethod { + /// No compression. + pub const Null: Self = Self(0x00); + /// DEFLATE compression. + pub const Deflate: Self = Self(0x01); + /// Convert a u8 value to a `CompressionMethod`. - pub fn from_u8(value: u8) -> Self { - match value { - 0x00 => CompressionMethod::Null, - 0x01 => CompressionMethod::Deflate, - _ => CompressionMethod::Unknown(value), - } + pub const fn from_u8(value: u8) -> Self { + Self(value) } /// Returns true if this compression method is supported by this implementation. @@ -900,12 +883,13 @@ impl CompressionMethod { } /// Convert this `CompressionMethod` to its u8 value. - pub fn as_u8(&self) -> u8 { - match self { - CompressionMethod::Null => 0x00, - CompressionMethod::Deflate => 0x01, - CompressionMethod::Unknown(value) => *value, - } + pub const fn as_u8(&self) -> u8 { + self.0 + } + + /// Returns true if this is not a known TLS compression method wire value. + pub const fn is_unknown(&self) -> bool { + self.0 > Self::Deflate.0 } /// Parse a `CompressionMethod` from wire format. @@ -915,10 +899,305 @@ impl CompressionMethod { } } +impl fmt::Debug for CompressionMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + CompressionMethod::Null => f.write_str("Null"), + CompressionMethod::Deflate => f.write_str("Deflate"), + _ => f.debug_tuple("Unknown").field(&self.0).finish(), + } + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn named_group_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(NamedGroup::default().as_u16(), 0); + assert!(NamedGroup::default().is_unknown()); + } + + #[test] + fn named_group_wire_roundtrip() { + for group in NamedGroup::all() { + assert_eq!(NamedGroup::from_u16(group.as_u16()), *group); + assert!(!group.is_unknown()); + } + + let unknown = NamedGroup::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn named_group_debug_stays_enum_like() { + assert_eq!(format!("{:?}", NamedGroup::Secp256r1), "Secp256r1"); + assert_eq!(format!("{:?}", NamedGroup::X25519), "X25519"); + assert_eq!( + format!("{:?}", NamedGroup::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + + #[test] + fn hash_algorithm_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(HashAlgorithm::default().as_u8(), 0); + assert_eq!(HashAlgorithm::default(), HashAlgorithm::None); + } + + #[test] + fn hash_algorithm_wire_roundtrip() { + let known = [ + (0, HashAlgorithm::None), + (1, HashAlgorithm::MD5), + (2, HashAlgorithm::SHA1), + (3, HashAlgorithm::SHA224), + (4, HashAlgorithm::SHA256), + (5, HashAlgorithm::SHA384), + (6, HashAlgorithm::SHA512), + ]; + + for (wire, algorithm) in known { + assert_eq!(HashAlgorithm::from_u8(wire), algorithm); + assert_eq!(algorithm.as_u8(), wire); + assert!(!algorithm.is_unknown()); + } + + let unknown = HashAlgorithm::from_u8(7); + assert_eq!(unknown.as_u8(), 7); + assert!(unknown.is_unknown()); + } + + #[test] + fn hash_algorithm_output_len() { + assert_eq!(HashAlgorithm::None.output_len(), 0); + assert_eq!(HashAlgorithm::MD5.output_len(), 16); + assert_eq!(HashAlgorithm::SHA1.output_len(), 20); + assert_eq!(HashAlgorithm::SHA224.output_len(), 28); + assert_eq!(HashAlgorithm::SHA256.output_len(), 32); + assert_eq!(HashAlgorithm::SHA384.output_len(), 48); + assert_eq!(HashAlgorithm::SHA512.output_len(), 64); + assert_eq!(HashAlgorithm::from_u8(7).output_len(), 0); + } + + #[test] + fn hash_algorithm_debug_stays_enum_like() { + assert_eq!(format!("{:?}", HashAlgorithm::None), "None"); + assert_eq!(format!("{:?}", HashAlgorithm::SHA256), "SHA256"); + assert_eq!(format!("{:?}", HashAlgorithm::from_u8(7)), "Unknown(7)"); + } + + #[test] + fn signature_algorithm_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(SignatureAlgorithm::default().as_u8(), 0); + assert_eq!(SignatureAlgorithm::default(), SignatureAlgorithm::Anonymous); + } + + #[test] + fn signature_algorithm_wire_roundtrip() { + let known = [ + (0, SignatureAlgorithm::Anonymous), + (1, SignatureAlgorithm::RSA), + (2, SignatureAlgorithm::DSA), + (3, SignatureAlgorithm::ECDSA), + ]; + + for (wire, algorithm) in known { + assert_eq!(SignatureAlgorithm::from_u8(wire), algorithm); + assert_eq!(algorithm.as_u8(), wire); + assert!(!algorithm.is_unknown()); + } + + let unknown = SignatureAlgorithm::from_u8(4); + assert_eq!(unknown.as_u8(), 4); + assert!(unknown.is_unknown()); + } + + #[test] + fn signature_algorithm_debug_stays_enum_like() { + assert_eq!(format!("{:?}", SignatureAlgorithm::Anonymous), "Anonymous"); + assert_eq!(format!("{:?}", SignatureAlgorithm::ECDSA), "ECDSA"); + assert_eq!( + format!("{:?}", SignatureAlgorithm::from_u8(4)), + "Unknown(4)" + ); + } + + #[test] + fn compression_method_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(CompressionMethod::default().as_u8(), 0); + assert_eq!(CompressionMethod::default(), CompressionMethod::Null); + } + + #[test] + fn compression_method_wire_roundtrip() { + let known = [ + (0x00, CompressionMethod::Null), + (0x01, CompressionMethod::Deflate), + ]; + + for (wire, method) in known { + assert_eq!(CompressionMethod::from_u8(wire), method); + assert_eq!(method.as_u8(), wire); + assert!(!method.is_unknown()); + } + + let unknown = CompressionMethod::from_u8(0x02); + assert_eq!(unknown.as_u8(), 0x02); + assert!(unknown.is_unknown()); + } + + #[test] + fn compression_method_debug_stays_enum_like() { + assert_eq!(format!("{:?}", CompressionMethod::Null), "Null"); + assert_eq!(format!("{:?}", CompressionMethod::Deflate), "Deflate"); + assert_eq!( + format!("{:?}", CompressionMethod::from_u8(0x02)), + "Unknown(2)" + ); + } + + #[test] + fn content_type_newtype_shape() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(ContentType::default().as_u8(), 0); + assert!(ContentType::default().is_unknown()); + } + + #[test] + fn content_type_wire_roundtrip() { + let known = [ + (20, ContentType::ChangeCipherSpec), + (21, ContentType::Alert), + (22, ContentType::Handshake), + (23, ContentType::ApplicationData), + (26, ContentType::Ack), + ]; + + for (wire, content_type) in known { + assert_eq!(ContentType::from_u8(wire), content_type); + assert_eq!(content_type.as_u8(), wire); + assert!(!content_type.is_unknown()); + } + + let unknown = ContentType::from_u8(24); + assert_eq!(unknown.as_u8(), 24); + assert!(unknown.is_unknown()); + } + + #[test] + fn content_type_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", ContentType::ChangeCipherSpec), + "ChangeCipherSpec" + ); + assert_eq!(format!("{:?}", ContentType::Handshake), "Handshake"); + assert_eq!(format!("{:?}", ContentType::from_u8(24)), "Unknown(24)"); + } + + #[test] + fn signature_scheme_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(SignatureScheme::default().as_u16(), 0); + assert!(SignatureScheme::default().is_unknown()); + } + + #[test] + fn signature_scheme_wire_roundtrip() { + for scheme in SignatureScheme::all() { + assert_eq!(SignatureScheme::from_u16(scheme.as_u16()), *scheme); + assert!(!scheme.is_unknown()); + } + + let unknown = SignatureScheme::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn signature_scheme_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", SignatureScheme::ECDSA_SECP256R1_SHA256), + "ECDSA_SECP256R1_SHA256" + ); + assert_eq!( + format!("{:?}", SignatureScheme::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + + #[test] + fn dtls13_cipher_suite_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(Dtls13CipherSuite::default().as_u16(), 0); + assert!(Dtls13CipherSuite::default().is_unknown()); + } + + #[test] + fn dtls13_cipher_suite_wire_roundtrip() { + for suite in Dtls13CipherSuite::all() { + assert_eq!(Dtls13CipherSuite::from_u16(suite.as_u16()), *suite); + assert!(!suite.is_unknown()); + } + + let unknown = Dtls13CipherSuite::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn dtls13_cipher_suite_debug_stays_enum_like() { + assert_eq!( + format!("{:?}", Dtls13CipherSuite::AES_128_GCM_SHA256), + "AES_128_GCM_SHA256" + ); + assert_eq!( + format!("{:?}", Dtls13CipherSuite::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + + #[test] + fn protocol_version_newtype_shape() { + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(ProtocolVersion::default().as_u16(), 0); + assert!(ProtocolVersion::default().is_unknown()); + } + + #[test] + fn protocol_version_wire_roundtrip() { + let known = [ + (0xFEFF, ProtocolVersion::DTLS1_0), + (0xFEFD, ProtocolVersion::DTLS1_2), + (0xFEFC, ProtocolVersion::DTLS1_3), + ]; + + for (wire, version) in known { + assert_eq!(ProtocolVersion::from_u16(wire), version); + assert_eq!(version.as_u16(), wire); + assert!(!version.is_unknown()); + } + + let unknown = ProtocolVersion::from_u16(0xFFFF); + assert_eq!(unknown.as_u16(), 0xFFFF); + assert!(unknown.is_unknown()); + } + + #[test] + fn protocol_version_debug_stays_enum_like() { + assert_eq!(format!("{:?}", ProtocolVersion::DTLS1_2), "DTLS1_2"); + assert_eq!( + format!("{:?}", ProtocolVersion::from_u16(0xFFFF)), + "Unknown(65535)" + ); + } + #[test] fn random_parse() { let data = [ @@ -976,7 +1255,7 @@ mod tests { assert_eq!(SignatureScheme::RSA_PSS_RSAE_SHA256.named_group(), None); assert_eq!(SignatureScheme::ED25519.named_group(), None); assert_eq!(SignatureScheme::ECDSA_SECP521R1_SHA512.named_group(), None); - assert_eq!(SignatureScheme::Unknown(0xFFFF).named_group(), None); + assert_eq!(SignatureScheme::from_u16(0xFFFF).named_group(), None); } #[test] diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 7624c612..6fd95d38 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -58,7 +58,7 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize { | Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256 => 24, Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256 => 16, Dtls12CipherSuite::PSK_AES128_CCM_8 => 16, - Dtls12CipherSuite::Unknown(_) => panic!("unknown cipher suite"), + _ => panic!("unknown cipher suite"), } }