Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrow-flight/benches/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod common;
use common::{TYPES, build_batch, start_server};

const ROWS: [usize; 2] = [8 * 1024, 64 * 1024];
const COLS: [usize; 2] = [1, 8];
const COLS: [usize; 4] = [1, 4, 8, 16];

fn bench_encode(c: &mut Criterion) {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand Down
142 changes: 111 additions & 31 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
use std::{
collections::VecDeque,
fmt::Debug,
pin::Pin,
sync::{Arc, Mutex},
task::Poll,
};

use crate::{FlightData, FlightDescriptor, SchemaAsIpc, error::Result};

use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteContext, IpcWriteOptions};

use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
use bytes::Bytes;
Expand Down Expand Up @@ -263,6 +269,65 @@ impl FlightDataEncoderBuilder {
}
}

const DEFAULT_ARROW_DATA_CAPACITY: usize = GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES;

/// Pool of reusable `Vec<u8>` buffers for the IPC arrow data body.
#[derive(Clone, Debug)]
pub struct ArrowDataPool(Arc<Mutex<Vec<Vec<u8>>>>);

impl ArrowDataPool {
/// Create n buffers with pre-allocated capacity for reuse in encoding IPC messages.
pub fn new(n: usize) -> Self {
Self(Arc::new(Mutex::new(
(0..n)
.map(|_| Vec::with_capacity(DEFAULT_ARROW_DATA_CAPACITY))
.collect(),
)))
}

fn acquire(&mut self) -> Vec<u8> {
let mut state = self.0.lock().unwrap();
let mut buf = match state.pop() {
Some(buf) => buf,
None => Vec::with_capacity(DEFAULT_ARROW_DATA_CAPACITY),
};
buf.clear();
buf
}

fn release(&mut self, mut buf: Vec<u8>) {
buf.clear();
let mut state = self.0.lock().unwrap();
state.push(buf);
}
}

impl Default for ArrowDataPool {
fn default() -> Self {
Self::new(1)
}
}

// A thin wrapper that gives `Bytes::from_owner` something to hold onto.
// `data` — the Vec<u8> written into by encode(). The buffer we keep alive.
// `pool` — shared handle back to the ArrowDataPool; on drop, the Vec finds its way home.
pub(crate) struct PooledBuf {
data: Vec<u8>,
pool: ArrowDataPool,
}

impl AsRef<[u8]> for PooledBuf {
fn as_ref(&self) -> &[u8] {
&self.data
}
}

impl Drop for PooledBuf {
fn drop(&mut self) {
self.pool.release(std::mem::take(&mut self.data));
}
}

/// Stream that encodes a stream of record batches to flight data.
///
/// See [`FlightDataEncoderBuilder`] for details and example.
Expand Down Expand Up @@ -329,20 +394,14 @@ impl FlightDataEncoder {
}

/// Place the `FlightData` in the queue to send
#[inline]

@Rich-T-kid Rich-T-kid Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler very likely could have inlined this, but I think its work adding this explicitly.

fn queue_message(&mut self, mut data: FlightData) {
if let Some(descriptor) = self.descriptor.take() {
data.flight_descriptor = Some(descriptor);
}
self.queue.push_back(data);
}

/// Place the `FlightData` in the queue to send
fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
for data in datas {
self.queue_message(data)
}
}

/// Encodes schema as a [`FlightData`] in self.queue.
/// Updates `self.schema` and returns the new schema
fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
Expand Down Expand Up @@ -381,8 +440,9 @@ impl FlightDataEncoder {

for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;

self.queue_messages(flight_dictionaries);
for dict in flight_dictionaries {
self.queue_message(dict);
}
self.queue_message(flight_batch);
}

