Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/dtls12/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,15 @@ impl Client {
}

pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> {
self.engine.parse_packet(packet)?;
self.make_progress()?;
Ok(())
match self
.engine
.parse_packet(packet)
.and_then(|_| self.make_progress())
{
Ok(()) => Ok(()),
Err(e) if e.is_transient() => Ok(()),
Err(e) => Err(e),
}
}

pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> {
Expand Down
12 changes: 9 additions & 3 deletions src/dtls12/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,15 @@ impl Server {
}

pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> {
self.engine.parse_packet(packet)?;
self.make_progress()?;
Ok(())
match self
.engine
.parse_packet(packet)
.and_then(|_| self.make_progress())
{
Ok(()) => Ok(()),
Err(e) if e.is_transient() => Ok(()),
Err(e) => Err(e),
}
}

pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> {
Expand Down
12 changes: 9 additions & 3 deletions src/dtls13/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,15 @@ impl Client {
}

pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> {
self.engine.parse_packet(packet)?;
self.make_progress()?;
Ok(())
match self
.engine
.parse_packet(packet)
.and_then(|_| self.make_progress())
{
Ok(()) => Ok(()),
Err(e) if e.is_transient() => Ok(()),
Err(e) => Err(e),
}
}

pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> {
Expand Down
11 changes: 9 additions & 2 deletions src/dtls13/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,15 @@ impl Server {
self.retained_hello.push_back(packet.to_buf());
}

self.engine.parse_packet(packet)?;
self.make_progress()?;
match self
.engine
.parse_packet(packet)
.and_then(|_| self.make_progress())
{
Ok(()) => {}
Err(e) if e.is_transient() => return Ok(()),
Err(e) => return Err(e),
}

// Once past AwaitClientHello, DTLS 1.3 is committed — free the buffer.
if self.auto_mode && self.state != State::AwaitClientHello {
Expand Down
30 changes: 21 additions & 9 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
#[non_exhaustive]
/// Errors returned by DTLS processing functions.
pub enum Error {
/// Parser requested more data
ParseIncomplete,
/// Parser encountered an error kind from nom
ParseError(nom::error::ErrorKind),
/// Unexpected DTLS message
UnexpectedMessage(String),
/// Local state was missing data required for the requested operation
Expand All @@ -30,8 +26,6 @@ pub enum Error {
Timeout(&'static str),
/// Configuration error (e.g., invalid crypto provider)
ConfigError(String),
/// Too many records in a single packet
TooManyRecords,
/// Peer attempted renegotiation (not supported)
RenegotiationAttempt,
/// Application data cannot be sent because the handshake is not yet complete.
Expand All @@ -54,6 +48,24 @@ pub enum Error {
/// value to communicate from dtls13/server.rs to lib.rs
#[doc(hidden)]
Dtls12Fallback,
/// Parser requested more data
#[doc(hidden)]
ParseIncomplete,
/// Parser encountered an error kind from nom
#[doc(hidden)]
ParseError(nom::error::ErrorKind),
/// Too many records in a single packet
#[doc(hidden)]
TooManyRecords,
}

impl Error {
pub(crate) fn is_transient(&self) -> bool {
matches!(
self,
Error::ParseIncomplete | Error::ParseError(_) | Error::TooManyRecords
)
}
}

impl<'a> From<nom::Err<nom::error::Error<&'a [u8]>>> for Error {
Expand All @@ -71,8 +83,6 @@ impl std::error::Error for Error {}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::ParseIncomplete => write!(f, "parse incomplete"),
Error::ParseError(kind) => write!(f, "parse error: {:?}", kind),
Error::UnexpectedMessage(msg) => write!(f, "unexpected message: {}", msg),
Error::InvalidState(msg) => write!(f, "invalid state: {}", msg),
Error::CryptoError(msg) => write!(f, "crypto error: {}", msg),
Expand All @@ -84,7 +94,6 @@ impl std::fmt::Display for Error {
Error::IncompleteServerHello => write!(f, "incomplete ServerHello"),
Error::Timeout(what) => write!(f, "timeout: {}", what),
Error::ConfigError(msg) => write!(f, "config error: {}", msg),
Error::TooManyRecords => write!(f, "too many records in packet"),
Error::RenegotiationAttempt => write!(f, "peer attempted renegotiation"),
Error::HandshakePending => {
write!(f, "handshake pending: cannot send application data yet")
Expand All @@ -94,6 +103,9 @@ impl std::fmt::Display for Error {
Error::Dtls12Fallback => {
write!(f, "dtls 1.2 fallback (internal)")
}
Error::ParseIncomplete => write!(f, "parse incomplete"),
Error::ParseError(kind) => write!(f, "parse error: {:?}", kind),
Error::TooManyRecords => write!(f, "too many records in packet"),
}
}
}
30 changes: 9 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ mod test {
}

#[test]
fn auto_server_rejects_ch_shaped_malformed_packet_without_fallback() {
fn auto_server_discards_ch_shaped_malformed_packet_without_fallback() {
let body = ch_shaped_malformed_body();
let len = body.len() as u32;
let hs = make_handshake(0x01, len, 0, len, &body);
Expand All @@ -1292,14 +1292,8 @@ mod test {
);

let mut dtls = new_instance_auto();
let err = dtls
.handle_packet(&pkt)
.expect_err("malformed ClientHello-shaped packet must not force fallback");

assert!(
matches!(err, Error::ParseIncomplete | Error::ParseError(_)),
"expected parser error, got {err:?}"
);
dtls.handle_packet(&pkt)
.expect("malformed ClientHello-shaped packet should be discarded");
assert!(matches!(
dtls.inner,
Some(Inner::Server13(ref server)) if server.is_auto_mode()
Expand Down Expand Up @@ -1370,7 +1364,7 @@ mod test {
}

#[test]
fn dtls12_server_rejects_oversized_ec_point_formats_extension() {
fn dtls12_server_discards_oversized_ec_point_formats_extension() {
let mut extensions = Vec::new();
extensions.extend_from_slice(&[0x00, 0x0B]); // ec_point_formats
extensions.extend_from_slice(&[0x00, 0x05]); // extension body length
Expand All @@ -1390,15 +1384,12 @@ mod test {
let pkt = make_record(0x16, &hs);

let mut dtls = new_instance_12_no_cookie();
let err = dtls.handle_packet(&pkt).unwrap_err();
assert!(matches!(
err,
Error::ParseError(nom::error::ErrorKind::LengthValue)
));
dtls.handle_packet(&pkt)
.expect("malformed extension should be discarded");
}

#[test]
fn dtls12_server_rejects_trailing_ec_point_formats_extension() {
fn dtls12_server_discards_trailing_ec_point_formats_extension() {
let mut extensions = Vec::new();
extensions.extend_from_slice(&[0x00, 0x0B]); // ec_point_formats
extensions.extend_from_slice(&[0x00, 0x03]); // extension body length
Expand All @@ -1416,11 +1407,8 @@ mod test {
let pkt = make_record(0x16, &hs);

let mut dtls = new_instance_12_no_cookie();
let err = dtls.handle_packet(&pkt).unwrap_err();
assert!(matches!(
err,
Error::ParseError(nom::error::ErrorKind::LengthValue)
));
dtls.handle_packet(&pkt)
.expect("malformed extension should be discarded");
}

#[test]
Expand Down
23 changes: 8 additions & 15 deletions tests/dtls12/edge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize {

#[test]
#[cfg(feature = "rcgen")]
fn dtls12_malformed_datagram_does_not_process_alerts_before_parse_completes() {
fn dtls12_malformed_datagram_is_discarded_without_processing_alerts() {
let _ = env_logger::try_init();

let server_cert = generate_self_signed_certificate().expect("gen server cert");
Expand All @@ -77,19 +77,17 @@ fn dtls12_malformed_datagram_does_not_process_alerts_before_parse_completes() {
let mut packet = dtls12_alert_record(1, 2, 40);
packet.push(0xFF); // trailing truncated record header

let err = server
server
.handle_packet(&packet)
.expect_err("malformed datagram should fail atomically");
.expect("malformed datagram should be discarded");

assert!(
matches!(err, dimpl::Error::ParseIncomplete),
"expected ParseIncomplete, got {err:?}"
);
let mut buf = [0; 1500];
assert!(!matches!(server.poll_output(&mut buf), Output::CloseNotify));
}

#[test]
#[cfg(feature = "rcgen")]
fn dtls12_too_many_control_records_still_fail_before_filtering() {
fn dtls12_too_many_control_records_are_discarded() {
let _ = env_logger::try_init();

let server_cert = generate_self_signed_certificate().expect("gen server cert");
Expand All @@ -104,14 +102,9 @@ fn dtls12_too_many_control_records_still_fail_before_filtering() {
packet.extend_from_slice(&dtls12_alert_record(seq, 1, 0));
}

let err = server
server
.handle_packet(&packet)
.expect_err("control-only datagram should still trip TooManyRecords");

assert!(
matches!(err, dimpl::Error::TooManyRecords),
"expected TooManyRecords, got {err:?}"
);
.expect("too many records should be discarded");
}

#[test]
Expand Down
23 changes: 8 additions & 15 deletions tests/dtls13/edge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn dtls13_ack_record_for_records(seq: u64, records: &[(u64, u64)]) -> Vec<u8> {

#[test]
#[cfg(feature = "rcgen")]
fn dtls13_malformed_datagram_does_not_process_alerts_before_parse_completes() {
fn dtls13_malformed_datagram_is_discarded_without_processing_alerts() {
let _ = env_logger::try_init();

let server_cert = generate_self_signed_certificate().expect("gen server cert");
Expand All @@ -65,19 +65,17 @@ fn dtls13_malformed_datagram_does_not_process_alerts_before_parse_completes() {
let mut packet = dtls13_alert_record(1, 2, 40);
packet.push(0xFF); // trailing truncated record header

let err = server
server
.handle_packet(&packet)
.expect_err("malformed datagram should fail atomically");
.expect("malformed datagram should be discarded");

assert!(
matches!(err, dimpl::Error::ParseIncomplete),
"expected ParseIncomplete, got {err:?}"
);
let mut buf = [0; 1500];
assert!(!matches!(server.poll_output(&mut buf), Output::CloseNotify));
}

#[test]
#[cfg(feature = "rcgen")]
fn dtls13_too_many_control_records_still_fail_before_filtering() {
fn dtls13_too_many_control_records_are_discarded() {
let _ = env_logger::try_init();

let server_cert = generate_self_signed_certificate().expect("gen server cert");
Expand All @@ -92,14 +90,9 @@ fn dtls13_too_many_control_records_still_fail_before_filtering() {
packet.extend_from_slice(&dtls13_ack_record(seq));
}

let err = server
server
.handle_packet(&packet)
.expect_err("control-only datagram should still trip TooManyRecords");

assert!(
matches!(err, dimpl::Error::TooManyRecords),
"expected TooManyRecords, got {err:?}"
);
.expect("too many records should be discarded");
}

#[test]
Expand Down
16 changes: 4 additions & 12 deletions tests/dtls13_cookie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,9 @@ fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() {
"fixture should contain a Cookie extension"
);

let err = client
client
.handle_packet(&hrr)
.expect_err("malformed HRR Cookie extension must be rejected");
assert!(matches!(
err,
dimpl::Error::ParseError(nom::error::ErrorKind::LengthValue)
));
.expect("malformed HRR Cookie extension should be discarded");

client
.handle_timeout(now)
Expand Down Expand Up @@ -192,11 +188,7 @@ fn dtls13_server_rejects_clienthello_cookie_extension_trailing_bytes() {
"fixture should contain a Cookie extension"
);

let err = server
server
.handle_packet(&ch2)
.expect_err("malformed ClientHello Cookie extension must be rejected");
assert!(matches!(
err,
dimpl::Error::ParseError(nom::error::ErrorKind::LengthValue)
));
.expect("malformed ClientHello Cookie extension should be discarded");
}
Loading