diff --git a/.github/workflows/ci-python.yml b/.github/workflows/ci-python.yml index 5bb1a39..a109a28 100644 --- a/.github/workflows/ci-python.yml +++ b/.github/workflows/ci-python.yml @@ -57,6 +57,10 @@ jobs: with: python-version: '3.12' + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - name: Copy root LICENSE for maturin run: cp ../LICENSE LICENSE diff --git a/python/Makefile b/python/Makefile index e7ec39d..d9d1849 100644 --- a/python/Makefile +++ b/python/Makefile @@ -67,6 +67,7 @@ fmt: $(VENV) -m black zerobus examples tests $(VENV) -m autoflake -ri --exclude '*_pb2*.py' zerobus examples tests $(VENV) -m isort zerobus examples tests + cd rust && cargo fmt --all lint: $(VENV) -m pycodestyle --exclude='*_pb2*.py' --max-line-length=120 --ignore=E203,W503 zerobus diff --git a/python/rust/src/arrow.rs b/python/rust/src/arrow.rs index 99fadf7..8d5b6be 100644 --- a/python/rust/src/arrow.rs +++ b/python/rust/src/arrow.rs @@ -11,31 +11,25 @@ use tokio::sync::RwLock; use databricks_zerobus_ingest_sdk::{ ArrowStreamConfigurationOptions as RustArrowStreamOptions, - ArrowTableProperties as RustArrowTableProperties, - ZerobusArrowStream as RustZerobusArrowStream, ZerobusError as RustError, - ZerobusSdk as RustSdk, + ArrowTableProperties as RustArrowTableProperties, ZerobusArrowStream as RustZerobusArrowStream, + ZerobusError as RustError, ZerobusSdk as RustSdk, }; use crate::auth::HeadersProviderWrapper; use crate::common::map_error; /// Deserialize Arrow IPC bytes into exactly one RecordBatch. -fn ipc_bytes_to_record_batch( - ipc_bytes: &[u8], -) -> Result { - let mut reader = - arrow_ipc::reader::StreamReader::try_new(ipc_bytes, None).map_err(|e| { - RustError::InvalidArgument(format!("Failed to parse Arrow IPC data: {}", e)) - })?; +fn ipc_bytes_to_record_batch(ipc_bytes: &[u8]) -> Result { + let mut reader = arrow_ipc::reader::StreamReader::try_new(ipc_bytes, None).map_err(|e| { + RustError::InvalidArgument(format!("Failed to parse Arrow IPC data: {}", e)) + })?; let batch = reader .next() .ok_or_else(|| { RustError::InvalidArgument("No batches found in Arrow IPC data".to_string()) })? - .map_err(|e| { - RustError::InvalidArgument(format!("Failed to read Arrow batch: {}", e)) - })?; + .map_err(|e| RustError::InvalidArgument(format!("Failed to read Arrow batch: {}", e)))?; if reader.next().is_some() { return Err(RustError::InvalidArgument( @@ -47,27 +41,18 @@ fn ipc_bytes_to_record_batch( } /// Serialize a RecordBatch to Arrow IPC bytes. -fn record_batch_to_ipc_bytes( - batch: &arrow_array::RecordBatch, -) -> Result, RustError> { +fn record_batch_to_ipc_bytes(batch: &arrow_array::RecordBatch) -> Result, RustError> { let mut buffer = Vec::new(); { - let mut writer = - arrow_ipc::writer::StreamWriter::try_new(&mut buffer, &batch.schema()) - .map_err(|e| { - RustError::InvalidArgument(format!( - "Failed to create Arrow IPC writer: {}", - e - )) - })?; + let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut buffer, &batch.schema()) + .map_err(|e| { + RustError::InvalidArgument(format!("Failed to create Arrow IPC writer: {}", e)) + })?; writer.write(batch).map_err(|e| { RustError::InvalidArgument(format!("Failed to write Arrow batch: {}", e)) })?; writer.finish().map_err(|e| { - RustError::InvalidArgument(format!( - "Failed to finish Arrow IPC stream: {}", - e - )) + RustError::InvalidArgument(format!("Failed to finish Arrow IPC stream: {}", e)) })?; } Ok(buffer) @@ -80,14 +65,13 @@ fn record_batch_to_ipc_bytes( fn ipc_schema_bytes_to_arrow_schema( schema_bytes: &[u8], ) -> Result { - let reader = arrow_ipc::reader::StreamReader::try_new(schema_bytes, None) - .map_err(|e| { - RustError::InvalidArgument(format!( - "Failed to parse Arrow IPC schema bytes: {}. \ + let reader = arrow_ipc::reader::StreamReader::try_new(schema_bytes, None).map_err(|e| { + RustError::InvalidArgument(format!( + "Failed to parse Arrow IPC schema bytes: {}. \ Pass bytes from pa.ipc.new_stream(sink, schema) with no batches written.", - e - )) - })?; + e + )) + })?; Ok(reader.schema().as_ref().clone()) } @@ -95,6 +79,32 @@ fn ipc_schema_bytes_to_arrow_schema( // ARROW STREAM CONFIGURATION OPTIONS // ============================================================================= +/// Arrow IPC compression codec. +#[pyclass] +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum IPCCompression { + /// No compression (default). + #[pyo3(name = "NONE")] + Uncompressed = 0, + /// LZ4 frame compression. + #[pyo3(name = "LZ4_FRAME")] + LZ4Frame = 1, + /// Zstandard compression. + #[pyo3(name = "ZSTD")] + Zstd = 2, +} + +#[pymethods] +impl IPCCompression { + fn __repr__(&self) -> &'static str { + match self { + IPCCompression::Uncompressed => "IPCCompression.NONE", + IPCCompression::LZ4Frame => "IPCCompression.LZ4_FRAME", + IPCCompression::Zstd => "IPCCompression.ZSTD", + } + } +} + /// Configuration options for Arrow Flight streams. #[pyclass] #[derive(Clone)] @@ -122,6 +132,10 @@ pub struct ArrowStreamConfigurationOptions { #[pyo3(get, set)] pub connection_timeout_ms: i64, + + /// IPC compression codec. Default: IPCCompression.NONE + #[pyo3(get, set)] + pub ipc_compression: IPCCompression, } impl Default for ArrowStreamConfigurationOptions { @@ -133,10 +147,10 @@ impl Default for ArrowStreamConfigurationOptions { recovery_timeout_ms: rust_default.recovery_timeout_ms as i64, recovery_backoff_ms: rust_default.recovery_backoff_ms as i64, recovery_retries: rust_default.recovery_retries as i32, - server_lack_of_ack_timeout_ms: rust_default.server_lack_of_ack_timeout_ms - as i64, + server_lack_of_ack_timeout_ms: rust_default.server_lack_of_ack_timeout_ms as i64, flush_timeout_ms: rust_default.flush_timeout_ms as i64, connection_timeout_ms: rust_default.connection_timeout_ms as i64, + ipc_compression: IPCCompression::Uncompressed, } } } @@ -152,23 +166,18 @@ impl ArrowStreamConfigurationOptions { for (key, value) in kwargs { let key_str: &str = key.extract()?; match key_str { - "max_inflight_batches" => { - options.max_inflight_batches = value.extract()? - } + "max_inflight_batches" => options.max_inflight_batches = value.extract()?, "recovery" => options.recovery = value.extract()?, - "recovery_timeout_ms" => { - options.recovery_timeout_ms = value.extract()? - } - "recovery_backoff_ms" => { - options.recovery_backoff_ms = value.extract()? - } + "recovery_timeout_ms" => options.recovery_timeout_ms = value.extract()?, + "recovery_backoff_ms" => options.recovery_backoff_ms = value.extract()?, "recovery_retries" => options.recovery_retries = value.extract()?, "server_lack_of_ack_timeout_ms" => { options.server_lack_of_ack_timeout_ms = value.extract()? } "flush_timeout_ms" => options.flush_timeout_ms = value.extract()?, - "connection_timeout_ms" => { - options.connection_timeout_ms = value.extract()? + "connection_timeout_ms" => options.connection_timeout_ms = value.extract()?, + "ipc_compression" => { + options.ipc_compression = value.extract()?; } _ => { return Err(pyo3::exceptions::PyValueError::new_err(format!( @@ -187,7 +196,8 @@ impl ArrowStreamConfigurationOptions { format!( "ArrowStreamConfigurationOptions(max_inflight_batches={}, recovery={}, \ recovery_timeout_ms={}, recovery_backoff_ms={}, recovery_retries={}, \ - server_lack_of_ack_timeout_ms={}, flush_timeout_ms={}, connection_timeout_ms={})", + server_lack_of_ack_timeout_ms={}, flush_timeout_ms={}, connection_timeout_ms={}, \ + ipc_compression={})", self.max_inflight_batches, self.recovery, self.recovery_timeout_ms, @@ -196,6 +206,7 @@ impl ArrowStreamConfigurationOptions { self.server_lack_of_ack_timeout_ms, self.flush_timeout_ms, self.connection_timeout_ms, + self.ipc_compression.__repr__(), ) } } @@ -237,6 +248,11 @@ impl ArrowStreamConfigurationOptions { "connection_timeout_ms must be non-negative", )); } + let ipc_compression = match self.ipc_compression { + IPCCompression::Uncompressed => None, + IPCCompression::LZ4Frame => Some(arrow_ipc::CompressionType::LZ4_FRAME), + IPCCompression::Zstd => Some(arrow_ipc::CompressionType::ZSTD), + }; Ok(RustArrowStreamOptions { max_inflight_batches: self.max_inflight_batches as usize, recovery: self.recovery, @@ -246,7 +262,7 @@ impl ArrowStreamConfigurationOptions { server_lack_of_ack_timeout_ms: self.server_lack_of_ack_timeout_ms as u64, flush_timeout_ms: self.flush_timeout_ms as u64, connection_timeout_ms: self.connection_timeout_ms as u64, - ipc_compression: None, + ipc_compression, }) } } @@ -272,8 +288,7 @@ impl ZerobusArrowStream { // TODO(perf): eliminate double IPC serialization - Python-to-IPC-to-RecordBatch here, // then RecordBatch-to-IPC again inside the Rust SDK for Flight. Pass IPC bytes // directly to the SDK instead. - let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) - .map_err(|e| map_error(e))?; + let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()).map_err(|e| map_error(e))?; let stream_clone = self.inner.clone(); let runtime = self.runtime.clone(); @@ -379,10 +394,9 @@ impl ZerobusArrowStream { Python::with_gil(|py| { let mut py_batches: Vec = Vec::with_capacity(batches.len()); for batch in &batches { - let ipc_bytes = record_batch_to_ipc_bytes(batch) - .map_err(|e| map_error(e))?; - py_batches - .push(PyBytes::new(py, &ipc_bytes).into()); + let ipc_bytes = + record_batch_to_ipc_bytes(batch).map_err(|e| map_error(e))?; + py_batches.push(PyBytes::new(py, &ipc_bytes).into()); } Ok(py_batches) }) @@ -404,13 +418,8 @@ pub struct AsyncZerobusArrowStream { #[pymethods] impl AsyncZerobusArrowStream { /// Ingest a single Arrow RecordBatch (as IPC bytes) and return the offset. - fn ingest_batch<'py>( - &self, - py: Python<'py>, - ipc_bytes: &PyBytes, - ) -> PyResult<&'py PyAny> { - let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()) - .map_err(|e| map_error(e))?; + fn ingest_batch<'py>(&self, py: Python<'py>, ipc_bytes: &PyBytes) -> PyResult<&'py PyAny> { + let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()).map_err(|e| map_error(e))?; let stream_clone = self.inner.clone(); @@ -424,11 +433,7 @@ impl AsyncZerobusArrowStream { } /// Wait for a specific offset to be acknowledged. - fn wait_for_offset<'py>( - &self, - py: Python<'py>, - offset: i64, - ) -> PyResult<&'py PyAny> { + fn wait_for_offset<'py>(&self, py: Python<'py>, offset: i64) -> PyResult<&'py PyAny> { let stream_clone = self.inner.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { @@ -505,8 +510,7 @@ impl AsyncZerobusArrowStream { Python::with_gil(|py| { let mut py_batches: Vec = Vec::with_capacity(batches.len()); for batch in &batches { - let ipc_bytes = record_batch_to_ipc_bytes(batch) - .map_err(|e| map_error(e))?; + let ipc_bytes = record_batch_to_ipc_bytes(batch).map_err(|e| map_error(e))?; py_batches.push(PyBytes::new(py, &ipc_bytes).into()); } Ok(py_batches) @@ -530,8 +534,7 @@ pub fn create_arrow_stream_sync( client_secret: String, options: Option<&ArrowStreamConfigurationOptions>, ) -> PyResult { - let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) - .map_err(|e| map_error(e))?; + let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes).map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -569,8 +572,7 @@ pub fn create_arrow_stream_with_headers_provider_sync( headers_provider: PyObject, options: Option<&ArrowStreamConfigurationOptions>, ) -> PyResult { - let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes) - .map_err(|e| map_error(e))?; + let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes).map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -587,11 +589,7 @@ pub fn create_arrow_stream_with_headers_provider_sync( runtime_clone.block_on(async move { let sdk_guard = sdk_clone.read().await; sdk_guard - .create_arrow_stream_with_headers_provider( - table_props, - provider, - rust_options, - ) + .create_arrow_stream_with_headers_provider(table_props, provider, rust_options) .await .map_err(|e| Python::with_gil(|_py| map_error(e))) }) @@ -641,8 +639,7 @@ pub fn create_arrow_stream_async<'py>( client_secret: String, options: Option, ) -> PyResult<&'py PyAny> { - let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) - .map_err(|e| map_error(e))?; + let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes).map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -674,8 +671,7 @@ pub fn create_arrow_stream_with_headers_provider_async<'py>( headers_provider: PyObject, options: Option, ) -> PyResult<&'py PyAny> { - let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes) - .map_err(|e| map_error(e))?; + let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes).map_err(|e| map_error(e))?; let table_props = RustArrowTableProperties { table_name, @@ -689,11 +685,7 @@ pub fn create_arrow_stream_with_headers_provider_async<'py>( pyo3_asyncio::tokio::future_into_py(py, async move { let sdk_guard = sdk_clone.read().await; let stream = sdk_guard - .create_arrow_stream_with_headers_provider( - table_props, - provider, - rust_options, - ) + .create_arrow_stream_with_headers_provider(table_props, provider, rust_options) .await .map_err(|e| Python::with_gil(|_py| map_error(e)))?; diff --git a/python/rust/src/async_wrapper.rs b/python/rust/src/async_wrapper.rs index 7a4a6e4..e4b719e 100644 --- a/python/rust/src/async_wrapper.rs +++ b/python/rust/src/async_wrapper.rs @@ -570,7 +570,9 @@ fn convert_stream_options( 2 => RustRecordType::Json, _ => RustRecordType::Proto, }, - stream_paused_max_wait_time_ms: opts.stream_paused_max_wait_time_ms.map(|v| v as u64), + stream_paused_max_wait_time_ms: opts + .stream_paused_max_wait_time_ms + .map(|v| v as u64), callback_max_wait_time_ms: opts.callback_max_wait_time_ms.map(|v| v as u64), ack_callback, ..Default::default() diff --git a/python/rust/src/lib.rs b/python/rust/src/lib.rs index b8f2eb2..c203c6b 100644 --- a/python/rust/src/lib.rs +++ b/python/rust/src/lib.rs @@ -48,6 +48,7 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> { // Add arrow submodule let arrow_module = PyModule::new(py, "arrow")?; + arrow_module.add_class::()?; arrow_module.add_class::()?; arrow_module.add_class::()?; arrow_module.add_class::()?; diff --git a/python/tests/test_arrow.py b/python/tests/test_arrow.py index 80c59d7..ae6d547 100644 --- a/python/tests/test_arrow.py +++ b/python/tests/test_arrow.py @@ -12,6 +12,7 @@ from zerobus.sdk.shared.arrow import ( ArrowStreamConfigurationOptions, + IPCCompression, _check_pyarrow, _deserialize_batch, _serialize_batch, @@ -176,6 +177,7 @@ def test_default_construction(self): self.assertIsInstance(options.server_lack_of_ack_timeout_ms, int) self.assertIsInstance(options.flush_timeout_ms, int) self.assertIsInstance(options.connection_timeout_ms, int) + self.assertEqual(options.ipc_compression, IPCCompression.NONE) def test_kwargs_construction(self): options = ArrowStreamConfigurationOptions( @@ -218,12 +220,41 @@ def test_setters(self): self.assertEqual(options.max_inflight_batches, 99) self.assertFalse(options.recovery) + def test_ipc_compression_lz4(self): + options = ArrowStreamConfigurationOptions(ipc_compression=IPCCompression.LZ4_FRAME) + self.assertEqual(options.ipc_compression, IPCCompression.LZ4_FRAME) + + def test_ipc_compression_zstd(self): + options = ArrowStreamConfigurationOptions(ipc_compression=IPCCompression.ZSTD) + self.assertEqual(options.ipc_compression, IPCCompression.ZSTD) + + def test_ipc_compression_none_explicit(self): + options = ArrowStreamConfigurationOptions(ipc_compression=IPCCompression.NONE) + self.assertEqual(options.ipc_compression, IPCCompression.NONE) + + def test_ipc_compression_default_is_none(self): + options = ArrowStreamConfigurationOptions() + self.assertEqual(options.ipc_compression, IPCCompression.NONE) + + def test_ipc_compression_setter(self): + options = ArrowStreamConfigurationOptions() + options.ipc_compression = IPCCompression.ZSTD + self.assertEqual(options.ipc_compression, IPCCompression.ZSTD) + options.ipc_compression = IPCCompression.NONE + self.assertEqual(options.ipc_compression, IPCCompression.NONE) + + def test_ipc_compression_invalid_type_rejected(self): + """Invalid types are rejected at construction time.""" + with self.assertRaises(TypeError): + ArrowStreamConfigurationOptions(ipc_compression="lz4_frame") + def test_repr(self): options = ArrowStreamConfigurationOptions() repr_str = repr(options) self.assertIn("ArrowStreamConfigurationOptions", repr_str) self.assertIn("max_inflight_batches", repr_str) self.assertIn("recovery", repr_str) + self.assertIn("ipc_compression", repr_str) class TestSerializeBatchEmptyRecordBatch(unittest.TestCase): diff --git a/python/zerobus/__init__.py b/python/zerobus/__init__.py index e3499e6..430265b 100644 --- a/python/zerobus/__init__.py +++ b/python/zerobus/__init__.py @@ -45,7 +45,7 @@ # Import from Rust core import zerobus._zerobus_core as _core -from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions +from zerobus.sdk.shared.arrow import ArrowStreamConfigurationOptions, IPCCompression from zerobus.sdk.sync import ZerobusArrowStream, ZerobusSdk, ZerobusStream __version__ = "1.1.0" @@ -67,6 +67,7 @@ # Arrow (experimental) "ZerobusArrowStream", "ArrowStreamConfigurationOptions", + "IPCCompression", "RecordAcknowledgment", # Common types "TableProperties", diff --git a/python/zerobus/_zerobus_core.pyi b/python/zerobus/_zerobus_core.pyi index 8b3be2e..8110880 100644 --- a/python/zerobus/_zerobus_core.pyi +++ b/python/zerobus/_zerobus_core.pyi @@ -516,6 +516,18 @@ class aio: class arrow: """Arrow Flight support submodule.""" + class IPCCompression: + """Arrow IPC compression codec.""" + + NONE: "IPCCompression" + """No compression (default).""" + + LZ4_FRAME: "IPCCompression" + """LZ4 frame compression.""" + + ZSTD: "IPCCompression" + """Zstandard compression.""" + class ArrowStreamConfigurationOptions: """Configuration options for Arrow Flight streams.""" @@ -543,6 +555,9 @@ class arrow: connection_timeout_ms: int """Connection establishment timeout in milliseconds. Default: 30000""" + ipc_compression: "arrow.IPCCompression" + """IPC compression codec. Default: IPCCompression.None""" + def __init__( self, *, @@ -554,6 +569,7 @@ class arrow: server_lack_of_ack_timeout_ms: int = 60000, flush_timeout_ms: int = 300000, connection_timeout_ms: int = 30000, + ipc_compression: "arrow.IPCCompression" = ..., ) -> None: ... def __repr__(self) -> str: ... diff --git a/python/zerobus/sdk/shared/arrow.py b/python/zerobus/sdk/shared/arrow.py index 53d2c83..3af65f9 100644 --- a/python/zerobus/sdk/shared/arrow.py +++ b/python/zerobus/sdk/shared/arrow.py @@ -87,3 +87,4 @@ def _deserialize_batch(ipc_bytes): # Re-export configuration from Rust core ArrowStreamConfigurationOptions = _core.arrow.ArrowStreamConfigurationOptions +IPCCompression = _core.arrow.IPCCompression