diff --git a/CHANGELOG.md b/CHANGELOG.md index e3bdafd..a811677 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Discard DTLS handshake records with malformed same-record tails #139 * Represent DTLS wire-code identifiers as compact newtypes (breaking) #137 * Make public errors structured and fatal-only (breaking) #134 diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 3bd86e9..7e8d743 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -1155,7 +1155,16 @@ impl Engine { for record in unhandled { let buf = record.into_buffer(); - self.parse_packet(&buf)?; + match self.parse_packet(&buf) { + Ok(()) => {} + Err(InternalError::Transient(err)) => { + trace!("Discarding buffered protected record after reparse failed: {err}"); + } + Err(err) => { + self.buffers_free.push(buf); + return Err(err); + } + } self.buffers_free.push(buf); } } diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index ca21b6c..6508a48 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -145,7 +145,7 @@ impl Record { // ONLY COPY: UDP packet slice -> pooled buffer let mut buffer = Buf::new(); buffer.extend_from_slice(record_slice); - let parsed = match ParsedRecord::parse(&buffer, cs, 0) { + let parsed = match ParsedRecord::parse(&buffer, cs, 0, true) { Ok(p) => p, Err(e) => { // RFC 6347 §4.1.2.7: Invalid records SHOULD be silently discarded. @@ -210,22 +210,23 @@ impl Record { buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 6347 §4.1.2.6: "The receive window is updated only if the - // MAC verification succeeds." - decrypt.replay_update(sequence); - - // The record is now authenticated. Tell the handler so it can act on a - // confirmed-genuine record (e.g. mark the peer past its handshake). - decrypt.note_decrypted_record(content_type); - // Update the length of the record. buffer[11] = (new_len >> 8) as u8; buffer[12] = new_len as u8; - let parsed = ParsedRecord::parse(&buffer, cs, explicit_nonce_len)?; + let parsed = ParsedRecord::parse(&buffer, cs, explicit_nonce_len, false)?; let parsed = Box::new(parsed); + // Decryption and parsing both succeeded. Commit replay state only once + // the record is publishable, so a malformed protected handshake tail + // cannot consume the retransmission slot for a clean record. + decrypt.replay_update(sequence); + + // The record is now authenticated and accepted. Tell the handler so it + // can act on a confirmed-genuine record (e.g. mark the peer past its + // handshake). + decrypt.note_decrypted_record(content_type); + Ok(Some(Record { buffer, parsed })) } @@ -276,14 +277,24 @@ impl ParsedRecord { input: &[u8], cipher_suite: Option, offset: usize, + defer_protected_handshake_parse: bool, ) -> Result { let (_, record) = DTLSRecord::parse(input, 0, offset)?; let handshakes = if record.content_type == ContentType::Handshake { + if record.sequence.epoch != 0 && defer_protected_handshake_parse { + trace!("Deferring protected handshake parsing until after decryption"); + return Ok(ParsedRecord { + record, + handshakes: ArrayVec::new(), + handled: AtomicBool::new(false), + }); + } + // This will also return None on the encrypted Finished after ChangeCipherSpec. // However we will then decrypt and try again. let fragment_offset = record.fragment_range.start; - parse_handshakes(record.fragment(input), fragment_offset, cipher_suite) + parse_handshakes(record.fragment(input), fragment_offset, cipher_suite)? } else { ArrayVec::new() }; @@ -330,22 +341,18 @@ fn parse_handshakes( mut input: &[u8], mut base_offset: usize, cipher_suite: Option, -) -> ArrayVec { +) -> Result, InternalError> { let mut handshakes = ArrayVec::new(); while !input.is_empty() { - if let Ok((remaining, handshake)) = Handshake::parse(input, base_offset, cipher_suite, true) - { - let len = input.len() - remaining.len(); - base_offset += len; - input = remaining; - if handshakes.try_push(handshake).is_err() { - break; - } - } else { - break; + let (remaining, handshake) = Handshake::parse(input, base_offset, cipher_suite, true)?; + let len = input.len() - remaining.len(); + base_offset += len; + input = remaining; + if handshakes.try_push(handshake).is_err() { + return Err(InternalError::too_many_records()); } } - handshakes + Ok(handshakes) } impl fmt::Debug for Incoming { @@ -399,11 +406,16 @@ impl std::panic::UnwindSafe for Incoming {} #[cfg(test)] mod tests { use super::*; + use crate::dtls12::message::MessageType; #[derive(Default)] struct TestHandler { classify_calls: usize, dropped_alerts: usize, + peer_encryption_enabled: bool, + explicit_nonce_len: usize, + min_protected_fragment_len: usize, + replay_updates: usize, } impl RecordHandler for TestHandler { @@ -417,27 +429,29 @@ mod tests { } fn is_peer_encryption_enabled(&self) -> bool { - false + self.peer_encryption_enabled } fn replay_check(&self, _seq: Sequence) -> bool { - panic!("replay_check should not be called for plaintext tests"); + assert!(self.peer_encryption_enabled); + true } fn replay_update(&mut self, _seq: Sequence) { - panic!("replay_update should not be called for plaintext tests"); + self.replay_updates += 1; } fn decryption_aad_and_nonce(&self, _dtls: &DTLSRecord, _buf: &[u8]) -> (Aad, Nonce) { - panic!("decryption_aad_and_nonce should not be called for plaintext tests"); + assert!(self.peer_encryption_enabled); + (Aad(ArrayVec::new()), Nonce([0; 12])) } fn explicit_nonce_len(&self) -> usize { - panic!("explicit_nonce_len should not be called for plaintext tests"); + self.explicit_nonce_len } fn min_protected_fragment_len(&self) -> usize { - panic!("min_protected_fragment_len should not be called for plaintext tests"); + self.min_protected_fragment_len } fn decrypt_data( @@ -446,7 +460,8 @@ mod tests { _aad: Aad, _nonce: Nonce, ) -> Result<(), Error> { - panic!("decrypt_data should not be called for plaintext tests"); + assert!(self.peer_encryption_enabled); + Ok(()) } } @@ -461,6 +476,18 @@ mod tests { out } + fn handshake_fragment(msg_type: MessageType, message_seq: u16, fragment: &[u8]) -> Vec { + let mut out = Vec::new(); + let len = fragment.len() as u32; + out.push(msg_type.as_u8()); + out.extend_from_slice(&len.to_be_bytes()[1..]); + out.extend_from_slice(&message_seq.to_be_bytes()); + out.extend_from_slice(&0u32.to_be_bytes()[1..]); + out.extend_from_slice(&len.to_be_bytes()[1..]); + out.extend_from_slice(fragment); + out + } + #[test] fn parse_packet_filters_control_records_after_packet_validation() { let mut packet = Vec::new(); @@ -486,4 +513,59 @@ mod tests { ); assert_eq!(incoming.first().record().sequence.epoch, 1); } + + #[test] + fn parse_record_accepts_multiple_handshakes() { + let mut fragment = Vec::new(); + fragment.extend_from_slice(&handshake_fragment(MessageType::HelloRequest, 0, &[])); + fragment.extend_from_slice(&handshake_fragment(MessageType::ServerHelloDone, 1, &[])); + + let packet = build_record(ContentType::Handshake, 0, 1, &fragment); + let mut handler = TestHandler::default(); + let incoming = Incoming::parse_packet(&packet, &mut handler, None) + .unwrap() + .expect("handshake record should remain"); + + assert_eq!(incoming.first().handshakes().len(), 2); + } + + #[test] + fn pre_decrypt_protected_handshake_parsing_is_deferred() { + let fragment = handshake_fragment(MessageType::HelloRequest, 0, &[]); + let packet = build_record(ContentType::Handshake, 1, 1, &fragment); + let mut handler = TestHandler::default(); + + let incoming = Incoming::parse_packet(&packet, &mut handler, None) + .unwrap() + .expect("protected record should queue until peer encryption is enabled"); + + assert!( + incoming.first().handshakes().is_empty(), + "ciphertext bytes must not be parsed as plaintext handshakes" + ); + assert_eq!(handler.replay_updates, 0); + } + + #[test] + fn post_decrypt_malformed_handshake_tail_is_not_deferred() { + let mut fragment = handshake_fragment(MessageType::HelloRequest, 0, &[]); + fragment.push(0xff); + let packet = build_record(ContentType::Handshake, 1, 1, &fragment); + let mut handler = TestHandler { + peer_encryption_enabled: true, + explicit_nonce_len: 0, + min_protected_fragment_len: 0, + ..Default::default() + }; + + let err = Incoming::parse_packet( + &packet, + &mut handler, + Some(Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256), + ) + .expect_err("post-decrypt malformed handshake tail must be rejected"); + + assert!(matches!(err, InternalError::Transient(_))); + assert_eq!(handler.replay_updates, 0); + } } diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index d67ccc3..b7c835a 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -2118,7 +2118,16 @@ impl Engine { for record in unhandled { let buf = record.into_buffer(); - self.parse_packet(&buf)?; + match self.parse_packet(&buf) { + Ok(()) => {} + Err(InternalError::Transient(err)) => { + trace!("Discarding buffered protected record after reparse failed: {err}"); + } + Err(err) => { + self.buffers_free.push(buf); + return Err(err); + } + } self.buffers_free.push(buf); } } @@ -2511,6 +2520,29 @@ mod tests { struct PassthroughRecordHandler; + #[derive(Debug)] + struct PassthroughCipher; + + impl Cipher for PassthroughCipher { + fn encrypt( + &mut self, + _plaintext: &mut Buf, + _aad: Aad, + _nonce: Nonce, + ) -> Result<(), crate::CryptoError> { + Ok(()) + } + + fn decrypt( + &mut self, + _ciphertext: &mut TmpBuf, + _aad: Aad, + _nonce: Nonce, + ) -> Result<(), crate::CryptoError> { + Ok(()) + } + } + impl RecordHandler for PassthroughRecordHandler { fn classify_record(&mut self, record: Record) -> Result, Error> { Ok(Some(record)) @@ -2579,6 +2611,25 @@ mod tests { packet } + fn encrypted_malformed_handshake_tail_record(seq: u16) -> Vec { + let mut fragment = Vec::new(); + fragment.push(0xff); + fragment.push(ContentType::Handshake.as_u8()); + fragment.resize(17, 0); + + let mut packet = Vec::new(); + packet.push( + 0b0010_0000 + | 0b0000_1000 // 2-byte sequence number. + | 0b0000_0100 // explicit length. + | 0b0000_0010, // epoch bits. + ); + packet.extend_from_slice(&seq.to_be_bytes()); + packet.extend_from_slice(&(fragment.len() as u16).to_be_bytes()); + packet.extend_from_slice(&fragment); + packet + } + fn parsed_key_update(seq: u16) -> Incoming { Incoming::parse_packet( &encrypted_key_update_record(seq), @@ -2731,6 +2782,33 @@ mod tests { ); } + #[test] + #[cfg(feature = "rcgen")] + fn enable_peer_encryption_discards_malformed_buffered_protected_handshake() { + let mut engine = test_engine(); + engine.set_cipher_suite(Dtls13CipherSuite::AES_128_GCM_SHA256); + let mut sn_key = Buf::new(); + sn_key.extend_from_slice(&[0; 16]); + engine.hs_recv_keys = Some(EpochKeys { + cipher: Box::new(PassthroughCipher), + iv: [0; 12], + traffic_secret: Buf::new(), + sn_key, + }); + + engine + .parse_packet(&encrypted_malformed_handshake_tail_record(0)) + .expect("pre-encryption protected record should queue"); + assert_eq!(engine.queue_rx.len(), 1); + + engine + .enable_peer_encryption() + .expect("malformed queued protected record should be discarded"); + + assert!(engine.peer_encryption_enabled); + assert!(engine.queue_rx.is_empty()); + } + #[test] #[cfg(feature = "rcgen")] fn malformed_ack_record_number_vector_is_ignored() { diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index f617a41..3647894 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -284,11 +284,6 @@ impl Record { buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 9147 §4.5.1: "The window MUST NOT be updated due to a received - // record until that record has been deprotected successfully." - decrypt.replay_update(full_sequence); - // Recover inner content type from DTLSInnerPlaintext let decrypted = &buffer[header_end..header_end + new_len]; let (inner_content_type, content_len) = match recover_inner_content_type(decrypted) { @@ -308,9 +303,14 @@ impl Record { }, &buffer, cs, - ); + )?; let parsed = Box::new(parsed); + // Decryption and parsing both succeeded. Commit replay state only once + // the record is publishable, so a malformed protected handshake tail + // cannot consume the retransmission slot for a clean record. + decrypt.replay_update(full_sequence); + Ok(Some(Record { buffer, parsed })) } @@ -365,7 +365,7 @@ impl ParsedRecord { let handshakes = if record.content_type == ContentType::Handshake { let fragment_offset = record.fragment_range.start; - parse_handshakes(record.fragment(input), fragment_offset, cipher_suite) + parse_handshakes(record.fragment(input), fragment_offset, cipher_suite)? } else { ArrayVec::new() }; @@ -382,19 +382,19 @@ impl ParsedRecord { record: Dtls13Record, input: &[u8], cipher_suite: Option, - ) -> ParsedRecord { + ) -> Result { let handshakes = if record.content_type == ContentType::Handshake { let fragment_offset = record.fragment_range.start; - parse_handshakes(record.fragment(input), fragment_offset, cipher_suite) + parse_handshakes(record.fragment(input), fragment_offset, cipher_suite)? } else { ArrayVec::new() }; - ParsedRecord { + Ok(ParsedRecord { record, handshakes, handled: AtomicBool::new(false), - } + }) } } @@ -442,22 +442,18 @@ fn parse_handshakes( mut input: &[u8], mut base_offset: usize, cipher_suite: Option, -) -> ArrayVec { +) -> Result, InternalError> { let mut handshakes = ArrayVec::new(); while !input.is_empty() { - if let Ok((remaining, handshake)) = Handshake::parse(input, base_offset, cipher_suite, true) - { - let len = input.len() - remaining.len(); - base_offset += len; - input = remaining; - if handshakes.try_push(handshake).is_err() { - break; - } - } else { - break; + let (remaining, handshake) = Handshake::parse(input, base_offset, cipher_suite, true)?; + let len = input.len() - remaining.len(); + base_offset += len; + input = remaining; + if handshakes.try_push(handshake).is_err() { + return Err(InternalError::too_many_records()); } } - handshakes + Ok(handshakes) } /// Recover the inner content type from a decrypted DTLSInnerPlaintext. @@ -531,11 +527,15 @@ impl std::panic::UnwindSafe for Incoming {} #[cfg(test)] mod tests { use super::*; + use crate::dtls13::message::MessageType; #[derive(Default)] struct TestHandler { classify_calls: usize, dropped_acks: usize, + peer_encryption_enabled: bool, + min_protected_fragment_len: usize, + replay_updates: usize, } impl RecordHandler for TestHandler { @@ -549,29 +549,30 @@ mod tests { } fn is_peer_encryption_enabled(&self) -> bool { - false + self.peer_encryption_enabled } - fn resolve_epoch(&self, _epoch_bits: u8) -> u16 { - panic!("resolve_epoch should not be called when peer encryption is disabled"); + fn resolve_epoch(&self, epoch_bits: u8) -> u16 { + assert!(self.peer_encryption_enabled); + epoch_bits as u16 } - fn resolve_sequence(&self, _epoch: u16, _seq_bits: u64, _s_flag: bool) -> u64 { - panic!("resolve_sequence should not be called when peer encryption is disabled"); + fn resolve_sequence(&self, _epoch: u16, seq_bits: u64, _s_flag: bool) -> u64 { + assert!(self.peer_encryption_enabled); + seq_bits } fn replay_check(&self, _seq: Sequence) -> bool { - panic!("replay_check should not be called when peer encryption is disabled"); + assert!(self.peer_encryption_enabled); + true } fn replay_update(&mut self, _seq: Sequence) { - panic!("replay_update should not be called when peer encryption is disabled"); + self.replay_updates += 1; } fn min_protected_fragment_len(&self) -> usize { - panic!( - "min_protected_fragment_len should not be called when peer encryption is disabled" - ); + self.min_protected_fragment_len } fn decrypt_record( @@ -580,7 +581,8 @@ mod tests { _seq: Sequence, _ciphertext: &mut TmpBuf, ) -> Result<(), Error> { - panic!("decrypt_record should not be called when peer encryption is disabled"); + assert!(self.peer_encryption_enabled); + Ok(()) } fn decrypt_sequence_number( @@ -589,7 +591,7 @@ mod tests { _seq_bytes: &mut [u8], _ciphertext_sample: &[u8; 16], ) { - panic!("decrypt_sequence_number should not be called when peer encryption is disabled"); + assert!(self.peer_encryption_enabled); } } @@ -604,6 +606,18 @@ mod tests { out } + fn handshake_fragment(msg_type: MessageType, message_seq: u16, fragment: &[u8]) -> Vec { + let mut out = Vec::new(); + let len = fragment.len() as u32; + out.push(msg_type.as_u8()); + out.extend_from_slice(&len.to_be_bytes()[1..]); + out.extend_from_slice(&message_seq.to_be_bytes()); + out.extend_from_slice(&0u32.to_be_bytes()[1..]); + out.extend_from_slice(&len.to_be_bytes()[1..]); + out.extend_from_slice(fragment); + out + } + fn build_ciphertext_record(epoch: u16, seq: u16, fragment: &[u8]) -> Vec { let mut out = Vec::new(); let flags = 0b0010_0000 | 0b0000_1000 | 0b0000_0100 | (epoch as u8 & 0x03); @@ -634,4 +648,46 @@ mod tests { ); assert_eq!(incoming.first().record().sequence.epoch, 2); } + + #[test] + fn parse_record_accepts_multiple_handshakes() { + let mut fragment = Vec::new(); + fragment.extend_from_slice(&handshake_fragment(MessageType::CertificateRequest, 0, &[])); + fragment.extend_from_slice(&handshake_fragment( + MessageType::EncryptedExtensions, + 1, + &[], + )); + + let packet = build_plaintext_record(ContentType::Handshake, 1, &fragment); + let mut handler = TestHandler::default(); + let incoming = Incoming::parse_packet(&packet, &mut handler, None) + .unwrap() + .expect("handshake record should remain"); + + assert_eq!(incoming.first().handshakes().len(), 2); + } + + #[test] + fn post_decrypt_malformed_handshake_tail_is_rejected() { + let mut fragment = handshake_fragment(MessageType::CertificateRequest, 0, &[]); + fragment.push(0xff); + fragment.push(ContentType::Handshake.as_u8()); + let packet = build_ciphertext_record(2, 1, &fragment); + let mut handler = TestHandler { + peer_encryption_enabled: true, + min_protected_fragment_len: 0, + ..Default::default() + }; + + let err = Incoming::parse_packet( + &packet, + &mut handler, + Some(Dtls13CipherSuite::AES_128_GCM_SHA256), + ) + .expect_err("post-decrypt malformed handshake tail must be rejected"); + + assert!(matches!(err, InternalError::Transient(_))); + assert_eq!(handler.replay_updates, 0); + } } diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 6fd95d3..3a87653 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -62,6 +62,56 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize { } } +fn append_truncated_handshake_tail(mut packet: Vec) -> Vec { + let len = u16::from_be_bytes([packet[11], packet[12]]); + let len = len.checked_add(1).expect("record length fits"); + packet[11..13].copy_from_slice(&len.to_be_bytes()); + packet.push(0xff); + packet +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_same_record_malformed_handshake_tail_is_discarded_atomically() { + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = dtls12_config(); + let now = Instant::now(); + + let mut client = Dtls::new_12(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_12(config, server_cert, now); + server.set_active(false); + + client.handle_timeout(now).expect("client timeout"); + let client_hello = collect_packets(&mut client); + assert!(!client_hello.is_empty(), "client should emit ClientHello"); + + let malformed = append_truncated_handshake_tail(client_hello[0].clone()); + server + .handle_packet(&malformed) + .expect("malformed same-record tail should be discarded"); + server.handle_timeout(now).expect("server timeout"); + assert!( + collect_packets(&mut server).is_empty(), + "server must not process the valid ClientHello prefix" + ); + + for packet in &client_hello { + server + .handle_packet(packet) + .expect("clean retransmitted ClientHello should recover"); + } + server.handle_timeout(now).expect("server timeout"); + assert!( + !collect_packets(&mut server).is_empty(), + "server should respond to clean retransmission" + ); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_malformed_datagram_is_discarded_without_processing_alerts() { diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 33e9442..becc88b 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -50,6 +50,56 @@ fn dtls13_ack_record_for_records(seq: u64, records: &[(u64, u64)]) -> Vec { out } +fn append_truncated_handshake_tail(mut packet: Vec) -> Vec { + let len = u16::from_be_bytes([packet[11], packet[12]]); + let len = len.checked_add(1).expect("record length fits"); + packet[11..13].copy_from_slice(&len.to_be_bytes()); + packet.push(0xff); + packet +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_same_record_malformed_handshake_tail_is_discarded_atomically() { + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = dtls13_config(); + let now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + client.handle_timeout(now).expect("client timeout"); + let client_hello = collect_packets(&mut client); + assert!(!client_hello.is_empty(), "client should emit ClientHello"); + + let malformed = append_truncated_handshake_tail(client_hello[0].clone()); + server + .handle_packet(&malformed) + .expect("malformed same-record tail should be discarded"); + server.handle_timeout(now).expect("server timeout"); + assert!( + collect_packets(&mut server).is_empty(), + "server must not process the valid ClientHello prefix" + ); + + for packet in &client_hello { + server + .handle_packet(packet) + .expect("clean retransmitted ClientHello should recover"); + } + server.handle_timeout(now).expect("server timeout"); + assert!( + !collect_packets(&mut server).is_empty(), + "server should respond to clean retransmission" + ); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_malformed_datagram_is_discarded_without_processing_alerts() {