diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 6f4d99c2..a39a3bff 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -187,9 +187,15 @@ impl Client { } pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { - self.engine.parse_packet(packet)?; - self.make_progress()?; - Ok(()) + match self + .engine + .parse_packet(packet) + .and_then(|_| self.make_progress()) + { + Ok(()) => Ok(()), + Err(e) if e.is_transient() => Ok(()), + Err(e) => Err(e), + } } pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> { diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 52d9e91c..449d2e06 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -189,9 +189,15 @@ impl Server { } pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { - self.engine.parse_packet(packet)?; - self.make_progress()?; - Ok(()) + match self + .engine + .parse_packet(packet) + .and_then(|_| self.make_progress()) + { + Ok(()) => Ok(()), + Err(e) if e.is_transient() => Ok(()), + Err(e) => Err(e), + } } pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> { diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index 4e93c8d7..7c87632d 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -219,9 +219,15 @@ impl Client { } pub fn handle_packet(&mut self, packet: &[u8]) -> Result<(), Error> { - self.engine.parse_packet(packet)?; - self.make_progress()?; - Ok(()) + match self + .engine + .parse_packet(packet) + .and_then(|_| self.make_progress()) + { + Ok(()) => Ok(()), + Err(e) if e.is_transient() => Ok(()), + Err(e) => Err(e), + } } pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> { diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 58689c87..1fb7f964 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -248,8 +248,15 @@ impl Server { self.retained_hello.push_back(packet.to_buf()); } - self.engine.parse_packet(packet)?; - self.make_progress()?; + match self + .engine + .parse_packet(packet) + .and_then(|_| self.make_progress()) + { + Ok(()) => {} + Err(e) if e.is_transient() => return Ok(()), + Err(e) => return Err(e), + } // Once past AwaitClientHello, DTLS 1.3 is committed — free the buffer. if self.auto_mode && self.state != State::AwaitClientHello { diff --git a/src/error.rs b/src/error.rs index 35a0582f..e9146263 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,10 +4,6 @@ #[non_exhaustive] /// Errors returned by DTLS processing functions. pub enum Error { - /// Parser requested more data - ParseIncomplete, - /// Parser encountered an error kind from nom - ParseError(nom::error::ErrorKind), /// Unexpected DTLS message UnexpectedMessage(String), /// Local state was missing data required for the requested operation @@ -30,8 +26,6 @@ pub enum Error { Timeout(&'static str), /// Configuration error (e.g., invalid crypto provider) ConfigError(String), - /// Too many records in a single packet - TooManyRecords, /// Peer attempted renegotiation (not supported) RenegotiationAttempt, /// Application data cannot be sent because the handshake is not yet complete. @@ -54,6 +48,24 @@ pub enum Error { /// value to communicate from dtls13/server.rs to lib.rs #[doc(hidden)] Dtls12Fallback, + /// Parser requested more data + #[doc(hidden)] + ParseIncomplete, + /// Parser encountered an error kind from nom + #[doc(hidden)] + ParseError(nom::error::ErrorKind), + /// Too many records in a single packet + #[doc(hidden)] + TooManyRecords, +} + +impl Error { + pub(crate) fn is_transient(&self) -> bool { + matches!( + self, + Error::ParseIncomplete | Error::ParseError(_) | Error::TooManyRecords + ) + } } impl<'a> From>> for Error { @@ -71,8 +83,6 @@ impl std::error::Error for Error {} impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::ParseIncomplete => write!(f, "parse incomplete"), - Error::ParseError(kind) => write!(f, "parse error: {:?}", kind), Error::UnexpectedMessage(msg) => write!(f, "unexpected message: {}", msg), Error::InvalidState(msg) => write!(f, "invalid state: {}", msg), Error::CryptoError(msg) => write!(f, "crypto error: {}", msg), @@ -84,7 +94,6 @@ impl std::fmt::Display for Error { Error::IncompleteServerHello => write!(f, "incomplete ServerHello"), Error::Timeout(what) => write!(f, "timeout: {}", what), Error::ConfigError(msg) => write!(f, "config error: {}", msg), - Error::TooManyRecords => write!(f, "too many records in packet"), Error::RenegotiationAttempt => write!(f, "peer attempted renegotiation"), Error::HandshakePending => { write!(f, "handshake pending: cannot send application data yet") @@ -94,6 +103,9 @@ impl std::fmt::Display for Error { Error::Dtls12Fallback => { write!(f, "dtls 1.2 fallback (internal)") } + Error::ParseIncomplete => write!(f, "parse incomplete"), + Error::ParseError(kind) => write!(f, "parse error: {:?}", kind), + Error::TooManyRecords => write!(f, "too many records in packet"), } } } diff --git a/src/lib.rs b/src/lib.rs index 60540976..5c70f8a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1281,7 +1281,7 @@ mod test { } #[test] - fn auto_server_rejects_ch_shaped_malformed_packet_without_fallback() { + fn auto_server_discards_ch_shaped_malformed_packet_without_fallback() { let body = ch_shaped_malformed_body(); let len = body.len() as u32; let hs = make_handshake(0x01, len, 0, len, &body); @@ -1292,14 +1292,8 @@ mod test { ); let mut dtls = new_instance_auto(); - let err = dtls - .handle_packet(&pkt) - .expect_err("malformed ClientHello-shaped packet must not force fallback"); - - assert!( - matches!(err, Error::ParseIncomplete | Error::ParseError(_)), - "expected parser error, got {err:?}" - ); + dtls.handle_packet(&pkt) + .expect("malformed ClientHello-shaped packet should be discarded"); assert!(matches!( dtls.inner, Some(Inner::Server13(ref server)) if server.is_auto_mode() @@ -1370,7 +1364,7 @@ mod test { } #[test] - fn dtls12_server_rejects_oversized_ec_point_formats_extension() { + fn dtls12_server_discards_oversized_ec_point_formats_extension() { let mut extensions = Vec::new(); extensions.extend_from_slice(&[0x00, 0x0B]); // ec_point_formats extensions.extend_from_slice(&[0x00, 0x05]); // extension body length @@ -1390,15 +1384,12 @@ mod test { let pkt = make_record(0x16, &hs); let mut dtls = new_instance_12_no_cookie(); - let err = dtls.handle_packet(&pkt).unwrap_err(); - assert!(matches!( - err, - Error::ParseError(nom::error::ErrorKind::LengthValue) - )); + dtls.handle_packet(&pkt) + .expect("malformed extension should be discarded"); } #[test] - fn dtls12_server_rejects_trailing_ec_point_formats_extension() { + fn dtls12_server_discards_trailing_ec_point_formats_extension() { let mut extensions = Vec::new(); extensions.extend_from_slice(&[0x00, 0x0B]); // ec_point_formats extensions.extend_from_slice(&[0x00, 0x03]); // extension body length @@ -1416,11 +1407,8 @@ mod test { let pkt = make_record(0x16, &hs); let mut dtls = new_instance_12_no_cookie(); - let err = dtls.handle_packet(&pkt).unwrap_err(); - assert!(matches!( - err, - Error::ParseError(nom::error::ErrorKind::LengthValue) - )); + dtls.handle_packet(&pkt) + .expect("malformed extension should be discarded"); } #[test] diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 7640007c..97bcecf3 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -64,7 +64,7 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize { #[test] #[cfg(feature = "rcgen")] -fn dtls12_malformed_datagram_does_not_process_alerts_before_parse_completes() { +fn dtls12_malformed_datagram_is_discarded_without_processing_alerts() { let _ = env_logger::try_init(); let server_cert = generate_self_signed_certificate().expect("gen server cert"); @@ -77,19 +77,17 @@ fn dtls12_malformed_datagram_does_not_process_alerts_before_parse_completes() { let mut packet = dtls12_alert_record(1, 2, 40); packet.push(0xFF); // trailing truncated record header - let err = server + server .handle_packet(&packet) - .expect_err("malformed datagram should fail atomically"); + .expect("malformed datagram should be discarded"); - assert!( - matches!(err, dimpl::Error::ParseIncomplete), - "expected ParseIncomplete, got {err:?}" - ); + let mut buf = [0; 1500]; + assert!(!matches!(server.poll_output(&mut buf), Output::CloseNotify)); } #[test] #[cfg(feature = "rcgen")] -fn dtls12_too_many_control_records_still_fail_before_filtering() { +fn dtls12_too_many_control_records_are_discarded() { let _ = env_logger::try_init(); let server_cert = generate_self_signed_certificate().expect("gen server cert"); @@ -104,14 +102,9 @@ fn dtls12_too_many_control_records_still_fail_before_filtering() { packet.extend_from_slice(&dtls12_alert_record(seq, 1, 0)); } - let err = server + server .handle_packet(&packet) - .expect_err("control-only datagram should still trip TooManyRecords"); - - assert!( - matches!(err, dimpl::Error::TooManyRecords), - "expected TooManyRecords, got {err:?}" - ); + .expect("too many records should be discarded"); } #[test] diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 9e0664d0..c716964d 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -52,7 +52,7 @@ fn dtls13_ack_record_for_records(seq: u64, records: &[(u64, u64)]) -> Vec { #[test] #[cfg(feature = "rcgen")] -fn dtls13_malformed_datagram_does_not_process_alerts_before_parse_completes() { +fn dtls13_malformed_datagram_is_discarded_without_processing_alerts() { let _ = env_logger::try_init(); let server_cert = generate_self_signed_certificate().expect("gen server cert"); @@ -65,19 +65,17 @@ fn dtls13_malformed_datagram_does_not_process_alerts_before_parse_completes() { let mut packet = dtls13_alert_record(1, 2, 40); packet.push(0xFF); // trailing truncated record header - let err = server + server .handle_packet(&packet) - .expect_err("malformed datagram should fail atomically"); + .expect("malformed datagram should be discarded"); - assert!( - matches!(err, dimpl::Error::ParseIncomplete), - "expected ParseIncomplete, got {err:?}" - ); + let mut buf = [0; 1500]; + assert!(!matches!(server.poll_output(&mut buf), Output::CloseNotify)); } #[test] #[cfg(feature = "rcgen")] -fn dtls13_too_many_control_records_still_fail_before_filtering() { +fn dtls13_too_many_control_records_are_discarded() { let _ = env_logger::try_init(); let server_cert = generate_self_signed_certificate().expect("gen server cert"); @@ -92,14 +90,9 @@ fn dtls13_too_many_control_records_still_fail_before_filtering() { packet.extend_from_slice(&dtls13_ack_record(seq)); } - let err = server + server .handle_packet(&packet) - .expect_err("control-only datagram should still trip TooManyRecords"); - - assert!( - matches!(err, dimpl::Error::TooManyRecords), - "expected TooManyRecords, got {err:?}" - ); + .expect("too many records should be discarded"); } #[test] diff --git a/tests/dtls13_cookie.rs b/tests/dtls13_cookie.rs index 65e44633..ac1391f1 100644 --- a/tests/dtls13_cookie.rs +++ b/tests/dtls13_cookie.rs @@ -128,13 +128,9 @@ fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { "fixture should contain a Cookie extension" ); - let err = client + client .handle_packet(&hrr) - .expect_err("malformed HRR Cookie extension must be rejected"); - assert!(matches!( - err, - dimpl::Error::ParseError(nom::error::ErrorKind::LengthValue) - )); + .expect("malformed HRR Cookie extension should be discarded"); client .handle_timeout(now) @@ -192,11 +188,7 @@ fn dtls13_server_rejects_clienthello_cookie_extension_trailing_bytes() { "fixture should contain a Cookie extension" ); - let err = server + server .handle_packet(&ch2) - .expect_err("malformed ClientHello Cookie extension must be rejected"); - assert!(matches!( - err, - dimpl::Error::ParseError(nom::error::ErrorKind::LengthValue) - )); + .expect("malformed ClientHello Cookie extension should be discarded"); }