diff --git a/Cargo.lock b/Cargo.lock index af5a830..526cbf6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,6 +66,18 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + [[package]] name = "http" version = "1.3.1" @@ -77,16 +89,43 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + [[package]] name = "http-handler" version = "1.0.0" dependencies = [ "bytes", + "futures-core", "http", + "http-body", + "http-body-util", "napi", "napi-build", "napi-derive", "tokio", + "tokio-util", ] [[package]] @@ -117,6 +156,7 @@ dependencies = [ "napi-sys", "nohash-hasher", "rustc-hash", + "tokio", ] [[package]] @@ -220,6 +260,7 @@ version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" dependencies = [ + "bytes", "pin-project-lite", "tokio-macros", ] @@ -235,6 +276,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-util" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "unicode-ident" version = "1.0.20" diff --git a/Cargo.toml b/Cargo.toml index c516d59..e421a79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,14 +24,13 @@ napi-build = { version = "2", optional = true } [dependencies] bytes = "1.10.1" http = "1.0" -# http-body = "1.0" -# http-rewriter = { path = "../http-rewriter" } -# napi = { path = "../napi-rs/crates/napi", features = ["napi4"], optional = true } -# napi-derive = { path = "../napi-rs/crates/macro", optional = true } -napi = { version = "3", features = ["napi4"], optional = true } +tokio = { version = "1.45.1", features = ["sync", "macros", "rt", "io-util"] } +tokio-util = { version = "0.7", features = ["codec"] } +http-body = "1.0" +http-body-util = "0.1" +futures-core = "0.3" +napi = { version = "3", features = ["napi4", "tokio_rt", "async"], optional = true } napi-derive = { version = "3", optional = true } -# napi = { version = "2.12.2", default-features = false, features = ["napi4"], optional = true } -# napi-derive = { version = "2.12.2", optional = true } [dev-dependencies] tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"] } diff --git a/src/body.rs b/src/body.rs new file mode 100644 index 0000000..cf9a681 --- /dev/null +++ b/src/body.rs @@ -0,0 +1,352 @@ +use std::{ + fmt, io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::{Bytes, BytesMut}; +use futures_core::Stream; +use http_body::{Body, Frame}; +use tokio::{ + io::{AsyncRead, AsyncWrite, DuplexStream}, + sync::Mutex, +}; + +/// Error type for stream operations +#[derive(Debug, Clone)] +pub enum StreamError { + /// The stream has been closed and cannot accept more data + StreamClosed, + /// The stream receiver has already been consumed and cannot be taken again + StreamAlreadyConsumed, + /// An I/O error occurred + IoError(String), +} + +impl fmt::Display for StreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StreamError::StreamClosed => write!(f, "Stream closed"), + StreamError::StreamAlreadyConsumed => write!(f, "Stream already consumed"), + StreamError::IoError(msg) => write!(f, "I/O error: {}", msg), + } + } +} + +impl std::error::Error for StreamError {} + +impl From for StreamError { + fn from(err: io::Error) -> Self { + StreamError::IoError(err.to_string()) + } +} + +/// Request body with duplex stream for bidirectional I/O +/// +/// This type holds both halves of a duplex stream pair. One half is used for polling +/// (by the handler), and the other half is accessible via `stream()` for external writes. +/// +/// # Cloning Behavior +/// +/// RequestBody is clonable, and clones share the same underlying streams via Arc. +/// This allows NAPI to clone Request objects while preserving the streams. +#[derive(Debug)] +pub struct RequestBody { + // The half used for polling/reading by the handler + read_side: Arc>, + // The half used by external code to write data into the body + write_side: Arc>, + buffer_size: usize, +} + +impl RequestBody { + /// Create a new request body with specified buffer size + pub fn new_with_buffer_size(buffer_size: usize) -> Self { + let (read_side, write_side) = tokio::io::duplex(buffer_size); + + Self { + read_side: Arc::new(Mutex::new(read_side)), + write_side: Arc::new(Mutex::new(write_side)), + buffer_size, + } + } + + /// Create a new request body with default buffer size (16KB) + pub fn new() -> Self { + Self::new_with_buffer_size(16384) + } + + /// Create from buffered data (writes data to stream immediately) + pub async fn from_data(data: Bytes) -> Result { + let body = Self::new(); + + // Write data to the write side of the stream + use tokio::io::AsyncWriteExt; + let mut stream = body.write_side.lock().await; + stream.write_all(&data).await?; + stream.shutdown().await?; + drop(stream); + + Ok(body) + } + + /// Get the buffer size for this request body + pub fn buffer_size(&self) -> usize { + self.buffer_size + } + + /// Create response body with the same buffer size + /// Returns a new ResponseBody that uses a separate duplex stream + pub fn create_response(&self) -> ResponseBody { + ResponseBody::new_with_buffer_size(self.buffer_size) + } +} + +impl Default for RequestBody { + fn default() -> Self { + Self::new() + } +} + +impl Clone for RequestBody { + fn clone(&self) -> Self { + Self { + read_side: Arc::clone(&self.read_side), + write_side: Arc::clone(&self.write_side), + buffer_size: self.buffer_size, + } + } +} + +impl AsyncRead for RequestBody { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let mut stream = match self.read_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for RequestBody { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut stream = match self.write_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut stream = match self.write_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut stream = match self.write_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_shutdown(cx) + } +} + +/// Response body with duplex stream for bidirectional I/O +/// +/// This type holds both halves of a duplex stream pair and implements `http-body::Body`. +/// One half is used for polling (reading frames), and the other half is accessible via +/// `stream()` for external writes (e.g., handler writing response data). +/// +/// # Cloning Behavior +/// +/// ResponseBody is clonable, and clones share the same underlying streams via Arc. +/// This allows NAPI to clone Response objects. The `poll_frame` implementation handles +/// concurrent access via the mutex. +/// +/// ## Reading Frames +/// To read frames from this body, use `BodyExt::frame()` from http-body-util. +#[derive(Debug)] +pub struct ResponseBody { + // The half used for polling/reading frames + read_side: Arc>, + // The half used by handlers to write response data + write_side: Arc>, + buffer_size: usize, +} + +impl ResponseBody { + /// Create a new response body with specified buffer size + pub fn new_with_buffer_size(buffer_size: usize) -> Self { + let (read_side, write_side) = tokio::io::duplex(buffer_size); + + Self { + read_side: Arc::new(Mutex::new(read_side)), + write_side: Arc::new(Mutex::new(write_side)), + buffer_size, + } + } + + /// Create a new response body with default buffer size (16KB) + pub fn new() -> Self { + Self::new_with_buffer_size(16384) + } + + /// Get the buffer size for this response body + pub fn buffer_size(&self) -> usize { + self.buffer_size + } +} + +impl Default for ResponseBody { + fn default() -> Self { + Self::new() + } +} + +impl Clone for ResponseBody { + fn clone(&self) -> Self { + Self { + read_side: Arc::clone(&self.read_side), + write_side: Arc::clone(&self.write_side), + buffer_size: self.buffer_size, + } + } +} + +impl AsyncRead for ResponseBody { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let mut stream = match self.read_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for ResponseBody { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut stream = match self.write_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut stream = match self.write_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut stream = match self.write_side.try_lock() { + Ok(guard) => guard, + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + }; + Pin::new(&mut *stream).poll_shutdown(cx) + } +} + +impl Body for ResponseBody { + type Data = Bytes; + type Error = String; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + // Try to read data from the stream + let mut buffer = BytesMut::with_capacity(8192); + unsafe { + buffer.set_len(8192); + } + + let mut read_buf = tokio::io::ReadBuf::new(&mut buffer); + let initial_filled = read_buf.filled().len(); + + match self.as_mut().poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) => { + let filled = read_buf.filled().len(); + if filled == initial_filled { + // EOF reached + Poll::Ready(None) + } else { + // Data was read + buffer.truncate(filled); + Poll::Ready(Some(Ok(Frame::data(buffer.freeze())))) + } + } + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e.to_string()))), + Poll::Pending => Poll::Pending, + } + } +} + +/// Implement Stream for ResponseBody to enable async iteration in Rust +impl Stream for ResponseBody { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Use poll_frame and extract data + match self.poll_frame(cx) { + Poll::Ready(Some(Ok(frame))) => { + if let Ok(data) = frame.into_data() { + Poll::Ready(Some(Ok(data))) + } else { + // Frame was not data (e.g., trailers) - skip it + cx.waker().wake_by_ref(); + Poll::Pending + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/extensions.rs b/src/extensions.rs index b444ac5..f24a199 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -5,7 +5,9 @@ use std::{ net::SocketAddr, ops::{Deref, DerefMut}, path::{Path, PathBuf}, + sync::Arc, }; +use tokio::sync::Mutex; /// Socket information for a request #[derive(Clone, Debug, Default, PartialEq, Eq)] @@ -75,6 +77,47 @@ impl From for DocumentRoot { } } +/// WebSocket mode marker for a request/response +/// +/// This extension indicates that the request/response should be treated as a WebSocket +/// connection, where each write() call represents a complete WebSocket message rather +/// than HTTP chunks. +/// +/// The presence of this extension in the request/response extensions indicates WebSocket mode is enabled. +/// To check if WebSocket mode is enabled, use: `request.extensions().get::().is_some()` +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct WebSocketMode; + +/// WebSocket decoder state for response body decoding +/// +/// This extension stores a persistent WebSocketCodec and buffer that are used across multiple +/// Response::next() calls to properly decode WebSocket frames from HTTP body data. +#[derive(Clone)] +pub struct WebSocketDecoderState { + codec: Arc>, + buffer: Arc>, +} + +impl WebSocketDecoderState { + /// Create a new WebSocketDecoderState + pub fn new() -> Self { + Self { + codec: Arc::new(Mutex::new(crate::websocket::WebSocketCodec::new())), + buffer: Arc::new(Mutex::new(BytesMut::with_capacity(8192))), + } + } + + /// Get a reference to the codec + pub fn codec(&self) -> &Arc> { + &self.codec + } + + /// Get a reference to the buffer + pub fn buffer(&self) -> &Arc> { + &self.buffer + } +} + /// Response log buffer #[derive(Clone, Debug, Default)] pub struct ResponseLog { diff --git a/src/handler.rs b/src/handler.rs index c8d241b..826c2ae 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -9,19 +9,29 @@ //! ## Basic handler implementation //! //! ``` -//! use http_handler::Handler; -//! use bytes::BytesMut; +//! use http_handler::{Handler, Request, Response}; +//! use bytes::Bytes; +//! use tokio::io::AsyncWriteExt; //! //! struct HelloHandler; //! //! impl Handler for HelloHandler { //! type Error = std::convert::Infallible; //! -//! async fn handle(&self, _request: http::Request) -> Result, Self::Error> { +//! async fn handle(&self, request: Request) -> Result { +//! let (_parts, body) = request.into_parts(); +//! let response_body = body.create_response(); +//! +//! let mut response_writer = response_body.clone(); +//! tokio::spawn(async move { +//! let _ = response_writer.write_all(b"Hello, World!").await; +//! let _ = response_writer.shutdown().await; +//! }); +//! //! Ok(http::Response::builder() //! .status(200) //! .header("Content-Type", "text/plain") -//! .body(BytesMut::from("Hello, World!")) +//! .body(response_body) //! .unwrap()) //! } //! } @@ -30,8 +40,9 @@ //! ## Handler composition //! //! ``` -//! use http_handler::Handler; -//! use bytes::BytesMut; +//! use http_handler::{Handler, Request, Response, RequestBody, ResponseBody}; +//! use bytes::Bytes; +//! use tokio::io::AsyncWriteExt; //! //! // Middleware that adds a header //! struct AddHeaderHandler { @@ -40,14 +51,16 @@ //! header_value: &'static str, //! } //! -//! impl Handler for AddHeaderHandler +//! impl Handler for AddHeaderHandler //! where -//! H: Handler + std::marker::Sync, -//! B: std::marker::Send +//! H: Handler + std::marker::Sync, //! { //! type Error = H::Error; //! -//! async fn handle(&self, request: http::Request) -> Result, Self::Error> { +//! async fn handle( +//! &self, +//! request: http::Request +//! ) -> Result, Self::Error> { //! let mut response = self.inner.handle(request).await?; //! response.headers_mut().insert( //! self.header_name, @@ -62,10 +75,19 @@ //! //! impl Handler for ApiHandler { //! type Error = std::convert::Infallible; -//! async fn handle(&self, _req: http::Request) -> Result, Self::Error> { +//! async fn handle(&self, request: Request) -> Result { +//! let (_parts, body) = request.into_parts(); +//! let response_body = body.create_response(); +//! +//! let mut response_writer = response_body.clone(); +//! tokio::spawn(async move { +//! let _ = response_writer.write_all(br#"{"status": "ok"}"#).await; +//! let _ = response_writer.shutdown().await; +//! }); +//! //! Ok(http::Response::builder() //! .status(200) -//! .body(BytesMut::from(r#"{"status": "ok"}"#)) +//! .body(response_body) //! .unwrap()) //! } //! } @@ -77,53 +99,35 @@ //! }; //! ``` -use bytes::BytesMut; - /// Trait for types that can handle HTTP requests and produce responses /// -/// The handler trait is generic over the request body type `B`, allowing -/// handlers to work with different body representations such as `Bytes`, -/// `String`, streaming bodies, or custom types. -/// -/// The response body type is fixed to `Bytes` for simplicity, but handlers -/// can be composed with body transformers if different response types are needed. +/// The handler trait works with duplex stream-based request and response bodies, +/// providing efficient bidirectional I/O with configurable buffer sizes for +/// backpressure control. /// /// # Examples /// -/// ## Handler for Bytes body (default) +/// ## Basic handler /// /// ``` -/// use http_handler::Handler; -/// use bytes::BytesMut; +/// use http_handler::{Handler, Request, Response}; +/// use bytes::Bytes; +/// use tokio::io::AsyncWriteExt; /// /// struct MyHandler; /// /// impl Handler for MyHandler { /// type Error = std::convert::Infallible; /// -/// async fn handle(&self, request: http::Request) -> Result, Self::Error> { -/// Ok(http::Response::builder() -/// .status(200) -/// .body(BytesMut::from("Hello, World!")) -/// .unwrap()) -/// } -/// } -/// ``` +/// async fn handle(&self, request: Request) -> Result { +/// let (_parts, body) = request.into_parts(); +/// let response_body = body.create_response(); /// -/// ## Handler for String body -/// -/// ``` -/// use http_handler::Handler; -/// use bytes::BytesMut; -/// -/// struct StringHandler; -/// -/// impl Handler for StringHandler { -/// type Error = std::convert::Infallible; -/// -/// async fn handle(&self, request: http::Request) -> Result, Self::Error> { -/// let body = request.body(); -/// let response_body = format!("You sent: {}", body); +/// let mut response_writer = response_body.clone(); +/// tokio::spawn(async move { +/// let _ = response_writer.write_all(b"Hello, World!").await; +/// let _ = response_writer.shutdown().await; +/// }); /// /// Ok(http::Response::builder() /// .status(200) @@ -132,12 +136,16 @@ use bytes::BytesMut; /// } /// } /// ``` -pub trait Handler { +pub trait Handler { /// The error type returned by the handler type Error; /// Handle an HTTP request and produce a response - async fn handle(&self, request: http::Request) -> Result, Self::Error>; + #[allow(async_fn_in_trait)] + async fn handle( + &self, + request: http::Request, + ) -> Result, Self::Error>; } #[cfg(test)] @@ -145,54 +153,91 @@ mod tests { use super::*; use crate::extensions::SocketInfo; use crate::extensions::{RequestExt, ResponseExt}; - use bytes::Bytes; + use bytes::{Bytes, BytesMut}; + use http_body_util::BodyExt; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; /// Example handler that echoes the request body pub struct EchoHandler; - impl Handler for EchoHandler { + impl Handler for EchoHandler { type Error = http::Error; - async fn handle( - &self, - request: http::Request, - ) -> Result, Self::Error> { - http::Response::builder() - .status(200) - .body(request.body().clone()) + async fn handle(&self, request: crate::Request) -> Result { + let (_parts, mut body) = request.into_parts(); + let response_body = body.create_response(); + + // Spawn task to echo request to response + let mut response_writer = response_body.clone(); + tokio::spawn(async move { + use tokio::io::AsyncReadExt; + use tokio::io::AsyncWriteExt; + let mut buffer = vec![0u8; 8192]; + loop { + let n = body.read(&mut buffer).await.unwrap_or(0); + if n == 0 { + break; + } + let _ = response_writer.write_all(&buffer[..n]).await; + } + let _ = response_writer.shutdown().await; + }); + + http::Response::builder().status(200).body(response_body) } } #[tokio::test] async fn test_echo_handler() { let handler = EchoHandler; - let request = http::Request::builder() - .uri("/echo") - .body(Bytes::from("Hello, world!")) + let body = crate::RequestBody::from_data(Bytes::from("Hello, world!")) + .await .unwrap(); + let request = http::Request::builder().uri("/echo").body(body).unwrap(); let response = handler.handle(request).await.unwrap(); assert_eq!(response.status(), 200); - assert_eq!(response.body(), &Bytes::from("Hello, world!")); + + // Read the response body + let (_, mut response_body) = response.into_parts(); + let mut collected = BytesMut::new(); + while let Some(result) = response_body.frame().await { + match result { + Ok(frame) => { + if let Ok(data) = frame.into_data() { + collected.extend_from_slice(&data); + } + } + Err(_) => break, + } + } + assert_eq!(&collected[..], b"Hello, world!"); } /// Test handler that adds logging struct LoggingHandler; - impl Handler for LoggingHandler { + impl Handler for LoggingHandler { type Error = String; - async fn handle( - &self, - request: http::Request, - ) -> Result, Self::Error> { - let method = request.method(); - let uri = request.uri(); + async fn handle(&self, request: crate::Request) -> Result { + let method = request.method().clone(); + let uri = request.uri().clone(); + let (_, body) = request.into_parts(); + + let response_body = body.create_response(); + + // Send OK response + let mut response_writer = response_body.clone(); + tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + let _ = response_writer.write_all(b"OK").await; + let _ = response_writer.shutdown().await; + }); let mut response = http::Response::builder() .status(200) - .body(Bytes::from("OK")) + .body(response_body) .unwrap(); response.append_log(format!("{} {}", method, uri)); @@ -204,42 +249,63 @@ mod tests { #[tokio::test] async fn test_logging_handler() { let handler = LoggingHandler; + let body = crate::RequestBody::new(); let request = http::Request::builder() .method("POST") .uri("/api/users") - .body(Bytes::new()) + .body(body) .unwrap(); let response = handler.handle(request).await.unwrap(); assert_eq!(response.status(), 200); - assert_eq!(response.body(), &Bytes::from("OK")); let log = response.log().unwrap(); assert_eq!(log.as_bytes(), b"POST /api/users\n"); + + // Read the response body + let (_, mut response_body) = response.into_parts(); + let mut collected = BytesMut::new(); + while let Some(result) = response_body.frame().await { + match result { + Ok(frame) => { + if let Ok(data) = frame.into_data() { + collected.extend_from_slice(&data); + } + } + Err(_) => break, + } + } + assert_eq!(&collected[..], b"OK"); } /// Test handler that uses socket info struct SocketAwareHandler; - impl Handler for SocketAwareHandler { + impl Handler for SocketAwareHandler { type Error = String; - async fn handle( - &self, - request: http::Request, - ) -> Result, Self::Error> { - let socket_info = request.socket_info(); + async fn handle(&self, request: crate::Request) -> Result { + let socket_info = request.socket_info().cloned(); + let (_, body) = request.into_parts(); + let response_body = body.create_response(); - let body = match socket_info { + let body_text = match socket_info { Some(info) => { format!("Local: {:?}, Remote: {:?}", info.local, info.remote) } None => "No socket info".to_string(), }; + let mut response_writer = response_body.clone(); + tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + let _ = response_writer.write_all(body_text.as_bytes()).await; + let _ = response_writer.shutdown().await; + }); + Ok(http::Response::builder() .status(200) - .body(Bytes::from(body)) + .body(response_body) .unwrap()) } } @@ -249,26 +315,46 @@ mod tests { let handler = SocketAwareHandler; // Test without socket info - let request = http::Request::builder() - .uri("/test") - .body(Bytes::new()) - .unwrap(); + let body = crate::RequestBody::new(); + let request = http::Request::builder().uri("/test").body(body).unwrap(); let response = handler.handle(request).await.unwrap(); - assert_eq!(response.body(), &Bytes::from("No socket info")); + let (_, mut response_body) = response.into_parts(); + let mut collected = BytesMut::new(); + while let Some(result) = response_body.frame().await { + match result { + Ok(frame) => { + if let Ok(data) = frame.into_data() { + collected.extend_from_slice(&data); + } + } + Err(_) => break, + } + } + assert_eq!(&collected[..], b"No socket info"); // Test with socket info - let mut request = http::Request::builder() - .uri("/test") - .body(Bytes::new()) - .unwrap(); + let body = crate::RequestBody::new(); + let mut request = http::Request::builder().uri("/test").body(body).unwrap(); let local = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 5000); request.set_socket_info(SocketInfo::new(Some(local), Some(remote))); let response = handler.handle(request).await.unwrap(); - let body_str = std::str::from_utf8(response.body()).unwrap(); + let (_, mut response_body) = response.into_parts(); + let mut collected = BytesMut::new(); + while let Some(result) = response_body.frame().await { + match result { + Ok(frame) => { + if let Ok(data) = frame.into_data() { + collected.extend_from_slice(&data); + } + } + Err(_) => break, + } + } + let body_str = std::str::from_utf8(&collected).unwrap(); assert!(body_str.contains("127.0.0.1:8080")); assert!(body_str.contains("192.168.1.1:5000")); } @@ -276,13 +362,10 @@ mod tests { /// Test handler that returns errors struct ErrorHandler; - impl Handler for ErrorHandler { + impl Handler for ErrorHandler { type Error = String; - async fn handle( - &self, - _request: http::Request, - ) -> Result, Self::Error> { + async fn handle(&self, _request: crate::Request) -> Result { Err("Something went wrong".to_string()) } } @@ -290,10 +373,8 @@ mod tests { #[tokio::test] async fn test_error_handler() { let handler = ErrorHandler; - let request = http::Request::builder() - .uri("/error") - .body(Bytes::new()) - .unwrap(); + let body = crate::RequestBody::new(); + let request = http::Request::builder().uri("/error").body(body).unwrap(); let result = handler.handle(request).await; assert!(result.is_err()); @@ -303,16 +384,23 @@ mod tests { /// Test handler that sets an exception struct ExceptionHandler; - impl Handler for ExceptionHandler { + impl Handler for ExceptionHandler { type Error = std::convert::Infallible; - async fn handle( - &self, - _request: http::Request, - ) -> Result, Self::Error> { + async fn handle(&self, request: crate::Request) -> Result { + let (_, body) = request.into_parts(); + let response_body = body.create_response(); + + let mut response_writer = response_body.clone(); + tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + let _ = response_writer.write_all(b"Internal Server Error").await; + let _ = response_writer.shutdown().await; + }); + let mut response = http::Response::builder() .status(500) - .body(Bytes::from("Internal Server Error")) + .body(response_body) .unwrap(); response.set_exception("Database connection failed"); @@ -324,146 +412,27 @@ mod tests { #[tokio::test] async fn test_exception_handler() { let handler = ExceptionHandler; - let request = http::Request::builder() - .uri("/fail") - .body(Bytes::new()) - .unwrap(); + let body = crate::RequestBody::new(); + let request = http::Request::builder().uri("/fail").body(body).unwrap(); let response = handler.handle(request).await.unwrap(); assert_eq!(response.status(), 500); - assert_eq!(response.body(), &Bytes::from("Internal Server Error")); let exception = response.exception().unwrap(); assert_eq!(exception.message(), "Database connection failed"); - } - - /// Test handler that works with String bodies - struct StringBodyHandler; - - impl Handler for StringBodyHandler { - type Error = std::convert::Infallible; - async fn handle( - &self, - request: http::Request, - ) -> Result, Self::Error> { - let body = request.body(); - let response_body = format!("Received: {}", body.to_uppercase()); - - Ok(http::Response::builder() - .status(200) - .body(response_body) - .unwrap()) - } - } - - #[tokio::test] - async fn test_string_body_handler() { - let handler = StringBodyHandler; - let request = http::Request::builder() - .uri("/string") - .body("hello world".to_string()) - .unwrap(); - - let response = handler.handle(request).await.unwrap(); - assert_eq!(response.status(), 200); - assert_eq!(response.body(), &Bytes::from("Received: HELLO WORLD")); - } - - // /// Test generic handler with different body types - // struct TypeAwareHandler; - - // impl Handler for TypeAwareHandler { - // type Error = std::convert::Infallible; - - // fn handle(&self, request: http::Request) -> Result, Self::Error> { - // let type_name = std::any::type_name::(); - // let body_debug = format!("{:?}", request.body()); - // let response_body = format!("Type: {}\nBody: {}", type_name, body_debug); - - // Ok(http::Response::builder() - // .status(200) - // .body(response_body) - // .unwrap()) - // } - // } - - // #[test] - // fn test_type_aware_handler() { - // let handler = TypeAwareHandler; - - // // Test with String body - // let request = http::Request::builder() - // .uri("/type") - // .body("test string".to_string()) - // .unwrap(); - - // let response = handler.handle(request).unwrap(); - // let body_str = std::str::from_utf8(response.body()).unwrap(); - // assert!(body_str.contains("alloc::string::String")); - // assert!(body_str.contains("test string")); - - // // Test with Vec body - // let request = http::Request::builder() - // .uri("/type") - // .body(vec![1u8, 2, 3, 4]) - // .unwrap(); - - // let response = handler.handle(request).unwrap(); - // let body_str = std::str::from_utf8(response.body()).unwrap(); - // assert!(body_str.contains("vec::Vec")); - // assert!(body_str.contains("[1, 2, 3, 4]")); - // } - - /// Generic echo handler that works with any cloneable body type - pub struct GenericEchoHandler; - - impl Handler for GenericEchoHandler { - type Error = http::Error; - - async fn handle( - &self, - request: http::Request, - ) -> Result, Self::Error> { - http::Response::builder() - .status(200) - .body(request.into_body()) - } - } - - impl Handler> for GenericEchoHandler { - type Error = http::Error; - - async fn handle( - &self, - request: http::Request>, - ) -> Result>, Self::Error> { - http::Response::builder() - .status(200) - .body(request.into_body()) + let (_, mut response_body) = response.into_parts(); + let mut collected = BytesMut::new(); + while let Some(result) = response_body.frame().await { + match result { + Ok(frame) => { + if let Ok(data) = frame.into_data() { + collected.extend_from_slice(&data); + } + } + Err(_) => break, + } } - } - - #[tokio::test] - async fn test_generic_echo_handler() { - let handler = GenericEchoHandler; - - // Test with Bytes - let request = http::Request::builder() - .uri("/echo") - .body(Bytes::from("echo bytes")) - .unwrap(); - - let response = handler.handle(request).await.unwrap(); - assert_eq!(response.body(), &Bytes::from("echo bytes")); - - // Test with Vec - let request = http::Request::builder() - .uri("/echo") - .body(vec![72, 101, 108, 108, 111]) // "Hello" in ASCII - .unwrap(); - - let response = handler.handle(request).await.unwrap(); - assert_eq!(response.body(), &Bytes::from("Hello")); + assert_eq!(&collected[..], b"Internal Server Error"); } } diff --git a/src/lib.rs b/src/lib.rs index 0f24561..02143cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,17 +7,23 @@ // Re-export everything from http crate pub use http::*; +/// Body types for HTTP requests and responses with streaming support +pub mod body; pub mod extensions; pub mod handler; pub mod types; +/// WebSocket frame codec for RFC 6455 compliant framing +pub mod websocket; + /// Provides N-API bindings to expose the `http` crate types to Node.js. #[cfg(feature = "napi-support")] pub mod napi; +pub use body::{RequestBody, ResponseBody, StreamError}; pub use extensions::{ BodyBuffer, RequestBuilderExt, RequestExt, ResponseBuilderExt, ResponseException, ResponseExt, - ResponseLog, SocketInfo, + ResponseLog, SocketInfo, WebSocketMode, }; pub use handler::Handler; pub use types::{Request, Response}; diff --git a/src/napi.rs b/src/napi.rs index b708d41..f5860ba 100644 --- a/src/napi.rs +++ b/src/napi.rs @@ -2,19 +2,22 @@ use std::{ collections::HashMap, net::SocketAddr, ops::{Deref, DerefMut}, + pin::Pin, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use http::{ HeaderMap as HttpHeaderMap, HeaderName, HeaderValue as HttpHeaderValue, request::Builder as RequestBuilder, response::Builder as ResponseBuilder, }; +use http_body::Body; +use napi::bindgen_prelude::async_iterator::AsyncGenerator; use napi::{Either, Error, Result, Status, bindgen_prelude::*}; use napi_derive::napi; use crate::{ - Request as InnerRequest, RequestBuilderExt, RequestExt, Response as InnerResponse, - ResponseBuilderExt, ResponseExt, SocketInfo as InnerSocketInfo, + RequestBody, RequestBuilderExt, RequestExt, ResponseBody, ResponseBuilderExt, ResponseExt, + SocketInfo as InnerSocketInfo, WebSocketMode, }; // @@ -63,36 +66,6 @@ impl TryFrom for HttpHeaderMap { } } -// -// HeaderValue -// - -struct HeaderValue(HttpHeaderValue); - -impl Deref for HeaderValue { - type Target = HttpHeaderValue; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for HeaderValue { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl TryFrom for HeaderValue { - type Error = Error; - - fn try_from(value: String) -> std::result::Result { - HttpHeaderValue::try_from(value) - .map_err(|e| Error::new(Status::InvalidArg, format!("Invalid header value: {}", e))) - .map(HeaderValue) - } -} - // // SocketInfo // @@ -542,7 +515,7 @@ impl Headers { /// console.log(headers.toJSON()); /// ``` #[napi(js_name = "toJSON")] - pub fn to_json(&self, env: &Env) -> Result { + pub fn to_json(&self, env: &Env) -> Result> { let mut obj = Object::new(env)?; for key in self.keys() { @@ -583,6 +556,8 @@ pub struct RequestOptions { pub socket: Option, /// Document root for the request, if applicable. pub docroot: Option, + /// Whether this is a WebSocket request. + pub websocket: Option, } /// Wraps an http::Request instance to expose it to JavaScript. @@ -591,7 +566,7 @@ pub struct RequestOptions { /// the request along with a toJSON method to convert it to a JSON object. #[napi] #[derive(Debug)] -pub struct Request(InnerRequest); +pub struct Request(crate::Request); #[napi] impl Request { @@ -671,12 +646,26 @@ impl Request { request = request.document_root(docroot.into()); } - let body = options - .body - .map(|body| BytesMut::from(body.deref())) - .unwrap_or_default(); + // Build the request first, then set WebSocket mode extension if specified + let websocket = options.websocket.unwrap_or(false); + + // Create empty request body + let body = RequestBody::new(); - let request = request.body(body).expect("Failed to build request"); + let mut request = request.body(body).expect("Failed to build request"); + + // Store body data in BodyBuffer extension if provided (to be sent later in Task::compute) + if let Some(body_buf) = options.body { + let bytes = Bytes::copy_from_slice(body_buf.as_ref()); + request + .extensions_mut() + .insert(crate::BodyBuffer::from_bytes(bytes)); + } + + // Set WebSocket mode extension after building + if websocket { + request.extensions_mut().insert(WebSocketMode); + } Ok(Request(request)) } @@ -851,6 +840,8 @@ impl Request { /// Get the body of the request as a Buffer. /// + /// Returns buffered data if the request was created with a body in the constructor. + /// /// # Examples /// /// ```js @@ -864,28 +855,12 @@ impl Request { /// console.log(request.body.toString()); // {"message":"Hello, world!"} /// ``` #[napi(getter, enumerable = true)] - pub fn body(&self) -> Buffer { - Buffer::from(self.0.body().to_vec()) - } - - /// Set the body of the request. - /// - /// # Examples - /// - /// ```js - /// const request = new Request({ - /// url: "/v2/api/thing" - /// }); - /// - /// request.body = Buffer.from(JSON.stringify({ - /// message: 'Hello, world!' - /// })); - /// - /// console.log(request.body.toString()); // {"message":"Hello, world!"} - /// ``` - #[napi(setter, enumerable = true, js_name = "body")] - pub fn set_body(&mut self, body: Buffer) { - *self.0.body_mut() = BytesMut::from(body.deref()); + pub fn body(&self) -> Option { + // Check if there's a BodyBuffer extension with buffered data + self.0 + .extensions() + .get::() + .map(|buf| Buffer::from(buf.as_bytes().to_vec())) } /// Convert the response to a JSON object representation. @@ -907,20 +882,116 @@ impl Request { /// console.log(request.toJSON()); /// ``` #[napi(js_name = "toJSON")] - pub fn to_json(&self, env: &Env) -> Result { + pub fn to_json(&self, env: &Env) -> Result> { let mut obj = Object::new(env)?; obj.set("method", self.method())?; obj.set("url", self.url())?; obj.set("headers", self.headers().to_json(env)?)?; - obj.set("body", self.body())?; + + // Include body if available (buffered from constructor) + if let Some(body) = self.body() { + obj.set("body", body)?; + } + Ok(obj) } -} -// Rust-only methods (not exposed to JavaScript) -impl Request { - /// Consume this Request and return the inner HTTP request. - pub fn into_inner(self) -> InnerRequest { + /// Write a chunk to the request body stream + /// + /// # Examples + /// + /// ```js + /// const request = new Request({ + /// method: "POST", + /// url: "/upload" + /// }); + /// + /// await request.write(Buffer.from('chunk 1')); + /// await request.write('chunk 2'); + /// await request.end(); + /// ``` + #[napi] + pub async fn write(&self, chunk: Either) -> Result<()> { + use tokio::io::AsyncWriteExt; + + // Check if a body buffer is already present (body already provided) + if self + .0 + .extensions() + .get::() + .is_some() + { + return Err(Error::from_reason( + "Cannot write to request: body has already been provided", + )); + } + + // Auto-detect WebSocket mode and encode frames transparently + let is_websocket = self.0.extensions().get::().is_some(); + + if is_websocket { + // WebSocket mode: encode as frames + let encoder = crate::websocket::WebSocketEncoder::new(self.0.body().clone()); + match chunk { + Either::A(buf) => encoder + .write_binary(buf.as_ref(), false) + .await + .map_err(|e| Error::from_reason(format!("WebSocket error: {:?}", e))), + Either::B(s) => encoder + .write_text(&s, false) + .await + .map_err(|e| Error::from_reason(format!("WebSocket error: {:?}", e))), + } + } else { + // HTTP mode: write raw bytes + let bytes = match chunk { + Either::A(buf) => Bytes::copy_from_slice(buf.as_ref()), + Either::B(s) => Bytes::from(s), + }; + + let mut body = self.0.body().clone(); + body.write_all(&bytes) + .await + .map_err(|e| Error::from_reason(e.to_string())) + } + } + + /// End the request body stream (HTTP mode only) + /// + /// # Examples + /// + /// ```js + /// const request = new Request({ + /// method: "POST", + /// url: "/upload" + /// }); + /// + /// await request.write(Buffer.from('data')); + /// await request.end(); + /// ``` + #[napi] + pub async fn end(&self) -> Result<()> { + use tokio::io::AsyncWriteExt; + + // If a body buffer is already present, the body has been provided so just return + if self + .0 + .extensions() + .get::() + .is_some() + { + return Ok(()); + } + + // Shutdown the write side of the duplex stream to signal end of request + let mut body = self.0.body().clone(); + body.shutdown() + .await + .map_err(|e| Error::from_reason(e.to_string())) + } + + /// Consume this Request and return the inner Request + pub fn into_inner(self) -> crate::Request { self.0 } } @@ -951,20 +1022,30 @@ impl Clone for Request { req.set_socket_info(socket.clone()); } + // Copy the BodyBuffer extension if it exists (for buffered requests) + if let Some(body_buffer) = self.0.extensions().get::() { + req.extensions_mut().insert(body_buffer.clone()); + } + + // Copy the WebSocketMode extension if it exists + if self.0.extensions().get::().is_some() { + req.extensions_mut().insert(crate::WebSocketMode); + } + Request(req) } } impl Deref for Request { - type Target = InnerRequest; + type Target = crate::Request; fn deref(&self) -> &Self::Target { &self.0 } } -impl From for Request { - fn from(request: InnerRequest) -> Self { +impl From for Request { + fn from(request: crate::Request) -> Self { Request(request) } } @@ -976,7 +1057,7 @@ impl FromNapiValue for Request { return Ok(instance.deref().clone()); } - // If both conversions fail, return an error + // If conversion fails, return an error Err(Error::new(Status::InvalidArg, "Expected Request")) } } @@ -1027,7 +1108,7 @@ pub struct ResponseOptions { /// console.log(response.body.toString()); // {"message":"Hello, world!"} /// ``` #[napi] -pub struct Response(InnerResponse); +pub struct Response(crate::Response); #[napi] impl Response { @@ -1067,18 +1148,32 @@ impl Response { builder = builder.exception(exception); } - let body = options - .body - .map(|body| BytesMut::from(body.deref())) - .unwrap_or_default(); + // Create response body + let response_body = ResponseBody::new(); + + // If body data is provided, store it in buffered_body extension + // The actual writing to the stream happens lazily when the body is accessed + let buffered_body = if let Some(body_buf) = options.body { + let bytes = Bytes::copy_from_slice(body_buf.as_ref()); + Some(bytes) + } else { + None + }; - let response = builder.body(body).map_err(|e| { + let mut response = builder.body(response_body).map_err(|e| { Error::new( Status::InvalidArg, format!("Failed to build response: {}", e), ) })?; + // Store buffered body as extension if provided + if let Some(bytes) = buffered_body { + response + .extensions_mut() + .insert(crate::BodyBuffer::from_bytes(bytes)); + } + Ok(Response(response)) } @@ -1164,40 +1259,33 @@ impl Response { *self.0.headers_mut() = headers.deref().clone(); } - /// Get the body of the response as a Buffer. + /// Get the buffered body of the response as a Buffer. /// - /// # Examples + /// Note: With the new streaming architecture, response bodies are not buffered by default. + /// This getter returns buffered data if it was explicitly buffered (e.g., by handleRequest). + /// For streaming responses, use the AsyncIterator protocol via next(). /// - /// ```js - /// const response = new Response({ - /// body: Buffer.from(JSON.stringify({ - /// message: 'Hello, world!' - /// })) - /// }); - /// - /// console.log(response.body.toString()); // {"message":"Hello, world!"} - /// ``` - #[napi(getter, enumerable = true)] - pub fn body(&self) -> Buffer { - Buffer::from(self.0.body().to_vec()) - } - - /// Set the body of the response. + /// Returns `undefined` for streaming responses without buffering. /// /// # Examples /// /// ```js - /// const response = new Response(); + /// // After handleRequest (automatically buffered) + /// const response = await python.handleRequest(request); + /// console.log(response.body.toString()); // Works - body was buffered /// - /// response.body = Buffer.from(JSON.stringify({ - /// message: 'Hello, world!' - /// })); - /// - /// console.log(response.body.toString()); // {"message":"Hello, world!"} + /// // For streaming responses, use AsyncIterator + /// for await (const chunk of response) { + /// console.log(chunk.toString()); + /// } /// ``` - #[napi(setter, enumerable = true, js_name = "body")] - pub fn set_body(&mut self, body: Buffer) { - *self.0.body_mut() = BytesMut::from(body.deref()); + #[napi(getter, enumerable = true)] + pub fn body(&self) -> Option { + // Check if there's a BodyBuffer extension with buffered data + self.0 + .extensions() + .get::() + .map(|buf| Buffer::from(buf.as_bytes().to_vec())) } /// Get the log of the response as a Buffer. @@ -1253,11 +1341,15 @@ impl Response { /// console.log(response.toJSON()); /// ``` #[napi(js_name = "toJSON")] - pub fn to_json(&self, env: &Env) -> Result { + pub fn to_json(&self, env: &Env) -> Result> { let mut obj = Object::new(env)?; obj.set("status", self.status())?; obj.set("headers", self.headers().to_json(env)?)?; - obj.set("body", self.body())?; + + // Include body if available (either buffered or null) + if let Some(body) = self.body() { + obj.set("body", body)?; + } // Only include log if it has content if let Some(log) = self.0.log() { @@ -1273,10 +1365,112 @@ impl Response { Ok(obj) } + + /// Set up async iteration support on this Response object. + /// + /// This method sets up Symbol.asyncIterator on the JavaScript Response object, + /// allowing the response body to be consumed using `for await...of` loops. + /// + /// # Examples + /// + /// ```js + /// const res = await handler.handleStream(req); + /// + /// // Access response properties immediately + /// console.log(res.status); // 200 + /// console.log(res.headers.get('content-type')); // 'text/plain' + /// + /// // Stream the response body + /// for await (const chunk of res) { + /// console.log(chunk.toString()); + /// } + /// ``` + pub fn make_streamable(self, env: Env) -> Result> { + use napi::bindgen_prelude::async_iterator::symbol_async_generator; + use napi::sys; + use std::ptr; + + let raw_env = env.raw(); + + // Convert this Response to a JavaScript value (this creates the JS object and consumes self) + let response_js_value = unsafe { Self::to_napi_value(raw_env, self)? }; + + // Get Symbol.asyncIterator + let mut global = ptr::null_mut(); + napi::check_status!( + unsafe { sys::napi_get_global(raw_env, &mut global) }, + "Get global failed" + )?; + + let mut symbol_object = ptr::null_mut(); + napi::check_status!( + unsafe { + sys::napi_get_named_property( + raw_env, + global, + c"Symbol".as_ptr().cast(), + &mut symbol_object, + ) + }, + "Get Symbol failed" + )?; + + let mut iterator_symbol = ptr::null_mut(); + napi::check_status!( + unsafe { + sys::napi_get_named_property( + raw_env, + symbol_object, + c"asyncIterator".as_ptr().cast(), + &mut iterator_symbol, + ) + }, + "Get Symbol.asyncIterator failed" + )?; + + // Extract native pointer to use in the generator function + let mut response_ref = ptr::null_mut(); + napi::check_status!( + unsafe { sys::napi_unwrap(raw_env, response_js_value, &mut response_ref) }, + "Failed to unwrap Response" + )?; + + // Create generator function + let mut generator_function = ptr::null_mut(); + napi::check_status!( + unsafe { + sys::napi_create_function( + raw_env, + c"AsyncIterator".as_ptr().cast(), + 13, + Some(symbol_async_generator::), + response_ref, + &mut generator_function, + ) + }, + "Create asyncIterator function failed" + )?; + + // Set Symbol.asyncIterator on the Response object + napi::check_status!( + unsafe { + sys::napi_set_property( + raw_env, + response_js_value, + iterator_symbol, + generator_function, + ) + }, + "Failed to set Symbol.asyncIterator" + )?; + + // Return the JS object we just modified + Ok(Object::from_raw(raw_env, response_js_value)) + } } impl Deref for Response { - type Target = InnerResponse; + type Target = crate::Response; fn deref(&self) -> &Self::Target { &self.0 @@ -1289,8 +1483,236 @@ impl DerefMut for Response { } } -impl From for Response { - fn from(response: InnerResponse) -> Self { +impl From for Response { + fn from(response: crate::Response) -> Self { Response(response) } } + +#[napi] +impl Response { + /// Read the next chunk from the response body stream + /// + /// Returns the next chunk as a Buffer, or undefined if the stream has ended. + /// This method is used to implement AsyncIterator in JavaScript. + /// + /// For WebSocket responses (when WebSocketMode extension is present), this automatically + /// decodes WebSocket frames and returns the payload data. + /// + /// # Examples + /// + /// ```js + /// console.log(await response.next()); // Buffer | undefined + /// ``` + #[napi] + pub async unsafe fn next(&mut self) -> Result> { + use http_body_util::BodyExt; + use tokio_util::codec::Decoder; + + // Auto-detect WebSocket mode and decode frames transparently + let is_websocket = self.0.extensions().get::().is_some(); + + if is_websocket { + // WebSocket mode: read HTTP body frames and decode as WebSocket frames + // Get or create the decoder state from extensions + let (codec, buffer) = { + let extensions = self.0.extensions(); + if extensions + .get::() + .is_none() + { + self.0 + .extensions_mut() + .insert(crate::extensions::WebSocketDecoderState::new()); + } + + let state = self + .0 + .extensions() + .get::() + .unwrap(); + (state.codec().clone(), state.buffer().clone()) + }; + + // Try to decode a frame from existing buffer first + loop { + { + let mut buf = buffer.lock().await; + let mut codec_guard = codec.lock().await; + + match codec_guard.decode(&mut *buf) { + Ok(Some(frame)) => { + // Successfully decoded a frame + // Handle different frame types + if frame.is_close() { + // Close frame - signal end of stream + return Ok(None); + } else if frame.is_text() || frame.is_binary() { + // Data frame - return payload + if frame.payload.is_empty() { + continue; // Empty frame, try next + } + return Ok(Some(Buffer::from(frame.payload))); + } else { + // Control frames (ping/pong) or unknown - skip them + continue; + } + } + Ok(None) => { + // Need more data - fall through to read HTTP body frame + } + Err(e) => { + return Err(Error::from_reason(format!( + "WebSocket decode error: {:?}", + e + ))); + } + } + } + + // Read next HTTP body frame to get more WebSocket data + match self.0.body_mut().frame().await { + Some(Ok(frame)) => { + if let Ok(data) = frame.into_data() { + if data.is_empty() { + continue; // Empty HTTP frame, try next + } + // Append data to buffer and try decoding again + let mut buf = buffer.lock().await; + buf.extend_from_slice(&data); + } else { + // Trailers or empty, continue + continue; + } + } + Some(Err(e)) => { + return Err(Error::from_reason(e)); + } + None => { + // HTTP body ended - check for exception + // Exception is stored as Arc>> by python-node + if let Some(exc_holder) = self.0.extensions().get::>, + >>() { + if let Ok(guard) = exc_holder.try_lock() { + if let Some(exc) = guard.as_ref() { + return Err(Error::from_reason(exc.message().to_string())); + } + } + } + return Ok(None); + } + } + } + } else { + // HTTP mode: read raw body frames + match self.0.body_mut().frame().await { + Some(Ok(frame)) => { + // Extract data from frame if present + if let Ok(data) = frame.into_data() { + Ok(Some(Buffer::from(data.to_vec()))) + } else { + // Frame was trailers, skip it + Ok(None) + } + } + Some(Err(e)) => Err(Error::from_reason(e)), + None => { + // Check if there's a ResponseException before signaling EOF + // Exception is stored as Arc>> by python-node + if let Some(exc_holder) = self.0.extensions().get::>, + >>() { + if let Ok(guard) = exc_holder.try_lock() { + if let Some(exc) = guard.as_ref() { + return Err(Error::from_reason(exc.message().to_string())); + } + } + } + Ok(None) + } + } + } + } +} + +/// Implement AsyncGenerator on Response to enable JavaScript's `for await` syntax. +/// +/// # Safety Considerations +/// +/// This implementation uses unsafe code to work around a fundamental lifetime constraint: +/// - `AsyncGenerator::next(&mut self)` borrows `self` with a limited lifetime +/// - But it must return a `Future + 'static` (required by the trait) +/// +/// We use a raw pointer to the ResponseBody to bridge this gap. This is sound because: +/// +/// 1. **NAPI-RS Lifetime Management**: NAPI-RS leaks the Response object using +/// `Box::leak(Box::from_raw(ptr))`, creating a true `'static` reference. The `&mut self` +/// parameter actually has a `'static` lifetime. +/// +/// 2. **Single-Threaded Execution**: Node.js is single-threaded. While JavaScript can create +/// multiple concurrent promises by calling `next()` repeatedly, the synchronous execution +/// of the `next()` method itself (creating the future) happens sequentially on the event +/// loop thread. There are no overlapping mutable borrows during the synchronous portion. +/// +/// 3. **Independent Futures**: Each returned future captures the pointer independently. +/// While multiple futures may exist concurrently, they access the underlying `ResponseBody` +/// through a channel receiver (`poll_frame()`), which safely handles concurrent polling. +/// +/// The unsafe code is confined to pointer creation and dereferencing within the future, +/// with detailed documentation of the invariants that make it sound. +impl AsyncGenerator for Response { + type Yield = Buffer; + type Next = (); + type Return = (); + + fn next( + &mut self, + _value: Option, + ) -> impl Future>> + Send + 'static { + use std::future::poll_fn; + + // SAFETY: Extend the lifetime of the body reference to 'static. + // This is safe because NAPI-RS has already leaked the Response object via Box::leak, + // so `self` is actually &'static mut Response. We're just making that explicit. + // Node.js is single-threaded, so concurrent calls to next() execute sequentially. + let body: &'static mut crate::ResponseBody = + unsafe { std::mem::transmute(self.0.body_mut()) }; + + // Capture exception holder for checking on EOF + let exception_holder = self.0.extensions() + .get::>>>() + .cloned(); + + async move { + let result = poll_fn(|cx| Pin::new(&mut *body).poll_frame(cx)).await; + + match result { + Some(Ok(frame)) => { + if let Ok(data) = frame.into_data() { + if data.is_empty() { + Ok(None) + } else { + Ok(Some(Buffer::from(data.to_vec()))) + } + } else { + // Frame contains trailers or is empty, treat as no data + Ok(None) + } + } + Some(Err(e)) => Err(Error::from_reason(e)), + None => { + // Stream ended - check for exception stored by python-node + if let Some(exc_holder) = exception_holder { + if let Ok(guard) = exc_holder.try_lock() { + if let Some(exc) = guard.as_ref() { + return Err(Error::from_reason(exc.message().to_string())); + } + } + } + Ok(None) + } + } + } + } +} diff --git a/src/types.rs b/src/types.rs index 7278e90..b913392 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,13 +1,13 @@ -//! Core type aliases and implementations for v2 +//! Core type aliases and implementations +use super::body::{RequestBody, ResponseBody}; use super::extensions::{RequestExt, ResponseExt, SocketInfo}; -use bytes::BytesMut; -/// Type alias for HTTP Request with BytesMut body -pub type Request = http::Request; +/// Type alias for HTTP Request with streaming body +pub type Request = http::Request; -/// Type alias for HTTP Response with BytesMut body -pub type Response = http::Response; +/// Type alias for HTTP Response with streaming body +pub type Response = http::Response; /// Helper functions for building requests with extensions pub mod request { @@ -46,28 +46,33 @@ pub mod response { #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; use http::{Method, StatusCode}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - #[test] - fn test_request_type_alias() { + #[tokio::test] + async fn test_request_type_alias() { + let body = RequestBody::from_data(Bytes::from("request body")) + .await + .unwrap(); let request = http::Request::builder() .method(Method::GET) .uri("/test") - .body(BytesMut::from("request body")) + .body(body) .unwrap(); assert_eq!(request.method(), Method::GET); assert_eq!(request.uri().path(), "/test"); - assert_eq!(request.body(), &BytesMut::from("request body")); } #[test] fn test_response_type_alias() { + let request_body = RequestBody::new(); + let response_body = request_body.create_response(); let response = http::Response::builder() .status(StatusCode::OK) .header("Content-Type", "text/plain") - .body(BytesMut::from("response body")) + .body(response_body) .unwrap(); assert_eq!(response.status(), StatusCode::OK); @@ -75,15 +80,12 @@ mod tests { response.headers().get("content-type").unwrap(), "text/plain" ); - assert_eq!(response.body(), &BytesMut::from("response body")); } #[test] fn test_request_with_socket_info() { - let request = http::Request::builder() - .uri("/test") - .body(BytesMut::new()) - .unwrap(); + let body = RequestBody::new(); + let request = http::Request::builder().uri("/test").body(body).unwrap(); let local = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 5000); @@ -97,9 +99,11 @@ mod tests { #[test] fn test_response_with_log() { + let request_body = RequestBody::new(); + let response_body = request_body.create_response(); let response = http::Response::builder() .status(StatusCode::OK) - .body(BytesMut::new()) + .body(response_body) .unwrap(); let response = response::with_log(response, "Test log message"); @@ -110,9 +114,11 @@ mod tests { #[test] fn test_response_with_exception() { + let request_body = RequestBody::new(); + let response_body = request_body.create_response(); let response = http::Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(BytesMut::new()) + .body(response_body) .unwrap(); let response = response::with_exception(response, "Something went wrong"); @@ -124,9 +130,11 @@ mod tests { #[test] fn test_combined_extensions() { // Test that we can use multiple extensions together + let request_body = RequestBody::new(); + let response_body = request_body.create_response(); let mut response = http::Response::builder() .status(StatusCode::OK) - .body(BytesMut::from("body")) + .body(response_body) .unwrap(); response.set_log("Initial log"); diff --git a/src/websocket/codec.rs b/src/websocket/codec.rs new file mode 100644 index 0000000..c5cd457 --- /dev/null +++ b/src/websocket/codec.rs @@ -0,0 +1,211 @@ +//! WebSocket codec for use with tokio_util::codec::Framed. +//! +//! This codec provides a clean abstraction over DuplexStream, turning raw bytes +//! into a Stream of WebSocket frames. + +use super::frame::{WebSocketError, WebSocketFrame, WebSocketOpcode}; +use bytes::{Buf, BytesMut}; +use tokio_util::codec::{Decoder, Encoder}; + +/// WebSocket codec that implements tokio_util's Decoder and Encoder traits. +/// +/// This codec handles: +/// - Frame parsing from byte buffers +/// - Message fragmentation and reassembly +/// - Frame encoding to byte buffers +/// +/// Use with `tokio_util::codec::Framed` to turn a DuplexStream into a +/// `Stream` and `Sink`. +pub struct WebSocketCodec { + /// Fragments being assembled into a complete message + fragments: Vec>, + /// Opcode of the first fragment (determines final message type) + message_opcode: Option, +} + +impl WebSocketCodec { + /// Create a new WebSocket codec. + pub fn new() -> Self { + Self { + fragments: Vec::new(), + message_opcode: None, + } + } +} + +impl Default for WebSocketCodec { + fn default() -> Self { + Self::new() + } +} + +impl Decoder for WebSocketCodec { + type Item = WebSocketFrame; + type Error = WebSocketError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + // Try to parse a frame from the buffer + match WebSocketFrame::parse(src) { + Ok((frame, consumed)) => { + // Advance the buffer by the number of bytes consumed + src.advance(consumed); + + // Handle control frames (ping, pong, close) + // These are never fragmented and should be returned immediately + if frame.opcode.is_control() { + return Ok(Some(frame)); + } + + // Handle data frames (text, binary, continuation) + match frame.opcode { + WebSocketOpcode::Text | WebSocketOpcode::Binary => { + // First fragment of a new message + self.message_opcode = Some(frame.opcode); + self.fragments.push(frame.payload.clone()); + + if frame.fin { + // Single-frame message - complete immediately + let opcode = self.message_opcode.take().unwrap(); + let payload = self.fragments.drain(..).flatten().collect(); + + Ok(Some(WebSocketFrame::new_data(opcode, payload, true))) + } else { + // More fragments coming, wait for them + Ok(None) + } + } + WebSocketOpcode::Continuation => { + // Continuation of a fragmented message + if self.message_opcode.is_none() { + // Continuation without initial frame - protocol error + // For now, we'll treat this as incomplete + return Ok(None); + } + + self.fragments.push(frame.payload.clone()); + + if frame.fin { + // Final fragment - assemble complete message + let opcode = self.message_opcode.take().unwrap(); + let payload = self.fragments.drain(..).flatten().collect(); + + Ok(Some(WebSocketFrame::new_data(opcode, payload, true))) + } else { + // More fragments coming, wait for them + Ok(None) + } + } + // Control frames handled above + _ => unreachable!(), + } + } + Err(WebSocketError::IncompleteFrame) => { + // Need more data + Ok(None) + } + Err(e) => Err(e), + } + } +} + +impl Encoder for WebSocketCodec { + type Error = WebSocketError; + + fn encode(&mut self, frame: WebSocketFrame, dst: &mut BytesMut) -> Result<(), Self::Error> { + // Encode the frame (no masking for server->client frames) + let encoded = frame.encode(None); + + // Write to the destination buffer + dst.extend_from_slice(&encoded); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn test_decode_single_frame() { + let mut codec = WebSocketCodec::new(); + + // Create a simple text frame + let frame = WebSocketFrame::new_text("Hello".to_string(), true); + let encoded = frame.encode(None); + + let mut buffer = BytesMut::from(&encoded[..]); + let decoded = codec.decode(&mut buffer).unwrap(); + + assert!(decoded.is_some()); + let decoded_frame = decoded.unwrap(); + assert_eq!(decoded_frame.opcode, WebSocketOpcode::Text); + assert_eq!(decoded_frame.payload, b"Hello"); + assert!(decoded_frame.fin); + } + + #[test] + fn test_decode_fragmented_message() { + let mut codec = WebSocketCodec::new(); + + // First fragment + let frame1 = WebSocketFrame::new_text("Hel".to_string(), false); + let encoded1 = frame1.encode(None); + + let mut buffer = BytesMut::from(&encoded1[..]); + let result = codec.decode(&mut buffer).unwrap(); + assert!(result.is_none()); // Not complete yet + + // Second fragment (continuation) + let frame2 = WebSocketFrame::new_continuation(b"lo".to_vec(), true); + let encoded2 = frame2.encode(None); + + buffer.extend_from_slice(&encoded2); + let result = codec.decode(&mut buffer).unwrap(); + + assert!(result.is_some()); + let decoded_frame = result.unwrap(); + assert_eq!(decoded_frame.opcode, WebSocketOpcode::Text); + assert_eq!(decoded_frame.payload, b"Hello"); + assert!(decoded_frame.fin); + } + + #[test] + fn test_encode_frame() { + let mut codec = WebSocketCodec::new(); + let mut buffer = BytesMut::new(); + + let frame = WebSocketFrame::new_binary(vec![1, 2, 3], true); + codec.encode(frame, &mut buffer).unwrap(); + + assert!(!buffer.is_empty()); + + // Decode it back to verify + let mut decode_codec = WebSocketCodec::new(); + let decoded = decode_codec.decode(&mut buffer).unwrap(); + + assert!(decoded.is_some()); + let decoded_frame = decoded.unwrap(); + assert_eq!(decoded_frame.opcode, WebSocketOpcode::Binary); + assert_eq!(decoded_frame.payload, vec![1, 2, 3]); + } + + #[test] + fn test_control_frame_immediate_return() { + let mut codec = WebSocketCodec::new(); + + // Create a ping frame + let frame = WebSocketFrame::new_ping(b"test".to_vec()); + let encoded = frame.encode(None); + + let mut buffer = BytesMut::from(&encoded[..]); + let decoded = codec.decode(&mut buffer).unwrap(); + + // Control frames should be returned immediately + assert!(decoded.is_some()); + let decoded_frame = decoded.unwrap(); + assert_eq!(decoded_frame.opcode, WebSocketOpcode::Ping); + assert_eq!(decoded_frame.payload, b"test"); + } +} diff --git a/src/websocket/frame.rs b/src/websocket/frame.rs new file mode 100644 index 0000000..f8aaa58 --- /dev/null +++ b/src/websocket/frame.rs @@ -0,0 +1,554 @@ +//! WebSocket frame parsing and encoding conforming to RFC 6455. + +use std::fmt; + +/// WebSocket opcodes as defined in RFC 6455 Section 5.2. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum WebSocketOpcode { + /// Continuation frame (0x0) + Continuation = 0x0, + /// Text data frame (0x1) + Text = 0x1, + /// Binary data frame (0x2) + Binary = 0x2, + /// Connection close frame (0x8) + Close = 0x8, + /// Ping frame (0x9) + Ping = 0x9, + /// Pong frame (0xA) + Pong = 0xA, +} + +impl WebSocketOpcode { + /// Parse opcode from 4-bit value. + fn from_u8(value: u8) -> Result { + match value { + 0x0 => Ok(WebSocketOpcode::Continuation), + 0x1 => Ok(WebSocketOpcode::Text), + 0x2 => Ok(WebSocketOpcode::Binary), + 0x8 => Ok(WebSocketOpcode::Close), + 0x9 => Ok(WebSocketOpcode::Ping), + 0xA => Ok(WebSocketOpcode::Pong), + _ => Err(WebSocketError::InvalidOpcode(value)), + } + } + + /// Check if this is a control frame opcode. + pub fn is_control(&self) -> bool { + matches!( + self, + WebSocketOpcode::Close | WebSocketOpcode::Ping | WebSocketOpcode::Pong + ) + } + + /// Check if this is a data frame opcode. + pub fn is_data(&self) -> bool { + matches!( + self, + WebSocketOpcode::Text | WebSocketOpcode::Binary | WebSocketOpcode::Continuation + ) + } +} + +/// WebSocket frame structure per RFC 6455 Section 5.2. +#[derive(Debug, Clone)] +pub struct WebSocketFrame { + /// FIN bit: indicates this is the final fragment of a message + pub fin: bool, + /// RSV1 bit: reserved for extensions + pub rsv1: bool, + /// RSV2 bit: reserved for extensions + pub rsv2: bool, + /// RSV3 bit: reserved for extensions + pub rsv3: bool, + /// Opcode: identifies the frame type + pub opcode: WebSocketOpcode, + /// Mask bit: indicates if payload is masked (always true for client→server) + pub masked: bool, + /// Payload data + pub payload: Vec, +} + +/// Errors that can occur during WebSocket frame parsing/encoding. +#[derive(Debug)] +pub enum WebSocketError { + /// Invalid opcode value + InvalidOpcode(u8), + /// Incomplete frame data + IncompleteFrame, + /// Control frame exceeds maximum length (125 bytes) + ControlFrameTooLarge, + /// Control frame is fragmented (FIN=0) + ControlFrameFragmented, + /// Reserved bits are set without negotiated extension + ReservedBitsSet, + /// Invalid UTF-8 in text frame + InvalidUtf8, + /// Frame too large + FrameTooLarge, + /// I/O error + IoError(String), +} + +impl fmt::Display for WebSocketError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + WebSocketError::InvalidOpcode(op) => write!(f, "Invalid WebSocket opcode: {:#x}", op), + WebSocketError::IncompleteFrame => write!(f, "Incomplete WebSocket frame"), + WebSocketError::ControlFrameTooLarge => { + write!(f, "Control frame payload exceeds 125 bytes") + } + WebSocketError::ControlFrameFragmented => write!(f, "Control frame is fragmented"), + WebSocketError::ReservedBitsSet => write!(f, "Reserved bits set without extension"), + WebSocketError::InvalidUtf8 => write!(f, "Invalid UTF-8 in text frame"), + WebSocketError::FrameTooLarge => write!(f, "Frame too large"), + WebSocketError::IoError(msg) => write!(f, "I/O error: {}", msg), + } + } +} + +impl std::error::Error for WebSocketError {} + +impl From for WebSocketError { + fn from(err: std::io::Error) -> Self { + WebSocketError::IoError(err.to_string()) + } +} + +impl WebSocketFrame { + /// Parse a WebSocket frame from bytes. + /// + /// Returns the parsed frame and the number of bytes consumed. + /// Returns `Err(WebSocketError::IncompleteFrame)` if more data is needed. + /// + /// # RFC 6455 Frame Format + /// + /// ```text + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-------+-+-------------+-------------------------------+ + /// |F|R|R|R| opcode|M| Payload len | Extended payload length | + /// |I|S|S|S| (4) |A| (7) | (16/64) | + /// |N|V|V|V| |S| | (if payload len==126/127) | + /// | |1|2|3| |K| | | + /// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + /// | Extended payload length continued, if payload len == 127 | + /// + - - - - - - - - - - - - - - - +-------------------------------+ + /// | |Masking-key, if MASK set to 1 | + /// +-------------------------------+-------------------------------+ + /// | Masking-key (continued) | Payload Data | + /// +-------------------------------- - - - - - - - - - - - - - - - + + /// : Payload Data continued ... : + /// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + /// | Payload Data continued ... | + /// +---------------------------------------------------------------+ + /// ``` + pub fn parse(data: &[u8]) -> Result<(Self, usize), WebSocketError> { + // Need at least 2 bytes for header + if data.len() < 2 { + return Err(WebSocketError::IncompleteFrame); + } + + // Parse first byte: FIN, RSV1-3, Opcode + let byte1 = data[0]; + let fin = (byte1 & 0b1000_0000) != 0; + let rsv1 = (byte1 & 0b0100_0000) != 0; + let rsv2 = (byte1 & 0b0010_0000) != 0; + let rsv3 = (byte1 & 0b0001_0000) != 0; + let opcode = WebSocketOpcode::from_u8(byte1 & 0b0000_1111)?; + + // Parse second byte: MASK, Payload length + let byte2 = data[1]; + let masked = (byte2 & 0b1000_0000) != 0; + let mut payload_len = (byte2 & 0b0111_1111) as u64; + + let mut offset = 2; + + // Parse extended payload length if needed + if payload_len == 126 { + if data.len() < offset + 2 { + return Err(WebSocketError::IncompleteFrame); + } + payload_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as u64; + offset += 2; + } else if payload_len == 127 { + if data.len() < offset + 8 { + return Err(WebSocketError::IncompleteFrame); + } + payload_len = u64::from_be_bytes([ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + data[offset + 4], + data[offset + 5], + data[offset + 6], + data[offset + 7], + ]); + offset += 8; + } + + // Validate payload length + if payload_len > usize::MAX as u64 { + return Err(WebSocketError::FrameTooLarge); + } + let payload_len = payload_len as usize; + + // Validate control frames + if opcode.is_control() { + if payload_len > 125 { + return Err(WebSocketError::ControlFrameTooLarge); + } + if !fin { + return Err(WebSocketError::ControlFrameFragmented); + } + } + + // Validate reserved bits (must be 0 unless extension is negotiated) + if rsv1 || rsv2 || rsv3 { + return Err(WebSocketError::ReservedBitsSet); + } + + // Parse masking key if present + let masking_key = if masked { + if data.len() < offset + 4 { + return Err(WebSocketError::IncompleteFrame); + } + let key = [ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + ]; + offset += 4; + Some(key) + } else { + None + }; + + // Parse payload + if data.len() < offset + payload_len { + return Err(WebSocketError::IncompleteFrame); + } + + let mut payload = data[offset..offset + payload_len].to_vec(); + offset += payload_len; + + // Unmask payload if masked + if let Some(mask) = masking_key { + Self::apply_mask(&mut payload, &mask); + } + + // Validate UTF-8 for text frames + if opcode == WebSocketOpcode::Text && fin && std::str::from_utf8(&payload).is_err() { + return Err(WebSocketError::InvalidUtf8); + } + + Ok(( + WebSocketFrame { + fin, + rsv1, + rsv2, + rsv3, + opcode, + masked, + payload, + }, + offset, + )) + } + + /// Encode a WebSocket frame to bytes. + /// + /// # Arguments + /// + /// * `mask` - Optional masking key. If provided, the payload will be masked. + pub fn encode(&self, mask: Option<[u8; 4]>) -> Vec { + let mut frame = Vec::new(); + + // First byte: FIN, RSV1-3, Opcode + let mut byte1 = self.opcode as u8; + if self.fin { + byte1 |= 0b1000_0000; + } + if self.rsv1 { + byte1 |= 0b0100_0000; + } + if self.rsv2 { + byte1 |= 0b0010_0000; + } + if self.rsv3 { + byte1 |= 0b0001_0000; + } + frame.push(byte1); + + // Second byte: MASK, Payload length + let payload_len = self.payload.len(); + let mut byte2 = if mask.is_some() { + 0b1000_0000 + } else { + 0b0000_0000 + }; + + if payload_len < 126 { + byte2 |= payload_len as u8; + frame.push(byte2); + } else if payload_len <= 65535 { + byte2 |= 126; + frame.push(byte2); + frame.extend_from_slice(&(payload_len as u16).to_be_bytes()); + } else { + byte2 |= 127; + frame.push(byte2); + frame.extend_from_slice(&(payload_len as u64).to_be_bytes()); + } + + // Masking key if present + if let Some(masking_key) = mask { + frame.extend_from_slice(&masking_key); + } + + // Payload + if let Some(masking_key) = mask { + let mut masked_payload = self.payload.clone(); + Self::apply_mask(&mut masked_payload, &masking_key); + frame.extend_from_slice(&masked_payload); + } else { + frame.extend_from_slice(&self.payload); + } + + frame + } + + /// Apply XOR mask to payload data per RFC 6455 Section 5.3. + /// + /// This operation is reversible (applying the same mask twice yields the original data). + fn apply_mask(payload: &mut [u8], mask: &[u8; 4]) { + for (i, byte) in payload.iter_mut().enumerate() { + *byte ^= mask[i % 4]; + } + } + + /// Create a new data frame (text or binary). + pub fn new_data(opcode: WebSocketOpcode, payload: Vec, fin: bool) -> Self { + debug_assert!(opcode.is_data()); + WebSocketFrame { + fin, + rsv1: false, + rsv2: false, + rsv3: false, + opcode, + masked: false, + payload, + } + } + + /// Create a new text frame. + pub fn new_text(text: String, fin: bool) -> Self { + Self::new_data(WebSocketOpcode::Text, text.into_bytes(), fin) + } + + /// Create a new binary frame. + pub fn new_binary(data: Vec, fin: bool) -> Self { + Self::new_data(WebSocketOpcode::Binary, data, fin) + } + + /// Create a new continuation frame. + pub fn new_continuation(data: Vec, fin: bool) -> Self { + Self::new_data(WebSocketOpcode::Continuation, data, fin) + } + + /// Create a new close frame with optional status code and reason. + pub fn new_close(code: Option, reason: Option<&str>) -> Self { + let mut payload = Vec::new(); + if let Some(code) = code { + payload.extend_from_slice(&code.to_be_bytes()); + if let Some(reason) = reason { + payload.extend_from_slice(reason.as_bytes()); + } + } + WebSocketFrame { + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: WebSocketOpcode::Close, + masked: false, + payload, + } + } + + /// Create a new ping frame. + pub fn new_ping(data: Vec) -> Self { + WebSocketFrame { + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: WebSocketOpcode::Ping, + masked: false, + payload: data, + } + } + + /// Create a new pong frame. + pub fn new_pong(data: Vec) -> Self { + WebSocketFrame { + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: WebSocketOpcode::Pong, + masked: false, + payload: data, + } + } + + /// Parse close frame payload to extract status code and reason. + pub fn parse_close_payload(&self) -> Option<(u16, String)> { + if self.opcode != WebSocketOpcode::Close { + return None; + } + if self.payload.len() < 2 { + return None; + } + let code = u16::from_be_bytes([self.payload[0], self.payload[1]]); + let reason = String::from_utf8_lossy(&self.payload[2..]).to_string(); + Some((code, reason)) + } + + /// Check if this is a text frame. + pub fn is_text(&self) -> bool { + self.opcode == WebSocketOpcode::Text + } + + /// Check if this is a binary frame. + pub fn is_binary(&self) -> bool { + self.opcode == WebSocketOpcode::Binary + } + + /// Check if this is a close frame. + pub fn is_close(&self) -> bool { + self.opcode == WebSocketOpcode::Close + } + + /// Get the payload as a UTF-8 text string. + /// Returns None if the frame is not a text frame or contains invalid UTF-8. + pub fn payload_as_text(&self) -> Option { + if !self.is_text() { + return None; + } + String::from_utf8(self.payload.clone()).ok() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_simple_text_frame() { + // Simple unmasked text frame: "Hello" + let data = vec![ + 0b1000_0001, // FIN=1, RSV=0, Opcode=Text + 5, // Payload length=5 + b'H', + b'e', + b'l', + b'l', + b'o', + ]; + + let (frame, consumed) = WebSocketFrame::parse(&data).unwrap(); + assert_eq!(consumed, 7); + assert!(frame.fin); + assert_eq!(frame.opcode, WebSocketOpcode::Text); + assert_eq!(frame.payload, b"Hello"); + } + + #[test] + fn test_parse_masked_frame() { + // Masked text frame + let mask = [0x12, 0x34, 0x56, 0x78]; + let mut payload = b"Hello".to_vec(); + WebSocketFrame::apply_mask(&mut payload, &mask); + + let mut data = vec![ + 0b1000_0001, // FIN=1, RSV=0, Opcode=Text + 0b1000_0101, // MASK=1, Payload length=5 + ]; + data.extend_from_slice(&mask); + data.extend_from_slice(&payload); + + let (frame, consumed) = WebSocketFrame::parse(&data).unwrap(); + assert_eq!(consumed, 11); + assert!(frame.fin); + assert_eq!(frame.payload, b"Hello"); + } + + #[test] + fn test_encode_frame() { + let frame = WebSocketFrame::new_text("Hello".to_string(), true); + let encoded = frame.encode(None); + + let expected = vec![ + 0b1000_0001, // FIN=1, Opcode=Text + 5, // Payload length=5 + b'H', + b'e', + b'l', + b'l', + b'o', + ]; + assert_eq!(encoded, expected); + } + + #[test] + fn test_extended_length_16bit() { + let payload = vec![0u8; 200]; + let mut data = vec![ + 0b1000_0010, // FIN=1, Opcode=Binary + 126, // Extended 16-bit length indicator + 0x00, + 0xC8, // Length = 200 + ]; + data.extend_from_slice(&payload); + + let (frame, consumed) = WebSocketFrame::parse(&data).unwrap(); + assert_eq!(consumed, 204); + assert_eq!(frame.payload.len(), 200); + } + + #[test] + fn test_close_frame() { + let frame = WebSocketFrame::new_close(Some(1000), Some("Normal closure")); + let encoded = frame.encode(None); + + let (parsed, _) = WebSocketFrame::parse(&encoded).unwrap(); + let (code, reason) = parsed.parse_close_payload().unwrap(); + assert_eq!(code, 1000); + assert_eq!(reason, "Normal closure"); + } + + #[test] + fn test_control_frame_too_large() { + // Control frame with payload > 125 bytes + let data = vec![ + 0b1000_1000, // FIN=1, Opcode=Close + 126, // Extended length + 0x00, + 0x7F, // Length = 127 (> 125) + ]; + + let result = WebSocketFrame::parse(&data); + assert!(matches!(result, Err(WebSocketError::ControlFrameTooLarge))); + } + + #[test] + fn test_incomplete_frame() { + let data = vec![0b1000_0001]; // Only first byte + let result = WebSocketFrame::parse(&data); + assert!(matches!(result, Err(WebSocketError::IncompleteFrame))); + } +} diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs new file mode 100644 index 0000000..453ab10 --- /dev/null +++ b/src/websocket/mod.rs @@ -0,0 +1,12 @@ +//! WebSocket frame codec implementation conforming to RFC 6455. +//! +//! This module provides WebSocket frame parsing, encoding, and message assembly +//! for bidirectional WebSocket communication using tokio_util::codec. + +mod codec; +mod frame; +mod wrapper; + +pub use codec::WebSocketCodec; +pub use frame::{WebSocketError, WebSocketFrame, WebSocketOpcode}; +pub use wrapper::{WebSocketDecoder, WebSocketEncoder}; diff --git a/src/websocket/wrapper.rs b/src/websocket/wrapper.rs new file mode 100644 index 0000000..4965d2d --- /dev/null +++ b/src/websocket/wrapper.rs @@ -0,0 +1,306 @@ +//! WebSocket decoder and encoder wrappers that use WebSocketCodec. +//! +//! These types provide a clean API for JavaScript bindings while using +//! the WebSocketCodec for frame parsing and encoding. + +use super::{WebSocketCodec, WebSocketError, WebSocketFrame}; +use bytes::BytesMut; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Mutex; +use tokio_util::codec::{Decoder, Encoder}; + +/// WebSocket message decoder that reads and assembles frames. +/// +/// Uses WebSocketCodec internally to handle frame parsing and message assembly. +pub struct WebSocketDecoder { + reader: R, + codec: WebSocketCodec, + buffer: BytesMut, +} + +impl WebSocketDecoder { + /// Create a new WebSocketDecoder from any AsyncRead type. + pub fn new(reader: R) -> Self { + WebSocketDecoder { + reader, + codec: WebSocketCodec::new(), + buffer: BytesMut::with_capacity(8192), + } + } + + /// Read the next WebSocket message. + /// + /// Returns `Ok(Some(frame))` if a complete frame was read, + /// `Ok(None)` if the stream ended, or `Err` on error. + pub async fn read_message(&mut self) -> Result, WebSocketError> { + loop { + // Try to decode a frame from the buffer + match self.codec.decode(&mut self.buffer)? { + Some(frame) => return Ok(Some(frame)), + None => { + // Need more data - read from stream + let mut temp_buf = vec![0u8; 8192]; + + match self.reader.read(&mut temp_buf).await { + Ok(0) => return Ok(None), // EOF + Ok(n) => { + self.buffer.extend_from_slice(&temp_buf[..n]); + // Loop to try decoding again + } + Err(e) => return Err(WebSocketError::IoError(e.to_string())), + } + } + } + } + } +} + +/// WebSocket message encoder that generates and writes frames. +/// +/// Uses WebSocketCodec internally to handle frame encoding. +pub struct WebSocketEncoder { + writer: Arc>, + codec: Mutex, +} + +impl WebSocketEncoder { + /// Create a new WebSocketEncoder from any AsyncWrite type. + pub fn new(writer: W) -> Self { + WebSocketEncoder { + writer: Arc::new(Mutex::new(writer)), + codec: Mutex::new(WebSocketCodec::new()), + } + } + + /// Write a text message. + pub async fn write_text(&self, text: &str, _masked: bool) -> Result<(), WebSocketError> { + let frame = WebSocketFrame::new_text(text.to_string(), true); + let mut buffer = BytesMut::new(); + + // Lock the codec to encode the frame + let mut codec = self.codec.lock().await; + codec.encode(frame, &mut buffer)?; + drop(codec); // Release lock early + + let mut writer = self.writer.lock().await; + writer + .write_all(&buffer) + .await + .map_err(|e| WebSocketError::IoError(e.to_string()))?; + + Ok(()) + } + + /// Write a binary message. + pub async fn write_binary(&self, data: &[u8], _masked: bool) -> Result<(), WebSocketError> { + let frame = WebSocketFrame::new_binary(data.to_vec(), true); + let mut buffer = BytesMut::new(); + + // Lock the codec to encode the frame + let mut codec = self.codec.lock().await; + codec.encode(frame, &mut buffer)?; + drop(codec); // Release lock early + + let mut writer = self.writer.lock().await; + writer + .write_all(&buffer) + .await + .map_err(|e| WebSocketError::IoError(e.to_string()))?; + + Ok(()) + } + + /// Send a close frame with optional code and reason, then close the stream. + pub async fn write_close( + &self, + code: Option, + reason: Option<&str>, + ) -> Result<(), WebSocketError> { + let frame = WebSocketFrame::new_close(code, reason); + let mut buffer = BytesMut::new(); + + // Lock the codec to encode the frame + let mut codec = self.codec.lock().await; + codec.encode(frame, &mut buffer)?; + drop(codec); // Release lock early + + let mut writer = self.writer.lock().await; + writer + .write_all(&buffer) + .await + .map_err(|e| WebSocketError::IoError(e.to_string()))?; + + // Shutdown the stream + writer + .shutdown() + .await + .map_err(|e| WebSocketError::IoError(e.to_string()))?; + + Ok(()) + } + + /// Close the encoder stream without sending a close frame. + pub async fn end(&self) -> Result<(), WebSocketError> { + let mut writer = self.writer.lock().await; + writer + .shutdown() + .await + .map_err(|e| WebSocketError::IoError(e.to_string()))?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::duplex; + + #[tokio::test] + async fn test_encoder_decoder_creation() { + let (client, server) = duplex(1024); + + let _encoder = WebSocketEncoder::new(client); + let _decoder = WebSocketDecoder::new(server); + } + + #[tokio::test] + async fn test_write_and_read_text_message() { + let (client, server) = duplex(1024); + + let encoder = WebSocketEncoder::new(client); + let mut decoder = WebSocketDecoder::new(server); + + // Write a text message + encoder.write_text("Hello WebSocket!", false).await.unwrap(); + + // Read it back + let frame = decoder.read_message().await.unwrap().unwrap(); + assert!(frame.is_text()); + assert_eq!(frame.payload_as_text().unwrap(), "Hello WebSocket!"); + } + + #[tokio::test] + async fn test_write_and_read_binary_message() { + let (client, server) = duplex(1024); + + let encoder = WebSocketEncoder::new(client); + let mut decoder = WebSocketDecoder::new(server); + + // Write binary data + let data = vec![0x01, 0x02, 0x03, 0x04]; + encoder.write_binary(&data, false).await.unwrap(); + + // Read it back + let frame = decoder.read_message().await.unwrap().unwrap(); + assert!(frame.is_binary()); + assert_eq!(frame.payload, data); + } + + #[tokio::test] + async fn test_write_close_shuts_down_stream() { + let (client, server) = duplex(1024); + + let encoder = WebSocketEncoder::new(client); + let mut decoder = WebSocketDecoder::new(server); + + // Send a close frame + encoder + .write_close(Some(1000), Some("Normal closure")) + .await + .unwrap(); + + // Read the close frame + let frame = decoder.read_message().await.unwrap().unwrap(); + assert!(frame.is_close()); + + // Try to read again - should get None (EOF) because stream was shut down + let eof = decoder.read_message().await.unwrap(); + assert!(eof.is_none(), "Expected EOF after close frame"); + + // Verify we can't write more (stream is closed) + let write_result = encoder.write_text("Should fail", false).await; + assert!(write_result.is_err(), "Write should fail after close"); + } + + #[tokio::test] + async fn test_end_shuts_down_stream_without_close_frame() { + let (client, server) = duplex(1024); + + let encoder = WebSocketEncoder::new(client); + let mut decoder = WebSocketDecoder::new(server); + + // Write a message first + encoder.write_text("Hello", false).await.unwrap(); + + // Read it + let frame = decoder.read_message().await.unwrap().unwrap(); + assert_eq!(frame.payload_as_text().unwrap(), "Hello"); + + // Call end() to close stream without sending close frame + encoder.end().await.unwrap(); + + // Should get EOF immediately (no close frame) + let eof = decoder.read_message().await.unwrap(); + assert!(eof.is_none(), "Expected EOF after end()"); + } + + #[tokio::test] + async fn test_multiple_messages_then_close() { + let (client, server) = duplex(2048); + + let encoder = WebSocketEncoder::new(client); + let mut decoder = WebSocketDecoder::new(server); + + // Send multiple messages + encoder.write_text("Message 1", false).await.unwrap(); + encoder.write_text("Message 2", false).await.unwrap(); + encoder.write_binary(&[1, 2, 3], false).await.unwrap(); + + // Read them back + let msg1 = decoder.read_message().await.unwrap().unwrap(); + assert_eq!(msg1.payload_as_text().unwrap(), "Message 1"); + + let msg2 = decoder.read_message().await.unwrap().unwrap(); + assert_eq!(msg2.payload_as_text().unwrap(), "Message 2"); + + let msg3 = decoder.read_message().await.unwrap().unwrap(); + assert_eq!(msg3.payload, vec![1, 2, 3]); + + // Now close + encoder.write_close(None, None).await.unwrap(); + + let close_frame = decoder.read_message().await.unwrap().unwrap(); + assert!(close_frame.is_close()); + + // EOF after close + assert!(decoder.read_message().await.unwrap().is_none()); + } + + #[tokio::test] + async fn test_close_cannot_be_called_twice() { + let (client, _server) = duplex(1024); + + let encoder = WebSocketEncoder::new(client); + + // First close should succeed + encoder.write_close(Some(1000), None).await.unwrap(); + + // Second close should fail (stream already shut down) + let result = encoder.write_close(Some(1000), None).await; + assert!(result.is_err(), "Second close should fail"); + } + + #[tokio::test] + async fn test_end_is_idempotent() { + let (client, _server) = duplex(1024); + + let encoder = WebSocketEncoder::new(client); + + // First end should succeed + encoder.end().await.unwrap(); + + // Second end should also succeed (shutdown is idempotent) + encoder.end().await.unwrap(); + } +} diff --git a/test/request.test.mjs b/test/request.test.mjs index 4c8bf10..5520c55 100644 --- a/test/request.test.mjs +++ b/test/request.test.mjs @@ -1,4 +1,4 @@ -import { ok, throws, doesNotThrow, deepStrictEqual, strictEqual } from 'node:assert/strict' +import { ok, throws, doesNotThrow, deepStrictEqual, strictEqual, rejects } from 'node:assert/strict' import { test } from 'node:test' import { Request } from '../index.js' @@ -133,9 +133,6 @@ test('Request', async t => { ok(request.body instanceof Buffer, 'should create Buffer instance for body') deepStrictEqual(request.body, body, 'should set the body correctly') - - request.body = Buffer.from('New Body') - deepStrictEqual(request.body, Buffer.from('New Body'), 'should update the body correctly') }) await t.test('toJSON', () => { @@ -153,4 +150,52 @@ test('Request', async t => { body: Buffer.from('Hello, World!') }, 'should convert to JSON correctly') }) + + await t.test('write() should error when body is already provided', async () => { + // Create a request with a body already provided + const request = new Request({ + method: 'POST', + url: 'https://example.com/test', + body: Buffer.from('initial body') + }) + + // Trying to write should throw an error + await rejects( + async () => { + await request.write(Buffer.from('more data')) + }, + { + message: 'Cannot write to request: body has already been provided' + } + ) + }) + + await t.test('end() should succeed silently when body is already provided', async () => { + // Create a request with a body already provided + const request = new Request({ + method: 'POST', + url: 'https://example.com/test', + body: Buffer.from('initial body') + }) + + // Trying to end should not throw (returns silently) + await doesNotThrow(async () => { + await request.end() + }, 'should not throw error when calling end() on request with existing body buffer') + }) + + await t.test('write() and end() should work when body is not provided', async () => { + // Create a request without a body + const request = new Request({ + method: 'POST', + url: 'https://example.com/test' + }) + + // These should not throw + await doesNotThrow(async () => { + await request.write(Buffer.from('chunk 1')) + await request.write(Buffer.from('chunk 2')) + await request.end() + }, 'should allow write() and end() when no body buffer is present') + }) }) diff --git a/test/response.test.mjs b/test/response.test.mjs index d6682ab..e660815 100644 --- a/test/response.test.mjs +++ b/test/response.test.mjs @@ -57,9 +57,6 @@ test('Response', async t => { ok(response.body instanceof Buffer, 'should create Buffer instance for body') strictEqual(response.body.toString('utf8'), 'Hello, World!', 'should set the body correctly') - - response.body = Buffer.from('New body content') - strictEqual(response.body.toString('utf8'), 'New body content', 'should update the body content correctly') }) await t.test('toJSON', () => {