diff --git a/CHANGELOG.md b/CHANGELOG.md index e3bdafd..d91af1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Defer oversized DTLS poll outputs instead of panicking #138 * Represent DTLS wire-code identifiers as compact newtypes (breaking) #137 * Make public errors structured and fatal-only (breaking) #134 diff --git a/README.md b/README.md index 88c16c6..94bdbc7 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ The output is an [`Output`][output] enum with borrowed references into your provided buffer: - `Packet(&[u8])`: send on your UDP socket - `Timeout(Instant)`: schedule a timer and call `handle_timeout` at/after it +- `BufferTooSmall { needed }`: grow the caller-owned poll buffer and retry - `Connected`: handshake complete - `PeerCert(&[u8])`: peer leaf certificate (DER) — validate in your app - `KeyingMaterial(KeyingMaterial, SrtpProfile)`: DTLS‑SRTP export @@ -114,6 +115,9 @@ fn example_event_loop(mut dtls: Dtls) -> Result<(), dimpl::Error> { match dtls.poll_output(&mut out_buf) { Output::Packet(p) => send_udp(p), Output::Timeout(t) => { next_wake = Some(t); break; } + Output::BufferTooSmall { needed } => { + out_buf.resize(needed, 0); + } Output::Connected => { // DTLS established — application may start sending } diff --git a/src/auto.rs b/src/auto.rs index 410b4bd..bc93180 100644 --- a/src/auto.rs +++ b/src/auto.rs @@ -320,10 +320,7 @@ impl ClientPending { if buf.len() < len { // Buffer too small; keep needs_send armed so the packet // is emitted on the next poll with a sufficiently large buffer. - let next = self - .retransmit_at - .unwrap_or(self.last_now + Duration::from_secs(1)); - return Output::Timeout(next); + return Output::BufferTooSmall { needed: len }; } self.needs_send = false; buf[..len].copy_from_slice(&self.wire_packet); diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 71ac7a7..eebfc93 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -199,7 +199,13 @@ impl Client { pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> { if let Some(event) = self.local_events.pop_front() { - return event.into_output(buf, &self.server_certificates); + return match event.into_output(buf, &self.server_certificates) { + Ok(output) => output, + Err((event, needed)) => { + self.local_events.push_front(event); + Output::BufferTooSmall { needed } + } + }; } self.engine.poll_output(buf, self.last_now) } @@ -1376,20 +1382,23 @@ fn handshake_create_certificate_verify(body: &mut Buf, engine: &mut Engine) -> R } impl LocalEvent { - pub fn into_output<'a>(self, buf: &'a mut [u8], peer_certs: &[Buf]) -> Output<'a> { + pub fn into_output<'a>( + self, + buf: &'a mut [u8], + peer_certs: &[Buf], + ) -> Result, (Self, usize)> { match self { LocalEvent::PeerCert => { let l = peer_certs[0].len(); - assert!( - l <= buf.len(), - "Output buffer too small for peer certificate" - ); + if l > buf.len() { + return Err((LocalEvent::PeerCert, l)); + } buf[..l].copy_from_slice(&peer_certs[0]); - Output::PeerCert(&buf[..l]) + Ok(Output::PeerCert(&buf[..l])) } - LocalEvent::Connected => Output::Connected, + LocalEvent::Connected => Ok(Output::Connected), LocalEvent::KeyingMaterial(m, profile) => { - Output::KeyingMaterial(KeyingMaterial::new(&m), profile) + Ok(Output::KeyingMaterial(KeyingMaterial::new(&m), profile)) } } } @@ -1466,4 +1475,31 @@ mod tests { assert!(matches!(err, Error::InvalidState(_))); } + + #[test] + fn peer_cert_output_is_deferred_when_buffer_is_too_small() { + let mut cert = Buf::new(); + cert.extend_from_slice(&[1, 2, 3, 4]); + let certs = [cert]; + + let mut out = [0u8; 2]; + let result = LocalEvent::PeerCert.into_output(&mut out, &certs); + + assert!(matches!(result, Err((LocalEvent::PeerCert, 4)))); + } + + #[test] + fn peer_cert_output_is_copied_when_buffer_fits() { + let mut cert = Buf::new(); + cert.extend_from_slice(&[1, 2, 3, 4]); + let certs = [cert]; + + let mut out = [0u8; 4]; + let result = LocalEvent::PeerCert.into_output(&mut out, &certs); + + match result { + Ok(Output::PeerCert(bytes)) => assert_eq!(bytes, &[1, 2, 3, 4]), + other => panic!("expected PeerCert output, got: {other:?}"), + } + } } diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 3bd86e9..5cb6d0b 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -17,6 +17,12 @@ use crate::{Config, Error, InternalError, Output, SeededRng}; const MAX_DEFRAGMENT_PACKETS: usize = 50; +enum PollBuffer<'a> { + Ready(&'a [u8]), + Empty(&'a mut [u8]), + TooSmall { needed: usize }, +} + // Using debug_ignore_primary since CryptoContext doesn't implement Debug pub struct Engine { config: Arc, @@ -416,12 +422,15 @@ impl Engine { // First check if we have any decrypted app data. let buf = match self.poll_app_data(buf) { - Ok(p) => return Output::ApplicationData(p), - Err(b) => b, + PollBuffer::Ready(p) => return Output::ApplicationData(p), + PollBuffer::Empty(b) => b, + PollBuffer::TooSmall { needed } => return Output::BufferTooSmall { needed }, }; - if let Ok(p) = self.poll_packet_tx(buf) { - return Output::Packet(p); + match self.poll_packet_tx(buf) { + PollBuffer::Ready(p) => return Output::Packet(p), + PollBuffer::Empty(_) => {} + PollBuffer::TooSmall { needed } => return Output::BufferTooSmall { needed }, } if self.close_notify_received && !self.close_notify_reported { @@ -434,9 +443,9 @@ impl Engine { Output::Timeout(next_timeout) } - fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { + fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> PollBuffer<'a> { if !self.release_app_data { - return Err(buf); + return PollBuffer::Empty(buf); } let mut unhandled = self @@ -447,24 +456,21 @@ impl Engine { .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { - return Err(buf); + return PollBuffer::Empty(buf); }; let record_buffer = next.buffer(); let fragment = next.record().fragment(record_buffer); let len = fragment.len(); - assert!( - len <= buf.len(), - "Output buffer too small for application data {} > {}", - len, - buf.len() - ); + if len > buf.len() { + return PollBuffer::TooSmall { needed: len }; + } buf[..len].copy_from_slice(fragment); next.set_handled(); - Ok(&buf[..len]) + PollBuffer::Ready(&buf[..len]) } fn purge_handled_queue_rx(&mut self) { @@ -482,22 +488,20 @@ impl Engine { } } - fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { - let Some(p) = self.queue_tx.pop_front() else { - return Err(buf); + fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> PollBuffer<'a> { + let Some(p) = self.queue_tx.front() else { + return PollBuffer::Empty(buf); }; - assert!( - p.len() <= buf.len(), - "Output buffer too small for packet {} > {}", - p.len(), - buf.len() - ); - let len = p.len(); - buf[..len].copy_from_slice(&p); + if len > buf.len() { + return PollBuffer::TooSmall { needed: len }; + } + + buf[..len].copy_from_slice(p); + self.queue_tx.pop_front(); - Ok(&buf[..len]) + PollBuffer::Ready(&buf[..len]) } fn poll_timeout(&self, now: Instant) -> Instant { diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 48e22f1..2fe3c5e 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -201,7 +201,13 @@ impl Server { pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> { if let Some(event) = self.local_events.pop_front() { - return event.into_output(buf, &self.client_certificates); + return match event.into_output(buf, &self.client_certificates) { + Ok(output) => output, + Err((event, needed)) => { + self.local_events.push_front(event); + Output::BufferTooSmall { needed } + } + }; } self.engine.poll_output(buf, self.last_now) } diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index a6c6063..2cbbded 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -231,7 +231,13 @@ impl Client { pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> { if let Some(event) = self.local_events.pop_front() { - return event.into_output(buf, &self.server_certificates); + return match event.into_output(buf, &self.server_certificates) { + Ok(output) => output, + Err((event, needed)) => { + self.local_events.push_front(event); + Output::BufferTooSmall { needed } + } + }; } self.engine.poll_output(buf, self.last_now) } @@ -1547,20 +1553,23 @@ fn parse_certificate_request(cr_data: &[u8], base_offset: usize) -> Result(self, buf: &'a mut [u8], peer_certs: &[Buf]) -> Output<'a> { + pub fn into_output<'a>( + self, + buf: &'a mut [u8], + peer_certs: &[Buf], + ) -> Result, (Self, usize)> { match self { LocalEvent::PeerCert => { let l = peer_certs[0].len(); - assert!( - l <= buf.len(), - "Output buffer too small for peer certificate" - ); + if l > buf.len() { + return Err((LocalEvent::PeerCert, l)); + } buf[..l].copy_from_slice(&peer_certs[0]); - Output::PeerCert(&buf[..l]) + Ok(Output::PeerCert(&buf[..l])) } - LocalEvent::Connected => Output::Connected, + LocalEvent::Connected => Ok(Output::Connected), LocalEvent::KeyingMaterial(m, profile) => { - Output::KeyingMaterial(KeyingMaterial::new(&m), profile) + Ok(Output::KeyingMaterial(KeyingMaterial::new(&m), profile)) } } } @@ -1697,4 +1706,31 @@ mod tests { )) )); } + + #[test] + fn peer_cert_output_is_deferred_when_buffer_is_too_small() { + let mut cert = Buf::new(); + cert.extend_from_slice(&[1, 2, 3, 4]); + let certs = [cert]; + + let mut out = [0u8; 2]; + let result = LocalEvent::PeerCert.into_output(&mut out, &certs); + + assert!(matches!(result, Err((LocalEvent::PeerCert, 4)))); + } + + #[test] + fn peer_cert_output_is_copied_when_buffer_fits() { + let mut cert = Buf::new(); + cert.extend_from_slice(&[1, 2, 3, 4]); + let certs = [cert]; + + let mut out = [0u8; 4]; + let result = LocalEvent::PeerCert.into_output(&mut out, &certs); + + match result { + Ok(Output::PeerCert(bytes)) => assert_eq!(bytes, &[1, 2, 3, 4]), + other => panic!("expected PeerCert output, got: {other:?}"), + } + } } diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index d67ccc3..0602e25 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -32,6 +32,12 @@ use crate::{Config, DtlsCertificate, Error, InternalError, Output, SeededRng}; const MAX_DEFRAGMENT_PACKETS: usize = 50; +enum PollBuffer<'a> { + Ready(&'a [u8]), + Empty(&'a mut [u8]), + TooSmall { needed: usize }, +} + /// Maximum DTLS sequence number (2^48 - 1). Per RFC 9147 §4.2, /// implementations MUST NOT allow the sequence number to wrap. const MAX_SEQUENCE_NUMBER: u64 = (1u64 << 48) - 1; @@ -529,14 +535,17 @@ impl Engine { self.purge_handled_queue_rx(); let buf = match self.poll_app_data(buf) { - Ok(p) => return Output::ApplicationData(p), - Err(b) => b, + PollBuffer::Ready(p) => return Output::ApplicationData(p), + PollBuffer::Empty(b) => b, + PollBuffer::TooSmall { needed } => return Output::BufferTooSmall { needed }, }; self.maybe_schedule_handshake_ack(now); - if let Ok(p) = self.poll_packet_tx(buf) { - return Output::Packet(p); + match self.poll_packet_tx(buf) { + PollBuffer::Ready(p) => return Output::Packet(p), + PollBuffer::Empty(_) => {} + PollBuffer::TooSmall { needed } => return Output::BufferTooSmall { needed }, } if self.close_notify_sequence.is_some() && !self.close_notify_reported { @@ -549,9 +558,9 @@ impl Engine { Output::Timeout(next_timeout) } - fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { + fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> PollBuffer<'a> { if !self.release_app_data { - return Err(buf); + return PollBuffer::Empty(buf); } let mut unhandled = self @@ -562,24 +571,21 @@ impl Engine { .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { - return Err(buf); + return PollBuffer::Empty(buf); }; let record_buffer = next.buffer(); let fragment = next.record().fragment(record_buffer); let len = fragment.len(); - assert!( - len <= buf.len(), - "Output buffer too small for application data {} > {}", - len, - buf.len() - ); + if len > buf.len() { + return PollBuffer::TooSmall { needed: len }; + } buf[..len].copy_from_slice(fragment); next.set_handled(); - Ok(&buf[..len]) + PollBuffer::Ready(&buf[..len]) } fn purge_handled_queue_rx(&mut self) { @@ -597,22 +603,20 @@ impl Engine { } } - fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { - let Some(p) = self.queue_tx.pop_front() else { - return Err(buf); + fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> PollBuffer<'a> { + let Some(p) = self.queue_tx.front() else { + return PollBuffer::Empty(buf); }; - assert!( - p.len() <= buf.len(), - "Output buffer too small for packet {} > {}", - p.len(), - buf.len() - ); - let len = p.len(); - buf[..len].copy_from_slice(&p); + if len > buf.len() { + return PollBuffer::TooSmall { needed: len }; + } + + buf[..len].copy_from_slice(p); + self.queue_tx.pop_front(); - Ok(&buf[..len]) + PollBuffer::Ready(&buf[..len]) } /// Prevent subsequent records from being appended to the current last diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 4711c8d..9ca7bc5 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -273,7 +273,13 @@ impl Server { pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> { if let Some(event) = self.local_events.pop_front() { - return event.into_output(buf, &self.client_certificates); + return match event.into_output(buf, &self.client_certificates) { + Ok(output) => output, + Err((event, needed)) => { + self.local_events.push_front(event); + Output::BufferTooSmall { needed } + } + }; } self.engine.poll_output(buf, self.last_now) } diff --git a/src/lib.rs b/src/lib.rs index e1680c0..6e369fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,6 +67,7 @@ //! references into your provided buffer: //! - `Packet(&[u8])`: send on your UDP socket //! - `Timeout(Instant)`: schedule a timer and call `handle_timeout` at/after it +//! - `BufferTooSmall { needed }`: grow the caller-owned poll buffer and retry //! - `Connected`: handshake complete //! - `PeerCert(&[u8])`: peer leaf certificate (DER) — validate in your app //! - `KeyingMaterial(KeyingMaterial, SrtpProfile)`: DTLS‑SRTP export @@ -114,6 +115,9 @@ //! match dtls.poll_output(&mut out_buf) { //! Output::Packet(p) => send_udp(p), //! Output::Timeout(t) => { next_wake = Some(t); break; } +//! Output::BufferTooSmall { needed } => { +//! out_buf.resize(needed, 0); +//! } //! Output::Connected => { //! // DTLS established — application may start sending //! } @@ -844,6 +848,14 @@ pub enum Output<'a> { /// This is always the last variant returned by a poll cycle. /// Internal state is only consistent after reaching `Timeout`. Timeout(Instant), + /// The caller-provided output buffer is too small for the next pending + /// output. Grow the buffer to at least `needed` bytes and call + /// [`Dtls::poll_output`] again without waiting for a timer. + BufferTooSmall { + /// Minimum caller-provided buffer length required to emit the pending + /// output. + needed: usize, + }, /// The handshake completed and the connection is established. Connected, /// The peer's leaf certificate in DER encoding. @@ -864,6 +876,7 @@ impl fmt::Debug for Output<'_> { match self { Self::Packet(v) => write!(f, "Packet({})", v.len()), Self::Timeout(v) => write!(f, "Timeout({:?})", v), + Self::BufferTooSmall { needed } => write!(f, "BufferTooSmall({})", needed), Self::Connected => write!(f, "Connected"), Self::PeerCert(v) => write!(f, "PeerCert({})", v.len()), Self::KeyingMaterial(v, p) => write!(f, "KeyingMaterial({}, {:?})", v.len(), p), diff --git a/tests/auto/common.rs b/tests/auto/common.rs index ef9babe..a7e82e6 100644 --- a/tests/auto/common.rs +++ b/tests/auto/common.rs @@ -25,7 +25,13 @@ pub fn collect_packets(endpoint: &mut Dtls) -> Vec> { let mut buf = vec![0u8; 2048]; loop { match endpoint.poll_output(&mut buf) { - Output::Packet(p) => out.push(p.to_vec()), + Output::Packet(p) => { + out.push(p.to_vec()); + buf.resize(2048, 0); + } + Output::BufferTooSmall { needed } => { + buf.resize(needed, 0); + } Output::Timeout(_) => break, _ => {} } @@ -39,14 +45,33 @@ pub fn drain_outputs(endpoint: &mut Dtls) -> DrainedOutputs { let mut buf = vec![0u8; 2048]; loop { match endpoint.poll_output(&mut buf) { - Output::Packet(p) => result.packets.push(p.to_vec()), - Output::Connected => result.connected = true, - Output::PeerCert(cert) => result.peer_cert = Some(cert.to_vec()), + Output::Packet(p) => { + result.packets.push(p.to_vec()); + buf.resize(2048, 0); + } + Output::Connected => { + result.connected = true; + buf.resize(2048, 0); + } + Output::PeerCert(cert) => { + result.peer_cert = Some(cert.to_vec()); + buf.resize(2048, 0); + } Output::KeyingMaterial(km, profile) => { result.keying_material = Some((km.to_vec(), profile)); + buf.resize(2048, 0); + } + Output::ApplicationData(data) => { + result.app_data.push(data.to_vec()); + buf.resize(2048, 0); + } + Output::CloseNotify => { + result.close_notify = true; + buf.resize(2048, 0); + } + Output::BufferTooSmall { needed } => { + buf.resize(needed, 0); } - Output::ApplicationData(data) => result.app_data.push(data.to_vec()), - Output::CloseNotify => result.close_notify = true, Output::Timeout(t) => { result.timeout = Some(t); break; diff --git a/tests/auto/handshake.rs b/tests/auto/handshake.rs index ea421e0..0e69b96 100644 --- a/tests/auto/handshake.rs +++ b/tests/auto/handshake.rs @@ -544,12 +544,13 @@ fn auto_client_poll_output_undersized_buffer() { // Poll with a buffer that is too small for the wire packet. // Before the fix this would panic with an index-out-of-bounds. let mut tiny_buf = [0u8; 4]; + let tiny_len = tiny_buf.len(); let output = client.poll_output(&mut tiny_buf); - // Should return Timeout (packet deferred), not a Packet. + // Should return BufferTooSmall (packet deferred), not a Packet or Timeout. assert!( - matches!(output, Output::Timeout(_)), - "undersized buffer should yield Timeout, got: {output:?}" + matches!(output, Output::BufferTooSmall { needed } if needed > tiny_len), + "undersized buffer should yield BufferTooSmall, got: {output:?}" ); // Now poll with a large buffer — the deferred packet should come through. diff --git a/tests/dtls12/common.rs b/tests/dtls12/common.rs index 7fc8710..f9b7934 100644 --- a/tests/dtls12/common.rs +++ b/tests/dtls12/common.rs @@ -112,6 +112,7 @@ pub struct DrainedOutputs { pub packets: Vec>, pub connected: bool, pub peer_cert: Option>, + pub peer_cert_deferred_for_small_buffer: bool, pub keying_material: Option<(Vec, SrtpProfile)>, pub app_data: Vec>, pub timeout: Option, @@ -120,18 +121,55 @@ pub struct DrainedOutputs { /// Poll until `Timeout`, collecting everything. pub fn drain_outputs(endpoint: &mut Dtls) -> DrainedOutputs { + drain_outputs_with_initial_buffer(endpoint, 2048) +} + +/// Poll until `Timeout`, collecting everything and growing the output buffer +/// when the engine reports that it is too small. +pub fn drain_outputs_with_initial_buffer( + endpoint: &mut Dtls, + initial_len: usize, +) -> DrainedOutputs { let mut result = DrainedOutputs::default(); - let mut buf = vec![0u8; 2048]; + let mut buf = vec![0u8; initial_len]; + let mut pending_too_small = None; loop { match endpoint.poll_output(&mut buf) { - Output::Packet(p) => result.packets.push(p.to_vec()), - Output::Connected => result.connected = true, - Output::PeerCert(cert) => result.peer_cert = Some(cert.to_vec()), + Output::Packet(p) => { + pending_too_small = None; + result.packets.push(p.to_vec()); + buf.resize(initial_len, 0); + } + Output::Connected => { + pending_too_small = None; + result.connected = true; + buf.resize(initial_len, 0); + } + Output::PeerCert(cert) => { + result.peer_cert_deferred_for_small_buffer |= pending_too_small == Some(cert.len()); + pending_too_small = None; + result.peer_cert = Some(cert.to_vec()); + buf.resize(initial_len, 0); + } Output::KeyingMaterial(km, profile) => { + pending_too_small = None; result.keying_material = Some((km.to_vec(), profile)); + buf.resize(initial_len, 0); + } + Output::ApplicationData(data) => { + pending_too_small = None; + result.app_data.push(data.to_vec()); + buf.resize(initial_len, 0); + } + Output::CloseNotify => { + pending_too_small = None; + result.close_notify = true; + buf.resize(initial_len, 0); + } + Output::BufferTooSmall { needed } => { + pending_too_small = Some(needed); + buf.resize(needed, 0); } - Output::ApplicationData(data) => result.app_data.push(data.to_vec()), - Output::CloseNotify => result.close_notify = true, Output::Timeout(t) => { result.timeout = Some(t); break; diff --git a/tests/dtls12/data.rs b/tests/dtls12/data.rs index 1ba23b5..2e96f13 100644 --- a/tests/dtls12/data.rs +++ b/tests/dtls12/data.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -use dimpl::Dtls; +use dimpl::{Dtls, Output}; use crate::common::*; @@ -95,6 +95,46 @@ fn dtls12_application_data_exchange() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_poll_output_small_buffers_defer_packet_and_app_data() { + let _ = env_logger::try_init(); + + let now = Instant::now(); + let (mut client, mut server, _) = setup_connected_12_pair(now); + + client.send_application_data(b"hello").expect("client send"); + + let mut tiny_packet_buf = [0u8; 4]; + let tiny_packet_len = tiny_packet_buf.len(); + let output = client.poll_output(&mut tiny_packet_buf); + assert!( + matches!(output, Output::BufferTooSmall { needed } if needed > tiny_packet_len), + "undersized packet buffer should yield BufferTooSmall, got: {output:?}" + ); + + let mut packet_buf = vec![0u8; 2048]; + let packet = match client.poll_output(&mut packet_buf) { + Output::Packet(packet) => packet.to_vec(), + output => panic!("large buffer should yield Packet, got: {output:?}"), + }; + server.handle_packet(&packet).expect("server handle packet"); + + let mut tiny_app_buf = [0u8; 2]; + let expected_app_len = b"hello".len(); + let output = server.poll_output(&mut tiny_app_buf); + assert!( + matches!(output, Output::BufferTooSmall { needed } if needed == expected_app_len), + "undersized app-data buffer should yield BufferTooSmall, got: {output:?}" + ); + + let mut app_buf = [0u8; 64]; + match server.poll_output(&mut app_buf) { + Output::ApplicationData(data) => assert_eq!(data, b"hello"), + output => panic!("large buffer should yield ApplicationData, got: {output:?}"), + } +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_multiple_application_data_messages() { diff --git a/tests/dtls12/handshake.rs b/tests/dtls12/handshake.rs index a6f1990..7a25445 100644 --- a/tests/dtls12/handshake.rs +++ b/tests/dtls12/handshake.rs @@ -583,6 +583,70 @@ fn dtls12_peer_certificate_exchange() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_peer_certificate_output_retries_after_small_buffer() { + use dimpl::certificate::generate_self_signed_certificate; + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let expected_client_cert = client_cert.certificate.clone(); + let expected_server_cert = server_cert.certificate.clone(); + + let config = dtls12_config(); + let mut now = Instant::now(); + + let mut client = Dtls::new_12(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_12(config, server_cert, now); + server.set_active(false); + + let mut client_peer_cert = None; + let mut server_peer_cert = None; + let mut client_cert_deferred = false; + let mut server_cert_deferred = false; + + for _ in 0..30 { + client.handle_timeout(now).expect("client timeout"); + server.handle_timeout(now).expect("server timeout"); + + let client_out = drain_outputs_with_initial_buffer(&mut client, 1); + let server_out = drain_outputs_with_initial_buffer(&mut server, 1); + + client_cert_deferred |= client_out.peer_cert_deferred_for_small_buffer; + server_cert_deferred |= server_out.peer_cert_deferred_for_small_buffer; + + if let Some(cert) = client_out.peer_cert { + client_peer_cert = Some(cert); + } + if let Some(cert) = server_out.peer_cert { + server_peer_cert = Some(cert); + } + + deliver_packets(&client_out.packets, &mut server); + deliver_packets(&server_out.packets, &mut client); + + if client_peer_cert.is_some() && server_peer_cert.is_some() { + break; + } + + now += Duration::from_millis(10); + } + + assert_eq!(client_peer_cert, Some(expected_server_cert)); + assert_eq!(server_peer_cert, Some(expected_client_cert)); + assert!( + client_cert_deferred, + "client PeerCert should be deferred once" + ); + assert!( + server_cert_deferred, + "server PeerCert should be deferred once" + ); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_handshake_client_certificate_auth() { diff --git a/tests/dtls13/common.rs b/tests/dtls13/common.rs index c452df0..b486054 100644 --- a/tests/dtls13/common.rs +++ b/tests/dtls13/common.rs @@ -13,6 +13,7 @@ pub struct DrainedOutputs { pub packets: Vec>, pub connected: bool, pub peer_cert: Option>, + pub peer_cert_deferred_for_small_buffer: bool, pub keying_material: Option<(Vec, SrtpProfile)>, pub app_data: Vec>, pub timeout: Option, @@ -35,18 +36,55 @@ pub fn collect_packets(endpoint: &mut Dtls) -> Vec> { /// Poll until `Timeout`, collecting everything. pub fn drain_outputs(endpoint: &mut Dtls) -> DrainedOutputs { + drain_outputs_with_initial_buffer(endpoint, 2048) +} + +/// Poll until `Timeout`, collecting everything and growing the output buffer +/// when the engine reports that it is too small. +pub fn drain_outputs_with_initial_buffer( + endpoint: &mut Dtls, + initial_len: usize, +) -> DrainedOutputs { let mut result = DrainedOutputs::default(); - let mut buf = vec![0u8; 2048]; + let mut buf = vec![0u8; initial_len]; + let mut pending_too_small = None; loop { match endpoint.poll_output(&mut buf) { - Output::Packet(p) => result.packets.push(p.to_vec()), - Output::Connected => result.connected = true, - Output::PeerCert(cert) => result.peer_cert = Some(cert.to_vec()), + Output::Packet(p) => { + pending_too_small = None; + result.packets.push(p.to_vec()); + buf.resize(initial_len, 0); + } + Output::Connected => { + pending_too_small = None; + result.connected = true; + buf.resize(initial_len, 0); + } + Output::PeerCert(cert) => { + result.peer_cert_deferred_for_small_buffer |= pending_too_small == Some(cert.len()); + pending_too_small = None; + result.peer_cert = Some(cert.to_vec()); + buf.resize(initial_len, 0); + } Output::KeyingMaterial(km, profile) => { + pending_too_small = None; result.keying_material = Some((km.to_vec(), profile)); + buf.resize(initial_len, 0); + } + Output::ApplicationData(data) => { + pending_too_small = None; + result.app_data.push(data.to_vec()); + buf.resize(initial_len, 0); + } + Output::CloseNotify => { + pending_too_small = None; + result.close_notify = true; + buf.resize(initial_len, 0); + } + Output::BufferTooSmall { needed } => { + pending_too_small = Some(needed); + buf.resize(needed, 0); } - Output::ApplicationData(data) => result.app_data.push(data.to_vec()), - Output::CloseNotify => result.close_notify = true, Output::Timeout(t) => { result.timeout = Some(t); break; diff --git a/tests/dtls13/data.rs b/tests/dtls13/data.rs index 002ef64..167c840 100644 --- a/tests/dtls13/data.rs +++ b/tests/dtls13/data.rs @@ -95,6 +95,46 @@ fn dtls13_application_data_exchange() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_poll_output_small_buffers_defer_packet_and_app_data() { + let _ = env_logger::try_init(); + + let now = Instant::now(); + let (mut client, mut server, _) = setup_connected_13_pair(now); + + client.send_application_data(b"hello").expect("client send"); + + let mut tiny_packet_buf = [0u8; 4]; + let tiny_packet_len = tiny_packet_buf.len(); + let output = client.poll_output(&mut tiny_packet_buf); + assert!( + matches!(output, Output::BufferTooSmall { needed } if needed > tiny_packet_len), + "undersized packet buffer should yield BufferTooSmall, got: {output:?}" + ); + + let mut packet_buf = vec![0u8; 2048]; + let packet = match client.poll_output(&mut packet_buf) { + Output::Packet(packet) => packet.to_vec(), + output => panic!("large buffer should yield Packet, got: {output:?}"), + }; + server.handle_packet(&packet).expect("server handle packet"); + + let mut tiny_app_buf = [0u8; 2]; + let expected_app_len = b"hello".len(); + let output = server.poll_output(&mut tiny_app_buf); + assert!( + matches!(output, Output::BufferTooSmall { needed } if needed == expected_app_len), + "undersized app-data buffer should yield BufferTooSmall, got: {output:?}" + ); + + let mut app_buf = [0u8; 64]; + match server.poll_output(&mut app_buf) { + Output::ApplicationData(data) => assert_eq!(data, b"hello"), + output => panic!("large buffer should yield ApplicationData, got: {output:?}"), + } +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_multiple_application_data_messages() { diff --git a/tests/dtls13/handshake.rs b/tests/dtls13/handshake.rs index 1799bf7..26e9858 100644 --- a/tests/dtls13/handshake.rs +++ b/tests/dtls13/handshake.rs @@ -200,6 +200,72 @@ fn dtls13_peer_certificate_exchange() { ); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_peer_certificate_output_retries_after_small_buffer() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + + let expected_client_cert = client_cert.certificate.clone(); + let expected_server_cert = server_cert.certificate.clone(); + + let config = dtls13_config(); + let mut now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + let mut client_peer_cert = None; + let mut server_peer_cert = None; + let mut client_cert_deferred = false; + let mut server_cert_deferred = false; + + for _ in 0..30 { + client.handle_timeout(now).expect("client timeout"); + server.handle_timeout(now).expect("server timeout"); + + let client_out = drain_outputs_with_initial_buffer(&mut client, 1); + let server_out = drain_outputs_with_initial_buffer(&mut server, 1); + + client_cert_deferred |= client_out.peer_cert_deferred_for_small_buffer; + server_cert_deferred |= server_out.peer_cert_deferred_for_small_buffer; + + if let Some(cert) = client_out.peer_cert { + client_peer_cert = Some(cert); + } + if let Some(cert) = server_out.peer_cert { + server_peer_cert = Some(cert); + } + + deliver_packets(&client_out.packets, &mut server); + deliver_packets(&server_out.packets, &mut client); + + if client_peer_cert.is_some() && server_peer_cert.is_some() { + break; + } + + now += Duration::from_millis(10); + } + + assert_eq!(client_peer_cert, Some(expected_server_cert)); + assert_eq!(server_peer_cert, Some(expected_client_cert)); + assert!( + client_cert_deferred, + "client PeerCert should be deferred once" + ); + assert!( + server_cert_deferred, + "server PeerCert should be deferred once" + ); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_srtp_keying_material_correct_size() {