Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ jobs:
with:
python-version: '3.12'

- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt

- name: Copy root LICENSE for maturin
run: cp ../LICENSE LICENSE

Expand Down
1 change: 1 addition & 0 deletions python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ fmt:
$(VENV) -m black zerobus examples tests
$(VENV) -m autoflake -ri --exclude '*_pb2*.py' zerobus examples tests
$(VENV) -m isort zerobus examples tests
cd rust && cargo fmt --all

lint:
$(VENV) -m pycodestyle --exclude='*_pb2*.py' --max-line-length=120 --ignore=E203,W503 zerobus
Expand Down
168 changes: 80 additions & 88 deletions python/rust/src/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,25 @@ use tokio::sync::RwLock;

use databricks_zerobus_ingest_sdk::{
ArrowStreamConfigurationOptions as RustArrowStreamOptions,
ArrowTableProperties as RustArrowTableProperties,
ZerobusArrowStream as RustZerobusArrowStream, ZerobusError as RustError,
ZerobusSdk as RustSdk,
ArrowTableProperties as RustArrowTableProperties, ZerobusArrowStream as RustZerobusArrowStream,
ZerobusError as RustError, ZerobusSdk as RustSdk,
};

use crate::auth::HeadersProviderWrapper;
use crate::common::map_error;

