diff --git a/CHANGELOG.md b/CHANGELOG.md index 261bfb8a..561b81c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Unreleased + * Restore DTLS 1.3 saved flights on resend backpressure #124 + * Preserve overlapping DTLS 1.3 KeyUpdate flights #123 + * Preserve DTLS 1.3 app data following KeyUpdate in one datagram #122 + * Fix malformed datagrams consuming DTLS replay-window state #121 + * Bound DTLS 1.3 ACK tracking during handshake replacement #120 * Replace pending DTLS 1.2 handshake output on resend #116 * Discard bad protected DTLS 1.2 records after handshake #115 * Reject oversized DTLS certificate lists #113 diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index 2c41ebbd..373c424b 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -8,6 +8,7 @@ use crate::Error; use crate::buffer::{Buf, TmpBuf}; use crate::crypto::{Aad, Nonce}; use crate::dtls12::message::{ContentType, DTLSRecord, Dtls12CipherSuite, Handshake, Sequence}; +use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. pub struct Incoming { @@ -67,12 +68,14 @@ pub struct Records { } impl Records { - pub fn parse( + fn parse( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut replay_updates: ArrayVec = ArrayVec::new(); + let mut pending_replay = ReplayWindow::new(); // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { @@ -91,14 +94,29 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; match Record::parse(record_slice, decrypt, cs) { - Ok(record) => { - if let Some(record) = record { + Ok(parsed) => { + if let Some(sequence) = parsed.replay_sequence { + if !pending_replay.check(sequence.sequence_number) { + trace!("Discarding duplicate rec in same datagram"); + packet = &packet[record_end..]; + continue; + } + } + + if let Some(record) = parsed.record { if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } - } else { + } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } + + if let Some(sequence) = parsed.replay_sequence { + pending_replay.update(sequence.sequence_number); + if replay_updates.try_push(sequence).is_err() { + return Err(Error::TooManyRecords); + } + } } Err(e) => return Err(e), } @@ -106,6 +124,13 @@ impl Records { packet = &packet[record_end..]; } + // Commit replay state only after the whole UDP datagram has parsed + // successfully. A malformed trailing record must not consume + // replay state for an earlier authenticated record in the same datagram. + for sequence in replay_updates { + decrypt.replay_update(sequence); + } + let mut records = ArrayVec::new(); for record in parsed_records { if let Some(record) = decrypt.classify_record(record)? { @@ -134,14 +159,19 @@ pub struct Record { parsed: Box, } +struct RecordParse { + record: Option, + replay_sequence: Option, +} + impl Record { /// The first parse pass only parses the DTLSRecord header which is unencrypted. /// Copies record data from UDP packet ONCE into a pooled buffer. - pub fn parse( + fn parse( record_slice: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + ) -> Result { // ONLY COPY: UDP packet slice -> pooled buffer let mut buffer = Buf::new(); buffer.extend_from_slice(record_slice); @@ -151,7 +181,10 @@ impl Record { // RFC 6347 §4.1.2.7: Invalid records SHOULD be silently discarded. // This includes epoch 0 records with invalid ContentType. trace!("Discarding record: parse failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } }; let parsed = Box::new(parsed); @@ -162,7 +195,10 @@ impl Record { // packet loss, we can end up seeing epoch 1 records before we can decrypt them. let is_epoch_0 = record.record().sequence.epoch == 0; if is_epoch_0 || !decrypt.is_peer_encryption_enabled() { - return Ok(Some(record)); + return Ok(RecordParse { + record: Some(record), + replay_sequence: None, + }); } // We need to decrypt the record and redo the parsing. @@ -171,12 +207,18 @@ impl Record { // Anti-replay check (read-only, does not update window) if !decrypt.replay_check(sequence) { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } let explicit_nonce_len = decrypt.explicit_nonce_len(); if (dtls.length as usize) < decrypt.min_protected_fragment_len() { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } // Get a reference to the buffer @@ -203,25 +245,35 @@ impl Record { } trace!("Discarding record: decrypt failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 6347 §4.1.2.6: "The receive window is updated only if the - // MAC verification succeeds." - decrypt.replay_update(sequence); - // Update the length of the record. buffer[11] = (new_len >> 8) as u8; buffer[12] = new_len as u8; - let parsed = ParsedRecord::parse(&buffer, cs, explicit_nonce_len)?; + let parsed = match ParsedRecord::parse(&buffer, cs, explicit_nonce_len) { + Ok(parsed) => parsed, + Err(e) => { + trace!("Discarding authenticated record: parse failed: {}", e); + return Ok(RecordParse { + record: None, + replay_sequence: Some(sequence), + }); + } + }; let parsed = Box::new(parsed); - Ok(Some(Record { buffer, parsed })) + Ok(RecordParse { + record: Some(Record { buffer, parsed }), + replay_sequence: Some(sequence), + }) } pub fn record(&self) -> &DTLSRecord { diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index b18408da..ddda80aa 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -107,6 +107,9 @@ pub struct Client { /// Whether we need to respond with our own KeyUpdate pending_key_update_response: bool, + /// Whether a peer KeyUpdate ACK needs to be retried. + pending_key_update_ack: bool, + /// Active key exchange state (ECDHE) active_key_exchange: Option>, @@ -153,6 +156,7 @@ impl Client { local_events: VecDeque::new(), queued_data: Vec::new(), pending_key_update_response: false, + pending_key_update_ack: false, active_key_exchange: None, hello_retry: false, hrr_selected_group: None, @@ -198,6 +202,7 @@ impl Client { local_events: VecDeque::new(), queued_data: Vec::new(), pending_key_update_response: false, + pending_key_update_ack: false, active_key_exchange: Some(hybrid.active_key_exchange), hello_retry: false, hrr_selected_group: None, @@ -219,7 +224,7 @@ impl Client { pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { self.engine.parse_packet(packet)?; - self.make_progress()?; + self.make_progress_and_drain_deferred()?; Ok(()) } @@ -227,7 +232,13 @@ impl Client { if let Some(event) = self.local_events.pop_front() { return event.into_output(buf, &self.server_certificates); } - self.engine.poll_output(buf, self.last_now) + + match self.engine.poll_output(buf, self.last_now) { + Output::Timeout(_) if self.has_pending_local_progress() => { + Output::Timeout(self.last_now) + } + output => output, + } } /// Explicitly start the handshake process by sending a ClientHello @@ -237,15 +248,89 @@ impl Client { self.random = Some(self.engine.random()); } self.engine.handle_timeout(now)?; + self.make_progress_and_drain_deferred()?; + Ok(()) + } + + fn make_progress_and_drain_deferred(&mut self) -> Result<(), Error> { self.make_progress()?; + while self.engine.parse_next_deferred_packet()? { + self.make_progress()?; + } Ok(()) } + fn has_pending_local_progress(&self) -> bool { + self.pending_key_update_ack + || self.engine.has_pending_post_handshake_ack() + || (self.pending_key_update_response && !self.engine.is_key_update_in_flight()) + } + fn initiate_key_update(&mut self) -> Result<(), Error> { self.engine .create_key_update(KeyUpdateRequest::UpdateRequested) } + fn send_pending_key_update_response(&mut self) -> Result<(), Error> { + if self.pending_key_update_response && !self.engine.is_key_update_in_flight() { + self.engine.send_ack()?; + self.engine + .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; + self.pending_key_update_response = false; + } + Ok(()) + } + + fn send_pending_key_update_ack(&mut self) -> Result<(), Error> { + if !self.pending_key_update_ack && !self.engine.has_pending_post_handshake_ack() { + return Ok(()); + } + + let result = if self.engine.is_key_update_in_flight() { + self.engine.send_ack_with_previous_app_epoch() + } else { + self.engine.send_ack() + }; + + result?; + self.pending_key_update_ack = false; + Ok(()) + } + + fn handle_incoming_key_update(&mut self) -> Result<(), Error> { + if self.engine.has_complete_handshake(MessageType::KeyUpdate) { + let maybe = self.engine.next_handshake_no_transcript( + MessageType::KeyUpdate, + &mut self.defragment_buffer, + )?; + + if let Some(handshake) = maybe { + let Body::KeyUpdate(request) = handshake.body else { + unreachable!() + }; + + // Install new recv keys + self.engine.update_recv_keys()?; + self.engine.advance_peer_handshake_seq(); + self.pending_key_update_ack = true; + + // If peer requested us to update, schedule our own KeyUpdate + if request == KeyUpdateRequest::UpdateRequested { + self.pending_key_update_response = true; + } + self.send_pending_key_update_ack()?; + + debug!("Received KeyUpdate (request={:?})", request); + + // Drain a fresh peer-requested response in the same progress + // pass when no local KeyUpdate is in flight. + self.send_pending_key_update_response()?; + } + } + + Ok(()) + } + /// Send application data when the client is connected. pub fn send_application_data(&mut self, data: &[u8]) -> Result<(), Error> { if self.state == State::Closed || self.state == State::HalfClosedLocal { @@ -1048,8 +1133,14 @@ impl State { } fn await_application_data(self, client: &mut Client) -> Result { + // Incoming peer requests require an update_not_requested response. They + // take priority over local AEAD-limit updates and queued app data. + client.handle_incoming_key_update()?; + client.send_pending_key_update_ack()?; + client.send_pending_key_update_response()?; + // Auto-trigger KeyUpdate when AEAD encryption limit is reached - if client.engine.needs_key_update() && !client.engine.is_key_update_in_flight() { + if !client.engine.is_key_update_in_flight() && client.engine.needs_key_update() { client.initiate_key_update()?; } @@ -1072,42 +1163,6 @@ impl State { } } - // Send pending KeyUpdate response before processing new KeyUpdates - if client.pending_key_update_response { - client - .engine - .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; - client.pending_key_update_response = false; - } - - // Check for incoming KeyUpdate - if client.engine.has_complete_handshake(MessageType::KeyUpdate) { - let maybe = client.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, - &mut client.defragment_buffer, - )?; - - if let Some(handshake) = maybe { - let Body::KeyUpdate(request) = handshake.body else { - unreachable!() - }; - - // Install new recv keys - client.engine.update_recv_keys()?; - - // ACK the KeyUpdate record - client.engine.send_ack()?; - - // If peer requested us to update, schedule our own KeyUpdate - if request == KeyUpdateRequest::UpdateRequested { - client.pending_key_update_response = true; - } - - client.engine.advance_peer_handshake_seq(); - debug!("Received KeyUpdate (request={:?})", request); - } - } - Ok(self) } diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 984b06b5..67c750cd 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -36,6 +36,11 @@ const MAX_DEFRAGMENT_PACKETS: usize = 50; /// implementations MUST NOT allow the sequence number to wrap. const MAX_SEQUENCE_NUMBER: u64 = (1u64 << 48) - 1; +struct FlightResendError { + error: Error, + queued_any: bool, +} + pub struct Engine { /// Configuration options. config: Arc, @@ -58,6 +63,10 @@ pub struct Engine { /// Queue of outgoing packets. queue_tx: QueueTx, + /// Deferred tail datagrams split after KeyUpdate records. They are parsed + /// after the state machine installs the next receive keys. + deferred_packets: ArrayVec, + /// The cipher suite in use. Set by ServerHello. cipher_suite: Option, @@ -193,10 +202,28 @@ struct Entry { content_type: ContentType, epoch: u16, send_seq: u64, + send_seq_history: Vec, fragment: Buf, acked: bool, } +impl Entry { + fn remember_current_send_seq(&mut self, history_limit: usize) { + if self.send_seq_history.contains(&self.send_seq) { + return; + } + let history_limit = history_limit.max(1); + if self.send_seq_history.len() == history_limit { + self.send_seq_history.remove(0); + } + self.send_seq_history.push(self.send_seq); + } + + fn is_acked_by(&self, epoch: u64, seq: u64) -> bool { + self.epoch as u64 == epoch && (self.send_seq == seq || self.send_seq_history.contains(&seq)) + } +} + impl Engine { pub fn new(config: Arc, certificate: DtlsCertificate) -> Self { let mut rng = SeededRng::new(config.rng_seed()); @@ -221,6 +248,7 @@ impl Engine { sequence_epoch_0: Sequence::new(0), queue_rx: QueueRx::new(), queue_tx: QueueTx::new(), + deferred_packets: ArrayVec::new(), cipher_suite: None, hs_send_keys: None, hs_recv_keys: None, @@ -305,14 +333,9 @@ impl Engine { } /// Returns true if the AEAD encryption limit has been reached and a - /// KeyUpdate should be initiated. Clears the flag after returning true. - pub fn needs_key_update(&mut self) -> bool { - if self.needs_key_update { - self.needs_key_update = false; - true - } else { - false - } + /// KeyUpdate should be initiated. + pub fn needs_key_update(&self) -> bool { + self.needs_key_update } pub fn is_cipher_suite_allowed(&self, suite: Dtls13CipherSuite) -> bool { @@ -330,15 +353,35 @@ impl Engine { } pub fn parse_packet(&mut self, packet: &[u8]) -> Result<(), Error> { + self.parse_packet_once(packet) + } + + fn parse_packet_once(&mut self, packet: &[u8]) -> Result<(), Error> { let cs = self.cipher_suite; - let incoming = Incoming::parse_packet(packet, self, cs)?; - if let Some(incoming) = incoming { + let parsed = Incoming::parse_packet_defer_after_key_update(packet, self, cs)?; + if let Some(incoming) = parsed.incoming { self.insert_incoming(incoming)?; } + if let Some(deferred_tail) = parsed.deferred_tail { + self.deferred_packets + .try_push(deferred_tail) + .map_err(|_| Error::TooManyRecords)?; + } Ok(()) } + pub fn parse_next_deferred_packet(&mut self) -> Result { + if self.deferred_packets.is_empty() { + return Ok(false); + } + + let packet = self.deferred_packets.remove(0); + self.parse_packet_once(&packet)?; + + Ok(true) + } + fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { if self.queue_rx.len() >= self.config.max_queue_rx() { warn!( @@ -376,7 +419,11 @@ impl Engine { if let Some(dupe_seq) = maybe_dupe_seq { if dupe_seq < self.peer_handshake_seq_no { + for seq in incoming.ackable_record_numbers() { + let _ = self.received_record_numbers.try_push(seq); + } self.flight_resend("dupe triggers resend")?; + self.send_ack()?; } } @@ -442,8 +489,9 @@ impl Engine { for record in incoming.records().iter() { let seq = record.record().sequence; if seq.epoch >= 2 { - self.received_record_numbers - .push((seq.epoch as u64, seq.sequence_number)); + let _ = self + .received_record_numbers + .try_push((seq.epoch as u64, seq.sequence_number)); } } self.queue_rx[index] = incoming; @@ -504,14 +552,15 @@ impl Engine { if now >= flight_timeout { if self.flight_backoff.can_retry() { - self.flight_backoff.attempt(&mut self.rng); - debug!( - "Re-arm flight timeout due to resend in {}", - self.flight_backoff.rto().as_secs_f32() - ); - let timeout = now + self.flight_backoff.rto(); - self.flight_timeout = Timeout::Armed(timeout); - self.flight_resend("flight timeout")?; + match self.flight_resend_with_progress("flight timeout") { + Ok(_) => self.rearm_flight_timeout_after_resend(now), + Err(error) + if error.queued_any && matches!(error.error, Error::TransmitQueueFull) => + { + self.rearm_flight_timeout_after_resend(now); + } + Err(error) => return Err(error.error), + } } else { return Err(Error::Timeout("handshake")); } @@ -524,6 +573,16 @@ impl Engine { Ok(()) } + fn rearm_flight_timeout_after_resend(&mut self, now: Instant) { + self.flight_backoff.attempt(&mut self.rng); + debug!( + "Re-arm flight timeout due to resend in {}", + self.flight_backoff.rto().as_secs_f32() + ); + let timeout = now + self.flight_backoff.rto(); + self.flight_timeout = Timeout::Armed(timeout); + } + pub fn poll_output<'a>(&mut self, buf: &'a mut [u8], now: Instant) -> Output<'a> { self.purge_handled_queue_rx(); @@ -548,6 +607,10 @@ impl Engine { Output::Timeout(next_timeout) } + pub(crate) fn has_pending_post_handshake_ack(&self) -> bool { + self.release_app_data && !self.received_record_numbers.is_empty() + } + fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { if !self.release_app_data { return Err(buf); @@ -671,14 +734,22 @@ impl Engine { } pub fn flight_resend(&mut self, reason: &str) -> Result<(), Error> { + self.flight_resend_with_progress(reason) + .map(|_| ()) + .map_err(|error| error.error) + } + + fn flight_resend_with_progress(&mut self, reason: &str) -> Result { debug!("Resending flight due to {}", reason); let mut records = mem::take(&mut self.flight_saved_records); + let mut queued_any = false; // Mark the current last datagram as "sealed" so resent records // go into fresh datagrams. Without this, can_append would pack // resent records into the same datagram as the original flight, // which causes duplicate handshake fragments at the receiver. self.seal_current_datagram(); + let send_seq_history_limit = self.config.flight_retries().saturating_add(1); for entry in &mut records { // Selective retransmission: skip records that have been ACKed @@ -689,9 +760,15 @@ impl Engine { if entry.epoch == 0 { // Capture the sequence number the retransmitted record will use let new_seq = self.sequence_epoch_0.sequence_number; - self.create_plaintext_record(entry.content_type, false, |fragment| { + let result = self.create_plaintext_record(entry.content_type, false, |fragment| { fragment.extend_from_slice(&entry.fragment); - })?; + }); + if let Err(error) = result { + self.flight_saved_records = records; + return Err(FlightResendError { error, queued_any }); + } + queued_any = true; + entry.remember_current_send_seq(send_seq_history_limit); entry.send_seq = new_seq; } else { // Capture the sequence number the retransmitted record will use @@ -704,21 +781,27 @@ impl Engine { } else { self.app_send_seq }; - self.create_ciphertext_record( + let result = self.create_ciphertext_record( entry.content_type, entry.epoch, false, |fragment| { fragment.extend_from_slice(&entry.fragment); }, - )?; + ); + if let Err(error) = result { + self.flight_saved_records = records; + return Err(FlightResendError { error, queued_any }); + } + queued_any = true; + entry.remember_current_send_seq(send_seq_history_limit); entry.send_seq = new_seq; } } self.flight_saved_records = records; - Ok(()) + Ok(queued_any) } pub fn has_complete_handshake(&mut self, wanted: MessageType) -> bool { @@ -856,16 +939,9 @@ impl Engine { let current_seq = self.sequence_epoch_0.sequence_number; - if save_fragment { - let mut clone = self.buffers_free.pop(); - clone.extend_from_slice(&fragment); - self.flight_saved_records.push(Entry { - content_type, - epoch: 0, - send_seq: current_seq, - fragment: clone, - acked: false, - }); + if save_fragment && self.flight_saved_records.is_full() { + self.buffers_free.push(fragment); + return Err(Error::TransmitQueueFull); } let record_wire_len = Dtls13Record::PLAINTEXT_HEADER_LEN + fragment.len(); @@ -883,6 +959,7 @@ impl Engine { self.config.max_queue_tx(), self.queue_tx ); + self.buffers_free.push(fragment); return Err(Error::TransmitQueueFull); } @@ -896,10 +973,25 @@ impl Engine { }; if self.sequence_epoch_0.sequence_number >= MAX_SEQUENCE_NUMBER { + self.buffers_free.push(fragment); return Err(Error::CryptoError( "Epoch 0 sequence number exhausted".to_string(), )); } + + if save_fragment { + let mut clone = self.buffers_free.pop(); + clone.extend_from_slice(&fragment); + self.flight_saved_records.push(Entry { + content_type, + epoch: 0, + send_seq: current_seq, + send_seq_history: Vec::new(), + fragment: clone, + acked: false, + }); + } + self.sequence_epoch_0.sequence_number += 1; if can_append { @@ -935,6 +1027,11 @@ impl Engine { let mut fragment = self.buffers_free.pop(); f(&mut fragment); + if save_fragment && self.flight_saved_records.is_full() { + self.buffers_free.push(fragment); + return Err(Error::TransmitQueueFull); + } + // Determine sequence number for this record let seq = if epoch == 2 { self.hs_send_seq @@ -944,20 +1041,15 @@ impl Engine { self.app_send_seq }; + // Build DTLSInnerPlaintext: content || content_type(1) + // (no zero padding for now) + let mut saved_fragment = None; if save_fragment { let mut clone = self.buffers_free.pop(); clone.extend_from_slice(&fragment); - self.flight_saved_records.push(Entry { - content_type, - epoch, - send_seq: seq, - fragment: clone, - acked: false, - }); + saved_fragment = Some(clone); } - // Build DTLSInnerPlaintext: content || content_type(1) - // (no zero padding for now) fragment.push(content_type.as_u8()); let suite = self.suite_provider(); @@ -973,6 +1065,10 @@ impl Engine { }; let Some(keys) = keys else { + if let Some(clone) = saved_fragment { + self.buffers_free.push(clone); + } + self.buffers_free.push(fragment); return Err(Error::CryptoError(format!( "Send keys not available for epoch {}", epoch @@ -1005,9 +1101,13 @@ impl Engine { sn_key[..sn_key_len].copy_from_slice(&keys.sn_key); // Encrypt in place (appends tag) - keys.cipher - .encrypt(&mut fragment, aad, nonce) - .map_err(|e| Error::CryptoError(format!("Encryption failed: {}", e)))?; + if let Err(error) = keys.cipher.encrypt(&mut fragment, aad, nonce) { + if let Some(clone) = saved_fragment { + self.buffers_free.push(clone); + } + self.buffers_free.push(fragment); + return Err(Error::CryptoError(format!("Encryption failed: {}", error))); + } // Record number encryption (RFC 9147 Section 4.2.3): // mask = AES-ECB(sn_key, ciphertext_sample) @@ -1036,6 +1136,10 @@ impl Engine { self.config.max_queue_tx(), self.queue_tx ); + if let Some(clone) = saved_fragment { + self.buffers_free.push(clone); + } + self.buffers_free.push(fragment); return Err(Error::TransmitQueueFull); } @@ -1053,6 +1157,10 @@ impl Engine { // Increment send sequence, guarding against 48-bit overflow (RFC 9147 §4.2) if epoch == 2 { if self.hs_send_seq >= MAX_SEQUENCE_NUMBER { + if let Some(clone) = saved_fragment { + self.buffers_free.push(clone); + } + self.buffers_free.push(fragment); return Err(Error::CryptoError( "Handshake epoch sequence number exhausted".to_string(), )); @@ -1060,6 +1168,10 @@ impl Engine { self.hs_send_seq += 1; } else if self.prev_app_send_keys.is_some() && epoch == self.prev_app_send_epoch { if self.prev_app_send_seq >= MAX_SEQUENCE_NUMBER { + if let Some(clone) = saved_fragment { + self.buffers_free.push(clone); + } + self.buffers_free.push(fragment); return Err(Error::CryptoError( "Previous epoch sequence number exhausted".to_string(), )); @@ -1067,6 +1179,10 @@ impl Engine { self.prev_app_send_seq += 1; } else { if self.app_send_seq >= MAX_SEQUENCE_NUMBER { + if let Some(clone) = saved_fragment { + self.buffers_free.push(clone); + } + self.buffers_free.push(fragment); return Err(Error::CryptoError( "Application epoch sequence number exhausted".to_string(), )); @@ -1082,6 +1198,17 @@ impl Engine { } } + if let Some(fragment) = saved_fragment { + self.flight_saved_records.push(Entry { + content_type, + epoch, + send_seq: seq, + send_seq_history: Vec::new(), + fragment, + acked: false, + }); + } + if can_append { let last = self.queue_tx.back_mut().unwrap(); let header_start = last.len(); @@ -1252,17 +1379,30 @@ impl Engine { /// /// ACK format: record_numbers_length(2) + N * (epoch(8) + sequence(8)) pub fn send_ack(&mut self) -> Result<(), Error> { - if self.received_record_numbers.is_empty() { - return Ok(()); - } - - let entries = mem::take(&mut self.received_record_numbers); let epoch = if self.app_send_keys.is_some() { self.app_send_epoch } else { 2 }; + self.send_ack_with_epoch(epoch) + } + + pub(crate) fn send_ack_with_previous_app_epoch(&mut self) -> Result<(), Error> { + if self.prev_app_send_keys.is_some() { + self.send_ack_with_epoch(self.prev_app_send_epoch) + } else { + self.send_ack() + } + } + + fn send_ack_with_epoch(&mut self, epoch: u16) -> Result<(), Error> { + if self.received_record_numbers.is_empty() { + return Ok(()); + } + + let entries = self.received_record_numbers.clone(); + self.create_ciphertext_record(ContentType::Ack, epoch, false, |fragment| { // record_numbers_length: 2 bytes, value = entries.len() * 16 let len = (entries.len() * 16) as u16; @@ -1272,6 +1412,7 @@ impl Engine { fragment.extend_from_slice(&seq.to_be_bytes()); } })?; + self.received_record_numbers.clear(); Ok(()) } @@ -1306,7 +1447,7 @@ impl Engine { // Mark matching flight entries as acknowledged for entry in &mut self.flight_saved_records { - if entry.epoch as u64 == ack_epoch && entry.send_seq == ack_seq { + if entry.is_acked_by(ack_epoch, ack_seq) { entry.acked = true; } } @@ -1792,15 +1933,19 @@ impl Engine { Ok(next) } - /// Rotate send keys: move current app send keys → prev, derive new ones. - fn update_send_keys(&mut self) -> Result<(), Error> { - let current_keys = self.app_send_keys.take().ok_or_else(|| { + fn next_send_keys(&self) -> Result { + let current_keys = self.app_send_keys.as_ref().ok_or_else(|| { Error::CryptoError("No current app send keys for KeyUpdate".to_string()) })?; - let next_secret = self.derive_next_traffic_secret(¤t_keys.traffic_secret)?; - let new_keys = self.derive_epoch_keys(&next_secret)?; + self.derive_epoch_keys(&next_secret) + } + /// Rotate send keys: move current app send keys → prev, install prederived new keys. + fn update_send_keys(&mut self, new_keys: EpochKeys) -> Result<(), Error> { + let current_keys = self.app_send_keys.take().ok_or_else(|| { + Error::CryptoError("No current app send keys for KeyUpdate".to_string()) + })?; // Save old keys for retransmission self.prev_app_send_keys = Some(current_keys); self.prev_app_send_epoch = self.app_send_epoch; @@ -1856,31 +2001,49 @@ impl Engine { /// the current app epoch, then send keys are rotated (old keys saved /// in `prev_app_send_*` for retransmission). pub fn create_key_update(&mut self, request: KeyUpdateRequest) -> Result<(), Error> { - // Set up retransmission - self.flight_backoff.reset(&mut self.rng); - self.flight_clear_resends(); - self.flight_timeout = Timeout::Unarmed; + if self.is_key_update_in_flight() { + return Err(Error::CryptoError( + "KeyUpdate already in flight".to_string(), + )); + } + + let new_send_keys = self.next_send_keys()?; + let old_saved_records = mem::take(&mut self.flight_saved_records); let msg_seq = self.next_handshake_seq_no; - self.next_handshake_seq_no += 1; let epoch = self.app_send_epoch; // Build the handshake message manually (12-byte DTLS header + 1-byte body) - self.create_ciphertext_record(ContentType::Handshake, epoch, true, |fragment| { - // DTLS handshake header (12 bytes): - // msg_type(1) + length(3) + message_seq(2) + fragment_offset(3) + fragment_length(3) - fragment.push(MessageType::KeyUpdate.as_u8()); - fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); // length = 1 - fragment.extend_from_slice(&msg_seq.to_be_bytes()); // message_seq - fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); // fragment_offset = 0 - fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); // fragment_length = 1 - // Body: 1 byte - fragment.push(request.as_u8()); - })?; + let result = + self.create_ciphertext_record(ContentType::Handshake, epoch, true, |fragment| { + // DTLS handshake header (12 bytes): + // msg_type(1) + length(3) + message_seq(2) + fragment_offset(3) + fragment_length(3) + fragment.push(MessageType::KeyUpdate.as_u8()); + fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); // length = 1 + fragment.extend_from_slice(&msg_seq.to_be_bytes()); // message_seq + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); // fragment_offset = 0 + fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); // fragment_length = 1 + // Body: 1 byte + fragment.push(request.as_u8()); + }); + if let Err(error) = result { + self.flight_saved_records = old_saved_records; + return Err(error); + } + + for entry in old_saved_records { + self.buffers_free.push(entry.fragment); + } + + // Set up retransmission for the newly saved KeyUpdate flight only + // after the packet has been queued successfully. + self.flight_backoff.reset(&mut self.rng); + self.flight_timeout = Timeout::Unarmed; + self.next_handshake_seq_no += 1; // Now rotate send keys (saves old keys for retransmission) - self.update_send_keys()?; + self.update_send_keys(new_send_keys)?; debug!( "KeyUpdate sent (request={:?}) on epoch {}, new send epoch {}", @@ -2338,7 +2501,13 @@ impl RecordHandler for Engine { epoch_bits } - fn resolve_sequence(&self, epoch: u16, seq_bits: u64, s_flag: bool) -> u64 { + fn resolve_sequence( + &self, + epoch: u16, + seq_bits: u64, + s_flag: bool, + expected_override: Option, + ) -> u64 { let expected = if epoch == 2 { self.hs_expected_recv_seq } else { @@ -2348,6 +2517,9 @@ impl RecordHandler for Engine { .map(|e| e.expected_recv_seq) .unwrap_or(0) }; + let expected = expected_override + .map(|override_expected| expected.max(override_expected)) + .unwrap_or(expected); let bits: u32 = if s_flag { 16 } else { 8 }; reconstruct_sequence(seq_bits, expected, bits) @@ -2473,6 +2645,162 @@ mod tests { Engine::new(config, cert) } + #[cfg(feature = "rcgen")] + fn test_engine_with_config(config: Config) -> Engine { + let cert = generate_self_signed_certificate().expect("gen cert"); + Engine::new(Arc::new(config), cert) + } + + #[cfg(feature = "rcgen")] + fn install_test_app_send_keys(engine: &mut Engine) { + engine.set_cipher_suite(Dtls13CipherSuite::AES_128_GCM_SHA256); + + let mut traffic_secret = Buf::new(); + traffic_secret.extend_from_slice(&vec![0x42; engine.hash_algorithm().output_len()]); + engine.app_send_keys = Some( + engine + .derive_epoch_keys(&traffic_secret) + .expect("derive app send keys"), + ); + } + + fn saved_record_snapshot(engine: &Engine) -> Vec<(ContentType, u16, u64, Vec, bool)> { + engine + .flight_saved_records + .iter() + .map(|entry| { + ( + entry.content_type, + entry.epoch, + entry.send_seq, + entry.fragment.as_ref().to_vec(), + entry.acked, + ) + }) + .collect() + } + + fn push_dummy_saved_record(engine: &mut Engine, index: usize) { + let mut fragment = Buf::new(); + fragment.extend_from_slice(&[index as u8]); + engine.flight_saved_records.push(Entry { + content_type: ContentType::Handshake, + epoch: 3, + send_seq: index as u64, + send_seq_history: Vec::new(), + fragment, + acked: false, + }); + } + + fn push_saved_app_record(engine: &mut Engine, send_seq: u64, byte: u8, len: usize) { + let mut fragment = Buf::new(); + fragment.extend_from_slice(&vec![byte; len]); + engine.flight_saved_records.push(Entry { + content_type: ContentType::ApplicationData, + epoch: engine.app_send_epoch, + send_seq, + send_seq_history: Vec::new(), + fragment, + acked: false, + }); + } + + fn ack_record_number(epoch: u64, sequence: u64) -> Vec { + let mut ack = Vec::new(); + ack.extend_from_slice(&16u16.to_be_bytes()); + ack.extend_from_slice(&epoch.to_be_bytes()); + ack.extend_from_slice(&sequence.to_be_bytes()); + ack + } + + struct PassthroughRecordHandler; + + impl RecordHandler for PassthroughRecordHandler { + fn classify_record(&mut self, record: Record) -> Result, Error> { + Ok(Some(record)) + } + + fn is_peer_encryption_enabled(&self) -> bool { + true + } + + fn resolve_epoch(&self, _epoch_bits: u8) -> u16 { + 2 + } + + fn resolve_sequence( + &self, + _epoch: u16, + seq_bits: u64, + _s_flag: bool, + _expected_override: Option, + ) -> u64 { + seq_bits + } + + fn replay_check(&self, _seq: Sequence) -> bool { + true + } + + fn replay_update(&mut self, _seq: Sequence) {} + + fn min_protected_fragment_len(&self) -> usize { + 0 + } + + fn decrypt_record( + &mut self, + _header: &[u8], + _seq: Sequence, + _ciphertext: &mut TmpBuf, + ) -> Result<(), Error> { + Ok(()) + } + + fn decrypt_sequence_number( + &self, + _epoch: u16, + _seq_bytes: &mut [u8], + _ciphertext_sample: &[u8; 16], + ) { + } + } + + fn encrypted_key_update_record(seq: u16) -> Vec { + let mut fragment = Vec::new(); + fragment.push(MessageType::KeyUpdate.as_u8()); + fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); + fragment.extend_from_slice(&0u16.to_be_bytes()); + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); + fragment.extend_from_slice(&1u32.to_be_bytes()[1..]); + fragment.push(KeyUpdateRequest::UpdateRequested.as_u8()); + fragment.push(ContentType::Handshake.as_u8()); + + let mut packet = Vec::new(); + packet.push( + 0b0010_0000 + | 0b0000_1000 // 2-byte sequence number. + | 0b0000_0100 // explicit length. + | 0b0000_0010, // epoch bits resolved by PassthroughRecordHandler. + ); + packet.extend_from_slice(&seq.to_be_bytes()); + packet.extend_from_slice(&(fragment.len() as u16).to_be_bytes()); + packet.extend_from_slice(&fragment); + packet + } + + fn parsed_key_update(seq: u16) -> Incoming { + Incoming::parse_packet_defer_after_key_update( + &encrypted_key_update_record(seq), + &mut PassthroughRecordHandler, + Some(Dtls13CipherSuite::AES_128_GCM_SHA256), + ) + .expect("parse key update packet") + .incoming + .expect("packet contains a record") + } + /// Issue 2: Epoch-0 sequence number must have an overflow guard. /// /// Per RFC 9147 §4.2, implementations MUST NOT allow the sequence number @@ -2571,4 +2899,497 @@ mod tests { "derive_early_secret must use the buffer pool, returning a buffer with pooled capacity" ); } + + #[test] + #[cfg(feature = "rcgen")] + fn peer_key_update_response_does_not_replace_in_flight_local_update() { + let mut engine = test_engine(); + install_test_app_send_keys(&mut engine); + + engine + .create_key_update(KeyUpdateRequest::UpdateRequested) + .expect("create local KeyUpdate"); + + let prev_epoch = engine.prev_app_send_epoch; + let prev_seq = engine.prev_app_send_seq; + let saved_records: Vec<_> = engine + .flight_saved_records + .iter() + .map(|entry| { + ( + entry.content_type, + entry.epoch, + entry.send_seq, + entry.fragment.as_ref().to_vec(), + entry.acked, + ) + }) + .collect(); + let send_epoch = engine.app_send_epoch; + let send_seq = engine.app_send_seq; + + let result = engine.create_key_update(KeyUpdateRequest::UpdateNotRequested); + assert!( + result.is_err(), + "a peer-requested KeyUpdate response must wait while a local KeyUpdate is in flight" + ); + + assert_eq!(engine.prev_app_send_epoch, prev_epoch); + assert_eq!(engine.prev_app_send_seq, prev_seq); + assert_eq!(engine.app_send_epoch, send_epoch); + assert_eq!(engine.app_send_seq, send_seq); + assert_eq!(engine.flight_saved_records.len(), saved_records.len()); + for (entry, saved) in engine.flight_saved_records.iter().zip(saved_records) { + assert_eq!(entry.content_type, saved.0); + assert_eq!(entry.epoch, saved.1); + assert_eq!(entry.send_seq, saved.2); + assert_eq!(entry.fragment.as_ref(), saved.3.as_slice()); + assert_eq!(entry.acked, saved.4); + } + } + + #[test] + #[cfg(feature = "rcgen")] + fn ack_tracking_full_does_not_panic_on_handshake_replacement() { + let mut engine = test_engine(); + + let first = parsed_key_update(0); + engine + .insert_incoming(first) + .expect("insert initial key update"); + engine.queue_rx[0] + .first() + .first_handshake() + .expect("initial key update handshake") + .set_handled(); + + engine.received_record_numbers.clear(); + for sequence in 0..engine.received_record_numbers.capacity() { + engine.received_record_numbers.push((2, sequence as u64)); + } + + let replacement = parsed_key_update(1); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + engine + .insert_incoming(replacement) + .expect("replace handled key update") + })); + + assert!( + result.is_ok(), + "full ACK bookkeeping must not panic when a handled handshake is replaced" + ); + assert_eq!(engine.queue_rx.len(), 1); + assert_eq!( + engine.queue_rx[0].first().record().sequence.sequence_number, + 1 + ); + assert_eq!( + engine.received_record_numbers.len(), + engine.received_record_numbers.capacity(), + "overflowing ACK bookkeeping should keep existing entries and drop the extra one" + ); + } + + #[test] + #[cfg(feature = "rcgen")] + fn key_update_transmit_queue_full_does_not_save_unsent_flight() { + let config = Config::builder() + .max_queue_tx(1) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + + let mut occupied = Buf::new(); + occupied.resize(engine.config.mtu(), 0); + engine.queue_tx.push_back(occupied); + + let next_handshake_seq = engine.next_handshake_seq_no; + let app_send_epoch = engine.app_send_epoch; + let app_send_seq = engine.app_send_seq; + let result = engine.create_key_update(KeyUpdateRequest::UpdateRequested); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert!(engine.flight_saved_records.is_empty()); + assert!(engine.prev_app_send_keys.is_none()); + assert_eq!(engine.next_handshake_seq_no, next_handshake_seq); + assert_eq!(engine.app_send_epoch, app_send_epoch); + assert_eq!(engine.app_send_seq, app_send_seq); + } + + #[test] + #[cfg(feature = "rcgen")] + fn key_update_transmit_queue_full_preserves_existing_saved_flight() { + let config = Config::builder() + .max_queue_tx(1) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + + engine + .create_ciphertext_record( + ContentType::Handshake, + engine.app_send_epoch(), + true, + |fragment| { + fragment.extend_from_slice(b"old-flight"); + }, + ) + .expect("save existing flight"); + let saved_records = saved_record_snapshot(&engine); + + engine.queue_tx.clear(); + let mut occupied = Buf::new(); + occupied.resize(engine.config.mtu(), 0); + engine.queue_tx.push_back(occupied); + + engine.flight_backoff.attempt(&mut engine.rng); + let now = Instant::now(); + engine.flight_timeout = Timeout::Armed(now + Duration::from_secs(123)); + let flight_timeout = engine.flight_timeout; + let flight_rto = engine.flight_backoff.rto(); + let flight_can_retry = engine.flight_backoff.can_retry(); + let next_handshake_seq = engine.next_handshake_seq_no; + let app_send_epoch = engine.app_send_epoch; + let app_send_seq = engine.app_send_seq; + let result = engine.create_key_update(KeyUpdateRequest::UpdateRequested); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert_eq!(saved_record_snapshot(&engine), saved_records); + assert!(engine.prev_app_send_keys.is_none()); + assert_eq!(engine.next_handshake_seq_no, next_handshake_seq); + assert_eq!(engine.app_send_epoch, app_send_epoch); + assert_eq!(engine.app_send_seq, app_send_seq); + assert_eq!(engine.flight_timeout, flight_timeout); + assert_eq!(engine.flight_backoff.rto(), flight_rto); + assert_eq!(engine.flight_backoff.can_retry(), flight_can_retry); + } + + #[test] + #[cfg(feature = "rcgen")] + fn needs_key_update_survives_transmit_queue_full() { + let config = Config::builder() + .max_queue_tx(1) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + engine.needs_key_update = true; + + assert!(engine.needs_key_update()); + + let mut occupied = Buf::new(); + occupied.resize(engine.config.mtu(), 0); + engine.queue_tx.push_back(occupied); + + let result = engine.create_key_update(KeyUpdateRequest::UpdateRequested); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert!(engine.needs_key_update()); + } + + #[test] + #[cfg(feature = "rcgen")] + fn flight_timeout_transmit_queue_full_preserves_retry_state() { + let config = Config::builder() + .max_queue_tx(1) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + + engine + .create_key_update(KeyUpdateRequest::UpdateRequested) + .expect("initial KeyUpdate fits empty queue"); + assert_eq!(engine.queue_tx.len(), engine.config.max_queue_tx()); + + engine.flight_backoff.attempt(&mut engine.rng); + let now = Instant::now(); + let expired = now - Duration::from_millis(1); + engine.flight_timeout = Timeout::Armed(expired); + + let saved_records = saved_record_snapshot(&engine); + let flight_timeout = engine.flight_timeout; + let flight_rto = engine.flight_backoff.rto(); + let flight_can_retry = engine.flight_backoff.can_retry(); + let app_send_epoch = engine.app_send_epoch; + let app_send_seq = engine.app_send_seq; + let prev_app_send_epoch = engine.prev_app_send_epoch; + let prev_app_send_seq = engine.prev_app_send_seq; + + let result = engine.handle_timeout(now); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert_eq!(saved_record_snapshot(&engine), saved_records); + assert_eq!(engine.flight_timeout, flight_timeout); + assert_eq!(engine.flight_backoff.rto(), flight_rto); + assert_eq!(engine.flight_backoff.can_retry(), flight_can_retry); + assert_eq!(engine.app_send_epoch, app_send_epoch); + assert_eq!(engine.app_send_seq, app_send_seq); + assert_eq!(engine.prev_app_send_epoch, prev_app_send_epoch); + assert_eq!(engine.prev_app_send_seq, prev_app_send_seq); + + engine.queue_tx.clear(); + engine + .handle_timeout(now) + .expect("retry should succeed after transmit queue drains"); + assert_eq!(engine.queue_tx.len(), 1); + assert!( + engine.flight_backoff.rto() > flight_rto, + "successful retransmit should consume backoff only after queueing output" + ); + assert!(matches!(engine.flight_timeout, Timeout::Armed(deadline) if deadline > now)); + } + + #[test] + #[cfg(feature = "rcgen")] + fn send_ack_transmit_queue_full_preserves_ack_entries() { + let config = Config::builder() + .max_queue_tx(1) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + + engine.received_record_numbers.push((3, 7)); + engine.received_record_numbers.push((3, 8)); + let ack_entries = engine.received_record_numbers.clone(); + let app_send_seq = engine.app_send_seq; + + let mut occupied = Buf::new(); + occupied.resize(engine.config.mtu(), 0); + engine.queue_tx.push_back(occupied); + + let result = engine.send_ack(); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert_eq!(engine.received_record_numbers, ack_entries); + assert_eq!(engine.app_send_seq, app_send_seq); + + engine.queue_tx.clear(); + engine + .send_ack() + .expect("ACK should send after transmit queue drains"); + assert!(engine.received_record_numbers.is_empty()); + assert_eq!(engine.app_send_seq, app_send_seq + 1); + assert_eq!(engine.queue_tx.len(), 1); + } + + #[test] + #[cfg(feature = "rcgen")] + fn saved_flight_capacity_error_does_not_advance_ciphertext_sequence() { + let mut engine = test_engine(); + install_test_app_send_keys(&mut engine); + for index in 0..engine.flight_saved_records.capacity() { + push_dummy_saved_record(&mut engine, index); + } + + let app_send_seq = engine.app_send_seq; + let saved_records = saved_record_snapshot(&engine); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + engine.create_ciphertext_record( + ContentType::Handshake, + engine.app_send_epoch(), + true, + |fragment| { + fragment.extend_from_slice(b"overflow"); + }, + ) + })); + + assert!(result.is_ok(), "saved-flight capacity must not panic"); + assert!(matches!(result.unwrap(), Err(Error::TransmitQueueFull))); + assert_eq!(engine.app_send_seq, app_send_seq); + assert_eq!(saved_record_snapshot(&engine), saved_records); + } + + #[test] + #[cfg(feature = "rcgen")] + fn saved_flight_capacity_error_does_not_advance_plaintext_sequence() { + let mut engine = test_engine(); + for index in 0..engine.flight_saved_records.capacity() { + push_dummy_saved_record(&mut engine, index); + } + + let epoch0_seq = engine.sequence_epoch_0.sequence_number; + let saved_records = saved_record_snapshot(&engine); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + engine.create_plaintext_record(ContentType::Handshake, true, |fragment| { + fragment.extend_from_slice(b"overflow"); + }) + })); + + assert!(result.is_ok(), "saved-flight capacity must not panic"); + assert!(matches!(result.unwrap(), Err(Error::TransmitQueueFull))); + assert_eq!(engine.sequence_epoch_0.sequence_number, epoch0_seq); + assert_eq!(saved_record_snapshot(&engine), saved_records); + } + + #[test] + #[cfg(feature = "rcgen")] + fn flight_resend_transmit_queue_full_restores_saved_records() { + let config = Config::builder() + .max_queue_tx(1) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + + engine + .create_key_update(KeyUpdateRequest::UpdateRequested) + .expect("initial KeyUpdate fits empty queue"); + let saved_records: Vec<_> = engine + .flight_saved_records + .iter() + .map(|entry| { + ( + entry.content_type, + entry.epoch, + entry.send_seq, + entry.fragment.as_ref().to_vec(), + entry.acked, + ) + }) + .collect(); + + let result = engine.flight_resend("test queue full"); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert_eq!(engine.flight_saved_records.len(), saved_records.len()); + for (entry, saved) in engine.flight_saved_records.iter().zip(saved_records) { + assert_eq!(entry.content_type, saved.0); + assert_eq!(entry.epoch, saved.1); + assert_eq!(entry.send_seq, saved.2); + assert_eq!(entry.fragment.as_ref(), saved.3.as_slice()); + assert_eq!(entry.acked, saved.4); + } + } + + #[test] + #[cfg(feature = "rcgen")] + fn flight_resend_partial_queue_full_preserves_ackable_prefix_and_resumable_tail() { + let config = Config::builder() + .max_queue_tx(1) + .mtu(80) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + engine.app_send_seq = 100; + + push_saved_app_record(&mut engine, 7, 0xA7, 40); + push_saved_app_record(&mut engine, 8, 0xB8, 40); + + let result = engine.flight_resend("test partial queue full"); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert_eq!(engine.queue_tx.len(), 1); + assert_eq!(engine.app_send_seq, 101); + assert_eq!(engine.flight_saved_records.len(), 2); + assert_eq!(engine.flight_saved_records[0].send_seq, 100); + assert_eq!(engine.flight_saved_records[1].send_seq, 8); + assert!(!engine.flight_saved_records[0].acked); + assert!(!engine.flight_saved_records[1].acked); + + engine.queue_tx.clear(); + let result = engine.flight_resend("second partial retry before late ACK"); + + assert!(matches!(result, Err(Error::TransmitQueueFull))); + assert_eq!(engine.queue_tx.len(), 1); + assert_eq!(engine.app_send_seq, 102); + assert_eq!(engine.flight_saved_records[0].send_seq, 101); + assert_eq!(engine.flight_saved_records[1].send_seq, 8); + assert!(!engine.flight_saved_records[0].acked); + assert!(!engine.flight_saved_records[1].acked); + + engine.process_ack(&ack_record_number(3, 100)); + assert!(engine.flight_saved_records[0].acked); + assert!(!engine.flight_saved_records[1].acked); + + engine.queue_tx.clear(); + engine + .flight_resend("retry after partial queue full") + .expect("unsent tail should resend after queue drains"); + + assert_eq!(engine.queue_tx.len(), 1); + assert_eq!(engine.app_send_seq, 103); + assert_eq!(engine.flight_saved_records[0].send_seq, 101); + assert_eq!(engine.flight_saved_records[1].send_seq, 102); + assert!(engine.flight_saved_records[0].acked); + assert!(!engine.flight_saved_records[1].acked); + } + + #[test] + #[cfg(feature = "rcgen")] + fn flight_timeout_partial_resend_rearms_before_tail_retry() { + let config = Config::builder() + .max_queue_tx(1) + .mtu(80) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + engine.app_send_seq = 100; + + push_saved_app_record(&mut engine, 7, 0xA7, 40); + push_saved_app_record(&mut engine, 8, 0xB8, 40); + + let now = Instant::now(); + let expired = now - Duration::from_millis(1); + engine.flight_timeout = Timeout::Armed(expired); + let initial_rto = engine.flight_backoff.rto(); + + engine + .handle_timeout(now) + .expect("partial resend queues progress and rearms timeout"); + + assert_eq!(engine.queue_tx.len(), 1); + assert_eq!(engine.app_send_seq, 101); + assert_eq!(engine.flight_saved_records[0].send_seq, 100); + assert_eq!(engine.flight_saved_records[1].send_seq, 8); + assert!( + engine.flight_backoff.rto() > initial_rto, + "partial progress should still consume resend backoff" + ); + assert!(matches!(engine.flight_timeout, Timeout::Armed(deadline) if deadline > now)); + + engine.queue_tx.clear(); + engine + .handle_timeout(now) + .expect("same instant should not immediately retry the same prefix"); + assert!( + engine.queue_tx.is_empty(), + "rearmed timeout should prevent duplicate prefix output before the peer can ACK it" + ); + } + + #[test] + #[cfg(feature = "rcgen")] + fn flight_resend_late_ack_history_scales_with_configured_retry_budget() { + let config = Config::builder() + .flight_retries(12) + .build() + .expect("build config"); + let mut engine = test_engine_with_config(config); + install_test_app_send_keys(&mut engine); + engine.app_send_seq = 100; + + push_saved_app_record(&mut engine, 7, 0xA7, 16); + + for _ in 0..10 { + engine.queue_tx.clear(); + engine + .flight_resend("configured retry budget preserves old ACKs") + .expect("single saved record should resend"); + } + + assert_eq!(engine.flight_saved_records[0].send_seq, 109); + engine.process_ack(&ack_record_number(3, 100)); + assert!( + engine.flight_saved_records[0].acked, + "late ACK for first retransmitted copy should still match after more than eight retries" + ); + } } diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index 240cc830..12ada56f 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -6,7 +6,10 @@ use std::fmt; use crate::Error; use crate::buffer::{Buf, TmpBuf}; -use crate::dtls13::message::{ContentType, Dtls13CipherSuite, Dtls13Record, Handshake, Sequence}; +use crate::dtls13::message::Sequence; +use crate::dtls13::message::{ContentType, Dtls13CipherSuite}; +use crate::dtls13::message::{Dtls13Record, Handshake, MessageType}; +use crate::window::ReplayWindow; /// Holds both the UDP packet and the parsed result of that packet. pub struct Incoming { @@ -15,6 +18,11 @@ pub struct Incoming { records: Box, } +pub(crate) struct IncomingParse { + pub incoming: Option, + pub deferred_tail: Option, +} + impl Incoming { pub fn records(&self) -> &Records { &self.records @@ -29,33 +37,50 @@ impl Incoming { pub fn into_records(self) -> impl Iterator { self.records.records.into_iter() } + + pub(crate) fn ackable_record_numbers(&self) -> impl Iterator + '_ { + self.records + .records + .iter() + .map(|record| record.record().sequence) + .filter(|seq| seq.epoch >= 2) + .map(|seq| (seq.epoch as u64, seq.sequence_number)) + } } impl Incoming { - /// Parse an incoming UDP packet - /// - /// * `packet` is the data from the UDP socket. - /// * `decrypt` provides the decryption operations for encrypted records. - /// * `cs` is the negotiated cipher suite, if any. - /// - /// Will surface parser errors. - pub fn parse_packet( + pub(crate) fn parse_packet_defer_after_key_update( packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + ) -> Result { + Self::parse_packet_inner(packet, decrypt, cs, true) + } + + fn parse_packet_inner( + packet: &[u8], + decrypt: &mut dyn RecordHandler, + cs: Option, + defer_after_key_update: bool, + ) -> Result { // Parse records directly from packet, copying each record ONCE into its own buffer - let records = Records::parse(packet, decrypt, cs)?; + let parse = Records::parse_inner(packet, decrypt, cs, defer_after_key_update)?; // We need at least one Record to be valid. For replayed frames, we discard // the records, hence this might be None - if records.records.is_empty() { - return Ok(None); + if parse.records.records.is_empty() { + return Ok(IncomingParse { + incoming: None, + deferred_tail: parse.deferred_tail, + }); } - let records = Box::new(records); + let records = Box::new(parse.records); - Ok(Some(Incoming { records })) + Ok(IncomingParse { + incoming: Some(Incoming { records }), + deferred_tail: parse.deferred_tail, + }) } } @@ -65,62 +90,28 @@ pub struct Records { pub records: ArrayVec, } +struct RecordsParse { + records: Records, + deferred_tail: Option, +} + impl Records { - pub fn parse( + fn parse_inner( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result { + defer_after_key_update: bool, + ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut replay_updates: ArrayVec = ArrayVec::new(); + let mut pending_replay: ArrayVec<(u16, ReplayWindow), 16> = ArrayVec::new(); + let mut pending_expected: ArrayVec<(u16, u64), 16> = ArrayVec::new(); + let mut deferred_tail = None; // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { - let record_end = if Dtls13Record::is_ciphertext_header(packet[0]) { - // CID bit set means we can't determine record boundaries (unsupported). - // Discard the rest of the datagram. - if packet[0] & 0x10 != 0 { - break; - } - - // Unified header: variable length - if packet.len() < 2 { - return Err(Error::ParseIncomplete); - } - - let flags = packet[0]; - let s_flag = flags & 0b0000_1000 != 0; - let l_flag = flags & 0b0000_0100 != 0; - let seq_len = if s_flag { 2 } else { 1 }; - let len_len = if l_flag { 2 } else { 0 }; - let header_len = 1 + seq_len + len_len; - - if packet.len() < header_len { - return Err(Error::ParseIncomplete); - } - - if l_flag { - let len_offset = 1 + seq_len; - // unwrap: header_len check above ensures 2 bytes at len_offset - let length_bytes: [u8; 2] = - packet[len_offset..len_offset + 2].try_into().unwrap(); - let length = u16::from_be_bytes(length_bytes) as usize; - header_len + length - } else { - // No length field: record consumes the rest of the datagram - packet.len() - } - } else { - // Plaintext: fixed 13-byte header - if packet.len() < Dtls13Record::PLAINTEXT_HEADER_LEN { - return Err(Error::ParseIncomplete); - } - - // unwrap: PLAINTEXT_HEADER_LEN check above ensures 2 bytes at offset - let length_bytes: [u8; 2] = packet[Dtls13Record::PLAINTEXT_LENGTH_OFFSET] - .try_into() - .unwrap(); - let length = u16::from_be_bytes(length_bytes) as usize; - Dtls13Record::PLAINTEXT_HEADER_LEN + length + let Some(record_end) = record_end(packet)? else { + break; }; if packet.len() < record_end { @@ -129,15 +120,50 @@ impl Records { // This is the ONLY copy: packet -> record buffer let record_slice = &packet[..record_end]; - match Record::parse(record_slice, decrypt, cs) { - Ok(record) => { - if let Some(record) = record { + let tail = &packet[record_end..]; + match Record::parse(record_slice, decrypt, cs, &pending_expected) { + Ok(parsed) => { + let mut should_break_after_replay_update = false; + + if let Some(sequence) = parsed.replay_sequence { + if !pending_replay_check(&pending_replay, sequence) { + trace!("Discarding duplicate rec in same datagram"); + packet = &packet[record_end..]; + continue; + } + } + + if let Some(record) = parsed.record { + let should_defer_tail = defer_after_key_update + && !tail.is_empty() + && record.contains_complete_key_update(); + if parsed_records.try_push(record).is_err() { return Err(Error::TooManyRecords); } - } else { + + should_break_after_replay_update = should_defer_tail; + if should_defer_tail { + validate_record_boundaries(tail, parsed_records.len())?; + let mut tail_buffer = Buf::new(); + tail_buffer.extend_from_slice(tail); + deferred_tail = Some(tail_buffer); + } + } else if parsed.replay_sequence.is_none() { trace!("Discarding replayed rec"); } + + if let Some(sequence) = parsed.replay_sequence { + pending_replay_update(&mut pending_replay, sequence)?; + pending_expected_update(&mut pending_expected, sequence)?; + if replay_updates.try_push(sequence).is_err() { + return Err(Error::TooManyRecords); + } + } + + if should_break_after_replay_update { + break; + } } Err(e) => return Err(e), } @@ -145,6 +171,13 @@ impl Records { packet = &packet[record_end..]; } + // Commit replay state only after the whole UDP datagram has parsed + // successfully. A malformed trailing record must not consume + // replay state for an earlier authenticated record in the same datagram. + for sequence in replay_updates { + decrypt.replay_update(sequence); + } + let mut records = ArrayVec::new(); for record in parsed_records { if let Some(record) = decrypt.classify_record(record)? { @@ -154,10 +187,140 @@ impl Records { } } - Ok(Records { records }) + Ok(RecordsParse { + records: Records { records }, + deferred_tail, + }) } } +fn record_end(packet: &[u8]) -> Result, Error> { + if Dtls13Record::is_ciphertext_header(packet[0]) { + // CID bit set means we can't determine record boundaries (unsupported). + // Discard the rest of the datagram. + if packet[0] & 0x10 != 0 { + return Ok(None); + } + + // Unified header: variable length + if packet.len() < 2 { + return Err(Error::ParseIncomplete); + } + + let flags = packet[0]; + let s_flag = flags & 0b0000_1000 != 0; + let l_flag = flags & 0b0000_0100 != 0; + let seq_len = if s_flag { 2 } else { 1 }; + let len_len = if l_flag { 2 } else { 0 }; + let header_len = 1 + seq_len + len_len; + + if packet.len() < header_len { + return Err(Error::ParseIncomplete); + } + + if l_flag { + let len_offset = 1 + seq_len; + // unwrap: header_len check above ensures 2 bytes at len_offset. + let length_bytes: [u8; 2] = packet[len_offset..len_offset + 2].try_into().unwrap(); + let length = u16::from_be_bytes(length_bytes) as usize; + Ok(Some(header_len + length)) + } else { + Ok(Some(packet.len())) + } + } else { + if packet.len() < Dtls13Record::PLAINTEXT_HEADER_LEN { + return Err(Error::ParseIncomplete); + } + + // unwrap: PLAINTEXT_HEADER_LEN check above ensures 2 bytes at offset. + let length_bytes: [u8; 2] = packet[Dtls13Record::PLAINTEXT_LENGTH_OFFSET] + .try_into() + .unwrap(); + let length = u16::from_be_bytes(length_bytes) as usize; + Ok(Some(Dtls13Record::PLAINTEXT_HEADER_LEN + length)) + } +} + +fn validate_record_boundaries( + mut packet: &[u8], + mut parsed_record_count: usize, +) -> Result<(), Error> { + while !packet.is_empty() { + let Some(end) = record_end(packet)? else { + return Ok(()); + }; + + if parsed_record_count >= 16 { + return Err(Error::TooManyRecords); + } + parsed_record_count += 1; + + if packet.len() < end { + return Err(Error::ParseIncomplete); + } + + packet = &packet[end..]; + } + + Ok(()) +} + +fn pending_replay_check(pending_replay: &ArrayVec<(u16, ReplayWindow), 16>, seq: Sequence) -> bool { + match pending_replay.iter().find(|(epoch, _)| *epoch == seq.epoch) { + Some((_, window)) => window.check(seq.sequence_number), + None => true, + } +} + +fn pending_replay_update( + pending_replay: &mut ArrayVec<(u16, ReplayWindow), 16>, + seq: Sequence, +) -> Result<(), Error> { + if let Some((_, window)) = pending_replay + .iter_mut() + .find(|(epoch, _)| *epoch == seq.epoch) + { + window.update(seq.sequence_number); + return Ok(()); + } + + let mut window = ReplayWindow::new(); + window.update(seq.sequence_number); + pending_replay + .try_push((seq.epoch, window)) + .map_err(|_| Error::TooManyRecords) +} + +fn pending_expected_override( + pending_expected: &ArrayVec<(u16, u64), 16>, + epoch: u16, +) -> Option { + pending_expected + .iter() + .find(|(candidate_epoch, _)| *candidate_epoch == epoch) + .map(|(_, expected)| *expected) +} + +fn pending_expected_update( + pending_expected: &mut ArrayVec<(u16, u64), 16>, + seq: Sequence, +) -> Result<(), Error> { + let next = seq.sequence_number + 1; + if let Some((_, expected)) = pending_expected + .iter_mut() + .find(|(epoch, _)| *epoch == seq.epoch) + { + if next > *expected { + *expected = next; + } + return Ok(()); + } + + pending_expected + .try_push((seq.epoch, next)) + .map_err(|_| Error::TooManyRecords) +} + impl Deref for Records { type Target = [Record]; @@ -173,14 +336,20 @@ pub struct Record { parsed: Box, } +struct RecordParse { + record: Option, + replay_sequence: Option, +} + impl Record { /// The first parse pass only parses the record header which is unencrypted. /// Copies record data from UDP packet ONCE into a pooled buffer. - pub fn parse( + fn parse( record_slice: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, - ) -> Result, Error> { + pending_expected: &ArrayVec<(u16, u64), 16>, + ) -> Result { // ONLY COPY: UDP packet slice -> pooled buffer let mut buffer = Buf::new(); buffer.extend_from_slice(record_slice); @@ -218,7 +387,10 @@ impl Record { Ok(p) => p, Err(e) => { trace!("Discarding record: parse failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } }; let parsed = Box::new(parsed); @@ -226,7 +398,10 @@ impl Record { // Plaintext records (epoch 0) are not encrypted if !is_ciphertext || !decrypt.is_peer_encryption_enabled() { - return Ok(Some(record)); + return Ok(RecordParse { + record: Some(record), + replay_sequence: None, + }); } // Resolve the full epoch from the 2-bit value in the unified header @@ -236,7 +411,12 @@ impl Record { // Resolve the full sequence number from the (now decrypted) partial value let seq_bits = record.record().sequence.sequence_number; let s_flag = record_slice[0] & 0b0000_1000 != 0; - let full_seq = decrypt.resolve_sequence(full_epoch, seq_bits, s_flag); + let full_seq = decrypt.resolve_sequence( + full_epoch, + seq_bits, + s_flag, + pending_expected_override(pending_expected, full_epoch), + ); let full_sequence = Sequence { epoch: full_epoch, @@ -245,7 +425,10 @@ impl Record { // Anti-replay check (read-only, does not update window) if !decrypt.replay_check(full_sequence) { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } // Save the raw header bytes for AAD before mutating the buffer. @@ -257,7 +440,10 @@ impl Record { // so decryption would necessarily fail. Catching it here keeps the // cipher impls' bounds-checking from being the only line of defence. if record.buffer.len() - header_end < decrypt.min_protected_fragment_len() { - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } let mut header_buf = [0u8; 5]; header_buf[..header_end].copy_from_slice(&record.buffer[..header_end]); @@ -277,25 +463,26 @@ impl Record { Ok(()) => {} Err(e) => { trace!("Discarding ciphertext record: decryption failed: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: None, + }); } } buffer.len() }; - // Decryption succeeded — now commit the replay window update. - // RFC 9147 §4.5.1: "The window MUST NOT be updated due to a received - // record until that record has been deprotected successfully." - decrypt.replay_update(full_sequence); - // Recover inner content type from DTLSInnerPlaintext let decrypted = &buffer[header_end..header_end + new_len]; let (inner_content_type, content_len) = match recover_inner_content_type(decrypted) { Ok(v) => v, Err(e) => { trace!("Discarding record: invalid inner content type: {}", e); - return Ok(None); + return Ok(RecordParse { + record: None, + replay_sequence: Some(full_sequence), + }); } }; @@ -311,7 +498,10 @@ impl Record { ); let parsed = Box::new(parsed); - Ok(Some(Record { buffer, parsed })) + Ok(RecordParse { + record: Some(Record { buffer, parsed }), + replay_sequence: Some(full_sequence), + }) } pub fn record(&self) -> &Dtls13Record { @@ -326,6 +516,15 @@ impl Record { self.parsed.handshakes.first() } + fn contains_complete_key_update(&self) -> bool { + self.record().content_type == ContentType::Handshake + && self.handshakes().iter().any(|handshake| { + handshake.header.msg_type == MessageType::KeyUpdate + && handshake.header.fragment_offset == 0 + && handshake.header.fragment_length == handshake.header.length + }) + } + pub fn is_handled(&self) -> bool { if self.parsed.handshakes.is_empty() { self.parsed.handled.load(Ordering::Relaxed) @@ -407,7 +606,13 @@ pub trait RecordHandler { fn classify_record(&mut self, record: Record) -> Result, Error>; fn is_peer_encryption_enabled(&self) -> bool; fn resolve_epoch(&self, epoch_bits: u8) -> u16; - fn resolve_sequence(&self, epoch: u16, seq_bits: u64, s_flag: bool) -> u64; + fn resolve_sequence( + &self, + epoch: u16, + seq_bits: u64, + s_flag: bool, + expected_override: Option, + ) -> u64; fn replay_check(&self, seq: Sequence) -> bool; fn replay_update(&mut self, seq: Sequence); @@ -556,7 +761,13 @@ mod tests { panic!("resolve_epoch should not be called when peer encryption is disabled"); } - fn resolve_sequence(&self, _epoch: u16, _seq_bits: u64, _s_flag: bool) -> u64 { + fn resolve_sequence( + &self, + _epoch: u16, + _seq_bits: u64, + _s_flag: bool, + _expected_override: Option, + ) -> u64 { panic!("resolve_sequence should not be called when peer encryption is disabled"); } @@ -621,8 +832,9 @@ mod tests { packet.extend_from_slice(&build_ciphertext_record(2, 2, &[0x11, 0x22, 0x33])); let mut handler = TestHandler::default(); - let incoming = Incoming::parse_packet(&packet, &mut handler, None) + let incoming = Incoming::parse_packet_defer_after_key_update(&packet, &mut handler, None) .unwrap() + .incoming .expect("ciphertext application data record should remain"); assert_eq!(handler.classify_calls, 2); diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index 55c6e427..92cb03cd 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -223,7 +223,8 @@ impl Handshake { let qualifies = matches!( self.header.msg_type, - MessageType::ClientHello // flight 1 + MessageType::ClientHello | // flight 1 + MessageType::KeyUpdate // post-handshake flight ); qualifies.then_some(self.header.message_seq) diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 71f3dfd7..93319637 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -138,6 +138,9 @@ pub struct Server { /// Whether we need to respond with our own KeyUpdate. pending_key_update_response: bool, + /// Whether a peer KeyUpdate ACK needs to be retried. + pending_key_update_ack: bool, + /// When true, a ClientHello without DTLS 1.3 in `supported_versions` /// returns [`Error::Dtls12Fallback`] instead of a security error. /// Used by the auto-sense server path. @@ -207,6 +210,7 @@ impl Server { hello_retry: false, cookie_secret, pending_key_update_response: false, + pending_key_update_ack: false, auto_mode, retained_hello: VecDeque::with_capacity(10), } @@ -248,7 +252,7 @@ impl Server { } self.engine.parse_packet(packet)?; - self.make_progress()?; + self.make_progress_and_drain_deferred()?; // Once past AwaitClientHello, DTLS 1.3 is committed — free the buffer. if self.auto_mode && self.state != State::AwaitClientHello { @@ -263,7 +267,13 @@ impl Server { if let Some(event) = self.local_events.pop_front() { return event.into_output(buf, &self.client_certificates); } - self.engine.poll_output(buf, self.last_now) + + match self.engine.poll_output(buf, self.last_now) { + Output::Timeout(_) if self.has_pending_local_progress() => { + Output::Timeout(self.last_now) + } + output => output, + } } /// Handle a timeout event. @@ -273,15 +283,89 @@ impl Server { self.random = Some(self.engine.random()); } self.engine.handle_timeout(now)?; + self.make_progress_and_drain_deferred()?; + Ok(()) + } + + fn make_progress_and_drain_deferred(&mut self) -> Result<(), Error> { self.make_progress()?; + while self.engine.parse_next_deferred_packet()? { + self.make_progress()?; + } Ok(()) } + fn has_pending_local_progress(&self) -> bool { + self.pending_key_update_ack + || self.engine.has_pending_post_handshake_ack() + || (self.pending_key_update_response && !self.engine.is_key_update_in_flight()) + } + fn initiate_key_update(&mut self) -> Result<(), Error> { self.engine .create_key_update(KeyUpdateRequest::UpdateRequested) } + fn send_pending_key_update_response(&mut self) -> Result<(), Error> { + if self.pending_key_update_response && !self.engine.is_key_update_in_flight() { + self.engine.send_ack()?; + self.engine + .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; + self.pending_key_update_response = false; + } + Ok(()) + } + + fn send_pending_key_update_ack(&mut self) -> Result<(), Error> { + if !self.pending_key_update_ack && !self.engine.has_pending_post_handshake_ack() { + return Ok(()); + } + + let result = if self.engine.is_key_update_in_flight() { + self.engine.send_ack_with_previous_app_epoch() + } else { + self.engine.send_ack() + }; + + result?; + self.pending_key_update_ack = false; + Ok(()) + } + + fn handle_incoming_key_update(&mut self) -> Result<(), Error> { + if self.engine.has_complete_handshake(MessageType::KeyUpdate) { + let maybe = self.engine.next_handshake_no_transcript( + MessageType::KeyUpdate, + &mut self.defragment_buffer, + )?; + + if let Some(handshake) = maybe { + let Body::KeyUpdate(request) = handshake.body else { + unreachable!() + }; + + // Install new recv keys + self.engine.update_recv_keys()?; + self.engine.advance_peer_handshake_seq(); + self.pending_key_update_ack = true; + + // If peer requested us to update, schedule our own KeyUpdate + if request == KeyUpdateRequest::UpdateRequested { + self.pending_key_update_response = true; + } + self.send_pending_key_update_ack()?; + + debug!("Received KeyUpdate (request={:?})", request); + + // Drain a fresh peer-requested response in the same progress + // pass when no local KeyUpdate is in flight. + self.send_pending_key_update_response()?; + } + } + + Ok(()) + } + /// Send application data when the server is connected. pub fn send_application_data(&mut self, data: &[u8]) -> Result<(), Error> { if self.state == State::Closed || self.state == State::HalfClosedLocal { @@ -1119,8 +1203,14 @@ impl State { } fn await_application_data(self, server: &mut Server) -> Result { + // Incoming peer requests require an update_not_requested response. They + // take priority over local AEAD-limit updates and queued app data. + server.handle_incoming_key_update()?; + server.send_pending_key_update_ack()?; + server.send_pending_key_update_response()?; + // Auto-trigger KeyUpdate when AEAD encryption limit is reached - if server.engine.needs_key_update() && !server.engine.is_key_update_in_flight() { + if !server.engine.is_key_update_in_flight() && server.engine.needs_key_update() { server.initiate_key_update()?; } @@ -1143,42 +1233,6 @@ impl State { } } - // Send pending KeyUpdate response before processing new KeyUpdates - if server.pending_key_update_response { - server - .engine - .create_key_update(KeyUpdateRequest::UpdateNotRequested)?; - server.pending_key_update_response = false; - } - - // Check for incoming KeyUpdate - if server.engine.has_complete_handshake(MessageType::KeyUpdate) { - let maybe = server.engine.next_handshake_no_transcript( - MessageType::KeyUpdate, - &mut server.defragment_buffer, - )?; - - if let Some(handshake) = maybe { - let Body::KeyUpdate(request) = handshake.body else { - unreachable!() - }; - - // Install new recv keys - server.engine.update_recv_keys()?; - - // ACK the KeyUpdate record - server.engine.send_ack()?; - - // If peer requested us to update, schedule our own KeyUpdate - if request == KeyUpdateRequest::UpdateRequested { - server.pending_key_update_response = true; - } - - server.engine.advance_peer_handshake_seq(); - debug!("Received KeyUpdate (request={:?})", request); - } - } - Ok(self) } @@ -1452,3 +1506,126 @@ fn serialize_certificate_authorities( output.extend_from_slice(data); } } + +#[cfg(all(test, feature = "rcgen"))] +mod tests { + use std::sync::Arc; + use std::time::{Duration, Instant}; + + use super::*; + use crate::certificate::generate_self_signed_certificate; + use crate::dtls13::client::Client; + use crate::dtls13::engine::Engine; + + fn collect_client(client: &mut Client) -> (Vec>, bool) { + let mut packets = Vec::new(); + let mut connected = false; + let mut buf = vec![0u8; 2048]; + + loop { + match client.poll_output(&mut buf) { + Output::Packet(packet) => packets.push(packet.to_vec()), + Output::Connected => connected = true, + Output::Timeout(_) => break, + _ => {} + } + } + + (packets, connected) + } + + fn collect_server(server: &mut Server) -> (Vec>, bool) { + let mut packets = Vec::new(); + let mut connected = false; + let mut buf = vec![0u8; 2048]; + + loop { + match server.poll_output(&mut buf) { + Output::Packet(packet) => packets.push(packet.to_vec()), + Output::Connected => connected = true, + Output::Timeout(_) => break, + _ => {} + } + } + + (packets, connected) + } + + fn deliver_to_server(packets: &[Vec], server: &mut Server) { + for packet in packets { + server.handle_packet(packet).expect("server handles packet"); + } + } + + fn deliver_to_client(packets: &[Vec], client: &mut Client) { + for packet in packets { + client.handle_packet(packet).expect("client handles packet"); + } + } + + fn complete_handshake(client: &mut Client, server: &mut Server, mut now: Instant) -> Instant { + let mut client_connected = false; + let mut server_connected = false; + + for _ in 0..40 { + client.handle_timeout(now).expect("client timeout"); + server.handle_timeout(now).expect("server timeout"); + + let (client_packets, client_event) = collect_client(client); + let (server_packets, server_event) = collect_server(server); + + client_connected |= client_event; + server_connected |= server_event; + + deliver_to_server(&client_packets, server); + deliver_to_client(&server_packets, client); + + if client_connected && server_connected { + return now; + } + + now += Duration::from_millis(10); + } + + panic!("DTLS 1.3 handshake did not complete"); + } + + #[test] + fn pending_key_update_response_does_not_replace_server_local_key_update() { + let _ = env_logger::try_init(); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build config"), + ); + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let now = Instant::now(); + + let client_engine = Engine::new(Arc::clone(&config), client_cert); + let mut client = Client::new_with_engine(client_engine, now); + let mut server = Server::new(config, server_cert, now); + + let _now = complete_handshake(&mut client, &mut server, now); + + server + .engine + .create_key_update(KeyUpdateRequest::UpdateRequested) + .expect("server creates local KeyUpdate"); + let (server_key_update, _) = collect_server(&mut server); + assert!(!server_key_update.is_empty()); + assert!(server.engine.is_key_update_in_flight()); + + server.pending_key_update_response = true; + server + .send_pending_key_update_response() + .expect("pending response check while server KeyUpdate is in flight"); + let (early_response, _) = collect_server(&mut server); + assert!( + early_response.is_empty(), + "server must not replace an in-flight local KeyUpdate with a peer-requested response" + ); + } +} diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 7640007c..6dcd02b5 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -680,6 +680,78 @@ fn dtls12_bad_encrypted_prefix_does_not_drop_valid_tail() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_malformed_trailing_record_does_not_consume_replay_window() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + client + .send_application_data(b"replay-atomic") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut malformed_packet = valid_packet.clone(); + malformed_packet.push(0xff); + + let err = server + .handle_packet(&malformed_packet) + .expect_err("trailing partial record must reject the datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let server_out = drain_outputs(&mut server); + assert!( + server_out.app_data.is_empty(), + "malformed datagram must not deliver application data" + ); + + server + .handle_packet(&valid_packet) + .expect("valid packet must still pass replay checks"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"replay-atomic".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_same_datagram_duplicate_encrypted_record_delivers_once() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + client + .send_application_data(b"duplicate-once") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut duplicate_datagram = valid_packet.clone(); + duplicate_datagram.extend_from_slice(&valid_packet); + + server + .handle_packet(&duplicate_datagram) + .expect("duplicate datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"duplicate-once".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_relabelled_encrypted_handshake_failure_is_not_silently_discarded() { @@ -767,6 +839,39 @@ fn dtls12_relabelled_encrypted_handshake_failure_is_not_silently_discarded() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_same_datagram_window_shift_drops_now_too_old_record() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_12_pair(now); + + let mut packets = Vec::new(); + for i in 0..66 { + client + .send_application_data(format!("msg-{i}").as_bytes()) + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let out = drain_outputs(&mut client); + let packet = out + .packets + .first() + .expect("application data packet") + .clone(); + packets.push(packet); + } + + let mut shifted_datagram = packets[65].clone(); + shifted_datagram.extend_from_slice(&packets[0]); + + server + .handle_packet(&shifted_datagram) + .expect("window-shift datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"msg-65".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_app_data_after_close_notify_is_ignored() { diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 1813b8ca..d7c36d90 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -1267,6 +1267,111 @@ fn dtls13_mixed_datagram_valid_first_then_bogus() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_malformed_trailing_record_does_not_consume_replay_window() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + client + .send_application_data(b"replay-atomic") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut malformed_packet = valid_packet.clone(); + malformed_packet.push(0xff); + + let err = server + .handle_packet(&malformed_packet) + .expect_err("trailing partial record must reject the datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let server_out = drain_outputs(&mut server); + assert!( + server_out.app_data.is_empty(), + "malformed datagram must not deliver application data" + ); + + server + .handle_packet(&valid_packet) + .expect("valid packet must still pass replay checks"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"replay-atomic".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_same_datagram_duplicate_encrypted_record_delivers_once() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + client + .send_application_data(b"duplicate-once") + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let client_out = drain_outputs(&mut client); + let valid_packet = client_out + .packets + .first() + .expect("application data packet") + .clone(); + + let mut duplicate_datagram = valid_packet.clone(); + duplicate_datagram.extend_from_slice(&valid_packet); + + server + .handle_packet(&duplicate_datagram) + .expect("duplicate datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"duplicate-once".to_vec()]); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_same_datagram_window_shift_drops_now_too_old_record() { + let _ = env_logger::try_init(); + let now = Instant::now(); + let (mut client, mut server, now) = setup_connected_13_pair(now); + + let mut packets = Vec::new(); + for i in 0..66 { + client + .send_application_data(format!("msg-{i}").as_bytes()) + .expect("send application data"); + client.handle_timeout(now).expect("client timeout"); + let out = drain_outputs(&mut client); + let packet = out + .packets + .first() + .expect("application data packet") + .clone(); + packets.push(packet); + } + + let mut shifted_datagram = packets[65].clone(); + shifted_datagram.extend_from_slice(&packets[0]); + + server + .handle_packet(&shifted_datagram) + .expect("window-shift datagram should parse"); + let server_out = drain_outputs(&mut server); + + assert_eq!(server_out.app_data, vec![b"msg-65".to_vec()]); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_half_close_send_then_close() { diff --git a/tests/dtls13/key_update.rs b/tests/dtls13/key_update.rs index d619f0f3..56ea771d 100644 --- a/tests/dtls13/key_update.rs +++ b/tests/dtls13/key_update.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -use dimpl::{Config, Dtls}; +use dimpl::{Config, Dtls, Error}; use crate::common::*; @@ -197,6 +197,1114 @@ fn dtls13_key_update_bidirectional_after_limit() { ); } +#[cfg(feature = "rcgen")] +fn assert_key_update_and_app_data_same_datagram( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + priming: &'static [u8], + post_update: &'static [u8], +) { + sender + .send_application_data(priming) + .expect("send priming app data"); + let first_packets = collect_packets(sender); + assert_eq!(first_packets.len(), 1); + + deliver_packets(&first_packets, receiver); + receiver.handle_timeout(*now).expect("receiver timeout"); + let first_received = drain_outputs(receiver); + assert_eq!(first_received.app_data, vec![priming.to_vec()]); + deliver_packets(&first_received.packets, sender); + + *now += Duration::from_millis(10); + sender.handle_timeout(*now).expect("sender timeout"); + let key_update_packets = collect_packets(sender); + assert_eq!(key_update_packets.len(), 1); + + sender + .send_application_data(post_update) + .expect("send post-key-update app data"); + let app_packets = collect_packets(sender); + assert_eq!(app_packets.len(), 1); + + let mut combined = key_update_packets[0].clone(); + combined.extend_from_slice(&app_packets[0]); + + receiver + .handle_packet(&combined) + .expect("receiver should accept combined datagram"); + receiver.handle_timeout(*now).expect("receiver timeout"); + + let received = drain_outputs(receiver); + assert_eq!( + received.app_data, + vec![post_update.to_vec()], + "receiver should deliver new-epoch app data that follows KeyUpdate in the same datagram" + ); +} + +#[cfg(feature = "rcgen")] +fn capture_key_update_and_app_data_packets( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + priming: &'static [u8], + post_update: &'static [u8], +) -> (Vec, Vec) { + sender + .send_application_data(priming) + .expect("send priming app data"); + let first_packets = collect_packets(sender); + assert_eq!(first_packets.len(), 1); + + deliver_packets(&first_packets, receiver); + receiver.handle_timeout(*now).expect("receiver timeout"); + let first_received = drain_outputs(receiver); + assert_eq!(first_received.app_data, vec![priming.to_vec()]); + deliver_packets(&first_received.packets, sender); + + *now += Duration::from_millis(10); + sender.handle_timeout(*now).expect("sender timeout"); + let key_update_packets = collect_packets(sender); + assert_eq!(key_update_packets.len(), 1); + + sender + .send_application_data(post_update) + .expect("send post-key-update app data"); + let app_packets = collect_packets(sender); + assert_eq!(app_packets.len(), 1); + + (key_update_packets[0].clone(), app_packets[0].clone()) +} + +#[cfg(feature = "rcgen")] +fn dtls13_ack_record_with_entry(seq: u64, ack_epoch: u64, ack_seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(26); // Ack + out.extend_from_slice(&[0xFE, 0xFD]); // legacy DTLS record version + out.extend_from_slice(&0u16.to_be_bytes()); // epoch 0 plaintext + out.extend_from_slice(&seq.to_be_bytes()[2..]); // u48 sequence number + out.extend_from_slice(&18u16.to_be_bytes()); // record_numbers_len + one entry + out.extend_from_slice(&16u16.to_be_bytes()); // record_numbers_len + out.extend_from_slice(&ack_epoch.to_be_bytes()); + out.extend_from_slice(&ack_seq.to_be_bytes()); + out +} + +/// Test that application data following a KeyUpdate in the same datagram is +/// delivered. The sender emits the KeyUpdate under the old application epoch, +/// rotates send keys, and the datagram then carries application data under the +/// new epoch. The receiver must process the KeyUpdate before trying to decrypt +/// the following record. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_client_key_update_and_new_epoch_app_data_in_same_datagram() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_key_update_and_app_data_same_datagram( + &mut client, + &mut server, + &mut now, + b"client-primes-key-update", + b"client-same-datagram-new-epoch", + ); +} + +/// If the deferred tail after a KeyUpdate is structurally malformed, the whole +/// UDP datagram must be rejected before the KeyUpdate is acted on. This keeps +/// the DIMP-007 datagram-atomic replay/state invariant for malformed tails. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_with_malformed_same_datagram_tail_is_atomic() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let (key_update_packet, app_packet) = capture_key_update_and_app_data_packets( + &mut client, + &mut server, + &mut now, + b"client-primes-malformed-tail-key-update", + b"client-valid-tail-after-malformed-attempt", + ); + + let mut malformed = key_update_packet.clone(); + malformed.push(0xff); + + let err = server + .handle_packet(&malformed) + .expect_err("malformed tail must reject the full datagram"); + assert!( + matches!(err, dimpl::Error::ParseIncomplete), + "expected ParseIncomplete, got {err:?}" + ); + + let after_malformed = drain_outputs(&mut server); + assert!( + after_malformed.packets.is_empty() && after_malformed.app_data.is_empty(), + "malformed datagram must not ACK, advance, or deliver anything" + ); + + let mut valid = key_update_packet; + valid.extend_from_slice(&app_packet); + + server + .handle_packet(&valid) + .expect("valid retry must still pass replay checks"); + server.handle_timeout(now).expect("server timeout"); + + let received = drain_outputs(&mut server); + assert_eq!( + received.app_data, + vec![b"client-valid-tail-after-malformed-attempt".to_vec()] + ); +} + +/// Deferring a post-KeyUpdate tail must not reset the per-datagram record +/// budget. Otherwise a `KeyUpdate || 16 records` datagram would be accepted as +/// two separately-budgeted parses and the KeyUpdate would take effect before the +/// over-capacity tail is rejected. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_with_over_capacity_same_datagram_tail_is_atomic() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let (key_update_packet, app_packet) = capture_key_update_and_app_data_packets( + &mut client, + &mut server, + &mut now, + b"client-primes-over-capacity-tail-key-update", + b"client-valid-tail-after-over-capacity-attempt", + ); + + let mut over_capacity = key_update_packet.clone(); + for seq in 0..16 { + over_capacity.extend_from_slice(&dtls13_ack_record_with_entry(0x200 + seq, 3, seq)); + } + + let err = server + .handle_packet(&over_capacity) + .expect_err("over-capacity tail must reject the full datagram"); + assert!( + matches!(err, dimpl::Error::TooManyRecords), + "expected TooManyRecords, got {err:?}" + ); + + let after_over_capacity = drain_outputs(&mut server); + assert!( + after_over_capacity.packets.is_empty() && after_over_capacity.app_data.is_empty(), + "over-capacity datagram must not ACK, advance, or deliver anything" + ); + + let mut valid = key_update_packet; + valid.extend_from_slice(&app_packet); + + server + .handle_packet(&valid) + .expect("valid retry must still pass replay checks"); + server.handle_timeout(now).expect("server timeout"); + + let received = drain_outputs(&mut server); + assert_eq!( + received.app_data, + vec![b"client-valid-tail-after-over-capacity-attempt".to_vec()] + ); +} + +/// Same as the client-sender case, but with the server as the KeyUpdate sender +/// so the client deferred-tail path is covered too. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_server_key_update_and_new_epoch_app_data_in_same_datagram() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_key_update_and_app_data_same_datagram( + &mut server, + &mut client, + &mut now, + b"server-primes-key-update", + b"server-same-datagram-new-epoch", + ); +} + +#[cfg(feature = "rcgen")] +fn trigger_key_update( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + label: &str, +) -> Vec> { + for i in 0..64 { + sender + .handle_timeout(*now) + .expect("sender checks existing KeyUpdate threshold"); + let pending_key_update = collect_packets(sender); + if !pending_key_update.is_empty() { + return pending_key_update; + } + + let msg = format!("{label}-{i}").into_bytes(); + sender + .send_application_data(&msg) + .expect("sender sends priming app data"); + let data_packets = collect_packets(sender); + assert_eq!( + data_packets.len(), + 1, + "sender should emit exactly one priming app-data packet before KeyUpdate" + ); + + deliver_packets(&data_packets, receiver); + receiver + .handle_timeout(*now) + .expect("receiver handles priming data"); + let receiver_out = drain_outputs(receiver); + assert!( + receiver_out + .app_data + .iter() + .any(|received| received == &msg), + "receiver should deliver priming app data before KeyUpdate" + ); + deliver_packets(&receiver_out.packets, sender); + + *now += Duration::from_millis(10); + sender + .handle_timeout(*now) + .expect("sender checks KeyUpdate threshold"); + let key_update = collect_packets(sender); + if !key_update.is_empty() { + return key_update; + } + } + + panic!("{label}: sender did not emit KeyUpdate within bounded attempts"); +} + +#[cfg(feature = "rcgen")] +fn dtls13_ciphertext_epoch_bits(packets: &[Vec]) -> Vec { + let mut epochs = Vec::new(); + + for packet in packets { + let mut rest = packet.as_slice(); + while !rest.is_empty() { + assert_eq!( + rest[0] & 0b1110_0000, + 0b0010_0000, + "expected DTLS 1.3 ciphertext unified header" + ); + assert_eq!( + rest[0] & 0b0001_0000, + 0, + "CID-bearing DTLS 1.3 records are not expected in these tests" + ); + + let s_flag = rest[0] & 0b0000_1000 != 0; + let l_flag = rest[0] & 0b0000_0100 != 0; + assert!(l_flag, "test records should carry explicit lengths"); + + let seq_len = if s_flag { 2 } else { 1 }; + let header_len = 1 + seq_len + 2; + assert!( + rest.len() >= header_len, + "truncated DTLS 1.3 ciphertext header" + ); + + let len_offset = 1 + seq_len; + let body_len = u16::from_be_bytes([rest[len_offset], rest[len_offset + 1]]) as usize; + let record_len = header_len + body_len; + assert!( + rest.len() >= record_len, + "truncated DTLS 1.3 ciphertext record" + ); + + epochs.push(rest[0] & 0x03); + rest = &rest[record_len..]; + } + } + + epochs +} + +#[cfg(feature = "rcgen")] +fn assert_peer_requested_response_waits_for_local_update_ack( + local: &mut Dtls, + peer: &mut Dtls, + now: &mut Instant, + local_label: &str, + peer_label: &str, + post_overlap: &[u8], +) { + let local_key_update = trigger_key_update(local, peer, now, local_label); + let peer_key_update = trigger_key_update(peer, local, now, peer_label); + + deliver_packets(&peer_key_update, local); + let local_ack_only = collect_packets(local); + assert!( + !local_ack_only.is_empty(), + "local endpoint should ACK peer KeyUpdate without sending the pending response" + ); + assert_eq!( + dtls13_ciphertext_epoch_bits(&local_ack_only), + vec![3], + concat!( + "local ACK must use the retained previous app epoch ", + "so peer can decrypt it before receiving local KeyUpdate" + ) + ); + + deliver_packets(&local_ack_only, peer); + let peer_after_local_ack = collect_packets(peer); + assert!( + peer_after_local_ack.is_empty(), + "local endpoint must not send its pending KeyUpdate response before its local update is ACKed" + ); + + let mut peer_after_local_ack_timeout = Vec::new(); + for _ in 0..20 { + *now += Duration::from_millis(500); + peer.handle_timeout(*now) + .expect("peer checks whether its KeyUpdate is still in flight"); + peer_after_local_ack_timeout.extend(collect_packets(peer)); + } + assert!( + peer_after_local_ack_timeout.is_empty(), + "peer should not retransmit its KeyUpdate after local endpoint ACKs it" + ); + + deliver_packets(&local_key_update, peer); + let peer_ack_for_local_update = collect_packets(peer); + assert!( + !peer_ack_for_local_update.is_empty(), + "peer should ACK local KeyUpdate" + ); + + deliver_packets(&peer_ack_for_local_update, local); + let local_response = collect_packets(local); + assert!( + !local_response.is_empty(), + "pending peer-requested ACK and response should drain once the local update is ACKed" + ); + + deliver_packets(&local_response, peer); + let peer_ack_for_local_response = collect_packets(peer); + assert!( + !peer_ack_for_local_response.is_empty(), + "peer should consume and ACK the delayed KeyUpdate response" + ); + deliver_packets(&peer_ack_for_local_response, local); + let local_after_response_ack = collect_packets(local); + assert!( + local_after_response_ack.is_empty(), + "delayed response must be update_not_requested and must not trigger a further response" + ); + + local + .send_application_data(post_overlap) + .expect("local sends post-overlap data"); + let local_post_overlap = collect_packets(local); + assert_eq!( + local_post_overlap.len(), + 1, + concat!( + "delayed KeyUpdate response must clear the peer-requested response ", + "before auto KeyUpdate can send another UpdateRequested" + ) + ); + deliver_packets(&local_post_overlap, peer); + peer.handle_timeout(*now) + .expect("peer handles post-overlap data"); + let peer_post_overlap = drain_outputs(peer); + assert!( + peer_post_overlap + .app_data + .iter() + .any(|received| received == post_overlap), + "peer should receive post-overlap app data" + ); +} + +#[cfg(feature = "rcgen")] +fn assert_peer_key_update_ack_retries_after_transmit_backpressure( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + label: &str, + receiver_occupant: &[u8], + sender_after_retry: &[u8], +) { + let key_update = trigger_key_update(sender, receiver, now, label); + assert_eq!(key_update.len(), 1, "test expects one KeyUpdate datagram"); + + receiver + .send_application_data(receiver_occupant) + .expect("receiver queues one app-data packet to occupy transmit queue"); + + let err = receiver + .handle_packet(&key_update[0]) + .expect_err("receiver KeyUpdate ACK should hit transmit backpressure"); + assert!( + matches!(err, Error::TransmitQueueFull), + "expected TransmitQueueFull, got {err:?}" + ); + + let receiver_after_backpressure = drain_outputs(receiver); + let occupied = receiver_after_backpressure.packets; + assert_eq!( + occupied.len(), + 1, + "only the pre-existing app-data packet should be queued after failed ACK" + ); + let retry_at = receiver_after_backpressure + .timeout + .expect("receiver should surface immediate retry timeout"); + assert!( + retry_at <= *now, + "pending KeyUpdate ACK retry should be immediately scheduled" + ); + deliver_packets(&occupied, sender); + sender + .handle_timeout(*now) + .expect("sender handles receiver occupant"); + let sender_after_occupant = drain_outputs(sender); + assert!( + sender_after_occupant + .app_data + .iter() + .any(|received| received == receiver_occupant), + "sender should receive the packet that occupied receiver output" + ); + + receiver + .handle_timeout(retry_at) + .expect("receiver retries pending KeyUpdate ACK from poll timeout"); + let retried_ack_and_response = collect_packets(receiver); + assert!( + !retried_ack_and_response.is_empty(), + "receiver should retry the ACK and peer-requested response" + ); + *now += Duration::from_secs(2); + sender + .handle_timeout(*now) + .expect("sender retransmits unacked KeyUpdate"); + let retransmitted_key_update = collect_packets(sender); + assert!( + !retransmitted_key_update.is_empty(), + "sender should retransmit KeyUpdate before receiving ACK" + ); + + let duplicate_occupant = b"receiver-duplicate-keyupdate-occupant"; + receiver + .send_application_data(duplicate_occupant) + .expect("receiver queues app data before duplicate KeyUpdate"); + let err = receiver + .handle_packet(&retransmitted_key_update[0]) + .expect_err("duplicate KeyUpdate ACK should hit transmit backpressure"); + assert!( + matches!(err, Error::TransmitQueueFull), + "expected duplicate KeyUpdate ACK TransmitQueueFull, got {err:?}" + ); + + let receiver_after_duplicate = drain_outputs(receiver); + assert_eq!( + receiver_after_duplicate.packets.len(), + 1, + "duplicate path should only leave the pre-existing app-data packet queued" + ); + let duplicate_retry_at = receiver_after_duplicate + .timeout + .expect("duplicate KeyUpdate ACK should surface retry timeout"); + assert!( + duplicate_retry_at <= *now, + "duplicate KeyUpdate ACK retry should be immediately scheduled" + ); + receiver + .handle_timeout(duplicate_retry_at) + .expect("receiver retries duplicate KeyUpdate ACK from poll timeout"); + let duplicate_ack = collect_packets(receiver); + assert!( + !duplicate_ack.is_empty(), + "receiver should retry duplicate KeyUpdate ACK after queue drains" + ); + deliver_packets(&duplicate_ack, sender); + sender + .handle_timeout(*now) + .expect("sender handles duplicate KeyUpdate ACK retry"); + let sender_after_duplicate_ack = drain_outputs(sender); + deliver_packets(&sender_after_duplicate_ack.packets, receiver); + + deliver_packets(&retried_ack_and_response, sender); + + sender + .handle_timeout(*now) + .expect("sender handles retried ACK and response"); + let sender_ack_for_response = collect_packets(sender); + deliver_packets(&sender_ack_for_response, receiver); + + receiver + .handle_packet(&key_update[0]) + .expect("duplicate KeyUpdate must not derive receive keys a second time"); + receiver + .handle_timeout(*now) + .expect("receiver handles duplicate after retry"); + let duplicate_output = collect_packets(receiver); + deliver_packets(&duplicate_output, sender); + + sender + .send_application_data(sender_after_retry) + .expect("sender sends after duplicate KeyUpdate"); + let post_retry = collect_packets(sender); + assert_eq!(post_retry.len(), 1); + deliver_packets(&post_retry, receiver); + receiver + .handle_timeout(*now) + .expect("receiver handles sender data after duplicate"); + let received = drain_outputs(receiver); + assert!( + received + .app_data + .iter() + .any(|data| data == sender_after_retry), + "receiver should decrypt later sender data after exactly one receive-key update" + ); +} + +#[cfg(feature = "rcgen")] +fn assert_deferred_tail_drains_after_ack_backpressure( + sender: &mut Dtls, + receiver: &mut Dtls, + now: &mut Instant, + priming: &'static [u8], + post_update: &'static [u8], + receiver_occupant: &[u8], +) { + let (key_update_packet, app_packet) = + capture_key_update_and_app_data_packets(sender, receiver, now, priming, post_update); + + receiver + .send_application_data(receiver_occupant) + .expect("receiver queues one app-data packet to occupy transmit queue"); + + let mut combined = key_update_packet; + combined.extend_from_slice(&app_packet); + let err = receiver + .handle_packet(&combined) + .expect_err("receiver KeyUpdate ACK should hit transmit backpressure"); + assert!( + matches!(err, Error::TransmitQueueFull), + "expected TransmitQueueFull, got {err:?}" + ); + + let receiver_after_backpressure = drain_outputs(receiver); + let occupied = receiver_after_backpressure.packets; + assert_eq!( + occupied.len(), + 1, + "only the pre-existing app-data packet should be queued after failed ACK" + ); + let retry_at = receiver_after_backpressure + .timeout + .expect("receiver should surface immediate retry timeout"); + assert!( + retry_at <= *now, + "pending deferred-tail retry should be immediately scheduled" + ); + deliver_packets(&occupied, sender); + sender + .handle_timeout(*now) + .expect("sender handles receiver occupant"); + let sender_after_occupant = drain_outputs(sender); + assert!( + sender_after_occupant + .app_data + .iter() + .any(|received| received == receiver_occupant), + "sender should receive the packet that occupied receiver output" + ); + + receiver + .handle_timeout(retry_at) + .expect("receiver retries pending ACK and drains deferred tail from poll timeout"); + let receiver_after_retry = drain_outputs(receiver); + assert!( + !receiver_after_retry.packets.is_empty(), + "receiver should retry the pending KeyUpdate ACK" + ); + assert_eq!( + receiver_after_retry.app_data, + vec![post_update.to_vec()], + "receiver should deliver deferred same-datagram app data after ACK retry" + ); +} + +/// A peer-requested KeyUpdate response should be emitted in the same progress +/// pass when no local KeyUpdate is in flight. Otherwise it can sit pending +/// until unrelated input or a timeout drives the state machine again. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_peer_requested_key_update_response_drains_immediately_without_local_update() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let server_key_update = + trigger_key_update(&mut server, &mut client, &mut now, "server-immediate"); + + deliver_packets(&server_key_update, &mut client); + let client_response = collect_packets(&mut client); + assert!( + !client_response.is_empty(), + "client should emit KeyUpdate ACK and response without waiting for another tick" + ); + + deliver_packets(&client_response, &mut server); + let server_ack_for_client_response = collect_packets(&mut server); + assert!( + !server_ack_for_client_response.is_empty(), + "server should consume and ACK the immediate KeyUpdate response" + ); + deliver_packets(&server_ack_for_client_response, &mut client); + let client_after_response_ack = collect_packets(&mut client); + assert!( + client_after_response_ack.is_empty(), + "immediate response must be update_not_requested and must not trigger a further response" + ); + + client + .send_application_data(b"client-after-immediate-response") + .expect("client sends post-response data"); + deliver_packets(&collect_packets(&mut client), &mut server); + server + .handle_timeout(now) + .expect("server handles post-response data"); + let server_after_response = drain_outputs(&mut server); + assert_eq!( + server_after_response.app_data, + vec![b"client-after-immediate-response".to_vec()] + ); +} + +/// If the client's transmit queue is full while ACKing a peer KeyUpdate, the +/// client must not reprocess the same KeyUpdate on retry. Receive keys and peer +/// handshake sequence advancement stay transactional with ACK retry state. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_client_peer_key_update_ack_backpressure_retries_without_rederiving() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .max_queue_tx(1) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_peer_key_update_ack_retries_after_transmit_backpressure( + &mut server, + &mut client, + &mut now, + "server-to-client-backpressure", + &[b'c'; 1150], + b"server-after-client-ack-retry", + ); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_client_deferred_tail_drains_after_ack_backpressure_retry() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .max_queue_tx(1) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_deferred_tail_drains_after_ack_backpressure( + &mut server, + &mut client, + &mut now, + b"server-primes-client-backpressure-tail", + b"server-tail-after-client-ack-retry", + &[b'c'; 1150], + ); +} + +/// A peer-requested KeyUpdate response must not replace an in-flight locally +/// initiated KeyUpdate. The response is delayed until the local KeyUpdate is +/// ACKed, then sent normally. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_peer_requested_key_update_response_waits_for_local_update_ack() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_peer_requested_response_waits_for_local_update_ack( + &mut client, + &mut server, + &mut now, + "client-local", + "server-peer", + b"client-post-overlap", + ); +} + +/// Same as the client immediate-response case, but with the server as the +/// responder so the server-side pending response drain is covered too. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_server_peer_requested_key_update_response_drains_immediately_without_local_update() { + use dimpl::Config; + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let client_key_update = trigger_key_update(&mut client, &mut server, &mut now, "client-peer"); + + deliver_packets(&client_key_update, &mut server); + let server_response = collect_packets(&mut server); + assert!( + !server_response.is_empty(), + "server should emit KeyUpdate ACK and response without waiting for another tick" + ); + + deliver_packets(&server_response, &mut client); + let client_ack_for_server_response = collect_packets(&mut client); + assert!( + !client_ack_for_server_response.is_empty(), + "client should consume and ACK the immediate KeyUpdate response" + ); + deliver_packets(&client_ack_for_server_response, &mut server); + let server_after_response_ack = collect_packets(&mut server); + assert!( + server_after_response_ack.is_empty(), + "immediate response must be update_not_requested and must not trigger a further response" + ); + + server + .send_application_data(b"server-after-immediate-response") + .expect("server sends post-response data"); + deliver_packets(&collect_packets(&mut server), &mut client); + client + .handle_timeout(now) + .expect("client handles post-response data"); + let client_after_response = drain_outputs(&mut client); + assert!( + client_after_response + .app_data + .iter() + .any(|received| received == b"server-after-immediate-response"), + "client should receive post-response app data" + ); +} + +/// Same as the client backpressure case, but with the server receiving the +/// peer KeyUpdate while its transmit queue is full. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_server_peer_key_update_ack_backpressure_retries_without_rederiving() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .max_queue_tx(1) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_peer_key_update_ack_retries_after_transmit_backpressure( + &mut client, + &mut server, + &mut now, + "client-to-server-backpressure", + &[b's'; 1150], + b"client-after-server-ack-retry", + ); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_server_deferred_tail_drains_after_ack_backpressure_retry() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(1) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .max_queue_tx(1) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + assert_deferred_tail_drains_after_ack_backpressure( + &mut client, + &mut server, + &mut now, + b"client-primes-server-backpressure-tail", + b"client-tail-after-server-ack-retry", + &[b's'; 1150], + ); +} + /// Test that a reordered packet captured before a KeyUpdate is accepted when /// delivered alongside other packets during the transition. The packet is from /// the same epoch and arrives before any new-epoch records, so the replay @@ -555,6 +1663,122 @@ fn dtls13_key_update_with_packet_loss() { ); } +/// Test that a sender can recover when the KeyUpdate reaches the peer but the +/// peer's ACK is lost. The peer must treat the retransmitted KeyUpdate as a +/// duplicate that triggers a fresh ACK; otherwise the sender remains stuck with +/// previous send keys retained forever. +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_key_update_lost_ack_retransmission_gets_acknowledged() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let client_config = Arc::new( + Config::builder() + .aead_encryption_limit(2) + .build() + .expect("build client config"), + ); + let server_config = Arc::new( + Config::builder() + .aead_encryption_limit(16) + .build() + .expect("build server config"), + ); + + let mut now = Instant::now(); + + let mut client = Dtls::new_13(client_config, client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(server_config, server_cert, now); + server.set_active(false); + + now = complete_dtls13_handshake(&mut client, &mut server, now); + + let client_key_update = trigger_key_update(&mut client, &mut server, &mut now, "client"); + + deliver_packets(&client_key_update, &mut server); + let lost_server_ack = collect_packets(&mut server); + assert!(!lost_server_ack.is_empty(), "server should ACK KeyUpdate"); + + trigger_timeout(&mut client, &mut now); + client + .handle_timeout(now) + .expect("arm KeyUpdate flight timer"); + trigger_timeout(&mut client, &mut now); + let retransmitted_key_update = collect_packets(&mut client); + assert!( + !retransmitted_key_update.is_empty(), + "client should retransmit unacked KeyUpdate" + ); + + deliver_packets(&retransmitted_key_update, &mut server); + let server_ack_for_retransmit = collect_packets(&mut server); + assert!( + !server_ack_for_retransmit.is_empty(), + "server should ACK duplicate retransmitted KeyUpdate after the original ACK was lost" + ); + assert_eq!( + dtls13_ciphertext_epoch_bits(&server_ack_for_retransmit), + vec![3, 0], + concat!( + "server should retransmit the old-epoch KeyUpdate response ", + "before appending the fresh current-epoch ACK" + ) + ); + + deliver_packets(&server_ack_for_retransmit, &mut client); + client + .handle_timeout(now) + .expect("client handles ACK for retransmitted KeyUpdate"); + let client_ack_for_server_response = collect_packets(&mut client); + assert!( + !client_ack_for_server_response.is_empty(), + "client should ACK the server's retransmitted KeyUpdate response" + ); + + let mut client_after_retransmit_ack = Vec::new(); + for _ in 0..20 { + now += Duration::from_millis(500); + client + .handle_timeout(now) + .expect("client checks whether KeyUpdate is still in flight"); + client_after_retransmit_ack.extend(collect_packets(&mut client)); + } + assert!( + client_after_retransmit_ack.is_empty(), + "client should not retransmit KeyUpdate after ACKing the duplicate retransmission" + ); + + deliver_packets(&client_ack_for_server_response, &mut server); + server + .handle_timeout(now) + .expect("server handles ACK for retransmitted response"); + let server_after_ack = collect_packets(&mut server); + assert!( + server_after_ack.is_empty(), + "server should clear its retransmitted KeyUpdate response after ACK" + ); + + client + .send_application_data(b"post-lost-ack") + .expect("client sends post-recovery data"); + deliver_packets(&collect_packets(&mut client), &mut server); + server + .handle_timeout(now) + .expect("server handles post-recovery data"); + let server_after_recovery = drain_outputs(&mut server); + assert_eq!( + server_after_recovery.app_data, + vec![b"post-lost-ack".to_vec()] + ); +} + /// Test that high-frequency KeyUpdates work correctly. With the minimum /// AEAD limit of 2, nearly every message triggers a KeyUpdate. This stress- /// tests the key rotation machinery and epoch tracking under extreme churn,