Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Unreleased

* Discard DTLS handshake records with malformed same-record tails #139
* Represent DTLS wire-code identifiers as compact newtypes (breaking) #137
* Make public errors structured and fatal-only (breaking) #134

Expand Down
11 changes: 10 additions & 1 deletion src/dtls12/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,16 @@ impl Engine {

for record in unhandled {
let buf = record.into_buffer();
self.parse_packet(&buf)?;
match self.parse_packet(&buf) {
Ok(()) => {}
Err(InternalError::Transient(err)) => {
trace!("Discarding buffered protected record after reparse failed: {err}");
}
Err(err) => {
self.buffers_free.push(buf);
return Err(err);
}
}
self.buffers_free.push(buf);
}
}
Expand Down
144 changes: 113 additions & 31 deletions src/dtls12/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl Record {
// ONLY COPY: UDP packet slice -> pooled buffer
let mut buffer = Buf::new();
buffer.extend_from_slice(record_slice);
let parsed = match ParsedRecord::parse(&buffer, cs, 0) {
let parsed = match ParsedRecord::parse(&buffer, cs, 0, true) {
Ok(p) => p,
Err(e) => {
// RFC 6347 §4.1.2.7: Invalid records SHOULD be silently discarded.
Expand Down Expand Up @@ -210,22 +210,23 @@ impl Record {
buffer.len()
};

// Decryption succeeded — now commit the replay window update.
// RFC 6347 §4.1.2.6: "The receive window is updated only if the
// MAC verification succeeds."
decrypt.replay_update(sequence);

// The record is now authenticated. Tell the handler so it can act on a
// confirmed-genuine record (e.g. mark the peer past its handshake).
decrypt.note_decrypted_record(content_type);

// Update the length of the record.
buffer[11] = (new_len >> 8) as u8;
buffer[12] = new_len as u8;

let parsed = ParsedRecord::parse(&buffer, cs, explicit_nonce_len)?;
let parsed = ParsedRecord::parse(&buffer, cs, explicit_nonce_len, false)?;
let parsed = Box::new(parsed);

// Decryption and parsing both succeeded. Commit replay state only once
// the record is publishable, so a malformed protected handshake tail
// cannot consume the retransmission slot for a clean record.
decrypt.replay_update(sequence);

// The record is now authenticated and accepted. Tell the handler so it
// can act on a confirmed-genuine record (e.g. mark the peer past its
// handshake).
decrypt.note_decrypted_record(content_type);

Ok(Some(Record { buffer, parsed }))
}

Expand Down Expand Up @@ -276,14 +277,24 @@ impl ParsedRecord {
input: &[u8],
cipher_suite: Option<Dtls12CipherSuite>,
offset: usize,
defer_protected_handshake_parse: bool,
) -> Result<ParsedRecord, InternalError> {
let (_, record) = DTLSRecord::parse(input, 0, offset)?;

let handshakes = if record.content_type == ContentType::Handshake {
if record.sequence.epoch != 0 && defer_protected_handshake_parse {
trace!("Deferring protected handshake parsing until after decryption");
return Ok(ParsedRecord {
record,
handshakes: ArrayVec::new(),
handled: AtomicBool::new(false),
});
}

// This will also return None on the encrypted Finished after ChangeCipherSpec.
// However we will then decrypt and try again.
let fragment_offset = record.fragment_range.start;
parse_handshakes(record.fragment(input), fragment_offset, cipher_suite)
parse_handshakes(record.fragment(input), fragment_offset, cipher_suite)?
} else {
ArrayVec::new()
};
Expand Down Expand Up @@ -330,22 +341,18 @@ fn parse_handshakes(
mut input: &[u8],
mut base_offset: usize,
cipher_suite: Option<Dtls12CipherSuite>,
) -> ArrayVec<Handshake, 8> {
) -> Result<ArrayVec<Handshake, 8>, InternalError> {
let mut handshakes = ArrayVec::new();
while !input.is_empty() {
if let Ok((remaining, handshake)) = Handshake::parse(input, base_offset, cipher_suite, true)
{
let len = input.len() - remaining.len();
base_offset += len;
input = remaining;
if handshakes.try_push(handshake).is_err() {
break;
}
} else {
break;
let (remaining, handshake) = Handshake::parse(input, base_offset, cipher_suite, true)?;
let len = input.len() - remaining.len();
base_offset += len;
input = remaining;
if handshakes.try_push(handshake).is_err() {
return Err(InternalError::too_many_records());
}
}
handshakes
Ok(handshakes)
}

impl fmt::Debug for Incoming {
Expand Down Expand Up @@ -399,11 +406,16 @@ impl std::panic::UnwindSafe for Incoming {}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtls12::message::MessageType;

#[derive(Default)]
struct TestHandler {
classify_calls: usize,
dropped_alerts: usize,
peer_encryption_enabled: bool,
explicit_nonce_len: usize,
min_protected_fragment_len: usize,
replay_updates: usize,
}

impl RecordHandler for TestHandler {
Expand All @@ -417,27 +429,29 @@ mod tests {
}

fn is_peer_encryption_enabled(&self) -> bool {
false
self.peer_encryption_enabled
}

fn replay_check(&self, _seq: Sequence) -> bool {
panic!("replay_check should not be called for plaintext tests");
assert!(self.peer_encryption_enabled);
true
}

fn replay_update(&mut self, _seq: Sequence) {
panic!("replay_update should not be called for plaintext tests");
self.replay_updates += 1;
}

fn decryption_aad_and_nonce(&self, _dtls: &DTLSRecord, _buf: &[u8]) -> (Aad, Nonce) {
panic!("decryption_aad_and_nonce should not be called for plaintext tests");
assert!(self.peer_encryption_enabled);
(Aad(ArrayVec::new()), Nonce([0; 12]))
}

fn explicit_nonce_len(&self) -> usize {
panic!("explicit_nonce_len should not be called for plaintext tests");
self.explicit_nonce_len
}

fn min_protected_fragment_len(&self) -> usize {
panic!("min_protected_fragment_len should not be called for plaintext tests");
self.min_protected_fragment_len
}

fn decrypt_data(
Expand All @@ -446,7 +460,8 @@ mod tests {
_aad: Aad,
_nonce: Nonce,
) -> Result<(), Error> {
panic!("decrypt_data should not be called for plaintext tests");
assert!(self.peer_encryption_enabled);
Ok(())
}
}

Expand All @@ -461,6 +476,18 @@ mod tests {
out
}

fn handshake_fragment(msg_type: MessageType, message_seq: u16, fragment: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
let len = fragment.len() as u32;
out.push(msg_type.as_u8());
out.extend_from_slice(&len.to_be_bytes()[1..]);
out.extend_from_slice(&message_seq.to_be_bytes());
out.extend_from_slice(&0u32.to_be_bytes()[1..]);
out.extend_from_slice(&len.to_be_bytes()[1..]);
out.extend_from_slice(fragment);
out
}

#[test]
fn parse_packet_filters_control_records_after_packet_validation() {
let mut packet = Vec::new();
Expand All @@ -486,4 +513,59 @@ mod tests {
);
assert_eq!(incoming.first().record().sequence.epoch, 1);
}

#[test]
fn parse_record_accepts_multiple_handshakes() {
let mut fragment = Vec::new();
fragment.extend_from_slice(&handshake_fragment(MessageType::HelloRequest, 0, &[]));
fragment.extend_from_slice(&handshake_fragment(MessageType::ServerHelloDone, 1, &[]));

let packet = build_record(ContentType::Handshake, 0, 1, &fragment);
let mut handler = TestHandler::default();
let incoming = Incoming::parse_packet(&packet, &mut handler, None)
.unwrap()
.expect("handshake record should remain");

assert_eq!(incoming.first().handshakes().len(), 2);
}

#[test]
fn pre_decrypt_protected_handshake_parsing_is_deferred() {
let fragment = handshake_fragment(MessageType::HelloRequest, 0, &[]);
let packet = build_record(ContentType::Handshake, 1, 1, &fragment);
let mut handler = TestHandler::default();

let incoming = Incoming::parse_packet(&packet, &mut handler, None)
.unwrap()
.expect("protected record should queue until peer encryption is enabled");

assert!(
incoming.first().handshakes().is_empty(),
"ciphertext bytes must not be parsed as plaintext handshakes"
);
assert_eq!(handler.replay_updates, 0);
}

#[test]
fn post_decrypt_malformed_handshake_tail_is_not_deferred() {
let mut fragment = handshake_fragment(MessageType::HelloRequest, 0, &[]);
fragment.push(0xff);
let packet = build_record(ContentType::Handshake, 1, 1, &fragment);
let mut handler = TestHandler {
peer_encryption_enabled: true,
explicit_nonce_len: 0,
min_protected_fragment_len: 0,
..Default::default()
};

let err = Incoming::parse_packet(
&packet,
&mut handler,
Some(Dtls12CipherSuite::ECDHE_ECDSA_CHACHA20_POLY1305_SHA256),
)
.expect_err("post-decrypt malformed handshake tail must be rejected");

assert!(matches!(err, InternalError::Transient(_)));
assert_eq!(handler.replay_updates, 0);
}
}
80 changes: 79 additions & 1 deletion src/dtls13/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2118,7 +2118,16 @@ impl Engine {

for record in unhandled {
let buf = record.into_buffer();
self.parse_packet(&buf)?;
match self.parse_packet(&buf) {
Ok(()) => {}
Err(InternalError::Transient(err)) => {
trace!("Discarding buffered protected record after reparse failed: {err}");
}
Err(err) => {
self.buffers_free.push(buf);
return Err(err);
}
}
self.buffers_free.push(buf);
}
}
Expand Down Expand Up @@ -2511,6 +2520,29 @@ mod tests {

struct PassthroughRecordHandler;

#[derive(Debug)]
struct PassthroughCipher;

impl Cipher for PassthroughCipher {
fn encrypt(
&mut self,
_plaintext: &mut Buf,
_aad: Aad,
_nonce: Nonce,
) -> Result<(), crate::CryptoError> {
Ok(())
}

fn decrypt(
&mut self,
_ciphertext: &mut TmpBuf,
_aad: Aad,
_nonce: Nonce,
) -> Result<(), crate::CryptoError> {
Ok(())
}
}

impl RecordHandler for PassthroughRecordHandler {
fn classify_record(&mut self, record: Record) -> Result<Option<Record>, Error> {
Ok(Some(record))
Expand Down Expand Up @@ -2579,6 +2611,25 @@ mod tests {
packet
}

fn encrypted_malformed_handshake_tail_record(seq: u16) -> Vec<u8> {
let mut fragment = Vec::new();
fragment.push(0xff);
fragment.push(ContentType::Handshake.as_u8());
fragment.resize(17, 0);

let mut packet = Vec::new();
packet.push(
0b0010_0000
| 0b0000_1000 // 2-byte sequence number.
| 0b0000_0100 // explicit length.
| 0b0000_0010, // epoch bits.
);
packet.extend_from_slice(&seq.to_be_bytes());
packet.extend_from_slice(&(fragment.len() as u16).to_be_bytes());
packet.extend_from_slice(&fragment);
packet
}

fn parsed_key_update(seq: u16) -> Incoming {
Incoming::parse_packet(
&encrypted_key_update_record(seq),
Expand Down Expand Up @@ -2731,6 +2782,33 @@ mod tests {
);
}

#[test]
#[cfg(feature = "rcgen")]
fn enable_peer_encryption_discards_malformed_buffered_protected_handshake() {
let mut engine = test_engine();
engine.set_cipher_suite(Dtls13CipherSuite::AES_128_GCM_SHA256);
let mut sn_key = Buf::new();
sn_key.extend_from_slice(&[0; 16]);
engine.hs_recv_keys = Some(EpochKeys {
cipher: Box::new(PassthroughCipher),
iv: [0; 12],
traffic_secret: Buf::new(),
sn_key,
});

engine
.parse_packet(&encrypted_malformed_handshake_tail_record(0))
.expect("pre-encryption protected record should queue");
assert_eq!(engine.queue_rx.len(), 1);

engine
.enable_peer_encryption()
.expect("malformed queued protected record should be discarded");

assert!(engine.peer_encryption_enabled);
assert!(engine.queue_rx.is_empty());
}

#[test]
#[cfg(feature = "rcgen")]
fn malformed_ack_record_number_vector_is_ignored() {
Expand Down
Loading
Loading