From 1b56b2f561f539a816707b4f52ada5f6dd0a3eb1 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Sun, 24 May 2026 08:40:05 +0300 Subject: [PATCH 1/2] dtls: defer replay commits until datagram parse succeeds --- CHANGELOG.md | 1 + src/dtls12/incoming.rs | 88 +++++++++++++++++----- src/dtls13/engine.rs | 11 ++- src/dtls13/incoming.rs | 164 +++++++++++++++++++++++++++++++++++------ tests/dtls12/edge.rs | 105 ++++++++++++++++++++++++++ tests/dtls13/edge.rs | 105 ++++++++++++++++++++++++++ 6 files changed, 434 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 261bfb8a..5fc3ec96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Fix malformed datagrams consuming DTLS replay-window state #121 * Replace pending DTLS 1.2 handshake output on resend #116 * Discard bad protected DTLS 1.2 records after handshake #115 * Reject oversized DTLS certificate lists #113 diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index 2c41ebbd..373c424b 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -8,6 +8,7 @@ use crate::Error; use crate::buffer::{Buf, TmpBuf}; use crate::crypto::{Aad, Nonce}; use crate::dtls12::message::{ContentType, DTLSRecord, Dtls12CipherSuite, Handshake, Sequence}; +use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. pub struct Incoming { @@ -67,12 +68,14 @@ pub struct Records { } impl Records { - pub fn parse( + fn parse( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut replay_updates: ArrayVec = ArrayVec::new(); + let mut pending_replay = ReplayWindow::new(); // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { @@ -91,14 +94,29 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; match Record::parse(record_slice, decrypt, cs) { - Ok(record) => { - if let Some(record) = record { + Ok(parsed) => { + if let Some(sequence) = parsed.replay_sequence { + if !pending_replay.check(sequence.sequence_number) { + trace!("Discarding duplicate rec in same datagram"); + packet = &packet[record_end..]; + continue; + } + } + + if let Some(record) = parsed.record { if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } - } else { + } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } + + if let Some(sequence) = parsed.replay_sequence { + pending_replay.update(sequence.sequence_number); + if replay_updates.try_push(sequence).is_err() { + return Err(Error::TooManyRecords); + } + } } Err(e) => return Err(e), } @@ -106,6 +124,13 @@ impl Records { packet = &packet[record_end..]; } + // Commit replay state only after the whole UDP datagram has parsed + // successfully. A malformed trailing record must not consume + // replay state for an earlier authenticated record in the same datagram. + for sequence in replay_updates { + decrypt.replay_update(sequence); + } + let mut records = ArrayVec::new(); for record in parsed_records { if let Some(record) = decrypt.classify_record(record)? { @@ -134,14 +159,19 @@ pub struct Record { parsed: Box, } +struct RecordParse { + record: Option, + replay_sequence: Option, +} + impl Record { /// The first parse pass only parses the DTLSRecord header which is unencrypted. /// Copies record data from UDP packet ONCE into a pooled buffer. - pub fn parse( + fn parse( record_slice: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + ) -> Result { // ONLY COPY: UDP packet slice -> pooled buffer let mut buffer = Buf::new(); buffer.extend_from_slice(record_slice); @@ -151,7 +181,10 @@ impl Record { // RFC 6347 §4.1.2.7: Invalid records SHOULD be silently discarded. // This includes epoch 0 records with invalid ContentType. trace!("Discarding record: parse failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } }; let parsed = Box::new(parsed); @@ -162,7 +195,10 @@ impl Record { // packet loss, we can end up seeing epoch 1 records before we can decrypt them. let is_epoch_0 = record.record().sequence.epoch == 0; if is_epoch_0 || !decrypt.is_peer_encryption_enabled() { - return Ok(Some(record)); + return Ok(RecordParse { + record: Some(record), + replay_sequence: None, + }); } // We need to decrypt the record and redo the parsing. @@ -171,12 +207,18 @@ impl Record { // Anti-replay check (read-only, does not update window) if !decrypt.replay_check(sequence) { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } let explicit_nonce_len = decrypt.explicit_nonce_len(); if (dtls.length as usize) < decrypt.min_protected_fragment_len() { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } // Get a reference to the buffer @@ -203,25 +245,35 @@ impl Record { } trace!("Discarding record: decrypt failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 6347 §4.1.2.6: "The receive window is updated only if the - // MAC verification succeeds." - decrypt.replay_update(sequence); - // Update the length of the record. buffer[11] = (new_len >> 8) as u8; buffer[12] = new_len as u8; - let parsed = ParsedRecord::parse(&buffer, cs, explicit_nonce_len)?; + let parsed = match ParsedRecord::parse(&buffer, cs, explicit_nonce_len) { + Ok(parsed) => parsed, + Err(e) => { + trace!("Discarding authenticated record: parse failed: {}", e); + return Ok(RecordParse { + record: None, + replay_sequence: Some(sequence), + }); + } + }; let parsed = Box::new(parsed); - Ok(Some(Record { buffer, parsed })) + Ok(RecordParse { + record: Some(Record { buffer, parsed }), + replay_sequence: Some(sequence), + }) } pub fn record(&self) -> &DTLSRecord { diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 984b06b5..67b44422 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -2338,7 +2338,13 @@ impl RecordHandler for Engine { epoch_bits } - fn resolve_sequence(&self, epoch: u16, seq_bits: u64, s_flag: bool) -> u64 { + fn resolve_sequence( + &self, + epoch: u16, + seq_bits: u64, + s_flag: bool, + expected_override: Option, + ) -> u64 { let expected = if epoch == 2 { self.hs_expected_recv_seq } else { @@ -2348,6 +2354,9 @@ impl RecordHandler for Engine { .map(|e| e.expected_recv_seq) .unwrap_or(0) }; + let expected = expected_override + .map(|override_expected| expected.max(override_expected)) + .unwrap_or(expected); let bits: u32 = if s_flag { 16 } else { 8 }; reconstruct_sequence(seq_bits, expected, bits) diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index 240cc830..1e83ecd1 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -7,6 +7,7 @@ use std::fmt; use crate::Error; use crate::buffer::{Buf, TmpBuf}; use crate::dtls13::message::{ContentType, Dtls13CipherSuite, Dtls13Record, Handshake, Sequence}; +use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. pub struct Incoming { @@ -72,6 +73,9 @@ impl Records { cs: Option, ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut replay_updates: ArrayVec = ArrayVec::new(); + let mut pending_replay: ArrayVec<(u16, ReplayWindow), 16> = ArrayVec::new(); + let mut pending_expected: ArrayVec<(u16, u64), 16> = ArrayVec::new(); // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { @@ -129,15 +133,31 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; - match Record::parse(record_slice, decrypt, cs) { - Ok(record) => { - if let Some(record) = record { + match Record::parse(record_slice, decrypt, cs, &pending_expected) { + Ok(parsed) => { + if let Some(sequence) = parsed.replay_sequence { + if !pending_replay_check(&pending_replay, sequence) { + trace!("Discarding duplicate rec in same datagram"); + packet = &packet[record_end..]; + continue; + } + } + + if let Some(record) = parsed.record { if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } - } else { + } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } + + if let Some(sequence) = parsed.replay_sequence { + pending_replay_update(&mut pending_replay, sequence)?; + pending_expected_update(&mut pending_expected, sequence)?; + if replay_updates.try_push(sequence).is_err() { + return Err(Error::TooManyRecords); + } + } } Err(e) => return Err(e), } @@ -145,6 +165,13 @@ impl Records { packet = &packet[record_end..]; } + // Commit replay state only after the whole UDP datagram has parsed + // successfully. A malformed trailing record must not consume + // replay state for an earlier authenticated record in the same datagram. + for sequence in replay_updates { + decrypt.replay_update(sequence); + } + let mut records = ArrayVec::new(); for record in parsed_records { if let Some(record) = decrypt.classify_record(record)? { @@ -158,6 +185,62 @@ impl Records { } } +fn pending_replay_check(pending_replay: &ArrayVec<(u16, ReplayWindow), 16>, seq: Sequence) -> bool { + match pending_replay.iter().find(|(epoch, _)| *epoch == seq.epoch) { + Some((_, window)) => window.check(seq.sequence_number), + None => true, + } +} + +fn pending_replay_update( + pending_replay: &mut ArrayVec<(u16, ReplayWindow), 16>, + seq: Sequence, +) -> Result<(), Error> { + if let Some((_, window)) = pending_replay + .iter_mut() + .find(|(epoch, _)| *epoch == seq.epoch) + { + window.update(seq.sequence_number); + return Ok(()); + } + + let mut window = ReplayWindow::new(); + window.update(seq.sequence_number); + pending_replay + .try_push((seq.epoch, window)) + .map_err(|_| Error::TooManyRecords) +} + +fn pending_expected_override( + pending_expected: &ArrayVec<(u16, u64), 16>, + epoch: u16, +) -> Option { + pending_expected + .iter() + .find(|(candidate_epoch, _)| *candidate_epoch == epoch) + .map(|(_, expected)| *expected) +} + +fn pending_expected_update( + pending_expected: &mut ArrayVec<(u16, u64), 16>, + seq: Sequence, +) -> Result<(), Error> { + let next = seq.sequence_number + 1; + if let Some((_, expected)) = pending_expected + .iter_mut() + .find(|(epoch, _)| *epoch == seq.epoch) + { + if next > *expected { + *expected = next; + } + return Ok(()); + } + + pending_expected + .try_push((seq.epoch, next)) + .map_err(|_| Error::TooManyRecords) +} + impl Deref for Records { type Target = [Record]; @@ -173,14 +256,20 @@ pub struct Record { parsed: Box, } +struct RecordParse { + record: Option, + replay_sequence: Option, +} + impl Record { /// The first parse pass only parses the record header which is unencrypted. /// Copies record data from UDP packet ONCE into a pooled buffer. - pub fn parse( + fn parse( record_slice: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + pending_expected: &ArrayVec<(u16, u64), 16>, + ) -> Result { // ONLY COPY: UDP packet slice -> pooled buffer let mut buffer = Buf::new(); buffer.extend_from_slice(record_slice); @@ -218,7 +307,10 @@ impl Record { Ok(p) => p, Err(e) => { trace!("Discarding record: parse failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } }; let parsed = Box::new(parsed); @@ -226,7 +318,10 @@ impl Record { // Plaintext records (epoch 0) are not encrypted if !is_ciphertext || !decrypt.is_peer_encryption_enabled() { - return Ok(Some(record)); + return Ok(RecordParse { + record: Some(record), + replay_sequence: None, + }); } // Resolve the full epoch from the 2-bit value in the unified header @@ -236,7 +331,12 @@ impl Record { // Resolve the full sequence number from the (now decrypted) partial value let seq_bits = record.record().sequence.sequence_number; let s_flag = record_slice[0] & 0b0000_1000 != 0; - let full_seq = decrypt.resolve_sequence(full_epoch, seq_bits, s_flag); + let full_seq = decrypt.resolve_sequence( + full_epoch, + seq_bits, + s_flag, + pending_expected_override(pending_expected, full_epoch), + ); let full_sequence = Sequence { epoch: full_epoch, @@ -245,7 +345,10 @@ impl Record { // Anti-replay check (read-only, does not update window) if !decrypt.replay_check(full_sequence) { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } // Save the raw header bytes for AAD before mutating the buffer. @@ -257,7 +360,10 @@ impl Record { // so decryption would necessarily fail. Catching it here keeps the // cipher impls' bounds-checking from being the only line of defence. if record.buffer.len() - header_end < decrypt.min_protected_fragment_len() { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } let mut header_buf = [0u8; 5]; header_buf[..header_end].copy_from_slice(&record.buffer[..header_end]); @@ -277,25 +383,26 @@ impl Record { Ok(()) => {} Err(e) => { trace!("Discarding ciphertext record: decryption failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } } buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 9147 §4.5.1: "The window MUST NOT be updated due to a received - // record until that record has been deprotected successfully." - decrypt.replay_update(full_sequence); - // Recover inner content type from DTLSInnerPlaintext let decrypted = &buffer[header_end..header_end + new_len]; let (inner_content_type, content_len) = match recover_inner_content_type(decrypted) { Ok(v) => v, Err(e) => { trace!("Discarding record: invalid inner content type: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: Some(full_sequence), + }); } }; @@ -311,7 +418,10 @@ impl Record { ); let parsed = Box::new(parsed); - Ok(Some(Record { buffer, parsed })) + Ok(RecordParse { + record: Some(Record { buffer, parsed }), + replay_sequence: Some(full_sequence), + }) } pub fn record(&self) -> &Dtls13Record { @@ -407,7 +517,13 @@ pub trait RecordHandler { fn classify_record(&mut self, record: Record) -> Result, Error>; fn is_peer_encryption_enabled(&self) -> bool; fn resolve_epoch(&self, epoch_bits: u8) -> u16; - fn resolve_sequence(&self, epoch: u16, seq_bits: u64, s_flag: bool) -> u64; + fn resolve_sequence( + &self, + epoch: u16, + seq_bits: u64, + s_flag: bool, + expected_override: Option, + ) -> u64; fn replay_check(&self, seq: Sequence) -> bool; fn replay_update(&mut self, seq: Sequence); @@ -556,7 +672,13 @@ mod tests { panic!("resolve_epoch should not be called when peer encryption is disabled"); } - fn resolve_sequence(&self, _epoch: u16, _seq_bits: u64, _s_flag: bool) -> u64 { + fn resolve_sequence( + &self, + _epoch: u16, + _seq_bits: u64, + _s_flag: bool, + _expected_override: Option, + ) -> u64 { panic!("resolve_sequence should not be called when peer encryption is disabled"); } diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 7640007c..6dcd02b5 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -680,6 +680,78 @@ fn dtls12_bad_encrypted_prefix_does_not_drop_valid_tail() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_malformed_trailing_record_does_not_consume_replay_window() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + client + .send_application_data(b"replay-atomic") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut malformed_packet = valid_packet.clone(); + malformed_packet.push(0xff); + + let err = server + .handle_packet(&malformed_packet) + .expect_err("trailing partial record must reject the datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let server_out = drain_outputs(&mut server); + assert!( + server_out.app_data.is_empty(), + "malformed datagram must not deliver application data" + ); + + server + .handle_packet(&valid_packet) + .expect("valid packet must still pass replay checks"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"replay-atomic".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_same_datagram_duplicate_encrypted_record_delivers_once() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + client + .send_application_data(b"duplicate-once") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut duplicate_datagram = valid_packet.clone(); + duplicate_datagram.extend_from_slice(&valid_packet); + + server + .handle_packet(&duplicate_datagram) + .expect("duplicate datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"duplicate-once".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_relabelled_encrypted_handshake_failure_is_not_silently_discarded() { @@ -767,6 +839,39 @@ fn dtls12_relabelled_encrypted_handshake_failure_is_not_silently_discarded() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_same_datagram_window_shift_drops_now_too_old_record() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + let mut packets = Vec::new(); + for i in 0..66 { + client + .send_application_data(format!("msg-{i}").as_bytes()) + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let out = drain_outputs(&mut client); + let packet = out + .packets + .first() + .expect("application data packet") + .clone(); + packets.push(packet); + } + + let mut shifted_datagram = packets[65].clone(); + shifted_datagram.extend_from_slice(&packets[0]); + + server + .handle_packet(&shifted_datagram) + .expect("window-shift datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"msg-65".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_app_data_after_close_notify_is_ignored() { diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 1813b8ca..d7c36d90 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -1267,6 +1267,111 @@ fn dtls13_mixed_datagram_valid_first_then_bogus() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_malformed_trailing_record_does_not_consume_replay_window() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + client + .send_application_data(b"replay-atomic") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut malformed_packet = valid_packet.clone(); + malformed_packet.push(0xff); + + let err = server + .handle_packet(&malformed_packet) + .expect_err("trailing partial record must reject the datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let server_out = drain_outputs(&mut server); + assert!( + server_out.app_data.is_empty(), + "malformed datagram must not deliver application data" + ); + + server + .handle_packet(&valid_packet) + .expect("valid packet must still pass replay checks"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"replay-atomic".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_same_datagram_duplicate_encrypted_record_delivers_once() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + client + .send_application_data(b"duplicate-once") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut duplicate_datagram = valid_packet.clone(); + duplicate_datagram.extend_from_slice(&valid_packet); + + server + .handle_packet(&duplicate_datagram) + .expect("duplicate datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"duplicate-once".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_same_datagram_window_shift_drops_now_too_old_record() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + let mut packets = Vec::new(); + for i in 0..66 { + client + .send_application_data(format!("msg-{i}").as_bytes()) + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let out = drain_outputs(&mut client); + let packet = out + .packets + .first() + .expect("application data packet") + .clone(); + packets.push(packet); + } + + let mut shifted_datagram = packets[65].clone(); + shifted_datagram.extend_from_slice(&packets[0]); + + server + .handle_packet(&shifted_datagram) + .expect("window-shift datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"msg-65".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_half_close_send_then_close() { From a2ad7ad519df65cf9fa3efca706d0d100f0217d2 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Sun, 24 May 2026 08:43:48 +0300 Subject: [PATCH 2/2] dtls13: preserve app data after same-datagram KeyUpdate --- CHANGELOG.md | 1 + src/dtls13/client.rs | 3 + src/dtls13/engine.rs | 29 +++- src/dtls13/incoming.rs | 211 ++++++++++++++++-------- src/dtls13/server.rs | 3 + tests/dtls13/key_update.rs | 321 +++++++++++++++++++++++++++++++++++++ 6 files changed, 501 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc3ec96..f8743d0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Preserve DTLS 1.3 app data following KeyUpdate in one datagram #122 * Fix malformed datagrams consuming DTLS replay-window state #121 * Replace pending DTLS 1.2 handshake output on resend #116 * Discard bad protected DTLS 1.2 records after handshake #115 diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index b18408da..3bfb88a1 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -220,6 +220,9 @@ impl Client { pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { self.engine.parse_packet(packet)?; self.make_progress()?; + while self.engine.parse_next_deferred_packet()? { + self.make_progress()?; + } Ok(()) } diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 67b44422..3c904a95 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -58,6 +58,10 @@ pub struct Engine { /// Queue of outgoing packets. queue_tx: QueueTx, + /// Deferred tail datagrams split after KeyUpdate records. They are parsed + /// after the state machine installs the next receive keys. + deferred_packets: ArrayVec, + /// The cipher suite in use. Set by ServerHello. cipher_suite: Option, @@ -221,6 +225,7 @@ impl Engine { sequence_epoch_0: Sequence::new(0), queue_rx: QueueRx::new(), queue_tx: QueueTx::new(), + deferred_packets: ArrayVec::new(), cipher_suite: None, hs_send_keys: None, hs_recv_keys: None, @@ -330,15 +335,35 @@ impl Engine { } pub fn parse_packet(&mut self, packet: &[u8]) -> Result<(), Error> { + self.parse_packet_once(packet) + } + + fn parse_packet_once(&mut self, packet: &[u8]) -> Result<(), Error> { let cs = self.cipher_suite; - let incoming = Incoming::parse_packet(packet, self, cs)?; - if let Some(incoming) = incoming { + let parsed = Incoming::parse_packet_defer_after_key_update(packet, self, cs)?; + if let Some(incoming) = parsed.incoming { self.insert_incoming(incoming)?; } + if let Some(deferred_tail) = parsed.deferred_tail { + self.deferred_packets + .try_push(deferred_tail) + .map_err(|_| Error::TooManyRecords)?; + } Ok(()) } + pub fn parse_next_deferred_packet(&mut self) -> Result { + if self.deferred_packets.is_empty() { + return Ok(false); + } + + let packet = self.deferred_packets.remove(0); + self.parse_packet_once(&packet)?; + + Ok(true) + } + fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { if self.queue_rx.len() >= self.config.max_queue_rx() { warn!( diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index 1e83ecd1..604e2d72 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -6,7 +6,9 @@ use std::fmt; use crate::Error; use crate::buffer::{Buf, TmpBuf}; -use crate::dtls13::message::{ContentType, Dtls13CipherSuite, Dtls13Record, Handshake, Sequence}; +use crate::dtls13::message::Sequence; +use crate::dtls13::message::{ContentType, Dtls13CipherSuite}; +use crate::dtls13::message::{Dtls13Record, Handshake, MessageType}; use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. @@ -16,6 +18,11 @@ pub struct Incoming { records: Box, } +pub(crate) struct IncomingParse { + pub incoming: Option, + pub deferred_tail: Option, +} + impl Incoming { pub fn records(&self) -> &Records { &self.records @@ -33,30 +40,38 @@ impl Incoming { } impl Incoming { - /// Parse an incoming UDP packet - /// - /// * `packet` is the data from the UDP socket. - /// * `decrypt` provides the decryption operations for encrypted records. - /// * `cs` is the negotiated cipher suite, if any. - /// - /// Will surface parser errors. - pub fn parse_packet( + pub(crate) fn parse_packet_defer_after_key_update( + packet: &[u8], + decrypt: &mut dyn RecordHandler, + cs: Option, + ) -> Result { + Self::parse_packet_inner(packet, decrypt, cs, true) + } + + fn parse_packet_inner( packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + defer_after_key_update: bool, + ) -> Result { // Parse records directly from packet, copying each record ONCE into its own buffer - let records = Records::parse(packet, decrypt, cs)?; + let parse = Records::parse_inner(packet, decrypt, cs, defer_after_key_update)?; // We need at least one Record to be valid. For replayed frames, we discard // the records, hence this might be None - if records.records.is_empty() { - return Ok(None); + if parse.records.records.is_empty() { + return Ok(IncomingParse { + incoming: None, + deferred_tail: parse.deferred_tail, + }); } - let records = Box::new(records); + let records = Box::new(parse.records); - Ok(Some(Incoming { records })) + Ok(IncomingParse { + incoming: Some(Incoming { records }), + deferred_tail: parse.deferred_tail, + }) } } @@ -66,65 +81,28 @@ pub struct Records { pub records: ArrayVec, } +struct RecordsParse { + records: Records, + deferred_tail: Option, +} + impl Records { - pub fn parse( + fn parse_inner( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result { + defer_after_key_update: bool, + ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); let mut replay_updates: ArrayVec = ArrayVec::new(); let mut pending_replay: ArrayVec<(u16, ReplayWindow), 16> = ArrayVec::new(); let mut pending_expected: ArrayVec<(u16, u64), 16> = ArrayVec::new(); + let mut deferred_tail = None; // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { - let record_end = if Dtls13Record::is_ciphertext_header(packet[0]) { - // CID bit set means we can't determine record boundaries (unsupported). - // Discard the rest of the datagram. - if packet[0] & 0x10 != 0 { - break; - } - - // Unified header: variable length - if packet.len() < 2 { - return Err(Error::ParseIncomplete); - } - - let flags = packet[0]; - let s_flag = flags & 0b0000_1000 != 0; - let l_flag = flags & 0b0000_0100 != 0; - let seq_len = if s_flag { 2 } else { 1 }; - let len_len = if l_flag { 2 } else { 0 }; - let header_len = 1 + seq_len + len_len; - - if packet.len() < header_len { - return Err(Error::ParseIncomplete); - } - - if l_flag { - let len_offset = 1 + seq_len; - // unwrap: header_len check above ensures 2 bytes at len_offset - let length_bytes: [u8; 2] = - packet[len_offset..len_offset + 2].try_into().unwrap(); - let length = u16::from_be_bytes(length_bytes) as usize; - header_len + length - } else { - // No length field: record consumes the rest of the datagram - packet.len() - } - } else { - // Plaintext: fixed 13-byte header - if packet.len() < Dtls13Record::PLAINTEXT_HEADER_LEN { - return Err(Error::ParseIncomplete); - } - - // unwrap: PLAINTEXT_HEADER_LEN check above ensures 2 bytes at offset - let length_bytes: [u8; 2] = packet[Dtls13Record::PLAINTEXT_LENGTH_OFFSET] - .try_into() - .unwrap(); - let length = u16::from_be_bytes(length_bytes) as usize; - Dtls13Record::PLAINTEXT_HEADER_LEN + length + let Some(record_end) = record_end(packet)? else { + break; }; if packet.len() < record_end { @@ -133,8 +111,11 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; + let tail = &packet[record_end..]; match Record::parse(record_slice, decrypt, cs, &pending_expected) { Ok(parsed) => { + let mut should_break_after_replay_update = false; + if let Some(sequence) = parsed.replay_sequence { if !pending_replay_check(&pending_replay, sequence) { trace!("Discarding duplicate rec in same datagram"); @@ -144,9 +125,21 @@ impl Records { } if let Some(record) = parsed.record { + let should_defer_tail = defer_after_key_update + && !tail.is_empty() + && record.contains_complete_key_update(); + if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } + + should_break_after_replay_update = should_defer_tail; + if should_defer_tail { + validate_record_boundaries(tail, parsed_records.len())?; + let mut tail_buffer = Buf::new(); + tail_buffer.extend_from_slice(tail); + deferred_tail = Some(tail_buffer); + } } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } @@ -158,6 +151,10 @@ impl Records { return Err(Error::TooManyRecords); } } + + if should_break_after_replay_update { + break; + } } Err(e) => return Err(e), } @@ -181,10 +178,84 @@ impl Records { } } - Ok(Records { records }) + Ok(RecordsParse { + records: Records { records }, + deferred_tail, + }) + } +} + +fn record_end(packet: &[u8]) -> Result, Error> { + if Dtls13Record::is_ciphertext_header(packet[0]) { + // CID bit set means we can't determine record boundaries (unsupported). + // Discard the rest of the datagram. + if packet[0] & 0x10 != 0 { + return Ok(None); + } + + // Unified header: variable length + if packet.len() < 2 { + return Err(Error::ParseIncomplete); + } + + let flags = packet[0]; + let s_flag = flags & 0b0000_1000 != 0; + let l_flag = flags & 0b0000_0100 != 0; + let seq_len = if s_flag { 2 } else { 1 }; + let len_len = if l_flag { 2 } else { 0 }; + let header_len = 1 + seq_len + len_len; + + if packet.len() < header_len { + return Err(Error::ParseIncomplete); + } + + if l_flag { + let len_offset = 1 + seq_len; + // unwrap: header_len check above ensures 2 bytes at len_offset. + let length_bytes: [u8; 2] = packet[len_offset..len_offset + 2].try_into().unwrap(); + let length = u16::from_be_bytes(length_bytes) as usize; + Ok(Some(header_len + length)) + } else { + Ok(Some(packet.len())) + } + } else { + if packet.len() < Dtls13Record::PLAINTEXT_HEADER_LEN { + return Err(Error::ParseIncomplete); + } + + // unwrap: PLAINTEXT_HEADER_LEN check above ensures 2 bytes at offset. + let length_bytes: [u8; 2] = packet[Dtls13Record::PLAINTEXT_LENGTH_OFFSET] + .try_into() + .unwrap(); + let length = u16::from_be_bytes(length_bytes) as usize; + Ok(Some(Dtls13Record::PLAINTEXT_HEADER_LEN + length)) } } +fn validate_record_boundaries( + mut packet: &[u8], + mut parsed_record_count: usize, +) -> Result<(), Error> { + while !packet.is_empty() { + let Some(end) = record_end(packet)? else { + return Ok(()); + }; + + if parsed_record_count >= 16 { + return Err(Error::TooManyRecords); + } + parsed_record_count += 1; + + if packet.len() < end { + return Err(Error::ParseIncomplete); + } + + packet = &packet[end..]; + } + + Ok(()) +} + fn pending_replay_check(pending_replay: &ArrayVec<(u16, ReplayWindow), 16>, seq: Sequence) -> bool { match pending_replay.iter().find(|(epoch, _)| *epoch == seq.epoch) { Some((_, window)) => window.check(seq.sequence_number), @@ -436,6 +507,15 @@ impl Record { self.parsed.handshakes.first() } + fn contains_complete_key_update(&self) -> bool { + self.record().content_type == ContentType::Handshake + && self.handshakes().iter().any(|handshake| { + handshake.header.msg_type == MessageType::KeyUpdate + && handshake.header.fragment_offset == 0 + && handshake.header.fragment_length == handshake.header.length + }) + } + pub fn is_handled(&self) -> bool { if self.parsed.handshakes.is_empty() { self.parsed.handled.load(Ordering::Relaxed) @@ -743,8 +823,9 @@ mod tests { packet.extend_from_slice(&build_ciphertext_record(2, 2, &[0x11, 0x22, 0x33])); let mut handler = TestHandler::default(); - let incoming = Incoming::parse_packet(&packet, &mut handler, None) + let incoming = Incoming::parse_packet_defer_after_key_update(&packet, &mut handler, None) .unwrap() + .incoming .expect("ciphertext application data record should remain"); assert_eq!(handler.classify_calls, 2); diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 71f3dfd7..64f7e25a 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -249,6 +249,9 @@ impl Server { self.engine.parse_packet(packet)?; self.make_progress()?; + while self.engine.parse_next_deferred_packet()? { + self.make_progress()?; + } // Once past AwaitClientHello, DTLS 1.3 is committed — free the buffer. if self.auto_mode && self.state != State::AwaitClientHello { diff --git a/tests/dtls13/key_update.rs b/tests/dtls13/key_update.rs index d619f0f3..b9b8ccf9 100644 --- a/tests/dtls13/key_update.rs +++ b/tests/dtls13/key_update.rs @@ -197,6 +197,327 @@ fn dtls13_key_update_bidirectional_after_limit() { ); } +#[cfg(feature = "rcgen")] +fn assert_key_update_and_app_data_same_datagram( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + priming: &'static [u8], + post_update: &'static [u8], +) { + sender + .send_application_data(priming) + .expect("send priming app data"); + let first_packets = collect_packets(sender); + assert_eq!(first_packets.len(), 1); + + deliver_packets(&first_packets, receiver); + receiver.handle_timeout(*now).expect("receiver timeout"); + let first_received = drain_outputs(receiver); + assert_eq!(first_received.app_data, vec![priming.to_vec()]); + deliver_packets(&first_received.packets, sender); + + *now += Duration::from_millis(10); + sender.handle_timeout(*now).expect("sender timeout"); + let key_update_packets = collect_packets(sender); + assert_eq!(key_update_packets.len(), 1); + + sender + .send_application_data(post_update) + .expect("send post-key-update app data"); + let app_packets = collect_packets(sender); + assert_eq!(app_packets.len(), 1); + + let mut combined = key_update_packets[0].clone(); + combined.extend_from_slice(&app_packets[0]); + + receiver + .handle_packet(&combined) + .expect("receiver should accept combined datagram"); + receiver.handle_timeout(*now).expect("receiver timeout"); + + let received = drain_outputs(receiver); + assert_eq!( + received.app_data, + vec![post_update.to_vec()], + "receiver should deliver new-epoch app data that follows KeyUpdate in the same datagram" + ); +} + +#[cfg(feature = "rcgen")] +fn capture_key_update_and_app_data_packets( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + priming: &'static [u8], + post_update: &'static [u8], +) -> (Vec, Vec) { + sender + .send_application_data(priming) + .expect("send priming app data"); + let first_packets = collect_packets(sender); + assert_eq!(first_packets.len(), 1); + + deliver_packets(&first_packets, receiver); + receiver.handle_timeout(*now).expect("receiver timeout"); + let first_received = drain_outputs(receiver); + assert_eq!(first_received.app_data, vec![priming.to_vec()]); + deliver_packets(&first_received.packets, sender); + + *now += Duration::from_millis(10); + sender.handle_timeout(*now).expect("sender timeout"); + let key_update_packets = collect_packets(sender); + assert_eq!(key_update_packets.len(), 1); + + sender + .send_application_data(post_update) + .expect("send post-key-update app data"); + let app_packets = collect_packets(sender); + assert_eq!(app_packets.len(), 1); + + (key_update_packets[0].clone(), app_packets[0].clone()) +} + +#[cfg(feature = "rcgen")] +fn dtls13_ack_record_with_entry(seq: u64, ack_epoch: u64, ack_seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(26); // Ack + out.extend_from_slice(&[0xFE, 0xFD]); // legacy DTLS record version + out.extend_from_slice(&0u16.to_be_bytes()); // epoch 0 plaintext + out.extend_from_slice(&seq.to_be_bytes()[2..]); // u48 sequence number + out.extend_from_slice(&18u16.to_be_bytes()); // record_numbers_len + one entry + out.extend_from_slice(&16u16.to_be_bytes()); // record_numbers_len + out.extend_from_slice(&ack_epoch.to_be_bytes()); + out.extend_from_slice(&ack_seq.to_be_bytes()); + out +} + +/// Test that application data following a KeyUpdate in the same datagram is +/// delivered. The sender emits the KeyUpdate under the old application epoch, +/// rotates send keys, and the datagram then carries application data under the +/// new epoch. The receiver must process the KeyUpdate before trying to decrypt +/// the following record. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_client_key_update_and_new_epoch_app_data_in_same_datagram() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_key_update_and_app_data_same_datagram( + &mut client, + &mut server, + &mut now, + b"client-primes-key-update", + b"client-same-datagram-new-epoch", + ); +} + +/// If the deferred tail after a KeyUpdate is structurally malformed, the whole +/// UDP datagram must be rejected before the KeyUpdate is acted on. This keeps +/// the DIMP-007 datagram-atomic replay/state invariant for malformed tails. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_with_malformed_same_datagram_tail_is_atomic() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let (key_update_packet, app_packet) = capture_key_update_and_app_data_packets( + &mut client, + &mut server, + &mut now, + b"client-primes-malformed-tail-key-update", + b"client-valid-tail-after-malformed-attempt", + ); + + let mut malformed = key_update_packet.clone(); + malformed.push(0xff); + + let err = server + .handle_packet(&malformed) + .expect_err("malformed tail must reject the full datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let after_malformed = drain_outputs(&mut server); + assert!( + after_malformed.packets.is_empty() && after_malformed.app_data.is_empty(), + "malformed datagram must not ACK, advance, or deliver anything" + ); + + let mut valid = key_update_packet; + valid.extend_from_slice(&app_packet); + + server + .handle_packet(&valid) + .expect("valid retry must still pass replay checks"); + server.handle_timeout(now).expect("server timeout"); + + let received = drain_outputs(&mut server); + assert_eq!( + received.app_data, + vec![b"client-valid-tail-after-malformed-attempt".to_vec()] + ); +} + +/// Deferring a post-KeyUpdate tail must not reset the per-datagram record +/// budget. Otherwise a `KeyUpdate || 16 records` datagram would be accepted as +/// two separately-budgeted parses and the KeyUpdate would take effect before the +/// over-capacity tail is rejected. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_with_over_capacity_same_datagram_tail_is_atomic() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let (key_update_packet, app_packet) = capture_key_update_and_app_data_packets( + &mut client, + &mut server, + &mut now, + b"client-primes-over-capacity-tail-key-update", + b"client-valid-tail-after-over-capacity-attempt", + ); + + let mut over_capacity = key_update_packet.clone(); + for seq in 0..16 { + over_capacity.extend_from_slice(&dtls13_ack_record_with_entry(0x200 + seq, 3, seq)); + } + + let err = server + .handle_packet(&over_capacity) + .expect_err("over-capacity tail must reject the full datagram"); + assert!( + matches!(err, dimpl::Error::TooManyRecords), + "expected TooManyRecords, got {err:?}" + ); + + let after_over_capacity = drain_outputs(&mut server); + assert!( + after_over_capacity.packets.is_empty() && after_over_capacity.app_data.is_empty(), + "over-capacity datagram must not ACK, advance, or deliver anything" + ); + + let mut valid = key_update_packet; + valid.extend_from_slice(&app_packet); + + server + .handle_packet(&valid) + .expect("valid retry must still pass replay checks"); + server.handle_timeout(now).expect("server timeout"); + + let received = drain_outputs(&mut server); + assert_eq!( + received.app_data, + vec![b"client-valid-tail-after-over-capacity-attempt".to_vec()] + ); +} + +/// Same as the client-sender case, but with the server as the KeyUpdate sender +/// so the client deferred-tail path is covered too. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_server_key_update_and_new_epoch_app_data_in_same_datagram() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_key_update_and_app_data_same_datagram( + &mut server, + &mut client, + &mut now, + b"server-primes-key-update", + b"server-same-datagram-new-epoch", + ); +} + /// Test that a reordered packet captured before a KeyUpdate is accepted when /// delivered alongside other packets during the transition. The packet is from /// the same epoch and arrives before any new-epoch records, so the replay