/// Deserialize Arrow IPC bytes into exactly one RecordBatch.
fn ipc_bytes_to_record_batch(
ipc_bytes: &[u8],
) -> Result<arrow_array::RecordBatch, RustError> {
let mut reader =
arrow_ipc::reader::StreamReader::try_new(ipc_bytes, None).map_err(|e| {
RustError::InvalidArgument(format!("Failed to parse Arrow IPC data: {}", e))
})?;
fn ipc_bytes_to_record_batch(ipc_bytes: &[u8]) -> Result<arrow_array::RecordBatch, RustError> {
let mut reader = arrow_ipc::reader::StreamReader::try_new(ipc_bytes, None).map_err(|e| {
RustError::InvalidArgument(format!("Failed to parse Arrow IPC data: {}", e))
})?;
Comment on lines +22 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a lot of these formatting changes. Were they causing some fmt errors? And if they did, how come we didn't see them before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, we previously didn't run formatting of python/rust/*.rs files and we didn't check their formatting in the CI. I added both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!


let batch = reader
.next()
.ok_or_else(|| {
RustError::InvalidArgument("No batches found in Arrow IPC data".to_string())
})?
.map_err(|e| {
RustError::InvalidArgument(format!("Failed to read Arrow batch: {}", e))
})?;
.map_err(|e| RustError::InvalidArgument(format!("Failed to read Arrow batch: {}", e)))?;

if reader.next().is_some() {
return Err(RustError::InvalidArgument(
Expand All @@ -47,27 +41,18 @@ fn ipc_bytes_to_record_batch(
}

/// Serialize a RecordBatch to Arrow IPC bytes.
fn record_batch_to_ipc_bytes(
batch: &arrow_array::RecordBatch,
) -> Result<Vec<u8>, RustError> {
fn record_batch_to_ipc_bytes(batch: &arrow_array::RecordBatch) -> Result<Vec<u8>, RustError> {
let mut buffer = Vec::new();
{
let mut writer =
arrow_ipc::writer::StreamWriter::try_new(&mut buffer, &batch.schema())
.map_err(|e| {
RustError::InvalidArgument(format!(
"Failed to create Arrow IPC writer: {}",
e
))
})?;
let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut buffer, &batch.schema())
.map_err(|e| {
RustError::InvalidArgument(format!("Failed to create Arrow IPC writer: {}", e))
})?;
writer.write(batch).map_err(|e| {
RustError::InvalidArgument(format!("Failed to write Arrow batch: {}", e))
})?;
writer.finish().map_err(|e| {
RustError::InvalidArgument(format!(
"Failed to finish Arrow IPC stream: {}",
e
))
RustError::InvalidArgument(format!("Failed to finish Arrow IPC stream: {}", e))
})?;
}
Ok(buffer)
Expand All @@ -80,21 +65,46 @@ fn record_batch_to_ipc_bytes(
fn ipc_schema_bytes_to_arrow_schema(
schema_bytes: &[u8],
) -> Result<arrow_schema::Schema, RustError> {
let reader = arrow_ipc::reader::StreamReader::try_new(schema_bytes, None)
.map_err(|e| {
RustError::InvalidArgument(format!(
"Failed to parse Arrow IPC schema bytes: {}. \
let reader = arrow_ipc::reader::StreamReader::try_new(schema_bytes, None).map_err(|e| {
RustError::InvalidArgument(format!(
"Failed to parse Arrow IPC schema bytes: {}. \
Pass bytes from pa.ipc.new_stream(sink, schema) with no batches written.",
e
))
})?;
e
))
})?;
Ok(reader.schema().as_ref().clone())
}

// =============================================================================
// ARROW STREAM CONFIGURATION OPTIONS
// =============================================================================

/// Arrow IPC compression codec.
#[pyclass]
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum IPCCompression {
/// No compression (default).
#[pyo3(name = "NONE")]
Uncompressed = 0,
/// LZ4 frame compression.
#[pyo3(name = "LZ4_FRAME")]
LZ4Frame = 1,
/// Zstandard compression.
#[pyo3(name = "ZSTD")]
Zstd = 2,
}

#[pymethods]
impl IPCCompression {
fn __repr__(&self) -> &'static str {
match self {
IPCCompression::Uncompressed => "IPCCompression.NONE",
IPCCompression::LZ4Frame => "IPCCompression.LZ4_FRAME",
IPCCompression::Zstd => "IPCCompression.ZSTD",
}
}
}

/// Configuration options for Arrow Flight streams.
#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -122,6 +132,10 @@ pub struct ArrowStreamConfigurationOptions {

#[pyo3(get, set)]
pub connection_timeout_ms: i64,

/// IPC compression codec. Default: IPCCompression.NONE
#[pyo3(get, set)]
pub ipc_compression: IPCCompression,
}

impl Default for ArrowStreamConfigurationOptions {
Expand All @@ -133,10 +147,10 @@ impl Default for ArrowStreamConfigurationOptions {
recovery_timeout_ms: rust_default.recovery_timeout_ms as i64,
recovery_backoff_ms: rust_default.recovery_backoff_ms as i64,
recovery_retries: rust_default.recovery_retries as i32,
server_lack_of_ack_timeout_ms: rust_default.server_lack_of_ack_timeout_ms
as i64,
server_lack_of_ack_timeout_ms: rust_default.server_lack_of_ack_timeout_ms as i64,
flush_timeout_ms: rust_default.flush_timeout_ms as i64,
connection_timeout_ms: rust_default.connection_timeout_ms as i64,
ipc_compression: IPCCompression::Uncompressed,
}
}
}
Expand All @@ -152,23 +166,18 @@ impl ArrowStreamConfigurationOptions {
for (key, value) in kwargs {
let key_str: &str = key.extract()?;
match key_str {
"max_inflight_batches" => {
options.max_inflight_batches = value.extract()?
}
"max_inflight_batches" => options.max_inflight_batches = value.extract()?,
"recovery" => options.recovery = value.extract()?,
"recovery_timeout_ms" => {
options.recovery_timeout_ms = value.extract()?
}
"recovery_backoff_ms" => {
options.recovery_backoff_ms = value.extract()?
}
"recovery_timeout_ms" => options.recovery_timeout_ms = value.extract()?,
"recovery_backoff_ms" => options.recovery_backoff_ms = value.extract()?,
"recovery_retries" => options.recovery_retries = value.extract()?,
"server_lack_of_ack_timeout_ms" => {
options.server_lack_of_ack_timeout_ms = value.extract()?
}
"flush_timeout_ms" => options.flush_timeout_ms = value.extract()?,
"connection_timeout_ms" => {
options.connection_timeout_ms = value.extract()?
"connection_timeout_ms" => options.connection_timeout_ms = value.extract()?,
"ipc_compression" => {
options.ipc_compression = value.extract()?;
}
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
Expand All @@ -187,7 +196,8 @@ impl ArrowStreamConfigurationOptions {
format!(
"ArrowStreamConfigurationOptions(max_inflight_batches={}, recovery={}, \
recovery_timeout_ms={}, recovery_backoff_ms={}, recovery_retries={}, \
server_lack_of_ack_timeout_ms={}, flush_timeout_ms={}, connection_timeout_ms={})",
server_lack_of_ack_timeout_ms={}, flush_timeout_ms={}, connection_timeout_ms={}, \
ipc_compression={})",
self.max_inflight_batches,
self.recovery,
self.recovery_timeout_ms,
Expand All @@ -196,6 +206,7 @@ impl ArrowStreamConfigurationOptions {
self.server_lack_of_ack_timeout_ms,
self.flush_timeout_ms,
self.connection_timeout_ms,
self.ipc_compression.__repr__(),
)
}
}
Expand Down Expand Up @@ -237,6 +248,11 @@ impl ArrowStreamConfigurationOptions {
"connection_timeout_ms must be non-negative",
));
}
let ipc_compression = match self.ipc_compression {
IPCCompression::Uncompressed => None,
IPCCompression::LZ4Frame => Some(arrow_ipc::CompressionType::LZ4_FRAME),
IPCCompression::Zstd => Some(arrow_ipc::CompressionType::ZSTD),
};
Ok(RustArrowStreamOptions {
max_inflight_batches: self.max_inflight_batches as usize,
recovery: self.recovery,
Expand All @@ -246,7 +262,7 @@ impl ArrowStreamConfigurationOptions {
server_lack_of_ack_timeout_ms: self.server_lack_of_ack_timeout_ms as u64,
flush_timeout_ms: self.flush_timeout_ms as u64,
connection_timeout_ms: self.connection_timeout_ms as u64,
ipc_compression: None,
ipc_compression,
})
}
}
Expand All @@ -272,8 +288,7 @@ impl ZerobusArrowStream {
// TODO(perf): eliminate double IPC serialization - Python-to-IPC-to-RecordBatch here,
// then RecordBatch-to-IPC again inside the Rust SDK for Flight. Pass IPC bytes
// directly to the SDK instead.
let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes())
.map_err(|e| map_error(e))?;
let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()).map_err(|e| map_error(e))?;

let stream_clone = self.inner.clone();
let runtime = self.runtime.clone();
Expand Down Expand Up @@ -379,10 +394,9 @@ impl ZerobusArrowStream {
Python::with_gil(|py| {
let mut py_batches: Vec<PyObject> = Vec::with_capacity(batches.len());
for batch in &batches {
let ipc_bytes = record_batch_to_ipc_bytes(batch)
.map_err(|e| map_error(e))?;
py_batches
.push(PyBytes::new(py, &ipc_bytes).into());
let ipc_bytes =
record_batch_to_ipc_bytes(batch).map_err(|e| map_error(e))?;
py_batches.push(PyBytes::new(py, &ipc_bytes).into());
}
Ok(py_batches)
})
Expand All @@ -404,13 +418,8 @@ pub struct AsyncZerobusArrowStream {
#[pymethods]
impl AsyncZerobusArrowStream {
/// Ingest a single Arrow RecordBatch (as IPC bytes) and return the offset.
fn ingest_batch<'py>(
&self,
py: Python<'py>,
ipc_bytes: &PyBytes,
) -> PyResult<&'py PyAny> {
let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes())
.map_err(|e| map_error(e))?;
fn ingest_batch<'py>(&self, py: Python<'py>, ipc_bytes: &PyBytes) -> PyResult<&'py PyAny> {
let batch = ipc_bytes_to_record_batch(ipc_bytes.as_bytes()).map_err(|e| map_error(e))?;

let stream_clone = self.inner.clone();

Expand All @@ -424,11 +433,7 @@ impl AsyncZerobusArrowStream {
}

/// Wait for a specific offset to be acknowledged.
fn wait_for_offset<'py>(
&self,
py: Python<'py>,
offset: i64,
) -> PyResult<&'py PyAny> {
fn wait_for_offset<'py>(&self, py: Python<'py>, offset: i64) -> PyResult<&'py PyAny> {
let stream_clone = self.inner.clone();

pyo3_asyncio::tokio::future_into_py(py, async move {
Expand Down Expand Up @@ -505,8 +510,7 @@ impl AsyncZerobusArrowStream {
Python::with_gil(|py| {
let mut py_batches: Vec<PyObject> = Vec::with_capacity(batches.len());
for batch in &batches {
let ipc_bytes = record_batch_to_ipc_bytes(batch)
.map_err(|e| map_error(e))?;
let ipc_bytes = record_batch_to_ipc_bytes(batch).map_err(|e| map_error(e))?;
py_batches.push(PyBytes::new(py, &ipc_bytes).into());
}
Ok(py_batches)
Expand All @@ -530,8 +534,7 @@ pub fn create_arrow_stream_sync(
client_secret: String,
options: Option<&ArrowStreamConfigurationOptions>,
) -> PyResult<ZerobusArrowStream> {
let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes)
.map_err(|e| map_error(e))?;
let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes).map_err(|e| map_error(e))?;

let table_props = RustArrowTableProperties {
table_name,
Expand Down Expand Up @@ -569,8 +572,7 @@ pub fn create_arrow_stream_with_headers_provider_sync(
headers_provider: PyObject,
options: Option<&ArrowStreamConfigurationOptions>,
) -> PyResult<ZerobusArrowStream> {
let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes)
.map_err(|e| map_error(e))?;
let schema = ipc_schema_bytes_to_arrow_schema(schema_ipc_bytes).map_err(|e| map_error(e))?;

let table_props = RustArrowTableProperties {
table_name,
Expand All @@ -587,11 +589,7 @@ pub fn create_arrow_stream_with_headers_provider_sync(
runtime_clone.block_on(async move {
let sdk_guard = sdk_clone.read().await;
sdk_guard
.create_arrow_stream_with_headers_provider(
table_props,
provider,
rust_options,
)
.create_arrow_stream_with_headers_provider(table_props, provider, rust_options)
.await
.map_err(|e| Python::with_gil(|_py| map_error(e)))
})
Expand Down Expand Up @@ -641,8 +639,7 @@ pub fn create_arrow_stream_async<'py>(
client_secret: String,
options: Option<ArrowStreamConfigurationOptions>,
) -> PyResult<&'py PyAny> {
let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes)
.map_err(|e| map_error(e))?;
let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes).map_err(|e| map_error(e))?;

let table_props = RustArrowTableProperties {
table_name,
Expand Down Expand Up @@ -674,8 +671,7 @@ pub fn create_arrow_stream_with_headers_provider_async<'py>(
headers_provider: PyObject,
options: Option<ArrowStreamConfigurationOptions>,
) -> PyResult<&'py PyAny> {
let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes)
.map_err(|e| map_error(e))?;
let schema = ipc_schema_bytes_to_arrow_schema(&schema_ipc_bytes).map_err(|e| map_error(e))?;

let table_props = RustArrowTableProperties {
table_name,
Expand All @@ -689,11 +685,7 @@ pub fn create_arrow_stream_with_headers_provider_async<'py>(
pyo3_asyncio::tokio::future_into_py(py, async move {
let sdk_guard = sdk_clone.read().await;
let stream = sdk_guard
.create_arrow_stream_with_headers_provider(
table_props,
provider,
rust_options,
)
.create_arrow_stream_with_headers_provider(table_props, provider, rust_options)
.await
.map_err(|e| Python::with_gil(|_py| map_error(e)))?;

Expand Down
4 changes: 3 additions & 1 deletion python/rust/src/async_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,9 @@ fn convert_stream_options(
2 => RustRecordType::Json,
_ => RustRecordType::Proto,
},
stream_paused_max_wait_time_ms: opts.stream_paused_max_wait_time_ms.map(|v| v as u64),
stream_paused_max_wait_time_ms: opts
.stream_paused_max_wait_time_ms
.map(|v| v as u64),
callback_max_wait_time_ms: opts.callback_max_wait_time_ms.map(|v| v as u64),
ack_callback,
..Default::default()
Expand Down
1 change: 1 addition & 0 deletions python/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ fn _zerobus_core(py: Python, m: &PyModule) -> PyResult<()> {

// Add arrow submodule
let arrow_module = PyModule::new(py, "arrow")?;
arrow_module.add_class::<arrow::IPCCompression>()?;
arrow_module.add_class::<arrow::ArrowStreamConfigurationOptions>()?;
arrow_module.add_class::<arrow::ZerobusArrowStream>()?;
arrow_module.add_class::<arrow::AsyncZerobusArrowStream>()?;
Expand Down
Loading
Loading