From 5bc07df0cba1f351b58a9e764d8c5a2ffe945a89 Mon Sep 17 00:00:00 2001 From: Ronen Ulanovsky Date: Sat, 30 May 2026 18:00:38 +0300 Subject: [PATCH] dtls: filter pre-clienthello queue poison --- CHANGELOG.md | 1 + src/dtls12/engine.rs | 44 ++++++++++++++----- src/dtls12/incoming.rs | 38 +++++++++++++++- src/dtls12/server.rs | 24 +++++++--- src/dtls13/engine.rs | 60 +++++++++++++++++++++---- src/dtls13/incoming.rs | 46 ++++++++++++++++++- src/dtls13/server.rs | 65 ++++++++++++++++++++++----- tests/auto/server_fallback.rs | 41 +++++++++++++++++ tests/dtls12/edge.rs | 83 +++++++++++++++++++++++++++++++++++ tests/dtls13/edge.rs | 80 +++++++++++++++++++++++++++++++++ 10 files changed, 446 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e3bdafd..a51b0da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Filter pre-ClientHello queue poison records #140 * Represent DTLS wire-code identifiers as compact newtypes (breaking) #137 * Make public errors structured and fatal-only (breaking) #134 diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 3bd86e9..9e1ce64 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -229,22 +229,34 @@ impl Engine { Ok(()) } - /// Insert a parsed datagram into the receive queue. - fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { - // Capacity guard before iterating records. - if self.queue_rx.len() >= self.config.max_queue_rx() { - warn!( - "Receive queue full (max {}): {:?}", - self.config.max_queue_rx(), - self.queue_rx - ); - return Err(Error::ReceiveQueueFull); + pub(crate) fn parse_packet_filtering_records( + &mut self, + packet: &[u8], + keep_record: impl FnMut(&Record) -> bool, + ) -> Result<(), InternalError> { + let cs = self.cipher_suite; + let incoming = Incoming::parse_packet_filtering_records(packet, self, cs, keep_record)?; + if let Some(incoming) = incoming { + self.insert_incoming(incoming)?; } + Ok(()) + } + + /// Insert a parsed datagram into the receive queue. + fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { // Dispatch to specialized handlers if incoming.first().first_handshake().is_some() { self.insert_incoming_handshake(incoming) } else { + if self.queue_rx.len() >= self.config.max_queue_rx() { + warn!( + "Receive queue full (max {}): {:?}", + self.config.max_queue_rx(), + self.queue_rx + ); + return Err(Error::ReceiveQueueFull); + } self.insert_incoming_non_handshake(incoming) } } @@ -307,6 +319,14 @@ impl Engine { match search_result { Err(index) => { + if self.queue_rx.len() >= self.config.max_queue_rx() { + warn!( + "Receive queue full (max {}): {:?}", + self.config.max_queue_rx(), + self.queue_rx + ); + return Err(Error::ReceiveQueueFull); + } // Insert in order of handshake key self.queue_rx.insert(index, incoming); } @@ -590,6 +610,10 @@ impl Engine { self.has_complete_handshake_with_seq(wanted, self.peer_handshake_seq_no) } + pub(crate) fn expected_peer_handshake_seq_no(&self) -> u16 { + self.peer_handshake_seq_no + } + fn has_complete_handshake_with_seq(&mut self, wanted: MessageType, expected_seq: u16) -> bool { let mut skip_handled = self .queue_rx diff --git a/src/dtls12/incoming.rs b/src/dtls12/incoming.rs index ca21b6c..9ad798a 100644 --- a/src/dtls12/incoming.rs +++ b/src/dtls12/incoming.rs @@ -58,6 +58,23 @@ impl Incoming { Ok(Some(Incoming { records })) } + + pub(crate) fn parse_packet_filtering_records( + packet: &[u8], + decrypt: &mut dyn RecordHandler, + cs: Option, + keep_record: impl FnMut(&Record) -> bool, + ) -> Result, InternalError> { + let records = Records::parse_filtering_records(packet, decrypt, cs, keep_record)?; + + if records.records.is_empty() { + return Ok(None); + } + + let records = Box::new(records); + + Ok(Some(Incoming { records })) + } } /// A number of records parsed from a single UDP packet. @@ -68,11 +85,21 @@ pub struct Records { impl Records { pub fn parse( + packet: &[u8], + decrypt: &mut dyn RecordHandler, + cs: Option, + ) -> Result { + Self::parse_filtering_records(packet, decrypt, cs, |_| true) + } + + fn parse_filtering_records( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, + mut keep_record: impl FnMut(&Record) -> bool, ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut parsed_record_count = 0usize; // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { @@ -93,9 +120,18 @@ impl Records { match Record::parse(record_slice, decrypt, cs) { Ok(record) => { if let Some(record) = record { - if parsed_records.try_push(record).is_err() { + if parsed_record_count >= parsed_records.capacity() { return Err(InternalError::too_many_records()); } + parsed_record_count += 1; + + if !keep_record(&record) { + trace!("Discarding filtered rec"); + } else { + parsed_records + .try_push(record) + .expect("parsed record count is capacity-checked"); + } } else { trace!("Discarding replayed rec"); } diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 48e22f1..8e83840 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -189,11 +189,25 @@ impl Server { } pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { - match self - .engine - .parse_packet(packet) - .and_then(|_| self.make_progress()) - { + let result = if self.state == State::AwaitClientHello { + let expected_seq = self.engine.expected_peer_handshake_seq_no(); + self.engine + .parse_packet_filtering_records(packet, |record| { + record.record().sequence.epoch == 0 + && (record.record().content_type == ContentType::Alert + || (record.record().content_type == ContentType::Handshake + && record.first_handshake().is_some_and(|h| { + h.header.msg_type == MessageType::ClientHello + && (h.header.message_seq == expected_seq + || h.header.message_seq.saturating_add(1) + == expected_seq) + }))) + }) + } else { + self.engine.parse_packet(packet) + }; + + match result.and_then(|_| self.make_progress()) { Ok(()) => Ok(()), Err(e) => e.into_public_error().map_or(Ok(()), Err), } diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index d67ccc3..53001c4 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -55,6 +55,9 @@ pub struct Engine { /// Queue of incoming packets. queue_rx: QueueRx, + /// Last accepted pre-ClientHello packet after parser-side filtering. + last_parse_queueable_packet: Option, + /// Queue of outgoing packets. queue_tx: QueueTx, @@ -220,6 +223,7 @@ impl Engine { buffers_free: BufferPool::default(), sequence_epoch_0: Sequence::new(0), queue_rx: QueueRx::new(), + last_parse_queueable_packet: None, queue_tx: QueueTx::new(), cipher_suite: None, hs_send_keys: None, @@ -332,6 +336,7 @@ impl Engine { pub fn parse_packet(&mut self, packet: &[u8]) -> Result<(), InternalError> { let cs = self.cipher_suite; let incoming = Incoming::parse_packet(packet, self, cs)?; + self.last_parse_queueable_packet = None; if let Some(incoming) = incoming { self.insert_incoming(incoming)?; } @@ -339,19 +344,46 @@ impl Engine { Ok(()) } - fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { - if self.queue_rx.len() >= self.config.max_queue_rx() { - warn!( - "Receive queue full (max {}): {:?}", - self.config.max_queue_rx(), - self.queue_rx - ); - return Err(Error::ReceiveQueueFull); + pub(crate) fn parse_packet_filtering_records( + &mut self, + packet: &[u8], + retain_queueable_packet: bool, + keep_record: impl FnMut(&Record) -> bool, + ) -> Result<(), InternalError> { + let cs = self.cipher_suite; + let incoming = Incoming::parse_packet_filtering_records(packet, self, cs, keep_record)?; + self.last_parse_queueable_packet = None; + if let Some(incoming) = incoming { + let queueable_packet = (retain_queueable_packet + && incoming.first().first_handshake().is_some()) + .then(|| incoming.to_datagram_buf()); + self.insert_incoming(incoming)?; + self.last_parse_queueable_packet = queueable_packet; } + Ok(()) + } + + pub(crate) fn take_last_parse_queueable_packet(&mut self) -> Option { + self.last_parse_queueable_packet.take() + } + + pub(crate) fn clear_last_parse_queueable_packet(&mut self) { + self.last_parse_queueable_packet = None; + } + + fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { if incoming.first().first_handshake().is_some() { self.insert_incoming_handshake(incoming) } else { + if self.queue_rx.len() >= self.config.max_queue_rx() { + warn!( + "Receive queue full (max {}): {:?}", + self.config.max_queue_rx(), + self.queue_rx + ); + return Err(Error::ReceiveQueueFull); + } self.insert_incoming_non_handshake(incoming) } } @@ -406,6 +438,14 @@ impl Engine { match search_result { Err(index) => { + if self.queue_rx.len() >= self.config.max_queue_rx() { + warn!( + "Receive queue full (max {}): {:?}", + self.config.max_queue_rx(), + self.queue_rx + ); + return Err(Error::ReceiveQueueFull); + } // Track received record numbers for ACK generation for record in incoming.records().iter() { let seq = record.record().sequence; @@ -726,6 +766,10 @@ impl Engine { self.has_complete_handshake_with_seq(wanted, self.peer_handshake_seq_no) } + pub(crate) fn expected_peer_handshake_seq_no(&self) -> u16 { + self.peer_handshake_seq_no + } + fn has_complete_handshake_with_seq(&mut self, wanted: MessageType, expected_seq: u16) -> bool { let mut skip_handled = self .queue_rx diff --git a/src/dtls13/incoming.rs b/src/dtls13/incoming.rs index f617a41..a4c4f67 100644 --- a/src/dtls13/incoming.rs +++ b/src/dtls13/incoming.rs @@ -29,6 +29,14 @@ impl Incoming { pub fn into_records(self) -> impl Iterator { self.records.records.into_iter() } + + pub(crate) fn to_datagram_buf(&self) -> Buf { + let mut packet = Buf::new(); + for record in &self.records.records { + packet.extend_from_slice(record.buffer()); + } + packet + } } impl Incoming { @@ -57,6 +65,23 @@ impl Incoming { Ok(Some(Incoming { records })) } + + pub(crate) fn parse_packet_filtering_records( + packet: &[u8], + decrypt: &mut dyn RecordHandler, + cs: Option, + keep_record: impl FnMut(&Record) -> bool, + ) -> Result, InternalError> { + let records = Records::parse_filtering_records(packet, decrypt, cs, keep_record)?; + + if records.records.is_empty() { + return Ok(None); + } + + let records = Box::new(records); + + Ok(Some(Incoming { records })) + } } /// A number of records parsed from a single UDP packet. @@ -67,11 +92,21 @@ pub struct Records { impl Records { pub fn parse( + packet: &[u8], + decrypt: &mut dyn RecordHandler, + cs: Option, + ) -> Result { + Self::parse_filtering_records(packet, decrypt, cs, |_| true) + } + + fn parse_filtering_records( mut packet: &[u8], decrypt: &mut dyn RecordHandler, cs: Option, + mut keep_record: impl FnMut(&Record) -> bool, ) -> Result { let mut parsed_records: ArrayVec = ArrayVec::new(); + let mut parsed_record_count = 0usize; // Find record boundaries and copy each record ONCE from the packet while !packet.is_empty() { @@ -132,9 +167,18 @@ impl Records { match Record::parse(record_slice, decrypt, cs) { Ok(record) => { if let Some(record) = record { - if parsed_records.try_push(record).is_err() { + if parsed_record_count >= parsed_records.capacity() { return Err(InternalError::too_many_records()); } + parsed_record_count += 1; + + if !keep_record(&record) { + trace!("Discarding filtered rec"); + } else { + parsed_records + .try_push(record) + .expect("parsed record count is capacity-checked"); + } } else { trace!("Discarding replayed rec"); } diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 4711c8d..472d9c5 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -47,6 +47,7 @@ use crate::dtls13::message::CompressionMethod; use crate::dtls13::message::ContentType; use crate::dtls13::message::DistinguishedName; use crate::dtls13::message::Dtls13CipherSuite; +use crate::dtls13::message::Dtls13Record; use crate::dtls13::message::Extension; use crate::dtls13::message::ExtensionType; use crate::dtls13::message::KeyShareClientHello; @@ -238,21 +239,63 @@ impl Server { } pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { - // In auto-sense mode, buffer raw packets while still waiting for - // the ClientHello so they can be replayed to Server12 on fallback. - if self.auto_mode && self.state == State::AwaitClientHello { - // Cap buffered fragments to prevent unbounded growth from malicious traffic - if self.retained_hello.len() >= MAX_RETAINED_CLIENT_HELLO { + let awaiting_client_hello = self.state == State::AwaitClientHello; + let mut queueable_packet = None; + + if awaiting_client_hello { + if self.auto_mode && self.retained_hello.len() >= MAX_RETAINED_CLIENT_HELLO { return Err(Error::TooManyClientHelloFragments); } - self.retained_hello.push_back(packet.to_buf()); + + let expected_seq = self.engine.expected_peer_handshake_seq_no(); + match self + .engine + .parse_packet_filtering_records(packet, self.auto_mode, |record| { + !Dtls13Record::is_ciphertext_header(record.buffer()[0]) + && (record.record().content_type == ContentType::Alert + || (record.record().content_type == ContentType::Handshake + && record.first_handshake().is_some_and(|h| { + h.header.msg_type == MessageType::ClientHello + && (h.header.message_seq == expected_seq + || h.header.message_seq.saturating_add(1) + == expected_seq) + }))) + }) { + Ok(()) => { + if self.auto_mode { + queueable_packet = self.engine.take_last_parse_queueable_packet(); + } else { + self.engine.clear_last_parse_queueable_packet(); + } + } + Err(e) => { + self.engine.clear_last_parse_queueable_packet(); + if let Some(err) = e.into_public_error() { + return Err(err); + } + return Ok(()); + } + } + } else { + match self.engine.parse_packet(packet) { + Ok(()) => {} + Err(e) => { + self.engine.clear_last_parse_queueable_packet(); + if let Some(err) = e.into_public_error() { + return Err(err); + } + return Ok(()); + } + } } - match self - .engine - .parse_packet(packet) - .and_then(|_| self.make_progress()) - { + if self.auto_mode && awaiting_client_hello { + if let Some(packet) = queueable_packet { + self.retained_hello.push_back(packet); + } + } + + match self.make_progress() { Ok(()) => {} Err(e) => { if let Some(err) = e.into_public_error() { diff --git a/tests/auto/server_fallback.rs b/tests/auto/server_fallback.rs index b694b24..ad2347f 100644 --- a/tests/auto/server_fallback.rs +++ b/tests/auto/server_fallback.rs @@ -53,6 +53,14 @@ fn run_handshake( ) } +fn dtls13_future_epoch_ciphertext(seq: u16) -> Vec { + let mut out = Vec::new(); + out.push(0x2E); // fixed bits, S=1, L=1, epoch_bits=2 + out.extend_from_slice(&seq.to_be_bytes()); + out.extend_from_slice(&0u16.to_be_bytes()); // empty ciphertext + out +} + // ============================================================================ // Auto server + explicit DTLS 1.3 client → DTLS 1.3 (no fallback) // ============================================================================ @@ -136,6 +144,39 @@ fn auto_server_protocol_version_pending() { assert_eq!(sv, Some(ProtocolVersion::DTLS1_2)); } +#[test] +#[cfg(feature = "rcgen")] +fn auto_server_fallback_ignores_prehandshake_dtls13_ciphertext_poison() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().unwrap(); + let server_cert = generate_self_signed_certificate().unwrap(); + let config = Arc::new( + Config::builder() + .max_queue_rx(1) + .build() + .expect("build config"), + ); + + let mut client = Dtls::new_12(Arc::clone(&config), client_cert, Instant::now()); + client.set_active(true); + + let mut server = Dtls::new_auto(config, server_cert, Instant::now()); + + server + .handle_packet(&dtls13_future_epoch_ciphertext(0)) + .expect("pre-ClientHello DTLS 1.3 ciphertext should not poison fallback"); + + let (cc, sc, cv, sv) = run_handshake(&mut client, &mut server); + + assert!(cc, "Client should connect"); + assert!(sc, "Server should connect"); + assert_eq!(cv, Some(ProtocolVersion::DTLS1_2)); + assert_eq!(sv, Some(ProtocolVersion::DTLS1_2)); +} + #[test] #[cfg(feature = "rcgen")] fn auto_server_with_dtls12_client() { diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 6fd95d3..5e68dd1 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -62,6 +62,89 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize { } } +fn dtls12_change_cipher_spec_record(seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(20); // ChangeCipherSpec + out.extend_from_slice(&[0xFE, 0xFD]); // DTLS 1.2 + out.extend_from_slice(&0u16.to_be_bytes()); // epoch 0 + out.extend_from_slice(&seq.to_be_bytes()[2..]); // u48 sequence number + out.extend_from_slice(&1u16.to_be_bytes()); // payload length + out.push(1); // change_cipher_spec + out +} + +#[cfg(feature = "rcgen")] +fn dtls12_client_server_with_max_queue(max_queue_rx: usize) -> (Dtls, Dtls, Instant) { + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = Arc::new( + Config::builder() + .max_queue_rx(max_queue_rx) + .build() + .expect("build config"), + ); + + let now = Instant::now(); + let mut client = Dtls::new_12(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_12(config, server_cert, now); + server.set_active(false); + + (client, server, now) +} + +#[cfg(feature = "rcgen")] +fn dtls12_client_hello(client: &mut Dtls, now: Instant) -> Vec> { + client.handle_timeout(now).expect("client timeout"); + let packets = collect_packets(client); + assert!(!packets.is_empty(), "client should emit ClientHello"); + packets +} + +#[cfg(feature = "rcgen")] +fn dtls12_server_responds_to_client_hello(client: &mut Dtls, server: &mut Dtls, now: Instant) { + let packets = dtls12_client_hello(client, now); + for packet in &packets { + server + .handle_packet(packet) + .expect("ClientHello should not be blocked"); + } + + server.handle_timeout(now).expect("server timeout"); + let server_out = collect_packets(server); + assert!( + !server_out.is_empty(), + "server should respond after ClientHello" + ); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_prehandshake_future_epoch_records_do_not_block_client_hello() { + let _ = env_logger::try_init(); + let (mut client, mut server, now) = dtls12_client_server_with_max_queue(1); + + server + .handle_packet(&dtls12_epoch1_record(0, 0)) + .expect("pre-handshake future-epoch record should be tolerated"); + + dtls12_server_responds_to_client_hello(&mut client, &mut server, now); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_prehandshake_plaintext_non_client_hello_records_do_not_block_client_hello() { + let _ = env_logger::try_init(); + let (mut client, mut server, now) = dtls12_client_server_with_max_queue(1); + + server + .handle_packet(&dtls12_change_cipher_spec_record(0)) + .expect("pre-handshake CCS should not occupy receive queue"); + + dtls12_server_responds_to_client_hello(&mut client, &mut server, now); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_malformed_datagram_is_discarded_without_processing_alerts() { diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 33e9442..5594329 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -50,6 +50,86 @@ fn dtls13_ack_record_for_records(seq: u64, records: &[(u64, u64)]) -> Vec { out } +fn dtls13_future_epoch_ciphertext(seq: u16) -> Vec { + let mut out = Vec::new(); + out.push(0x2E); // fixed bits, S=1, L=1, epoch_bits=2 + out.extend_from_slice(&seq.to_be_bytes()); + out.extend_from_slice(&0u16.to_be_bytes()); // empty ciphertext + out +} + +#[cfg(feature = "rcgen")] +fn dtls13_client_server_with_max_queue(max_queue_rx: usize) -> (Dtls, Dtls, Instant) { + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = Arc::new( + Config::builder() + .max_queue_rx(max_queue_rx) + .build() + .expect("build config"), + ); + + let now = Instant::now(); + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + (client, server, now) +} + +#[cfg(feature = "rcgen")] +fn dtls13_client_hello(client: &mut Dtls, now: Instant) -> Vec> { + client.handle_timeout(now).expect("client timeout"); + let packets = collect_packets(client); + assert!(!packets.is_empty(), "client should emit ClientHello"); + packets +} + +#[cfg(feature = "rcgen")] +fn dtls13_server_responds_to_client_hello(client: &mut Dtls, server: &mut Dtls, now: Instant) { + let packets = dtls13_client_hello(client, now); + for packet in &packets { + server + .handle_packet(packet) + .expect("ClientHello should not be blocked"); + } + + server.handle_timeout(now).expect("server timeout"); + let server_out = collect_packets(server); + assert!( + !server_out.is_empty(), + "server should respond after ClientHello" + ); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_prehandshake_future_epoch_records_do_not_block_client_hello() { + let _ = env_logger::try_init(); + let (mut client, mut server, now) = dtls13_client_server_with_max_queue(1); + + server + .handle_packet(&dtls13_future_epoch_ciphertext(0)) + .expect("pre-handshake future-epoch ciphertext should be tolerated"); + + dtls13_server_responds_to_client_hello(&mut client, &mut server, now); +} + +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_prehandshake_plaintext_non_client_hello_records_do_not_block_client_hello() { + let _ = env_logger::try_init(); + let (mut client, mut server, now) = dtls13_client_server_with_max_queue(1); + + server + .handle_packet(&dtls13_ack_record(0)) + .expect("pre-handshake ACK should not occupy receive queue"); + + dtls13_server_responds_to_client_hello(&mut client, &mut server, now); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_malformed_datagram_is_discarded_without_processing_alerts() {