Expand Down Expand Up @@ -671,7 +731,7 @@ fn prepare_schema_for_flight(
fn split_batch_for_grpc_response(
batch: RecordBatch,
max_flight_data_size: usize,
) -> Vec<RecordBatch> {
) -> impl Iterator<Item = RecordBatch> {
let size = batch
.columns()
.iter()
Expand All @@ -680,18 +740,20 @@ fn split_batch_for_grpc_response(

let n_batches =
(size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
let rows_per_batch = (batch.num_rows() / n_batches).max(1);
let mut out = Vec::with_capacity(n_batches + 1);

let num_rows = batch.num_rows();
let rows_per_batch = (num_rows / n_batches).max(1);
let mut offset = 0;
while offset < batch.num_rows() {
let length = (rows_per_batch).min(batch.num_rows() - offset);
out.push(batch.slice(offset, length));

offset += length;
}

out
std::iter::from_fn(move || {
if offset < num_rows {
let length = rows_per_batch.min(num_rows - offset);
let slice = batch.slice(offset, length);
offset += length;
Some(slice)
} else {
None
}
})
}

/// The data needed to encode a stream of flight data, holding on to
Expand All @@ -704,7 +766,8 @@ struct FlightIpcEncoder {
options: IpcWriteOptions,
data_gen: IpcDataGenerator,
dictionary_tracker: DictionaryTracker,
compression_context: CompressionContext,
compression_context: IpcWriteContext,
pool: ArrowDataPool,
}

impl FlightIpcEncoder {
Expand All @@ -713,7 +776,8 @@ impl FlightIpcEncoder {
options,
data_gen: IpcDataGenerator::default(),
dictionary_tracker: DictionaryTracker::new(error_on_replacement),
compression_context: CompressionContext::default(),
compression_context: IpcWriteContext::default(),
pool: ArrowDataPool::default(),
}
}

Expand All @@ -724,16 +788,29 @@ impl FlightIpcEncoder {

/// Convert a `RecordBatch` to a Vec of `FlightData` representing
/// dictionaries and a `FlightData` representing the batch
fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
fn encode_batch(
&mut self,
batch: &RecordBatch,
) -> Result<(impl Iterator<Item = FlightData> + use<>, FlightData)> {
self.compression_context.scratch = self.pool.acquire();
let (encoded_dictionaries, encoded_batch) = self.data_gen.encode(
batch,
&mut self.dictionary_tracker,
&self.options,
&mut self.compression_context,
)?;

let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
let flight_batch = encoded_batch.into();
let flight_dictionaries = encoded_dictionaries.into_iter().map(|e| e.into());

let pooled = PooledBuf {
data: encoded_batch.arrow_data,
pool: self.pool.clone(),
};
let flight_batch = crate::FlightData {
data_header: encoded_batch.ipc_message.into(),
data_body: Bytes::from_owner(pooled),
..Default::default()
};

Ok((flight_dictionaries, flight_batch))
}
Expand Down Expand Up @@ -1813,7 +1890,7 @@ mod tests {
) -> (Vec<FlightData>, FlightData) {
let data_gen = IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(false);
let mut compression_context = CompressionContext::default();
let mut compression_context = IpcWriteContext::default();

let (encoded_dictionaries, encoded_batch) = data_gen
.encode(
Expand All @@ -1838,7 +1915,8 @@ mod tests {
let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect();
assert_eq!(split.len(), 1);
assert_eq!(batch, split[0]);

Expand All @@ -1848,7 +1926,8 @@ mod tests {
let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect();
assert_eq!(split.len(), 3);
assert_eq!(
split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
Expand Down Expand Up @@ -1892,7 +1971,8 @@ mod tests {

let input_rows = batch.num_rows();

let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes).collect();
let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
let output_rows: usize = sizes.iter().sum();

Expand Down
4 changes: 2 additions & 2 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::Buffer;
use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::writer::CompressionContext;
use arrow_ipc::writer::IpcWriteContext;
use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions};
use arrow_schema::{ArrowError, Schema, SchemaRef};

Expand Down Expand Up @@ -92,7 +92,7 @@ pub fn batches_to_flight_data(

let data_gen = writer::IpcDataGenerator::default();
let mut dictionary_tracker = writer::DictionaryTracker::new(false);
let mut compression_context = CompressionContext::default();
let mut compression_context = IpcWriteContext::default();

for batch in batches.iter() {
let (encoded_dictionaries, encoded_batch) = data_gen.encode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::{
datatypes::SchemaRef,
ipc::{
self, reader,
writer::{self, CompressionContext},
writer::{self, IpcWriteContext},
},
record_batch::RecordBatch,
};
Expand Down Expand Up @@ -95,7 +95,7 @@ async fn upload_data(

let mut original_data_iter = original_data.iter().enumerate();

let mut compression_context = CompressionContext::default();
let mut compression_context = IpcWriteContext::default();

if let Some((counter, first_batch)) = original_data_iter.next() {
let metadata = counter.to_string().into_bytes();
Expand Down Expand Up @@ -159,7 +159,7 @@ async fn send_batch(
batch: &RecordBatch,
options: &writer::IpcWriteOptions,
dictionary_tracker: &mut writer::DictionaryTracker,
compression_context: &mut CompressionContext,
compression_context: &mut IpcWriteContext,
) -> Result {
let data_gen = writer::IpcDataGenerator::default();

Expand Down
28 changes: 15 additions & 13 deletions arrow-ipc/src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,24 @@
use crate::CompressionType;
use arrow_buffer::Buffer;
use arrow_schema::ArrowError;
use flatbuffers::FlatBufferBuilder;

const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
const LENGTH_OF_PREFIX_DATA: i64 = 8;

/// Additional context that may be needed for compression.
///
/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
/// compression calls to avoid the performance overhead of initialising a new context for every
/// compression.
/// - The flatbuffer builder (`fbb`) is reset and reused across calls.
/// - The zstd compressor (when enabled) is kept alive to avoid re-initialisation overhead.
#[derive(Default)]
pub struct CompressionContext {
pub struct IpcWriteContext {
#[cfg(feature = "zstd")]
compressor: Option<zstd::bulk::Compressor<'static>>,
pub(crate) fbb: FlatBufferBuilder<'static>,
/// Scratch buffer for the IPC arrow data body. When set by the caller before
/// encode(), the existing allocation is reused instead of creating a fresh Vec.
pub scratch: Vec<u8>,
}

impl CompressionContext {
impl IpcWriteContext {
#[cfg(feature = "zstd")]
fn zstd_compressor(&mut self) -> &mut zstd::bulk::Compressor<'static> {
self.compressor.get_or_insert_with(|| {
Expand All @@ -43,9 +45,9 @@ impl CompressionContext {
}
}

impl std::fmt::Debug for CompressionContext {
impl std::fmt::Debug for IpcWriteContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_struct("CompressionContext");
let mut ds = f.debug_struct("IpcWriteContext");

#[cfg(feature = "zstd")]
ds.field(
Expand Down Expand Up @@ -143,7 +145,7 @@ impl CompressionCodec {
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
context: &mut IpcWriteContext,
) -> Result<usize, ArrowError> {
let uncompressed_data_len = input.len();
let original_output_len = output.len();
Expand Down Expand Up @@ -209,7 +211,7 @@ impl CompressionCodec {
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
context: &mut IpcWriteContext,
) -> Result<(), ArrowError> {
match self {
CompressionCodec::Lz4Frame => compress_lz4(input, output),
Expand Down Expand Up @@ -278,7 +280,7 @@ fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, A
fn compress_zstd(
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
context: &mut IpcWriteContext,
) -> Result<(), ArrowError> {
let result = context.zstd_compressor().compress(input)?;
output.extend_from_slice(&result);
Expand All @@ -290,7 +292,7 @@ fn compress_zstd(
fn compress_zstd(
_input: &[u8],
_output: &mut Vec<u8>,
_context: &mut CompressionContext,
_context: &mut IpcWriteContext,
) -> Result<(), ArrowError> {
Err(ArrowError::InvalidArgumentError(
"zstd IPC compression requires the zstd feature".to_string(),
Expand Down
Loading
Loading