Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions rust/lance-index/src/scalar/btree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use crate::{metrics::NoOpMetricsCollector, scalar::registry::TrainingCriteria};
use crate::{pbold, scalar::btree::flat::FlatIndex};
use arrow_arith::numeric::add;
use arrow_array::{
Array, ArrayAccessor, ArrowNativeTypeOp, PrimitiveArray, RecordBatch, UInt32Array,
Array, ArrayAccessor, ArrowNativeTypeOp, BooleanArray, PrimitiveArray, RecordBatch,
UInt32Array,
cast::AsArray,
new_empty_array,
types::{
Expand Down Expand Up @@ -1856,7 +1857,8 @@ impl BTreeIndex {
..Default::default()
},
)?;
let merged_stream = chunk_concat_stream(unchunked, first.batch_size as usize);
let deduplicated = deduplicate_value_ordered_btree_rows(unchunked);
let merged_stream = chunk_concat_stream(deduplicated, first.batch_size as usize);

let files =
train_btree_index(merged_stream, dest_store, first.batch_size, None, None).await?;
Expand Down Expand Up @@ -1889,6 +1891,51 @@ fn filter_row_ids(
Box::pin(RecordBatchStreamAdapter::new(schema, filtered))
}

fn deduplicate_value_ordered_btree_rows(
stream: SendableRecordBatchStream,
) -> SendableRecordBatchStream {
let schema = stream.schema();
let deduplicated = stream::try_unfold(
(stream, None::<ScalarValue>, HashSet::<u64>::new()),
|(mut stream, mut current_value, mut row_ids_for_value)| async move {
loop {
let Some(batch) = stream.next().await.transpose()? else {
return Ok(None);
};

let values = batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?;
let row_ids = batch
.column_by_name(ROW_ID)
.expect_ok()?
.as_primitive::<UInt64Type>();
let mut mask = Vec::with_capacity(batch.num_rows());

for idx in 0..batch.num_rows() {
let value = ScalarValue::try_from_array(values, idx)?;
match current_value.as_ref() {
Some(current)
if OrderableScalarValue(current.clone())
.cmp(&OrderableScalarValue(value.clone()))
== Ordering::Equal => {}
_ => {
current_value = Some(value);
row_ids_for_value.clear();
}
}
mask.push(row_ids_for_value.insert(row_ids.value(idx)));
}

let mask = BooleanArray::from(mask);
let batch = arrow_select::filter::filter_record_batch(&batch, &mask)?;
if batch.num_rows() > 0 {
return Ok(Some((batch, (stream, current_value, row_ids_for_value))));
}
}
},
);
Box::pin(RecordBatchStreamAdapter::new(schema, deduplicated))
}

fn wrap_bound(bound: &Bound<ScalarValue>) -> Bound<OrderableScalarValue> {
match bound {
Bound::Unbounded => Bound::Unbounded,
Expand Down
57 changes: 57 additions & 0 deletions rust/lance/src/dataset/tests/dataset_merge_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,63 @@ async fn test_merge_insert_with_reordered_columns_and_index() {
final_dataset.validate().await.unwrap();
}

#[tokio::test]
async fn test_btree_merge_deduplicates_row_addrs() {
// This public table flow creates an old BTree segment and a delta segment
// for the same row address. Merging them should not leave duplicate row
// addresses in the final flat page.
let batch = arrow_array::record_batch!(("id", Int32, [1]), ("payload", Int32, [10])).unwrap();
let test_uri = TempStrDir::default();
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
let mut dataset = Dataset::write(reader, &test_uri, None).await.unwrap();

dataset
.create_index(
&["id"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
false,
)
.await
.unwrap();

let source_batch =
arrow_array::record_batch!(("payload", Int32, [100]), ("id", Int32, [1])).unwrap();
let merge_job = MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.try_build()
.unwrap();
let reader = Box::new(RecordBatchIterator::new(
vec![Ok(source_batch.clone())],
source_batch.schema(),
));
let (updated_dataset, _) = merge_job.execute(reader_to_stream(reader)).await.unwrap();
let mut dataset = updated_dataset.as_ref().clone();

dataset
.optimize_indices(&OptimizeOptions::append())
.await
.unwrap();
dataset
.optimize_indices(&OptimizeOptions::merge(2))
.await
.unwrap();

let actual = dataset
.scan()
.filter("id = 1")
.unwrap()
.try_into_batch()
.await
.unwrap();

assert_eq!(actual.num_rows(), 1);
assert_eq!(actual["id"].as_primitive::<Int32Type>().value(0), 1);
assert_eq!(actual["payload"].as_primitive::<Int32Type>().value(0), 100);
}

/// DataReplacement should invalidate index fragment bitmaps for replaced fields.
#[tokio::test]
async fn test_data_replacement_invalidates_index_bitmap() {
Expand Down
Loading