diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 6d21e842e04..371be42126e 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -29,7 +29,8 @@ use crate::{metrics::NoOpMetricsCollector, scalar::registry::TrainingCriteria}; use crate::{pbold, scalar::btree::flat::FlatIndex}; use arrow_arith::numeric::add; use arrow_array::{ - Array, ArrayAccessor, ArrowNativeTypeOp, PrimitiveArray, RecordBatch, UInt32Array, + Array, ArrayAccessor, ArrowNativeTypeOp, BooleanArray, PrimitiveArray, RecordBatch, + UInt32Array, cast::AsArray, new_empty_array, types::{ @@ -1856,7 +1857,8 @@ impl BTreeIndex { ..Default::default() }, )?; - let merged_stream = chunk_concat_stream(unchunked, first.batch_size as usize); + let deduplicated = deduplicate_value_ordered_btree_rows(unchunked); + let merged_stream = chunk_concat_stream(deduplicated, first.batch_size as usize); let files = train_btree_index(merged_stream, dest_store, first.batch_size, None, None).await?; @@ -1889,6 +1891,51 @@ fn filter_row_ids( Box::pin(RecordBatchStreamAdapter::new(schema, filtered)) } +fn deduplicate_value_ordered_btree_rows( + stream: SendableRecordBatchStream, +) -> SendableRecordBatchStream { + let schema = stream.schema(); + let deduplicated = stream::try_unfold( + (stream, None::, HashSet::::new()), + |(mut stream, mut current_value, mut row_ids_for_value)| async move { + loop { + let Some(batch) = stream.next().await.transpose()? else { + return Ok(None); + }; + + let values = batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?; + let row_ids = batch + .column_by_name(ROW_ID) + .expect_ok()? + .as_primitive::(); + let mut mask = Vec::with_capacity(batch.num_rows()); + + for idx in 0..batch.num_rows() { + let value = ScalarValue::try_from_array(values, idx)?; + match current_value.as_ref() { + Some(current) + if OrderableScalarValue(current.clone()) + .cmp(&OrderableScalarValue(value.clone())) + == Ordering::Equal => {} + _ => { + current_value = Some(value); + row_ids_for_value.clear(); + } + } + mask.push(row_ids_for_value.insert(row_ids.value(idx))); + } + + let mask = BooleanArray::from(mask); + let batch = arrow_select::filter::filter_record_batch(&batch, &mask)?; + if batch.num_rows() > 0 { + return Ok(Some((batch, (stream, current_value, row_ids_for_value)))); + } + } + }, + ); + Box::pin(RecordBatchStreamAdapter::new(schema, deduplicated)) +} + fn wrap_bound(bound: &Bound) -> Bound { match bound { Bound::Unbounded => Bound::Unbounded, diff --git a/rust/lance/src/dataset/tests/dataset_merge_update.rs b/rust/lance/src/dataset/tests/dataset_merge_update.rs index 7fa03d6e78d..0660e5afcba 100644 --- a/rust/lance/src/dataset/tests/dataset_merge_update.rs +++ b/rust/lance/src/dataset/tests/dataset_merge_update.rs @@ -1626,6 +1626,63 @@ async fn test_merge_insert_with_reordered_columns_and_index() { final_dataset.validate().await.unwrap(); } +#[tokio::test] +async fn test_btree_merge_deduplicates_row_addrs() { + // This public table flow creates an old BTree segment and a delta segment + // for the same row address. Merging them should not leave duplicate row + // addresses in the final flat page. + let batch = arrow_array::record_batch!(("id", Int32, [1]), ("payload", Int32, [10])).unwrap(); + let test_uri = TempStrDir::default(); + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); + let mut dataset = Dataset::write(reader, &test_uri, None).await.unwrap(); + + dataset + .create_index( + &["id"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + false, + ) + .await + .unwrap(); + + let source_batch = + arrow_array::record_batch!(("payload", Int32, [100]), ("id", Int32, [1])).unwrap(); + let merge_job = MergeInsertBuilder::try_new(Arc::new(dataset.clone()), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .try_build() + .unwrap(); + let reader = Box::new(RecordBatchIterator::new( + vec![Ok(source_batch.clone())], + source_batch.schema(), + )); + let (updated_dataset, _) = merge_job.execute(reader_to_stream(reader)).await.unwrap(); + let mut dataset = updated_dataset.as_ref().clone(); + + dataset + .optimize_indices(&OptimizeOptions::append()) + .await + .unwrap(); + dataset + .optimize_indices(&OptimizeOptions::merge(2)) + .await + .unwrap(); + + let actual = dataset + .scan() + .filter("id = 1") + .unwrap() + .try_into_batch() + .await + .unwrap(); + + assert_eq!(actual.num_rows(), 1); + assert_eq!(actual["id"].as_primitive::().value(0), 1); + assert_eq!(actual["payload"].as_primitive::().value(0), 100); +} + /// DataReplacement should invalidate index fragment bitmaps for replaced fields. #[tokio::test] async fn test_data_replacement_invalidates_index_bitmap() {