Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions rust/lance/src/io/exec/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use std::task::Poll;

use arrow::array::AsArray;
use arrow::compute::{TakeOptions, concat_batches};
Expand All @@ -27,6 +28,7 @@ use lance_arrow::RecordBatchExt;
use lance_core::datatypes::{Field, OnMissing, Projection};
use lance_core::error::{DataFusionResult, LanceOptionExt};
use lance_core::utils::address::RowAddress;
use lance_core::utils::futures::FinallyStreamExt;
use lance_core::utils::tokio::get_num_compute_intensive_cpus;
use lance_core::{ROW_ADDR, ROW_ID};
use lance_io::scheduler::{ScanScheduler, SchedulerConfig};
Expand Down Expand Up @@ -353,19 +355,17 @@ impl TakeStream {
(None, None) => {}
}

self.metrics
.baseline_metrics
.record_output(new_data.num_rows());
self.metrics.batches_processed.add(1);
Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?)
}

fn apply<S: Stream<Item = Result<RecordBatch>> + Send + 'static>(
self: Arc<Self>,
input: S,
) -> impl Stream<Item = Result<RecordBatch>> {
let scan_scheduler = self.scan_scheduler.clone();
let metrics = self.metrics.clone();
let result_scan_scheduler = self.scan_scheduler.clone();
let final_scan_scheduler = self.scan_scheduler.clone();
let result_metrics = self.metrics.clone();
let final_metrics = self.metrics.clone();
let batches = input
.enumerate()
.map(move |(batch_index, batch)| {
Expand All @@ -378,8 +378,24 @@ impl TakeStream {
})
.boxed();
batches
.inspect_ok(move |_| metrics.io_metrics.record(&scan_scheduler))
.try_buffered(get_num_compute_intensive_cpus())
.map(move |result| {
if result.is_ok() {
result_metrics.batches_processed.add(1);
}
result_metrics.io_metrics.record(&result_scan_scheduler);
match result_metrics
.baseline_metrics
.record_poll(Poll::Ready(Some(result)))
{
Poll::Ready(Some(result)) => result,
_ => unreachable!("record_poll returned a different poll state"),
}
})
.finally(move || {
final_metrics.baseline_metrics.done();
final_metrics.io_metrics.record(&final_scan_scheduler);
})
}
}

Expand Down Expand Up @@ -839,6 +855,80 @@ mod tests {
}
}

#[tokio::test(flavor = "current_thread")]
async fn test_take_records_output_and_io_metrics() {
use datafusion::physical_plan::metrics::MetricValue;
use lance_datafusion::utils::{BYTES_READ_METRIC, IOPS_METRIC, REQUESTS_METRIC};
let TestFixture {
dataset,
_tmp_dir_guard,
} = test_fixture().await;

let row_addrs = UInt64Array::from(vec![0_u64, 1, 2, 3, 4]);
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
ROW_ADDR,
DataType::UInt64,
true,
)]));
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(row_addrs)]).unwrap();
let stream = futures::stream::iter(vec![Ok(batch)]);
let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream));
let input = Arc::new(OneShotExec::new(stream));

let projection = dataset
.empty_projection()
.union_column("s", OnMissing::Error)
.unwrap();

let take_exec = TakeExec::try_new(dataset, input, projection)
.unwrap()
.unwrap();

let stream = take_exec
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 5);

let metrics = take_exec.metrics().unwrap();

let output_batches: usize = metrics
.iter()
.filter_map(|m| match m.value() {
MetricValue::OutputBatches(count) => Some(count.value()),
_ => None,
})
.sum();

let output_bytes: usize = metrics
.iter()
.filter_map(|m| match m.value() {
MetricValue::OutputBytes(count) => Some(count.value()),
_ => None,
})
.sum();

let gauge = |name: &str| -> usize {
metrics
.iter_gauges()
.find_map(|(metric_name, gauge)| {
(metric_name.as_ref() == name).then(|| gauge.value())
})
.unwrap_or(0)
};

let bytes_read = gauge(BYTES_READ_METRIC);
let iops = gauge(IOPS_METRIC);
let requests = gauge(REQUESTS_METRIC);

assert_eq!(metrics.output_rows(), Some(5));
assert_eq!(metrics.find_count("batches_processed").unwrap().value(), 1);
assert!(
output_batches > 0 && output_bytes > 0 && bytes_read > 0 && iops > 0 && requests > 0,
"expected positive TakeExec metrics, got output_batches={output_batches}, output_bytes={output_bytes}, bytes_read={bytes_read}, iops={iops}, requests={requests}"
);
}

#[tokio::test]
async fn test_take_order() {
let TestFixture {
Expand Down
Loading