Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Unreleased

* Filter pre-ClientHello queue poison records #140
* Represent DTLS wire-code identifiers as compact newtypes (breaking) #137
* Make public errors structured and fatal-only (breaking) #134

Expand Down
44 changes: 34 additions & 10 deletions src/dtls12/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,34 @@ impl Engine {
Ok(())
}

/// Insert a parsed datagram into the receive queue.
fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
// Capacity guard before iterating records.
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
"Receive queue full (max {}): {:?}",
self.config.max_queue_rx(),
self.queue_rx
);
return Err(Error::ReceiveQueueFull);
pub(crate) fn parse_packet_filtering_records(
&mut self,
packet: &[u8],
keep_record: impl FnMut(&Record) -> bool,
) -> Result<(), InternalError> {
let cs = self.cipher_suite;
let incoming = Incoming::parse_packet_filtering_records(packet, self, cs, keep_record)?;
if let Some(incoming) = incoming {
self.insert_incoming(incoming)?;
}

Ok(())
}

/// Insert a parsed datagram into the receive queue.
fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
// Dispatch to specialized handlers
if incoming.first().first_handshake().is_some() {
self.insert_incoming_handshake(incoming)
} else {
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
"Receive queue full (max {}): {:?}",
self.config.max_queue_rx(),
self.queue_rx
);
return Err(Error::ReceiveQueueFull);
}
self.insert_incoming_non_handshake(incoming)
}
}
Expand Down Expand Up @@ -307,6 +319,14 @@ impl Engine {

match search_result {
Err(index) => {
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
"Receive queue full (max {}): {:?}",
self.config.max_queue_rx(),
self.queue_rx
);
return Err(Error::ReceiveQueueFull);
}
// Insert in order of handshake key
self.queue_rx.insert(index, incoming);
}
Expand Down Expand Up @@ -590,6 +610,10 @@ impl Engine {
self.has_complete_handshake_with_seq(wanted, self.peer_handshake_seq_no)
}

pub(crate) fn expected_peer_handshake_seq_no(&self) -> u16 {
self.peer_handshake_seq_no
}

fn has_complete_handshake_with_seq(&mut self, wanted: MessageType, expected_seq: u16) -> bool {
let mut skip_handled = self
.queue_rx
Expand Down
38 changes: 37 additions & 1 deletion src/dtls12/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ impl Incoming {

Ok(Some(Incoming { records }))
}

pub(crate) fn parse_packet_filtering_records(
packet: &[u8],
decrypt: &mut dyn RecordHandler,
cs: Option<Dtls12CipherSuite>,
keep_record: impl FnMut(&Record) -> bool,
) -> Result<Option<Self>, InternalError> {
let records = Records::parse_filtering_records(packet, decrypt, cs, keep_record)?;

if records.records.is_empty() {
return Ok(None);
}

let records = Box::new(records);

Ok(Some(Incoming { records }))
}
}

