From 1b56b2f561f539a816707b4f52ada5f6dd0a3eb1 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Sun, 24 May 2026 08:40:05 +0300 Subject: [PATCH 1/3] dtls: defer replay commits until datagram parse succeeds --- CHANGELOG.md | 1 + src/dtls12/incoming.rs | 88 +++++++++++++++++----- src/dtls13/engine.rs | 11 ++- src/dtls13/incoming.rs | 164 +++++++++++++++++++++++++++++++++++------ tests/dtls12/edge.rs | 105 ++++++++++++++++++++++++++ tests/dtls13/edge.rs | 105 ++++++++++++++++++++++++++ 6 files changed, 434 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 261bfb8a..5fc3ec96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Fix malformed datagrams consuming DTLS replay-window state #121 * Replace pending DTLS 1.2 handshake output on resend #116 * Discard bad protected DTLS 1.2 records after handshake #115 * Reject oversized DTLS certificate lists #113 diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index 2c41ebbd..373c424b 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -8,6 +8,7 @@ use crate::Error; use crate::buffer::{Buf, TmpBuf}; use crate::crypto::{Aad, Nonce}; use crate::dtls12::message::{ContentType, DTLSRecord, Dtls12CipherSuite, Handshake, Sequence}; +use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. pub struct Incoming { @@ -67,12 +68,14 @@ pub struct Records { } impl Records { - pub fn parse( + fn parse( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut replay_updates: ArrayVec = ArrayVec::new(); + let mut pending_replay = ReplayWindow::new(); // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { @@ -91,14 +94,29 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; match Record::parse(record_slice, decrypt, cs) { - Ok(record) => { - if let Some(record) = record { + Ok(parsed) => { + if let Some(sequence) = parsed.replay_sequence { + if !pending_replay.check(sequence.sequence_number) { + trace!("Discarding duplicate rec in same datagram"); + packet = &packet[record_end..]; + continue; + } + } + + if let Some(record) = parsed.record { if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } - } else { + } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } + + if let Some(sequence) = parsed.replay_sequence { + pending_replay.update(sequence.sequence_number); + if replay_updates.try_push(sequence).is_err() { + return Err(Error::TooManyRecords); + } + } } Err(e) => return Err(e), } @@ -106,6 +124,13 @@ impl Records { packet = &packet[record_end..]; } + // Commit replay state only after the whole UDP datagram has parsed + // successfully. A malformed trailing record must not consume + // replay state for an earlier authenticated record in the same datagram. + for sequence in replay_updates { + decrypt.replay_update(sequence); + } + let mut records = ArrayVec::new(); for record in parsed_records { if let Some(record) = decrypt.classify_record(record)? { @@ -134,14 +159,19 @@ pub struct Record { parsed: Box, } +struct RecordParse { + record: Option, + replay_sequence: Option, +} + impl Record { /// The first parse pass only parses the DTLSRecord header which is unencrypted. /// Copies record data from UDP packet ONCE into a pooled buffer. - pub fn parse( + fn parse( record_slice: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + ) -> Result { // ONLY COPY: UDP packet slice -> pooled buffer let mut buffer = Buf::new(); buffer.extend_from_slice(record_slice); @@ -151,7 +181,10 @@ impl Record { // RFC 6347 §4.1.2.7: Invalid records SHOULD be silently discarded. // This includes epoch 0 records with invalid ContentType. trace!("Discarding record: parse failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } }; let parsed = Box::new(parsed); @@ -162,7 +195,10 @@ impl Record { // packet loss, we can end up seeing epoch 1 records before we can decrypt them. let is_epoch_0 = record.record().sequence.epoch == 0; if is_epoch_0 || !decrypt.is_peer_encryption_enabled() { - return Ok(Some(record)); + return Ok(RecordParse { + record: Some(record), + replay_sequence: None, + }); } // We need to decrypt the record and redo the parsing. @@ -171,12 +207,18 @@ impl Record { // Anti-replay check (read-only, does not update window) if !decrypt.replay_check(sequence) { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } let explicit_nonce_len = decrypt.explicit_nonce_len(); if (dtls.length as usize) < decrypt.min_protected_fragment_len() { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } // Get a reference to the buffer @@ -203,25 +245,35 @@ impl Record { } trace!("Discarding record: decrypt failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 6347 §4.1.2.6: "The receive window is updated only if the - // MAC verification succeeds." - decrypt.replay_update(sequence); - // Update the length of the record. buffer[11] = (new_len >> 8) as u8; buffer[12] = new_len as u8; - let parsed = ParsedRecord::parse(&buffer, cs, explicit_nonce_len)?; + let parsed = match ParsedRecord::parse(&buffer, cs, explicit_nonce_len) { + Ok(parsed) => parsed, + Err(e) => { + trace!("Discarding authenticated record: parse failed: {}", e); + return Ok(RecordParse { + record: None, + replay_sequence: Some(sequence), + }); + } + }; let parsed = Box::new(parsed); - Ok(Some(Record { buffer, parsed })) + Ok(RecordParse { + record: Some(Record { buffer, parsed }), + replay_sequence: Some(sequence), + }) } pub fn record(&self) -> &DTLSRecord { diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 984b06b5..67b44422 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -2338,7 +2338,13 @@ impl RecordHandler for Engine { epoch_bits } - fn resolve_sequence(&self, epoch: u16, seq_bits: u64, s_flag: bool) -> u64 { + fn resolve_sequence( + &self, + epoch: u16, + seq_bits: u64, + s_flag: bool, + expected_override: Option, + ) -> u64 { let expected = if epoch == 2 { self.hs_expected_recv_seq } else { @@ -2348,6 +2354,9 @@ impl RecordHandler for Engine { .map(|e| e.expected_recv_seq) .unwrap_or(0) }; + let expected = expected_override + .map(|override_expected| expected.max(override_expected)) + .unwrap_or(expected); let bits: u32 = if s_flag { 16 } else { 8 }; reconstruct_sequence(seq_bits, expected, bits) diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index 240cc830..1e83ecd1 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -7,6 +7,7 @@ use std::fmt; use crate::Error; use crate::buffer::{Buf, TmpBuf}; use crate::dtls13::message::{ContentType, Dtls13CipherSuite, Dtls13Record, Handshake, Sequence}; +use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. pub struct Incoming { @@ -72,6 +73,9 @@ impl Records { cs: Option, ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut replay_updates: ArrayVec = ArrayVec::new(); + let mut pending_replay: ArrayVec<(u16, ReplayWindow), 16> = ArrayVec::new(); + let mut pending_expected: ArrayVec<(u16, u64), 16> = ArrayVec::new(); // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { @@ -129,15 +133,31 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; - match Record::parse(record_slice, decrypt, cs) { - Ok(record) => { - if let Some(record) = record { + match Record::parse(record_slice, decrypt, cs, &pending_expected) { + Ok(parsed) => { + if let Some(sequence) = parsed.replay_sequence { + if !pending_replay_check(&pending_replay, sequence) { + trace!("Discarding duplicate rec in same datagram"); + packet = &packet[record_end..]; + continue; + } + } + + if let Some(record) = parsed.record { if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } - } else { + } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } + + if let Some(sequence) = parsed.replay_sequence { + pending_replay_update(&mut pending_replay, sequence)?; + pending_expected_update(&mut pending_expected, sequence)?; + if replay_updates.try_push(sequence).is_err() { + return Err(Error::TooManyRecords); + } + } } Err(e) => return Err(e), } @@ -145,6 +165,13 @@ impl Records { packet = &packet[record_end..]; } + // Commit replay state only after the whole UDP datagram has parsed + // successfully. A malformed trailing record must not consume + // replay state for an earlier authenticated record in the same datagram. + for sequence in replay_updates { + decrypt.replay_update(sequence); + } + let mut records = ArrayVec::new(); for record in parsed_records { if let Some(record) = decrypt.classify_record(record)? { @@ -158,6 +185,62 @@ impl Records { } } +fn pending_replay_check(pending_replay: &ArrayVec<(u16, ReplayWindow), 16>, seq: Sequence) -> bool { + match pending_replay.iter().find(|(epoch, _)| *epoch == seq.epoch) { + Some((_, window)) => window.check(seq.sequence_number), + None => true, + } +} + +fn pending_replay_update( + pending_replay: &mut ArrayVec<(u16, ReplayWindow), 16>, + seq: Sequence, +) -> Result<(), Error> { + if let Some((_, window)) = pending_replay + .iter_mut() + .find(|(epoch, _)| *epoch == seq.epoch) + { + window.update(seq.sequence_number); + return Ok(()); + } + + let mut window = ReplayWindow::new(); + window.update(seq.sequence_number); + pending_replay + .try_push((seq.epoch, window)) + .map_err(|_| Error::TooManyRecords) +} + +fn pending_expected_override( + pending_expected: &ArrayVec<(u16, u64), 16>, + epoch: u16, +) -> Option { + pending_expected + .iter() + .find(|(candidate_epoch, _)| *candidate_epoch == epoch) + .map(|(_, expected)| *expected) +} + +fn pending_expected_update( + pending_expected: &mut ArrayVec<(u16, u64), 16>, + seq: Sequence, +) -> Result<(), Error> { + let next = seq.sequence_number + 1; + if let Some((_, expected)) = pending_expected + .iter_mut() + .find(|(epoch, _)| *epoch == seq.epoch) + { + if next > *expected { + *expected = next; + } + return Ok(()); + } + + pending_expected + .try_push((seq.epoch, next)) + .map_err(|_| Error::TooManyRecords) +} + impl Deref for Records { type Target = [Record]; @@ -173,14 +256,20 @@ pub struct Record { parsed: Box, } +struct RecordParse { + record: Option, + replay_sequence: Option, +} + impl Record { /// The first parse pass only parses the record header which is unencrypted. /// Copies record data from UDP packet ONCE into a pooled buffer. - pub fn parse( + fn parse( record_slice: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + pending_expected: &ArrayVec<(u16, u64), 16>, + ) -> Result { // ONLY COPY: UDP packet slice -> pooled buffer let mut buffer = Buf::new(); buffer.extend_from_slice(record_slice); @@ -218,7 +307,10 @@ impl Record { Ok(p) => p, Err(e) => { trace!("Discarding record: parse failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } }; let parsed = Box::new(parsed); @@ -226,7 +318,10 @@ impl Record { // Plaintext records (epoch 0) are not encrypted if !is_ciphertext || !decrypt.is_peer_encryption_enabled() { - return Ok(Some(record)); + return Ok(RecordParse { + record: Some(record), + replay_sequence: None, + }); } // Resolve the full epoch from the 2-bit value in the unified header @@ -236,7 +331,12 @@ impl Record { // Resolve the full sequence number from the (now decrypted) partial value let seq_bits = record.record().sequence.sequence_number; let s_flag = record_slice[0] & 0b0000_1000 != 0; - let full_seq = decrypt.resolve_sequence(full_epoch, seq_bits, s_flag); + let full_seq = decrypt.resolve_sequence( + full_epoch, + seq_bits, + s_flag, + pending_expected_override(pending_expected, full_epoch), + ); let full_sequence = Sequence { epoch: full_epoch, @@ -245,7 +345,10 @@ impl Record { // Anti-replay check (read-only, does not update window) if !decrypt.replay_check(full_sequence) { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } // Save the raw header bytes for AAD before mutating the buffer. @@ -257,7 +360,10 @@ impl Record { // so decryption would necessarily fail. Catching it here keeps the // cipher impls' bounds-checking from being the only line of defence. if record.buffer.len() - header_end < decrypt.min_protected_fragment_len() { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } let mut header_buf = [0u8; 5]; header_buf[..header_end].copy_from_slice(&record.buffer[..header_end]); @@ -277,25 +383,26 @@ impl Record { Ok(()) => {} Err(e) => { trace!("Discarding ciphertext record: decryption failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } } buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 9147 §4.5.1: "The window MUST NOT be updated due to a received - // record until that record has been deprotected successfully." - decrypt.replay_update(full_sequence); - // Recover inner content type from DTLSInnerPlaintext let decrypted = &buffer[header_end..header_end + new_len]; let (inner_content_type, content_len) = match recover_inner_content_type(decrypted) { Ok(v) => v, Err(e) => { trace!("Discarding record: invalid inner content type: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: Some(full_sequence), + }); } }; @@ -311,7 +418,10 @@ impl Record { ); let parsed = Box::new(parsed); - Ok(Some(Record { buffer, parsed })) + Ok(RecordParse { + record: Some(Record { buffer, parsed }), + replay_sequence: Some(full_sequence), + }) } pub fn record(&self) -> &Dtls13Record { @@ -407,7 +517,13 @@ pub trait RecordHandler { fn classify_record(&mut self, record: Record) -> Result, Error>; fn is_peer_encryption_enabled(&self) -> bool; fn resolve_epoch(&self, epoch_bits: u8) -> u16; - fn resolve_sequence(&self, epoch: u16, seq_bits: u64, s_flag: bool) -> u64; + fn resolve_sequence( + &self, + epoch: u16, + seq_bits: u64, + s_flag: bool, + expected_override: Option, + ) -> u64; fn replay_check(&self, seq: Sequence) -> bool; fn replay_update(&mut self, seq: Sequence); @@ -556,7 +672,13 @@ mod tests { panic!("resolve_epoch should not be called when peer encryption is disabled"); } - fn resolve_sequence(&self, _epoch: u16, _seq_bits: u64, _s_flag: bool) -> u64 { + fn resolve_sequence( + &self, + _epoch: u16, + _seq_bits: u64, + _s_flag: bool, + _expected_override: Option, + ) -> u64 { panic!("resolve_sequence should not be called when peer encryption is disabled"); } diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 7640007c..6dcd02b5 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -680,6 +680,78 @@ fn dtls12_bad_encrypted_prefix_does_not_drop_valid_tail() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_malformed_trailing_record_does_not_consume_replay_window() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + client + .send_application_data(b"replay-atomic") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut malformed_packet = valid_packet.clone(); + malformed_packet.push(0xff); + + let err = server + .handle_packet(&malformed_packet) + .expect_err("trailing partial record must reject the datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let server_out = drain_outputs(&mut server); + assert!( + server_out.app_data.is_empty(), + "malformed datagram must not deliver application data" + ); + + server + .handle_packet(&valid_packet) + .expect("valid packet must still pass replay checks"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"replay-atomic".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_same_datagram_duplicate_encrypted_record_delivers_once() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + client + .send_application_data(b"duplicate-once") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut duplicate_datagram = valid_packet.clone(); + duplicate_datagram.extend_from_slice(&valid_packet); + + server + .handle_packet(&duplicate_datagram) + .expect("duplicate datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"duplicate-once".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_relabelled_encrypted_handshake_failure_is_not_silently_discarded() { @@ -767,6 +839,39 @@ fn dtls12_relabelled_encrypted_handshake_failure_is_not_silently_discarded() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_same_datagram_window_shift_drops_now_too_old_record() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + let mut packets = Vec::new(); + for i in 0..66 { + client + .send_application_data(format!("msg-{i}").as_bytes()) + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let out = drain_outputs(&mut client); + let packet = out + .packets + .first() + .expect("application data packet") + .clone(); + packets.push(packet); + } + + let mut shifted_datagram = packets[65].clone(); + shifted_datagram.extend_from_slice(&packets[0]); + + server + .handle_packet(&shifted_datagram) + .expect("window-shift datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"msg-65".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_app_data_after_close_notify_is_ignored() { diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 1813b8ca..d7c36d90 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -1267,6 +1267,111 @@ fn dtls13_mixed_datagram_valid_first_then_bogus() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_malformed_trailing_record_does_not_consume_replay_window() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + client + .send_application_data(b"replay-atomic") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut malformed_packet = valid_packet.clone(); + malformed_packet.push(0xff); + + let err = server + .handle_packet(&malformed_packet) + .expect_err("trailing partial record must reject the datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let server_out = drain_outputs(&mut server); + assert!( + server_out.app_data.is_empty(), + "malformed datagram must not deliver application data" + ); + + server + .handle_packet(&valid_packet) + .expect("valid packet must still pass replay checks"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"replay-atomic".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_same_datagram_duplicate_encrypted_record_delivers_once() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + client + .send_application_data(b"duplicate-once") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut duplicate_datagram = valid_packet.clone(); + duplicate_datagram.extend_from_slice(&valid_packet); + + server + .handle_packet(&duplicate_datagram) + .expect("duplicate datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"duplicate-once".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_same_datagram_window_shift_drops_now_too_old_record() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + let mut packets = Vec::new(); + for i in 0..66 { + client + .send_application_data(format!("msg-{i}").as_bytes()) + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let out = drain_outputs(&mut client); + let packet = out + .packets + .first() + .expect("application data packet") + .clone(); + packets.push(packet); + } + + let mut shifted_datagram = packets[65].clone(); + shifted_datagram.extend_from_slice(&packets[0]); + + server + .handle_packet(&shifted_datagram) + .expect("window-shift datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"msg-65".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_half_close_send_then_close() { From a2ad7ad519df65cf9fa3efca706d0d100f0217d2 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Sun, 24 May 2026 08:43:48 +0300 Subject: [PATCH 2/3] dtls13: preserve app data after same-datagram KeyUpdate --- CHANGELOG.md | 1 + src/dtls13/client.rs | 3 + src/dtls13/engine.rs | 29 +++- src/dtls13/incoming.rs | 211 ++++++++++++++++-------- src/dtls13/server.rs | 3 + tests/dtls13/key_update.rs | 321 +++++++++++++++++++++++++++++++++++++ 6 files changed, 501 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc3ec96..f8743d0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Preserve DTLS 1.3 app data following KeyUpdate in one datagram #122 * Fix malformed datagrams consuming DTLS replay-window state #121 * Replace pending DTLS 1.2 handshake output on resend #116 * Discard bad protected DTLS 1.2 records after handshake #115 diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index b18408da..3bfb88a1 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -220,6 +220,9 @@ impl Client { pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { self.engine.parse_packet(packet)?; self.make_progress()?; + while self.engine.parse_next_deferred_packet()? { + self.make_progress()?; + } Ok(()) } diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 67b44422..3c904a95 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -58,6 +58,10 @@ pub struct Engine { /// Queue of outgoing packets. queue_tx: QueueTx, + /// Deferred tail datagrams split after KeyUpdate records. They are parsed + /// after the state machine installs the next receive keys. + deferred_packets: ArrayVec, + /// The cipher suite in use. Set by ServerHello. cipher_suite: Option, @@ -221,6 +225,7 @@ impl Engine { sequence_epoch_0: Sequence::new(0), queue_rx: QueueRx::new(), queue_tx: QueueTx::new(), + deferred_packets: ArrayVec::new(), cipher_suite: None, hs_send_keys: None, hs_recv_keys: None, @@ -330,15 +335,35 @@ impl Engine { } pub fn parse_packet(&mut self, packet: &[u8]) -> Result<(), Error> { + self.parse_packet_once(packet) + } + + fn parse_packet_once(&mut self, packet: &[u8]) -> Result<(), Error> { let cs = self.cipher_suite; - let incoming = Incoming::parse_packet(packet, self, cs)?; - if let Some(incoming) = incoming { + let parsed = Incoming::parse_packet_defer_after_key_update(packet, self, cs)?; + if let Some(incoming) = parsed.incoming { self.insert_incoming(incoming)?; } + if let Some(deferred_tail) = parsed.deferred_tail { + self.deferred_packets + .try_push(deferred_tail) + .map_err(|_| Error::TooManyRecords)?; + } Ok(()) } + pub fn parse_next_deferred_packet(&mut self) -> Result { + if self.deferred_packets.is_empty() { + return Ok(false); + } + + let packet = self.deferred_packets.remove(0); + self.parse_packet_once(&packet)?; + + Ok(true) + } + fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { if self.queue_rx.len() >= self.config.max_queue_rx() { warn!( diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index 1e83ecd1..604e2d72 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -6,7 +6,9 @@ use std::fmt; use crate::Error; use crate::buffer::{Buf, TmpBuf}; -use crate::dtls13::message::{ContentType, Dtls13CipherSuite, Dtls13Record, Handshake, Sequence}; +use crate::dtls13::message::Sequence; +use crate::dtls13::message::{ContentType, Dtls13CipherSuite}; +use crate::dtls13::message::{Dtls13Record, Handshake, MessageType}; use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. @@ -16,6 +18,11 @@ pub struct Incoming { records: Box, } +pub(crate) struct IncomingParse { + pub incoming: Option, + pub deferred_tail: Option, +} + impl Incoming { pub fn records(&self) -> &Records { &self.records @@ -33,30 +40,38 @@ impl Incoming { } impl Incoming { - /// Parse an incoming UDP packet - /// - /// * `packet` is the data from the UDP socket. - /// * `decrypt` provides the decryption operations for encrypted records. - /// * `cs` is the negotiated cipher suite, if any. - /// - /// Will surface parser errors. - pub fn parse_packet( + pub(crate) fn parse_packet_defer_after_key_update( + packet: &[u8], + decrypt: &mut dyn RecordHandler, + cs: Option, + ) -> Result { + Self::parse_packet_inner(packet, decrypt, cs, true) + } + + fn parse_packet_inner( packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + defer_after_key_update: bool, + ) -> Result { // Parse records directly from packet, copying each record ONCE into its own buffer - let records = Records::parse(packet, decrypt, cs)?; + let parse = Records::parse_inner(packet, decrypt, cs, defer_after_key_update)?; // We need at least one Record to be valid. For replayed frames, we discard // the records, hence this might be None - if records.records.is_empty() { - return Ok(None); + if parse.records.records.is_empty() { + return Ok(IncomingParse { + incoming: None, + deferred_tail: parse.deferred_tail, + }); } - let records = Box::new(records); + let records = Box::new(parse.records); - Ok(Some(Incoming { records })) + Ok(IncomingParse { + incoming: Some(Incoming { records }), + deferred_tail: parse.deferred_tail, + }) } } @@ -66,65 +81,28 @@ pub struct Records { pub records: ArrayVec, } +struct RecordsParse { + records: Records, + deferred_tail: Option, +} + impl Records { - pub fn parse( + fn parse_inner( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result { + defer_after_key_update: bool, + ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); let mut replay_updates: ArrayVec = ArrayVec::new(); let mut pending_replay: ArrayVec<(u16, ReplayWindow), 16> = ArrayVec::new(); let mut pending_expected: ArrayVec<(u16, u64), 16> = ArrayVec::new(); + let mut deferred_tail = None; // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { - let record_end = if Dtls13Record::is_ciphertext_header(packet[0]) { - // CID bit set means we can't determine record boundaries (unsupported). - // Discard the rest of the datagram. - if packet[0] & 0x10 != 0 { - break; - } - - // Unified header: variable length - if packet.len() < 2 { - return Err(Error::ParseIncomplete); - } - - let flags = packet[0]; - let s_flag = flags & 0b0000_1000 != 0; - let l_flag = flags & 0b0000_0100 != 0; - let seq_len = if s_flag { 2 } else { 1 }; - let len_len = if l_flag { 2 } else { 0 }; - let header_len = 1 + seq_len + len_len; - - if packet.len() < header_len { - return Err(Error::ParseIncomplete); - } - - if l_flag { - let len_offset = 1 + seq_len; - // unwrap: header_len check above ensures 2 bytes at len_offset - let length_bytes: [u8; 2] = - packet[len_offset..len_offset + 2].try_into().unwrap(); - let length = u16::from_be_bytes(length_bytes) as usize; - header_len + length - } else { - // No length field: record consumes the rest of the datagram - packet.len() - } - } else { - // Plaintext: fixed 13-byte header - if packet.len() < Dtls13Record::PLAINTEXT_HEADER_LEN { - return Err(Error::ParseIncomplete); - } - - // unwrap: PLAINTEXT_HEADER_LEN check above ensures 2 bytes at offset - let length_bytes: [u8; 2] = packet[Dtls13Record::PLAINTEXT_LENGTH_OFFSET] - .try_into() - .unwrap(); - let length = u16::from_be_bytes(length_bytes) as usize; - Dtls13Record::PLAINTEXT_HEADER_LEN + length + let Some(record_end) = record_end(packet)? else { + break; }; if packet.len() < record_end { @@ -133,8 +111,11 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; + let tail = &packet[record_end..]; match Record::parse(record_slice, decrypt, cs, &pending_expected) { Ok(parsed) => { + let mut should_break_after_replay_update = false; + if let Some(sequence) = parsed.replay_sequence { if !pending_replay_check(&pending_replay, sequence) { trace!("Discarding duplicate rec in same datagram"); @@ -144,9 +125,21 @@ impl Records { } if let Some(record) = parsed.record { + let should_defer_tail = defer_after_key_update + && !tail.is_empty() + && record.contains_complete_key_update(); + if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } + + should_break_after_replay_update = should_defer_tail; + if should_defer_tail { + validate_record_boundaries(tail, parsed_records.len())?; + let mut tail_buffer = Buf::new(); + tail_buffer.extend_from_slice(tail); + deferred_tail = Some(tail_buffer); + } } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } @@ -158,6 +151,10 @@ impl Records { return Err(Error::TooManyRecords); } } + + if should_break_after_replay_update { + break; + } } Err(e) => return Err(e), } @@ -181,10 +178,84 @@ impl Records { } } - Ok(Records { records }) + Ok(RecordsParse { + records: Records { records }, + deferred_tail, + }) + } +} + +fn record_end(packet: &[u8]) -> Result, Error> { + if Dtls13Record::is_ciphertext_header(packet[0]) { + // CID bit set means we can't determine record boundaries (unsupported). + // Discard the rest of the datagram. + if packet[0] & 0x10 != 0 { + return Ok(None); + } + + // Unified header: variable length + if packet.len() < 2 { + return Err(Error::ParseIncomplete); + } + + let flags = packet[0]; + let s_flag = flags & 0b0000_1000 != 0; + let l_flag = flags & 0b0000_0100 != 0; + let seq_len = if s_flag { 2 } else { 1 }; + let len_len = if l_flag { 2 } else { 0 }; + let header_len = 1 + seq_len + len_len; + + if packet.len() < header_len { + return Err(Error::ParseIncomplete); + } + + if l_flag { + let len_offset = 1 + seq_len; + // unwrap: header_len check above ensures 2 bytes at len_offset. + let length_bytes: [u8; 2] = packet[len_offset..len_offset + 2].try_into().unwrap(); + let length = u16::from_be_bytes(length_bytes) as usize; + Ok(Some(header_len + length)) + } else { + Ok(Some(packet.len())) + } + } else { + if packet.len() < Dtls13Record::PLAINTEXT_HEADER_LEN { + return Err(Error::ParseIncomplete); + } + + // unwrap: PLAINTEXT_HEADER_LEN check above ensures 2 bytes at offset. + let length_bytes: [u8; 2] = packet[Dtls13Record::PLAINTEXT_LENGTH_OFFSET] + .try_into() + .unwrap(); + let length = u16::from_be_bytes(length_bytes) as usize; + Ok(Some(Dtls13Record::PLAINTEXT_HEADER_LEN + length)) } } +fn validate_record_boundaries( + mut packet: &[u8], + mut parsed_record_count: usize, +) -> Result<(), Error> { + while !packet.is_empty() { + let Some(end) = record_end(packet)? else { + return Ok(()); + }; + + if parsed_record_count >= 16 { + return Err(Error::TooManyRecords); + } + parsed_record_count += 1; + + if packet.len() < end { + return Err(Error::ParseIncomplete); + } + + packet = &packet[end..]; + } + + Ok(()) +} + fn pending_replay_check(pending_replay: &ArrayVec<(u16, ReplayWindow), 16>, seq: Sequence) -> bool { match pending_replay.iter().find(|(epoch, _)| *epoch == seq.epoch) { Some((_, window)) => window.check(seq.sequence_number), @@ -436,6 +507,15 @@ impl Record { self.parsed.handshakes.first() } + fn contains_complete_key_update(&self) -> bool { + self.record().content_type == ContentType::Handshake + && self.handshakes().iter().any(|handshake| { + handshake.header.msg_type == MessageType::KeyUpdate + && handshake.header.fragment_offset == 0 + && handshake.header.fragment_length == handshake.header.length + }) + } + pub fn is_handled(&self) -> bool { if self.parsed.handshakes.is_empty() { self.parsed.handled.load(Ordering::Relaxed) @@ -743,8 +823,9 @@ mod tests { packet.extend_from_slice(&build_ciphertext_record(2, 2, &[0x11, 0x22, 0x33])); let mut handler = TestHandler::default(); - let incoming = Incoming::parse_packet(&packet, &mut handler, None) + let incoming = Incoming::parse_packet_defer_after_key_update(&packet, &mut handler, None) .unwrap() + .incoming .expect("ciphertext application data record should remain"); assert_eq!(handler.classify_calls, 2); diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 71f3dfd7..64f7e25a 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -249,6 +249,9 @@ impl Server { self.engine.parse_packet(packet)?; self.make_progress()?; + while self.engine.parse_next_deferred_packet()? { + self.make_progress()?; + } // Once past AwaitClientHello, DTLS 1.3 is committed — free the buffer. if self.auto_mode && self.state != State::AwaitClientHello { diff --git a/tests/dtls13/key_update.rs b/tests/dtls13/key_update.rs index d619f0f3..b9b8ccf9 100644 --- a/tests/dtls13/key_update.rs +++ b/tests/dtls13/key_update.rs @@ -197,6 +197,327 @@ fn dtls13_key_update_bidirectional_after_limit() { ); } +#[cfg(feature = "rcgen")] +fn assert_key_update_and_app_data_same_datagram( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + priming: &'static [u8], + post_update: &'static [u8], +) { + sender + .send_application_data(priming) + .expect("send priming app data"); + let first_packets = collect_packets(sender); + assert_eq!(first_packets.len(), 1); + + deliver_packets(&first_packets, receiver); + receiver.handle_timeout(*now).expect("receiver timeout"); + let first_received = drain_outputs(receiver); + assert_eq!(first_received.app_data, vec![priming.to_vec()]); + deliver_packets(&first_received.packets, sender); + + *now += Duration::from_millis(10); + sender.handle_timeout(*now).expect("sender timeout"); + let key_update_packets = collect_packets(sender); + assert_eq!(key_update_packets.len(), 1); + + sender + .send_application_data(post_update) + .expect("send post-key-update app data"); + let app_packets = collect_packets(sender); + assert_eq!(app_packets.len(), 1); + + let mut combined = key_update_packets[0].clone(); + combined.extend_from_slice(&app_packets[0]); + + receiver + .handle_packet(&combined) + .expect("receiver should accept combined datagram"); + receiver.handle_timeout(*now).expect("receiver timeout"); + + let received = drain_outputs(receiver); + assert_eq!( + received.app_data, + vec![post_update.to_vec()], + "receiver should deliver new-epoch app data that follows KeyUpdate in the same datagram" + ); +} + +#[cfg(feature = "rcgen")] +fn capture_key_update_and_app_data_packets( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + priming: &'static [u8], + post_update: &'static [u8], +) -> (Vec, Vec) { + sender + .send_application_data(priming) + .expect("send priming app data"); + let first_packets = collect_packets(sender); + assert_eq!(first_packets.len(), 1); + + deliver_packets(&first_packets, receiver); + receiver.handle_timeout(*now).expect("receiver timeout"); + let first_received = drain_outputs(receiver); + assert_eq!(first_received.app_data, vec![priming.to_vec()]); + deliver_packets(&first_received.packets, sender); + + *now += Duration::from_millis(10); + sender.handle_timeout(*now).expect("sender timeout"); + let key_update_packets = collect_packets(sender); + assert_eq!(key_update_packets.len(), 1); + + sender + .send_application_data(post_update) + .expect("send post-key-update app data"); + let app_packets = collect_packets(sender); + assert_eq!(app_packets.len(), 1); + + (key_update_packets[0].clone(), app_packets[0].clone()) +} + +#[cfg(feature = "rcgen")] +fn dtls13_ack_record_with_entry(seq: u64, ack_epoch: u64, ack_seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(26); // Ack + out.extend_from_slice(&[0xFE, 0xFD]); // legacy DTLS record version + out.extend_from_slice(&0u16.to_be_bytes()); // epoch 0 plaintext + out.extend_from_slice(&seq.to_be_bytes()[2..]); // u48 sequence number + out.extend_from_slice(&18u16.to_be_bytes()); // record_numbers_len + one entry + out.extend_from_slice(&16u16.to_be_bytes()); // record_numbers_len + out.extend_from_slice(&ack_epoch.to_be_bytes()); + out.extend_from_slice(&ack_seq.to_be_bytes()); + out +} + +/// Test that application data following a KeyUpdate in the same datagram is +/// delivered. The sender emits the KeyUpdate under the old application epoch, +/// rotates send keys, and the datagram then carries application data under the +/// new epoch. The receiver must process the KeyUpdate before trying to decrypt +/// the following record. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_client_key_update_and_new_epoch_app_data_in_same_datagram() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_key_update_and_app_data_same_datagram( + &mut client, + &mut server, + &mut now, + b"client-primes-key-update", + b"client-same-datagram-new-epoch", + ); +} + +/// If the deferred tail after a KeyUpdate is structurally malformed, the whole +/// UDP datagram must be rejected before the KeyUpdate is acted on. This keeps +/// the DIMP-007 datagram-atomic replay/state invariant for malformed tails. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_with_malformed_same_datagram_tail_is_atomic() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let (key_update_packet, app_packet) = capture_key_update_and_app_data_packets( + &mut client, + &mut server, + &mut now, + b"client-primes-malformed-tail-key-update", + b"client-valid-tail-after-malformed-attempt", + ); + + let mut malformed = key_update_packet.clone(); + malformed.push(0xff); + + let err = server + .handle_packet(&malformed) + .expect_err("malformed tail must reject the full datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let after_malformed = drain_outputs(&mut server); + assert!( + after_malformed.packets.is_empty() && after_malformed.app_data.is_empty(), + "malformed datagram must not ACK, advance, or deliver anything" + ); + + let mut valid = key_update_packet; + valid.extend_from_slice(&app_packet); + + server + .handle_packet(&valid) + .expect("valid retry must still pass replay checks"); + server.handle_timeout(now).expect("server timeout"); + + let received = drain_outputs(&mut server); + assert_eq!( + received.app_data, + vec![b"client-valid-tail-after-malformed-attempt".to_vec()] + ); +} + +/// Deferring a post-KeyUpdate tail must not reset the per-datagram record +/// budget. Otherwise a `KeyUpdate || 16 records` datagram would be accepted as +/// two separately-budgeted parses and the KeyUpdate would take effect before the +/// over-capacity tail is rejected. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_with_over_capacity_same_datagram_tail_is_atomic() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let (key_update_packet, app_packet) = capture_key_update_and_app_data_packets( + &mut client, + &mut server, + &mut now, + b"client-primes-over-capacity-tail-key-update", + b"client-valid-tail-after-over-capacity-attempt", + ); + + let mut over_capacity = key_update_packet.clone(); + for seq in 0..16 { + over_capacity.extend_from_slice(&dtls13_ack_record_with_entry(0x200 + seq, 3, seq)); + } + + let err = server + .handle_packet(&over_capacity) + .expect_err("over-capacity tail must reject the full datagram"); + assert!( + matches!(err, dimpl::Error::TooManyRecords), + "expected TooManyRecords, got {err:?}" + ); + + let after_over_capacity = drain_outputs(&mut server); + assert!( + after_over_capacity.packets.is_empty() && after_over_capacity.app_data.is_empty(), + "over-capacity datagram must not ACK, advance, or deliver anything" + ); + + let mut valid = key_update_packet; + valid.extend_from_slice(&app_packet); + + server + .handle_packet(&valid) + .expect("valid retry must still pass replay checks"); + server.handle_timeout(now).expect("server timeout"); + + let received = drain_outputs(&mut server); + assert_eq!( + received.app_data, + vec![b"client-valid-tail-after-over-capacity-attempt".to_vec()] + ); +} + +/// Same as the client-sender case, but with the server as the KeyUpdate sender +/// so the client deferred-tail path is covered too. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_server_key_update_and_new_epoch_app_data_in_same_datagram() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_key_update_and_app_data_same_datagram( + &mut server, + &mut client, + &mut now, + b"server-primes-key-update", + b"server-same-datagram-new-epoch", + ); +} + /// Test that a reordered packet captured before a KeyUpdate is accepted when /// delivered alongside other packets during the transition. The packet is from /// the same epoch and arrives before any new-epoch records, so the replay From 1cbeac335f0b73c1b5fbc159a363eda5ba91b23f Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Sun, 24 May 2026 08:44:00 +0300 Subject: [PATCH 3/3] dtls13: preserve overlapping KeyUpdate flights --- CHANGELOG.md | 1 + src/dtls13/client.rs | 93 +++--- src/dtls13/engine.rs | 94 +++++- src/dtls13/incoming.rs | 9 + src/dtls13/message/handshake.rs | 3 +- src/dtls13/server.rs | 216 +++++++++++--- tests/dtls13/key_update.rs | 509 ++++++++++++++++++++++++++++++++ 7 files changed, 845 insertions(+), 80 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8743d0c..073e0c2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Preserve overlapping DTLS 1.3 KeyUpdate flights #123 * Preserve DTLS 1.3 app data following KeyUpdate in one datagram #122 * Fix malformed datagrams consuming DTLS replay-window state #121 * Replace pending DTLS 1.2 handshake output on resend #116 diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 3bfb88a1..01cc450c 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -249,6 +249,56 @@ impl Client { .create_key_update(KeyUpdateRequest::UpdateRequested) } + fn send_pending_key_update_response(&mut self) -> Result<(), Error> { + if self.pending_key_update_response && !self.engine.is_key_update_in_flight() { + self.engine.send_ack()?; + self.engine + .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; + self.pending_key_update_response = false; + } + Ok(()) + } + + fn handle_incoming_key_update(&mut self) -> Result<(), Error> { + if self.engine.has_complete_handshake(MessageType::KeyUpdate) { + let maybe = self.engine.next_handshake_no_transcript( + MessageType::KeyUpdate, + &mut self.defragment_buffer, + )?; + + if let Some(handshake) = maybe { + let Body::KeyUpdate(request) = handshake.body else { + unreachable!() + }; + + // Install new recv keys + self.engine.update_recv_keys()?; + + // If peer requested us to update, schedule our own KeyUpdate + let local_key_update_in_flight = self.engine.is_key_update_in_flight(); + if request == KeyUpdateRequest::UpdateRequested { + self.pending_key_update_response = true; + if local_key_update_in_flight { + self.engine.send_ack_with_previous_app_epoch()?; + } else { + self.engine.send_ack()?; + } + } else { + self.engine.send_ack()?; + } + + self.engine.advance_peer_handshake_seq(); + debug!("Received KeyUpdate (request={:?})", request); + + // Drain a fresh peer-requested response in the same progress + // pass when no local KeyUpdate is in flight. + self.send_pending_key_update_response()?; + } + } + + Ok(()) + } + /// Send application data when the client is connected. pub fn send_application_data(&mut self, data: &[u8]) -> Result<(), Error> { if self.state == State::Closed || self.state == State::HalfClosedLocal { @@ -1051,8 +1101,13 @@ impl State { } fn await_application_data(self, client: &mut Client) -> Result { + // Incoming peer requests require an update_not_requested response. They + // take priority over local AEAD-limit updates and queued app data. + client.handle_incoming_key_update()?; + client.send_pending_key_update_response()?; + // Auto-trigger KeyUpdate when AEAD encryption limit is reached - if client.engine.needs_key_update() && !client.engine.is_key_update_in_flight() { + if !client.engine.is_key_update_in_flight() && client.engine.needs_key_update() { client.initiate_key_update()?; } @@ -1075,42 +1130,6 @@ impl State { } } - // Send pending KeyUpdate response before processing new KeyUpdates - if client.pending_key_update_response { - client - .engine - .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; - client.pending_key_update_response = false; - } - - // Check for incoming KeyUpdate - if client.engine.has_complete_handshake(MessageType::KeyUpdate) { - let maybe = client.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, - &mut client.defragment_buffer, - )?; - - if let Some(handshake) = maybe { - let Body::KeyUpdate(request) = handshake.body else { - unreachable!() - }; - - // Install new recv keys - client.engine.update_recv_keys()?; - - // ACK the KeyUpdate record - client.engine.send_ack()?; - - // If peer requested us to update, schedule our own KeyUpdate - if request == KeyUpdateRequest::UpdateRequested { - client.pending_key_update_response = true; - } - - client.engine.advance_peer_handshake_seq(); - debug!("Received KeyUpdate (request={:?})", request); - } - } - Ok(self) } diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 3c904a95..d92b291b 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -401,7 +401,11 @@ impl Engine { if let Some(dupe_seq) = maybe_dupe_seq { if dupe_seq < self.peer_handshake_seq_no { + for seq in incoming.ackable_record_numbers() { + let _ = self.received_record_numbers.try_push(seq); + } self.flight_resend("dupe triggers resend")?; + self.send_ack()?; } } @@ -1277,17 +1281,30 @@ impl Engine { /// /// ACK format: record_numbers_length(2) + N * (epoch(8) + sequence(8)) pub fn send_ack(&mut self) -> Result<(), Error> { - if self.received_record_numbers.is_empty() { - return Ok(()); - } - - let entries = mem::take(&mut self.received_record_numbers); let epoch = if self.app_send_keys.is_some() { self.app_send_epoch } else { 2 }; + self.send_ack_with_epoch(epoch) + } + + pub(crate) fn send_ack_with_previous_app_epoch(&mut self) -> Result<(), Error> { + if self.prev_app_send_keys.is_some() { + self.send_ack_with_epoch(self.prev_app_send_epoch) + } else { + self.send_ack() + } + } + + fn send_ack_with_epoch(&mut self, epoch: u16) -> Result<(), Error> { + if self.received_record_numbers.is_empty() { + return Ok(()); + } + + let entries = mem::take(&mut self.received_record_numbers); + self.create_ciphertext_record(ContentType::Ack, epoch, false, |fragment| { // record_numbers_length: 2 bytes, value = entries.len() * 16 let len = (entries.len() * 16) as u16; @@ -1881,6 +1898,12 @@ impl Engine { /// the current app epoch, then send keys are rotated (old keys saved /// in `prev_app_send_*` for retransmission). pub fn create_key_update(&mut self, request: KeyUpdateRequest) -> Result<(), Error> { + if self.is_key_update_in_flight() { + return Err(Error::CryptoError( + "KeyUpdate already in flight".to_string(), + )); + } + // Set up retransmission self.flight_backoff.reset(&mut self.rng); self.flight_clear_resends(); @@ -2507,6 +2530,19 @@ mod tests { Engine::new(config, cert) } + #[cfg(feature = "rcgen")] + fn install_test_app_send_keys(engine: &mut Engine) { + engine.set_cipher_suite(Dtls13CipherSuite::AES_128_GCM_SHA256); + + let mut traffic_secret = Buf::new(); + traffic_secret.extend_from_slice(&vec![0x42; engine.hash_algorithm().output_len()]); + engine.app_send_keys = Some( + engine + .derive_epoch_keys(&traffic_secret) + .expect("derive app send keys"), + ); + } + /// Issue 2: Epoch-0 sequence number must have an overflow guard. /// /// Per RFC 9147 §4.2, implementations MUST NOT allow the sequence number @@ -2605,4 +2641,52 @@ mod tests { "derive_early_secret must use the buffer pool, returning a buffer with pooled capacity" ); } + + #[test] + #[cfg(feature = "rcgen")] + fn peer_key_update_response_does_not_replace_in_flight_local_update() { + let mut engine = test_engine(); + install_test_app_send_keys(&mut engine); + + engine + .create_key_update(KeyUpdateRequest::UpdateRequested) + .expect("create local KeyUpdate"); + + let prev_epoch = engine.prev_app_send_epoch; + let prev_seq = engine.prev_app_send_seq; + let saved_records: Vec<_> = engine + .flight_saved_records + .iter() + .map(|entry| { + ( + entry.content_type, + entry.epoch, + entry.send_seq, + entry.fragment.as_ref().to_vec(), + entry.acked, + ) + }) + .collect(); + let send_epoch = engine.app_send_epoch; + let send_seq = engine.app_send_seq; + + let result = engine.create_key_update(KeyUpdateRequest::UpdateNotRequested); + assert!( + result.is_err(), + "a peer-requested KeyUpdate response must wait while a local KeyUpdate is in flight" + ); + + assert_eq!(engine.prev_app_send_epoch, prev_epoch); + assert_eq!(engine.prev_app_send_seq, prev_seq); + assert_eq!(engine.app_send_epoch, send_epoch); + assert_eq!(engine.app_send_seq, send_seq); + assert_eq!(engine.flight_saved_records.len(), saved_records.len()); + for (entry, saved) in engine.flight_saved_records.iter().zip(saved_records) { + assert_eq!(entry.content_type, saved.0); + assert_eq!(entry.epoch, saved.1); + assert_eq!(entry.send_seq, saved.2); + assert_eq!(entry.fragment.as_ref(), saved.3.as_slice()); + assert_eq!(entry.acked, saved.4); + } + } } diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index 604e2d72..12ada56f 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -37,6 +37,15 @@ impl Incoming { pub fn into_records(self) -> impl Iterator { self.records.records.into_iter() } + + pub(crate) fn ackable_record_numbers(&self) -> impl Iterator + '_ { + self.records + .records + .iter() + .map(|record| record.record().sequence) + .filter(|seq| seq.epoch >= 2) + .map(|seq| (seq.epoch as u64, seq.sequence_number)) + } } impl Incoming { diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index 55c6e427..92cb03cd 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -223,7 +223,8 @@ impl Handshake { let qualifies = matches!( self.header.msg_type, - MessageType::ClientHello // flight 1 + MessageType::ClientHello | // flight 1 + MessageType::KeyUpdate // post-handshake flight ); qualifies.then_some(self.header.message_seq) diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 64f7e25a..dc642ec5 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -285,6 +285,56 @@ impl Server { .create_key_update(KeyUpdateRequest::UpdateRequested) } + fn send_pending_key_update_response(&mut self) -> Result<(), Error> { + if self.pending_key_update_response && !self.engine.is_key_update_in_flight() { + self.engine.send_ack()?; + self.engine + .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; + self.pending_key_update_response = false; + } + Ok(()) + } + + fn handle_incoming_key_update(&mut self) -> Result<(), Error> { + if self.engine.has_complete_handshake(MessageType::KeyUpdate) { + let maybe = self.engine.next_handshake_no_transcript( + MessageType::KeyUpdate, + &mut self.defragment_buffer, + )?; + + if let Some(handshake) = maybe { + let Body::KeyUpdate(request) = handshake.body else { + unreachable!() + }; + + // Install new recv keys + self.engine.update_recv_keys()?; + + // If peer requested us to update, schedule our own KeyUpdate + let local_key_update_in_flight = self.engine.is_key_update_in_flight(); + if request == KeyUpdateRequest::UpdateRequested { + self.pending_key_update_response = true; + if local_key_update_in_flight { + self.engine.send_ack_with_previous_app_epoch()?; + } else { + self.engine.send_ack()?; + } + } else { + self.engine.send_ack()?; + } + + self.engine.advance_peer_handshake_seq(); + debug!("Received KeyUpdate (request={:?})", request); + + // Drain a fresh peer-requested response in the same progress + // pass when no local KeyUpdate is in flight. + self.send_pending_key_update_response()?; + } + } + + Ok(()) + } + /// Send application data when the server is connected. pub fn send_application_data(&mut self, data: &[u8]) -> Result<(), Error> { if self.state == State::Closed || self.state == State::HalfClosedLocal { @@ -1122,8 +1172,13 @@ impl State { } fn await_application_data(self, server: &mut Server) -> Result { + // Incoming peer requests require an update_not_requested response. They + // take priority over local AEAD-limit updates and queued app data. + server.handle_incoming_key_update()?; + server.send_pending_key_update_response()?; + // Auto-trigger KeyUpdate when AEAD encryption limit is reached - if server.engine.needs_key_update() && !server.engine.is_key_update_in_flight() { + if !server.engine.is_key_update_in_flight() && server.engine.needs_key_update() { server.initiate_key_update()?; } @@ -1146,42 +1201,6 @@ impl State { } } - // Send pending KeyUpdate response before processing new KeyUpdates - if server.pending_key_update_response { - server - .engine - .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; - server.pending_key_update_response = false; - } - - // Check for incoming KeyUpdate - if server.engine.has_complete_handshake(MessageType::KeyUpdate) { - let maybe = server.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, - &mut server.defragment_buffer, - )?; - - if let Some(handshake) = maybe { - let Body::KeyUpdate(request) = handshake.body else { - unreachable!() - }; - - // Install new recv keys - server.engine.update_recv_keys()?; - - // ACK the KeyUpdate record - server.engine.send_ack()?; - - // If peer requested us to update, schedule our own KeyUpdate - if request == KeyUpdateRequest::UpdateRequested { - server.pending_key_update_response = true; - } - - server.engine.advance_peer_handshake_seq(); - debug!("Received KeyUpdate (request={:?})", request); - } - } - Ok(self) } @@ -1455,3 +1474,126 @@ fn serialize_certificate_authorities( output.extend_from_slice(data); } } + +#[cfg(all(test, feature = "rcgen"))] +mod tests { + use std::sync::Arc; + use std::time::{Duration, Instant}; + + use super::*; + use crate::certificate::generate_self_signed_certificate; + use crate::dtls13::client::Client; + use crate::dtls13::engine::Engine; + + fn collect_client(client: &mut Client) -> (Vec>, bool) { + let mut packets = Vec::new(); + let mut connected = false; + let mut buf = vec![0u8; 2048]; + + loop { + match client.poll_output(&mut buf) { + Output::Packet(packet) => packets.push(packet.to_vec()), + Output::Connected => connected = true, + Output::Timeout(_) => break, + _ => {} + } + } + + (packets, connected) + } + + fn collect_server(server: &mut Server) -> (Vec>, bool) { + let mut packets = Vec::new(); + let mut connected = false; + let mut buf = vec![0u8; 2048]; + + loop { + match server.poll_output(&mut buf) { + Output::Packet(packet) => packets.push(packet.to_vec()), + Output::Connected => connected = true, + Output::Timeout(_) => break, + _ => {} + } + } + + (packets, connected) + } + + fn deliver_to_server(packets: &[Vec], server: &mut Server) { + for packet in packets { + server.handle_packet(packet).expect("server handles packet"); + } + } + + fn deliver_to_client(packets: &[Vec], client: &mut Client) { + for packet in packets { + client.handle_packet(packet).expect("client handles packet"); + } + } + + fn complete_handshake(client: &mut Client, server: &mut Server, mut now: Instant) -> Instant { + let mut client_connected = false; + let mut server_connected = false; + + for _ in 0..40 { + client.handle_timeout(now).expect("client timeout"); + server.handle_timeout(now).expect("server timeout"); + + let (client_packets, client_event) = collect_client(client); + let (server_packets, server_event) = collect_server(server); + + client_connected |= client_event; + server_connected |= server_event; + + deliver_to_server(&client_packets, server); + deliver_to_client(&server_packets, client); + + if client_connected && server_connected { + return now; + } + + now += Duration::from_millis(10); + } + + panic!("DTLS 1.3 handshake did not complete"); + } + + #[test] + fn pending_key_update_response_does_not_replace_server_local_key_update() { + let _ = env_logger::try_init(); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build config"), + ); + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let now = Instant::now(); + + let client_engine = Engine::new(Arc::clone(&config), client_cert); + let mut client = Client::new_with_engine(client_engine, now); + let mut server = Server::new(config, server_cert, now); + + let _now = complete_handshake(&mut client, &mut server, now); + + server + .engine + .create_key_update(KeyUpdateRequest::UpdateRequested) + .expect("server creates local KeyUpdate"); + let (server_key_update, _) = collect_server(&mut server); + assert!(!server_key_update.is_empty()); + assert!(server.engine.is_key_update_in_flight()); + + server.pending_key_update_response = true; + server + .send_pending_key_update_response() + .expect("pending response check while server KeyUpdate is in flight"); + let (early_response, _) = collect_server(&mut server); + assert!( + early_response.is_empty(), + "server must not replace an in-flight local KeyUpdate with a peer-requested response" + ); + } +} diff --git a/tests/dtls13/key_update.rs b/tests/dtls13/key_update.rs index b9b8ccf9..164ed47a 100644 --- a/tests/dtls13/key_update.rs +++ b/tests/dtls13/key_update.rs @@ -518,6 +518,399 @@ fn dtls13_server_key_update_and_new_epoch_app_data_in_same_datagram() { ); } +#[cfg(feature = "rcgen")] +fn trigger_key_update( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + label: &str, +) -> Vec> { + for i in 0..64 { + sender + .handle_timeout(*now) + .expect("sender checks existing KeyUpdate threshold"); + let pending_key_update = collect_packets(sender); + if !pending_key_update.is_empty() { + return pending_key_update; + } + + let msg = format!("{label}-{i}").into_bytes(); + sender + .send_application_data(&msg) + .expect("sender sends priming app data"); + let data_packets = collect_packets(sender); + assert_eq!( + data_packets.len(), + 1, + "sender should emit exactly one priming app-data packet before KeyUpdate" + ); + + deliver_packets(&data_packets, receiver); + receiver + .handle_timeout(*now) + .expect("receiver handles priming data"); + let receiver_out = drain_outputs(receiver); + assert!( + receiver_out + .app_data + .iter() + .any(|received| received == &msg), + "receiver should deliver priming app data before KeyUpdate" + ); + deliver_packets(&receiver_out.packets, sender); + + *now += Duration::from_millis(10); + sender + .handle_timeout(*now) + .expect("sender checks KeyUpdate threshold"); + let key_update = collect_packets(sender); + if !key_update.is_empty() { + return key_update; + } + } + + panic!("{label}: sender did not emit KeyUpdate within bounded attempts"); +} + +#[cfg(feature = "rcgen")] +fn dtls13_ciphertext_epoch_bits(packets: &[Vec]) -> Vec { + let mut epochs = Vec::new(); + + for packet in packets { + let mut rest = packet.as_slice(); + while !rest.is_empty() { + assert_eq!( + rest[0] & 0b1110_0000, + 0b0010_0000, + "expected DTLS 1.3 ciphertext unified header" + ); + assert_eq!( + rest[0] & 0b0001_0000, + 0, + "CID-bearing DTLS 1.3 records are not expected in these tests" + ); + + let s_flag = rest[0] & 0b0000_1000 != 0; + let l_flag = rest[0] & 0b0000_0100 != 0; + assert!(l_flag, "test records should carry explicit lengths"); + + let seq_len = if s_flag { 2 } else { 1 }; + let header_len = 1 + seq_len + 2; + assert!( + rest.len() >= header_len, + "truncated DTLS 1.3 ciphertext header" + ); + + let len_offset = 1 + seq_len; + let body_len = u16::from_be_bytes([rest[len_offset], rest[len_offset + 1]]) as usize; + let record_len = header_len + body_len; + assert!( + rest.len() >= record_len, + "truncated DTLS 1.3 ciphertext record" + ); + + epochs.push(rest[0] & 0x03); + rest = &rest[record_len..]; + } + } + + epochs +} + +#[cfg(feature = "rcgen")] +fn assert_peer_requested_response_waits_for_local_update_ack( + local: &mut Dtls, + peer: &mut Dtls, + now: &mut Instant, + local_label: &str, + peer_label: &str, + post_overlap: &[u8], +) { + let local_key_update = trigger_key_update(local, peer, now, local_label); + let peer_key_update = trigger_key_update(peer, local, now, peer_label); + + deliver_packets(&peer_key_update, local); + let local_ack_only = collect_packets(local); + assert!( + !local_ack_only.is_empty(), + "local endpoint should ACK peer KeyUpdate without sending the pending response" + ); + assert_eq!( + dtls13_ciphertext_epoch_bits(&local_ack_only), + vec![3], + concat!( + "local ACK must use the retained previous app epoch ", + "so peer can decrypt it before receiving local KeyUpdate" + ) + ); + + deliver_packets(&local_ack_only, peer); + let peer_after_local_ack = collect_packets(peer); + assert!( + peer_after_local_ack.is_empty(), + "local endpoint must not send its pending KeyUpdate response before its local update is ACKed" + ); + + let mut peer_after_local_ack_timeout = Vec::new(); + for _ in 0..20 { + *now += Duration::from_millis(500); + peer.handle_timeout(*now) + .expect("peer checks whether its KeyUpdate is still in flight"); + peer_after_local_ack_timeout.extend(collect_packets(peer)); + } + assert!( + peer_after_local_ack_timeout.is_empty(), + "peer should not retransmit its KeyUpdate after local endpoint ACKs it" + ); + + deliver_packets(&local_key_update, peer); + let peer_ack_for_local_update = collect_packets(peer); + assert!( + !peer_ack_for_local_update.is_empty(), + "peer should ACK local KeyUpdate" + ); + + deliver_packets(&peer_ack_for_local_update, local); + let local_response = collect_packets(local); + assert!( + !local_response.is_empty(), + "pending peer-requested ACK and response should drain once the local update is ACKed" + ); + + deliver_packets(&local_response, peer); + let peer_ack_for_local_response = collect_packets(peer); + assert!( + !peer_ack_for_local_response.is_empty(), + "peer should consume and ACK the delayed KeyUpdate response" + ); + deliver_packets(&peer_ack_for_local_response, local); + let local_after_response_ack = collect_packets(local); + assert!( + local_after_response_ack.is_empty(), + "delayed response must be update_not_requested and must not trigger a further response" + ); + + local + .send_application_data(post_overlap) + .expect("local sends post-overlap data"); + let local_post_overlap = collect_packets(local); + assert_eq!( + local_post_overlap.len(), + 1, + concat!( + "delayed KeyUpdate response must clear the peer-requested response ", + "before auto KeyUpdate can send another UpdateRequested" + ) + ); + deliver_packets(&local_post_overlap, peer); + peer.handle_timeout(*now) + .expect("peer handles post-overlap data"); + let peer_post_overlap = drain_outputs(peer); + assert!( + peer_post_overlap + .app_data + .iter() + .any(|received| received == post_overlap), + "peer should receive post-overlap app data" + ); +} + +/// A peer-requested KeyUpdate response should be emitted in the same progress +/// pass when no local KeyUpdate is in flight. Otherwise it can sit pending +/// until unrelated input or a timeout drives the state machine again. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_peer_requested_key_update_response_drains_immediately_without_local_update() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let server_key_update = + trigger_key_update(&mut server, &mut client, &mut now, "server-immediate"); + + deliver_packets(&server_key_update, &mut client); + let client_response = collect_packets(&mut client); + assert!( + !client_response.is_empty(), + "client should emit KeyUpdate ACK and response without waiting for another tick" + ); + + deliver_packets(&client_response, &mut server); + let server_ack_for_client_response = collect_packets(&mut server); + assert!( + !server_ack_for_client_response.is_empty(), + "server should consume and ACK the immediate KeyUpdate response" + ); + deliver_packets(&server_ack_for_client_response, &mut client); + let client_after_response_ack = collect_packets(&mut client); + assert!( + client_after_response_ack.is_empty(), + "immediate response must be update_not_requested and must not trigger a further response" + ); + + client + .send_application_data(b"client-after-immediate-response") + .expect("client sends post-response data"); + deliver_packets(&collect_packets(&mut client), &mut server); + server + .handle_timeout(now) + .expect("server handles post-response data"); + let server_after_response = drain_outputs(&mut server); + assert_eq!( + server_after_response.app_data, + vec![b"client-after-immediate-response".to_vec()] + ); +} + +/// A peer-requested KeyUpdate response must not replace an in-flight locally +/// initiated KeyUpdate. The response is delayed until the local KeyUpdate is +/// ACKed, then sent normally. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_peer_requested_key_update_response_waits_for_local_update_ack() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_peer_requested_response_waits_for_local_update_ack( + &mut client, + &mut server, + &mut now, + "client-local", + "server-peer", + b"client-post-overlap", + ); +} + +/// Same as the client immediate-response case, but with the server as the +/// responder so the server-side pending response drain is covered too. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_server_peer_requested_key_update_response_drains_immediately_without_local_update() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let client_key_update = trigger_key_update(&mut client, &mut server, &mut now, "client-peer"); + + deliver_packets(&client_key_update, &mut server); + let server_response = collect_packets(&mut server); + assert!( + !server_response.is_empty(), + "server should emit KeyUpdate ACK and response without waiting for another tick" + ); + + deliver_packets(&server_response, &mut client); + let client_ack_for_server_response = collect_packets(&mut client); + assert!( + !client_ack_for_server_response.is_empty(), + "client should consume and ACK the immediate KeyUpdate response" + ); + deliver_packets(&client_ack_for_server_response, &mut server); + let server_after_response_ack = collect_packets(&mut server); + assert!( + server_after_response_ack.is_empty(), + "immediate response must be update_not_requested and must not trigger a further response" + ); + + server + .send_application_data(b"server-after-immediate-response") + .expect("server sends post-response data"); + deliver_packets(&collect_packets(&mut server), &mut client); + client + .handle_timeout(now) + .expect("client handles post-response data"); + let client_after_response = drain_outputs(&mut client); + assert!( + client_after_response + .app_data + .iter() + .any(|received| received == b"server-after-immediate-response"), + "client should receive post-response app data" + ); +} + /// Test that a reordered packet captured before a KeyUpdate is accepted when /// delivered alongside other packets during the transition. The packet is from /// the same epoch and arrives before any new-epoch records, so the replay @@ -876,6 +1269,122 @@ fn dtls13_key_update_with_packet_loss() { ); } +/// Test that a sender can recover when the KeyUpdate reaches the peer but the +/// peer's ACK is lost. The peer must treat the retransmitted KeyUpdate as a +/// duplicate that triggers a fresh ACK; otherwise the sender remains stuck with +/// previous send keys retained forever. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_lost_ack_retransmission_gets_acknowledged() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let client_key_update = trigger_key_update(&mut client, &mut server, &mut now, "client"); + + deliver_packets(&client_key_update, &mut server); + let lost_server_ack = collect_packets(&mut server); + assert!(!lost_server_ack.is_empty(), "server should ACK KeyUpdate"); + + trigger_timeout(&mut client, &mut now); + client + .handle_timeout(now) + .expect("arm KeyUpdate flight timer"); + trigger_timeout(&mut client, &mut now); + let retransmitted_key_update = collect_packets(&mut client); + assert!( + !retransmitted_key_update.is_empty(), + "client should retransmit unacked KeyUpdate" + ); + + deliver_packets(&retransmitted_key_update, &mut server); + let server_ack_for_retransmit = collect_packets(&mut server); + assert!( + !server_ack_for_retransmit.is_empty(), + "server should ACK duplicate retransmitted KeyUpdate after the original ACK was lost" + ); + assert_eq!( + dtls13_ciphertext_epoch_bits(&server_ack_for_retransmit), + vec![3, 0], + concat!( + "server should retransmit the old-epoch KeyUpdate response ", + "before appending the fresh current-epoch ACK" + ) + ); + + deliver_packets(&server_ack_for_retransmit, &mut client); + client + .handle_timeout(now) + .expect("client handles ACK for retransmitted KeyUpdate"); + let client_ack_for_server_response = collect_packets(&mut client); + assert!( + !client_ack_for_server_response.is_empty(), + "client should ACK the server's retransmitted KeyUpdate response" + ); + + let mut client_after_retransmit_ack = Vec::new(); + for _ in 0..20 { + now += Duration::from_millis(500); + client + .handle_timeout(now) + .expect("client checks whether KeyUpdate is still in flight"); + client_after_retransmit_ack.extend(collect_packets(&mut client)); + } + assert!( + client_after_retransmit_ack.is_empty(), + "client should not retransmit KeyUpdate after ACKing the duplicate retransmission" + ); + + deliver_packets(&client_ack_for_server_response, &mut server); + server + .handle_timeout(now) + .expect("server handles ACK for retransmitted response"); + let server_after_ack = collect_packets(&mut server); + assert!( + server_after_ack.is_empty(), + "server should clear its retransmitted KeyUpdate response after ACK" + ); + + client + .send_application_data(b"post-lost-ack") + .expect("client sends post-recovery data"); + deliver_packets(&collect_packets(&mut client), &mut server); + server + .handle_timeout(now) + .expect("server handles post-recovery data"); + let server_after_recovery = drain_outputs(&mut server); + assert_eq!( + server_after_recovery.app_data, + vec![b"post-lost-ack".to_vec()] + ); +} + /// Test that high-frequency KeyUpdates work correctly. With the minimum /// AEAD limit of 2, nearly every message triggers a KeyUpdate. This stress- /// tests the key rotation machinery and epoch tracking under extreme churn,