diff --git a/arrow-flight/benches/common/mod.rs b/arrow-flight/benches/common/mod.rs index a55e1dd2f7f9..b716d3f31fd0 100644 --- a/arrow-flight/benches/common/mod.rs +++ b/arrow-flight/benches/common/mod.rs @@ -38,12 +38,10 @@ use tonic::{ pub type Builder = fn(usize) -> ArrayRef; -pub const TYPES: &[(&str, Builder)] = &[ - ("fixed", fixed), - ("nested", nested), - ("variable", variable), - ("dict", dict), -]; +pub const TYPES: &[(&str, Builder)] = + &[("fixed", fixed), ("nested", nested), ("variable", variable)]; + +pub const DICT_TYPES: &[(&str, Builder)] = &[("dict", dict)]; fn fixed(n: usize) -> ArrayRef { Arc::new(Int64Array::from_iter_values(0..n as i64)) diff --git a/arrow-flight/benches/flight.rs b/arrow-flight/benches/flight.rs index 4841e9dd9822..db03380bb005 100644 --- a/arrow-flight/benches/flight.rs +++ b/arrow-flight/benches/flight.rs @@ -16,16 +16,19 @@ // under the License. use arrow_array::RecordBatch; -use arrow_flight::{FlightClient, FlightData, encode::FlightDataEncoderBuilder}; +use arrow_flight::{ + FlightClient, FlightData, + encode::{DictionaryHandling, FlightDataEncoderBuilder}, +}; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use futures::TryStreamExt; use tonic::transport::Channel; mod common; -use common::{TYPES, build_batch, start_server}; +use common::{DICT_TYPES, TYPES, build_batch, start_server}; const ROWS: [usize; 2] = [8 * 1024, 64 * 1024]; -const COLS: [usize; 2] = [1, 8]; +const COLS: [usize; 3] = [1, 4, 8]; fn bench_encode(c: &mut Criterion) { let rt = tokio::runtime::Runtime::new().unwrap(); @@ -83,5 +86,55 @@ fn bench_roundtrip(c: &mut Criterion) { } } -criterion_group!(benches, bench_encode, bench_roundtrip); +fn bench_do_put_dictionary(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let (channel, _) = rt.block_on(start_server()); + let mut g = c.benchmark_group("do_put_dictionary"); + + for &(name, build) in DICT_TYPES { + for &rows in &ROWS { + for &cols in &COLS { + let batch = build_batch(name, rows, cols, build); + g.throughput(Throughput::Bytes(batch.get_array_memory_size() as u64)); + + for (label, handling) in [ + ("hydrate", DictionaryHandling::Hydrate), + ("resend", DictionaryHandling::Resend), + ] { + let frames: Vec = rt + .block_on( + FlightDataEncoderBuilder::new() + .with_dictionary_handling(handling) + .build(futures::stream::iter([Ok(batch.clone())])) + .try_collect(), + ) + .unwrap(); + let id = BenchmarkId::new(format!("{name}/{label}"), format!("{rows}x{cols}")); + g.bench_function(id, |b| { + b.to_async(&rt).iter_batched( + || (FlightClient::new(channel.clone()), frames.clone()), + |(mut client, frames)| async move { + client + .do_put(futures::stream::iter(frames.into_iter().map(Ok))) + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + }, + criterion::BatchSize::SmallInput, + ); + }); + } + } + } + } +} + +criterion_group!( + benches, + bench_encode, + bench_roundtrip, + bench_do_put_dictionary +); criterion_main!(benches); diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 191da024136f..f5961ddb95a7 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -329,6 +329,7 @@ impl FlightDataEncoder { } /// Place the `FlightData` in the queue to send + #[inline] fn queue_message(&mut self, mut data: FlightData) { if let Some(descriptor) = self.descriptor.take() { data.flight_descriptor = Some(descriptor); @@ -336,13 +337,6 @@ impl FlightDataEncoder { self.queue.push_back(data); } - /// Place the `FlightData` in the queue to send - fn queue_messages(&mut self, datas: impl IntoIterator) { - for data in datas { - self.queue_message(data) - } - } - /// Encodes schema as a [`FlightData`] in self.queue. /// Updates `self.schema` and returns the new schema fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef { @@ -381,8 +375,9 @@ impl FlightDataEncoder { for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; - - self.queue_messages(flight_dictionaries); + for dict in flight_dictionaries { + self.queue_message(dict); + } self.queue_message(flight_batch); } @@ -671,7 +666,7 @@ fn prepare_schema_for_flight( fn split_batch_for_grpc_response( batch: RecordBatch, max_flight_data_size: usize, -) -> Vec { +) -> impl Iterator { let size = batch .columns() .iter() @@ -680,18 +675,20 @@ fn split_batch_for_grpc_response( let n_batches = (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1); - let rows_per_batch = (batch.num_rows() / n_batches).max(1); - let mut out = Vec::with_capacity(n_batches + 1); - + let num_rows = batch.num_rows(); + let rows_per_batch = (num_rows / n_batches).max(1); let mut offset = 0; - while offset < batch.num_rows() { - let length = (rows_per_batch).min(batch.num_rows() - offset); - out.push(batch.slice(offset, length)); - offset += length; - } - - out + std::iter::from_fn(move || { + if offset < num_rows { + let length = rows_per_batch.min(num_rows - offset); + let slice = batch.slice(offset, length); + offset += length; + Some(slice) + } else { + None + } + }) } /// The data needed to encode a stream of flight data, holding on to @@ -724,7 +721,10 @@ impl FlightIpcEncoder { /// Convert a `RecordBatch` to a Vec of `FlightData` representing /// dictionaries and a `FlightData` representing the batch - fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec, FlightData)> { + fn encode_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<(impl Iterator + use<>, FlightData)> { let (encoded_dictionaries, encoded_batch) = self.data_gen.encode( batch, &mut self.dictionary_tracker, @@ -732,7 +732,7 @@ impl FlightIpcEncoder { &mut self.compression_context, )?; - let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_dictionaries = encoded_dictionaries.into_iter().map(|e| e.into()); let flight_batch = encoded_batch.into(); Ok((flight_dictionaries, flight_batch)) @@ -1838,7 +1838,8 @@ mod tests { let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) .expect("cannot create record batch"); - let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); + let split: Vec<_> = + split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect(); assert_eq!(split.len(), 1); assert_eq!(batch, split[0]); @@ -1848,7 +1849,8 @@ mod tests { let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) .expect("cannot create record batch"); - let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); + let split: Vec<_> = + split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect(); assert_eq!(split.len(), 3); assert_eq!( split.iter().map(|batch| batch.num_rows()).sum::(), @@ -1892,7 +1894,8 @@ mod tests { let input_rows = batch.num_rows(); - let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes); + let split: Vec<_> = + split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes).collect(); let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect(); let output_rows: usize = sizes.iter().sum();