/// A number of records parsed from a single UDP packet.
Expand All @@ -68,11 +85,21 @@ pub struct Records {

impl Records {
pub fn parse(
packet: &[u8],
decrypt: &mut dyn RecordHandler,
cs: Option<Dtls12CipherSuite>,
) -> Result<Records, InternalError> {
Self::parse_filtering_records(packet, decrypt, cs, |_| true)
}

fn parse_filtering_records(
mut packet: &[u8],
decrypt: &mut dyn RecordHandler,
cs: Option<Dtls12CipherSuite>,
mut keep_record: impl FnMut(&Record) -> bool,
) -> Result<Records, InternalError> {
let mut parsed_records: ArrayVec<Record, 8> = ArrayVec::new();
let mut parsed_record_count = 0usize;

// Find record boundaries and copy each record ONCE from the packet
while !packet.is_empty() {
Expand All @@ -93,9 +120,18 @@ impl Records {
match Record::parse(record_slice, decrypt, cs) {
Ok(record) => {
if let Some(record) = record {
if parsed_records.try_push(record).is_err() {
if parsed_record_count >= parsed_records.capacity() {
return Err(InternalError::too_many_records());
}
parsed_record_count += 1;

if !keep_record(&record) {
trace!("Discarding filtered rec");
} else {
parsed_records
.try_push(record)
.expect("parsed record count is capacity-checked");
}
} else {
trace!("Discarding replayed rec");
}
Expand Down
24 changes: 19 additions & 5 deletions src/dtls12/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,25 @@ impl Server {
}

pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> {
match self
.engine
.parse_packet(packet)
.and_then(|_| self.make_progress())
{
let result = if self.state == State::AwaitClientHello {
let expected_seq = self.engine.expected_peer_handshake_seq_no();
self.engine
.parse_packet_filtering_records(packet, |record| {
record.record().sequence.epoch == 0
&& (record.record().content_type == ContentType::Alert
|| (record.record().content_type == ContentType::Handshake
&& record.first_handshake().is_some_and(|h| {
h.header.msg_type == MessageType::ClientHello
&& (h.header.message_seq == expected_seq
|| h.header.message_seq.saturating_add(1)
== expected_seq)
})))
})
} else {
self.engine.parse_packet(packet)
};

match result.and_then(|_| self.make_progress()) {
Ok(()) => Ok(()),
Err(e) => e.into_public_error().map_or(Ok(()), Err),
}
Expand Down
60 changes: 52 additions & 8 deletions src/dtls13/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ pub struct Engine {
/// Queue of incoming packets.
queue_rx: QueueRx,

/// Last accepted pre-ClientHello packet after parser-side filtering.
last_parse_queueable_packet: Option<Buf>,

/// Queue of outgoing packets.
queue_tx: QueueTx,

Expand Down Expand Up @@ -220,6 +223,7 @@ impl Engine {
buffers_free: BufferPool::default(),
sequence_epoch_0: Sequence::new(0),
queue_rx: QueueRx::new(),
last_parse_queueable_packet: None,
queue_tx: QueueTx::new(),
cipher_suite: None,
hs_send_keys: None,
Expand Down Expand Up @@ -332,26 +336,54 @@ impl Engine {
pub fn parse_packet(&mut self, packet: &[u8]) -> Result<(), InternalError> {
let cs = self.cipher_suite;
let incoming = Incoming::parse_packet(packet, self, cs)?;
self.last_parse_queueable_packet = None;
if let Some(incoming) = incoming {
self.insert_incoming(incoming)?;
}

Ok(())
}

fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
"Receive queue full (max {}): {:?}",
self.config.max_queue_rx(),
self.queue_rx
);
return Err(Error::ReceiveQueueFull);
pub(crate) fn parse_packet_filtering_records(
&mut self,
packet: &[u8],
retain_queueable_packet: bool,
keep_record: impl FnMut(&Record) -> bool,
) -> Result<(), InternalError> {
let cs = self.cipher_suite;
let incoming = Incoming::parse_packet_filtering_records(packet, self, cs, keep_record)?;
self.last_parse_queueable_packet = None;
if let Some(incoming) = incoming {
let queueable_packet = (retain_queueable_packet
&& incoming.first().first_handshake().is_some())
.then(|| incoming.to_datagram_buf());
self.insert_incoming(incoming)?;
self.last_parse_queueable_packet = queueable_packet;
}

Ok(())
}

pub(crate) fn take_last_parse_queueable_packet(&mut self) -> Option<Buf> {
self.last_parse_queueable_packet.take()
}

pub(crate) fn clear_last_parse_queueable_packet(&mut self) {
self.last_parse_queueable_packet = None;
}

fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
if incoming.first().first_handshake().is_some() {
self.insert_incoming_handshake(incoming)
} else {
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
"Receive queue full (max {}): {:?}",
self.config.max_queue_rx(),
self.queue_rx
);
return Err(Error::ReceiveQueueFull);
}
self.insert_incoming_non_handshake(incoming)
}
}
Expand Down Expand Up @@ -406,6 +438,14 @@ impl Engine {

match search_result {
Err(index) => {
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
"Receive queue full (max {}): {:?}",
self.config.max_queue_rx(),
self.queue_rx
);
return Err(Error::ReceiveQueueFull);
}
// Track received record numbers for ACK generation
for record in incoming.records().iter() {
let seq = record.record().sequence;
Expand Down Expand Up @@ -726,6 +766,10 @@ impl Engine {
self.has_complete_handshake_with_seq(wanted, self.peer_handshake_seq_no)
}

pub(crate) fn expected_peer_handshake_seq_no(&self) -> u16 {
self.peer_handshake_seq_no
}

fn has_complete_handshake_with_seq(&mut self, wanted: MessageType, expected_seq: u16) -> bool {
let mut skip_handled = self
.queue_rx
Expand Down
46 changes: 45 additions & 1 deletion src/dtls13/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ impl Incoming {
pub fn into_records(self) -> impl Iterator<Item = Record> {
self.records.records.into_iter()
}

pub(crate) fn to_datagram_buf(&self) -> Buf {
let mut packet = Buf::new();
for record in &self.records.records {
packet.extend_from_slice(record.buffer());
}
packet
}
}

impl Incoming {
Expand Down Expand Up @@ -57,6 +65,23 @@ impl Incoming {

Ok(Some(Incoming { records }))
}

pub(crate) fn parse_packet_filtering_records(
packet: &[u8],
decrypt: &mut dyn RecordHandler,
cs: Option<Dtls13CipherSuite>,
keep_record: impl FnMut(&Record) -> bool,
) -> Result<Option<Self>, InternalError> {
let records = Records::parse_filtering_records(packet, decrypt, cs, keep_record)?;

if records.records.is_empty() {
return Ok(None);
}

let records = Box::new(records);

Ok(Some(Incoming { records }))
}
}

/// A number of records parsed from a single UDP packet.
Expand All @@ -67,11 +92,21 @@ pub struct Records {

impl Records {
pub fn parse(
packet: &[u8],
decrypt: &mut dyn RecordHandler,
cs: Option<Dtls13CipherSuite>,
) -> Result<Records, InternalError> {
Self::parse_filtering_records(packet, decrypt, cs, |_| true)
}

fn parse_filtering_records(
mut packet: &[u8],
decrypt: &mut dyn RecordHandler,
cs: Option<Dtls13CipherSuite>,
mut keep_record: impl FnMut(&Record) -> bool,
) -> Result<Records, InternalError> {
let mut parsed_records: ArrayVec<Record, 16> = ArrayVec::new();
let mut parsed_record_count = 0usize;

// Find record boundaries and copy each record ONCE from the packet
while !packet.is_empty() {
Expand Down Expand Up @@ -132,9 +167,18 @@ impl Records {
match Record::parse(record_slice, decrypt, cs) {
Ok(record) => {
if let Some(record) = record {
if parsed_records.try_push(record).is_err() {
if parsed_record_count >= parsed_records.capacity() {
return Err(InternalError::too_many_records());
}
parsed_record_count += 1;

if !keep_record(&record) {
trace!("Discarding filtered rec");
} else {
parsed_records
.try_push(record)
.expect("parsed record count is capacity-checked");
}
} else {
trace!("Discarding replayed rec");
}
Expand Down
Loading
Loading