diff --git a/rust/lance-index/benches/rq.rs b/rust/lance-index/benches/rq.rs index 4a7364d1313..e29ce9c4695 100644 --- a/rust/lance-index/benches/rq.rs +++ b/rust/lance-index/benches/rq.rs @@ -17,11 +17,16 @@ use lance_datagen::array::rand_type; use lance_datagen::{BatchGeneratorBuilder, RowCount}; use lance_index::vector::bq::RQRotationType; use lance_index::vector::bq::builder::RabitQuantizer; +use lance_index::vector::bq::ex_dot::{ + blocked_ex_code_bytes, ex_dot_kernel, pack_blocked_row, packed_ex_code_value, +}; use lance_index::vector::bq::storage::*; use lance_index::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN}; use lance_index::vector::quantizer::{Quantization, QuantizerStorage}; use lance_index::vector::storage::{DistCalculator, VectorStore}; use lance_linalg::distance::DistanceType; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; const DIM: usize = 128; const TOTAL: usize = 16 * 1000; @@ -119,16 +124,397 @@ fn compute_distances(c: &mut Criterion) { } } -#[cfg(target_os = "linux")] -criterion_group!( - name=benches; - config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = construct_dist_table, compute_distances); +/// The table-gather ex distance used before the dedicated ex-dot kernels, +/// kept here as the baseline: per dim, extract the packed code and gather +/// `query[d] * code` from a `dim * 2^ex_bits` table. +fn gather_ex_distance(row_codes: &[u8], dim: usize, ex_bits: u8, ex_dist_table: &[f32]) -> f32 { + let entries_per_dim = 1usize << ex_bits; + (0..dim) + .map(|dim_idx| { + let code = packed_ex_code_value(row_codes, dim_idx, ex_bits) as usize; + ex_dist_table[dim_idx * entries_per_dim + code] + }) + .sum() +} + +fn ex_dot_kernels(c: &mut Criterion) { + for ex_dim in [1536usize, 2048] { + ex_dot_kernels_for_dim(c, ex_dim); + } +} + +fn ex_dot_kernels_for_dim(c: &mut Criterion, ex_dim: usize) { + const NUM_ROWS: usize = 1024; + + let mut rng = SmallRng::seed_from_u64(42); + let query = (0..ex_dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + + for ex_bits in 1..=8u8 { + let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..NUM_ROWS * ex_dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + + // The gather baseline reads the legacy sequential layout it shipped + // with; the kernel reads the blocked layout. + let seq_code_len = (ex_dim * ex_bits as usize).div_ceil(8); + let mut seq_codes = vec![0u8; NUM_ROWS * seq_code_len]; + for (row, row_values) in seq_codes + .chunks_exact_mut(seq_code_len) + .zip(values.chunks_exact(ex_dim)) + { + for (dim, &value) in row_values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + row[bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + row[bit_offset / 8 + 1] |= (bits >> 8) as u8; + } + } + } + + let kernel_code_len = blocked_ex_code_bytes(ex_dim, ex_bits); + let mut kernel_codes = vec![0u8; NUM_ROWS * kernel_code_len]; + for (row, row_values) in kernel_codes + .chunks_exact_mut(kernel_code_len) + .zip(values.chunks_exact(ex_dim)) + { + pack_blocked_row(row_values, ex_bits, row); + } + + // ex_dim is block-aligned here, so the kernels read the query as-is. + let ex_query = &query; + let kernel = ex_dot_kernel(ex_bits); + c.bench_function( + format!("RQ ex_dot kernel: ex_bits={ex_bits}, DIM={ex_dim}, rows={NUM_ROWS}").as_str(), + |b| { + b.iter(|| { + let mut sum = 0.0f32; + for row in kernel_codes.chunks_exact(kernel_code_len) { + sum += kernel(ex_query, row); + } + black_box(sum) + }) + }, + ); + + let entries_per_dim = 1usize << ex_bits; + let mut ex_dist_table = vec![0.0f32; ex_dim * entries_per_dim]; + for (dim, table) in ex_dist_table.chunks_exact_mut(entries_per_dim).enumerate() { + for (code, value) in table.iter_mut().enumerate() { + *value = query[dim] * code as f32; + } + } + c.bench_function( + format!("RQ ex_dot table-gather: ex_bits={ex_bits}, DIM={ex_dim}, rows={NUM_ROWS}") + .as_str(), + |b| { + b.iter(|| { + let mut sum = 0.0f32; + for row in seq_codes.chunks_exact(seq_code_len) { + sum += gather_ex_distance(row, ex_dim, ex_bits, &ex_dist_table); + } + black_box(sum) + }) + }, + ); + } +} + +/// Storage load cost per format: blocked-format ex codes are aliased as-is, +/// legacy sequential ex codes are repacked row by row. +fn ex_code_storage_load(c: &mut Criterion) { + use arrow_array::{ArrayRef, FixedSizeListArray, Float32Array, UInt8Array, UInt64Array}; + use lance_arrow::FixedSizeListArrayExt; + use lance_index::vector::bq::ex_dot::repack_sequential_row; + use lance_index::vector::bq::rabit_ex_code_bytes; + use lance_index::vector::bq::transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}; + use std::sync::Arc; + + const LOAD_DIM: usize = 1536; + const LOAD_ROWS: usize = 8192; + const NUM_BITS: u8 = 4; // ex_bits=3, a bit-plane width + + let ex_bits = NUM_BITS - 1; + let mut rng = SmallRng::seed_from_u64(7); + let metadata = RabitQuantizationMetadata { + rotate_mat: None, + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Fast, + code_dim: LOAD_DIM as u32, + num_bits: NUM_BITS, + packed: true, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let code_len = LOAD_DIM / 8; + let binary_codes = (0..LOAD_ROWS * code_len) + .map(|_| rng.random_range(0..=u8::MAX)) + .collect::>(); + let seq_code_len = rabit_ex_code_bytes(LOAD_DIM, ex_bits).unwrap(); + let seq_codes = (0..LOAD_ROWS * seq_code_len) + .map(|_| rng.random_range(0..=u8::MAX)) + .collect::>(); + let blocked_code_len = blocked_ex_code_bytes(LOAD_DIM, ex_bits); + let mut blocked_codes = vec![0u8; LOAD_ROWS * blocked_code_len]; + for (seq_row, blocked_row) in seq_codes + .chunks_exact(seq_code_len) + .zip(blocked_codes.chunks_exact_mut(blocked_code_len)) + { + repack_sequential_row(seq_row, LOAD_DIM, ex_bits, blocked_row); + } + + let make_batch = |ex_column: &str, ex_values: Vec, ex_code_len: usize| { + arrow_array::RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..LOAD_ROWS as u64)) as ArrayRef, + ), + ( + RABIT_CODE_COLUMN, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(binary_codes.clone()), + code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ( + ex_column, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(ex_values), + ex_code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ]) + .unwrap() + }; + + let blocked_batch = make_batch( + RABIT_BLOCKED_EX_CODE_COLUMN, + blocked_codes, + blocked_code_len, + ); + c.bench_function( + format!("RQ storage load (blocked ex codes): num_bits={NUM_BITS}, DIM={LOAD_DIM}, rows={LOAD_ROWS}") + .as_str(), + |b| { + b.iter(|| { + black_box( + RabitQuantizationStorage::try_from_batch( + blocked_batch.clone(), + &metadata, + DistanceType::L2, + None, + ) + .unwrap(), + ) + }) + }, + ); + + let legacy_batch = make_batch(RABIT_EX_CODE_COLUMN, seq_codes, seq_code_len); + c.bench_function( + format!("RQ storage load (legacy ex codes): num_bits={NUM_BITS}, DIM={LOAD_DIM}, rows={LOAD_ROWS}") + .as_str(), + |b| { + b.iter(|| { + black_box( + RabitQuantizationStorage::try_from_batch( + legacy_batch.clone(), + &metadata, + DistanceType::L2, + None, + ) + .unwrap(), + ) + }) + }, + ); +} + +/// Bulk-scoring cost of the ex stage: the quantized ex-FastScan LUT path +/// (inside `distance_all`) vs the exact per-row ex-dot kernel. The +/// binary-only run isolates the shared binary stage so the ex cost is the +/// difference from the full run. +fn ex_bulk_paths(c: &mut Criterion) { + use arrow_array::{ArrayRef, FixedSizeListArray, Float32Array, UInt8Array, UInt64Array}; + use lance_arrow::FixedSizeListArrayExt; + use lance_index::vector::ApproxMode; + use lance_index::vector::bq::ex_dot::pad_query_into; + use lance_index::vector::bq::transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}; + use lance_index::vector::storage::DistanceCalculatorOptions; + use std::sync::Arc; + + const BULK_DIM: usize = 1536; + const BULK_ROWS: usize = 16384; + + let mut rng = SmallRng::seed_from_u64(13); + for num_bits in [3u8, 5, 9] { + let ex_bits = num_bits - 1; + let max_code = ((1u16 << ex_bits) - 1) as u8; + + let rq = RabitQuantizer::new_with_rotation::( + num_bits, + BULK_DIM as i32, + RQRotationType::Fast, + ); + let metadata = rq.metadata(None); + + let code_len = BULK_DIM / 8; + let binary_codes = (0..BULK_ROWS * code_len) + .map(|_| rng.random_range(0..=u8::MAX)) + .collect::>(); + let ex_code_len = blocked_ex_code_bytes(BULK_DIM, ex_bits); + let mut ex_codes = vec![0u8; BULK_ROWS * ex_code_len]; + let values = (0..BULK_DIM) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + for row in ex_codes.chunks_exact_mut(ex_code_len) { + pack_blocked_row(&values, ex_bits, row); + } + + // No error factors: `distance_all` takes the FastScan ex bulk branch. + let batch = arrow_array::RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..BULK_ROWS as u64)) as ArrayRef, + ), + ( + RABIT_CODE_COLUMN, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(binary_codes), + code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; BULK_ROWS])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; BULK_ROWS])) as ArrayRef, + ), + ( + RABIT_BLOCKED_EX_CODE_COLUMN, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(ex_codes.clone()), + ex_code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; BULK_ROWS])) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![1.0f32; BULK_ROWS])) as ArrayRef, + ), + ]) + .unwrap(); + let storage = + RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) + .unwrap(); + + let query: ArrayRef = Arc::new(Float32Array::from( + (0..BULK_DIM) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(), + )); + + for (label, approx_mode) in [ + ("full distance_all (binary + ex LUT)", ApproxMode::Normal), + ("binary-only distance_all (fast mode)", ApproxMode::Fast), + ] { + let mut f32_scratch = Vec::new(); + let calc = storage.dist_calculator_with_scratch( + query.clone(), + 0.0, + None, + &mut f32_scratch, + DistanceCalculatorOptions { approx_mode }, + ); + let mut dists = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + let mut u32_scratch = Vec::new(); + c.bench_function( + format!("RQ bulk {label}: num_bits={num_bits}, DIM={BULK_DIM}, rows={BULK_ROWS}") + .as_str(), + |b| { + b.iter(|| { + calc.distance_all_with_scratch( + 0, + &mut dists, + &mut u16_scratch, + &mut u8_scratch, + &mut u32_scratch, + ); + black_box(dists.len()) + }) + }, + ); + } + + let kernel = ex_dot_kernel(ex_bits); + let mut ex_query = vec![0.0f32; BULK_DIM]; + pad_query_into( + query + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &mut ex_query, + ); + c.bench_function( + format!( + "RQ bulk ex kernel loop: num_bits={num_bits}, DIM={BULK_DIM}, rows={BULK_ROWS}" + ) + .as_str(), + |b| { + b.iter(|| { + let mut sum = 0.0f32; + for row in ex_codes.chunks_exact(ex_code_len) { + sum += kernel(&ex_query, row); + } + black_box(sum) + }) + }, + ); + } +} -#[cfg(not(target_os = "linux"))] criterion_group!( name=benches; config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = construct_dist_table, compute_distances); + targets = construct_dist_table, compute_distances, ex_dot_kernels, ex_code_storage_load, ex_bulk_paths); criterion_main!(benches); diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index 0fdd918edab..71c4eed7fd8 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -18,6 +18,7 @@ use crate::vector::bq::storage::RabitQuantizationMetadata; use crate::vector::quantizer::QuantizerBuildParams; pub mod builder; +pub mod ex_dot; pub mod rotation; pub mod storage; pub mod transform; diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 178a6bb5435..9eb7fc76903 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -25,7 +25,7 @@ use crate::vector::bq::transform::{ SCALE_FACTORS_FIELD, }; use crate::vector::bq::{ - RQBuildParams, RQRotationType, rabit_binary_code_bytes, rabit_ex_bits, rabit_ex_code_bytes, + RQBuildParams, RQRotationType, rabit_binary_code_bytes, rabit_ex_bits, rotation::{apply_fast_rotation, fast_rotation_signs_len, random_fast_rotation_signs}, validate_rq_num_bits, }; @@ -78,21 +78,6 @@ fn pack_sign_bits(codes: &mut [u8], rotated: &[f32]) { } } -#[inline] -fn pack_ex_code_bits(codes: &mut [u8], ex_values: &[u8], ex_bits: u8) { - codes.fill(0); - let ex_bits = ex_bits as usize; - for (dim_idx, &value) in ex_values.iter().enumerate() { - let bit_offset = dim_idx * ex_bits; - for bit_idx in 0..ex_bits { - if (value >> bit_idx) & 1 != 0 { - let dst_bit = bit_offset + bit_idx; - codes[dst_bit / u8::BITS as usize] |= 1u8 << (dst_bit % u8::BITS as usize); - } - } - } -} - const EX_QUANTIZATION_EPSILON: f32 = 1.0e-5; const EX_TIGHT_START: [f32; 9] = [0.0, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81]; @@ -200,7 +185,7 @@ fn quantize_ex_code( *ex_code_value = ex_code; } - pack_ex_code_bits(ex_code_dst, ex_code_values_dst, ex_bits); + crate::vector::bq::ex_dot::pack_blocked_row(ex_code_values_dst, ex_bits, ex_code_dst); residual_dot_code } @@ -599,7 +584,11 @@ impl RabitQuantizer { .as_slice(); let code_dim = self.code_dim(); let code_bytes = rabit_binary_code_bytes(code_dim); - let ex_code_bytes = rabit_ex_code_bytes(code_dim, ex_bits)?; + let ex_code_bytes = if ex_bits == 0 { + 0 + } else { + crate::vector::bq::ex_dot::blocked_ex_code_bytes(code_dim, ex_bits) + }; let mut encoded_codes = vec![0u8; n * code_bytes]; let mut encoded_ex_codes = (ex_bits != 0).then(|| vec![0u8; n * ex_code_bytes]); @@ -901,7 +890,7 @@ mod tests { use lance_linalg::distance::DistanceType; use rstest::rstest; - use crate::vector::bq::storage::RABIT_EX_CODE_COLUMN; + use crate::vector::bq::storage::RABIT_BLOCKED_EX_CODE_COLUMN; #[rstest] #[case(8)] @@ -978,14 +967,14 @@ mod tests { assert!( !fields .iter() - .any(|field| field.name() == RABIT_EX_CODE_COLUMN) + .any(|field| field.name() == RABIT_BLOCKED_EX_CODE_COLUMN) ); let q = RabitQuantizer::new_with_rotation::(3, 128, RQRotationType::Fast); let fields = q.extra_fields(); for expected in [ ERROR_FACTORS_FIELD.name().as_str(), - RABIT_EX_CODE_COLUMN, + RABIT_BLOCKED_EX_CODE_COLUMN, EX_ADD_FACTORS_FIELD.name().as_str(), EX_SCALE_FACTORS_FIELD.name().as_str(), ] { @@ -1095,7 +1084,8 @@ mod tests { .unwrap() .as_fixed_size_list() .value_length(), - 32 + // dim=32 is padded to one 64-dim block at ex_bits=8. + 64 ); } diff --git a/rust/lance-index/src/vector/bq/ex_dot.rs b/rust/lance-index/src/vector/bq/ex_dot.rs new file mode 100644 index 00000000000..1aeb83ba40c --- /dev/null +++ b/rust/lance-index/src/vector/bq/ex_dot.rs @@ -0,0 +1,1078 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Inner-product kernels between an `f32` query and bit-packed RaBitQ ex-codes. +//! +//! Multi-bit RaBitQ reranking reduces to `sum_d query[d] * ex_code[d]`, where +//! `ex_code[d]` is an unsigned `ex_bits`-wide integer. Materializing a +//! `dim * 2^ex_bits` lookup table and gathering one entry per dimension is +//! cache-hostile (the table is 1MiB for `ex_bits=8`, `dim=1024`); these kernels +//! instead unpack the codes with shifts and masks and FMA them against the +//! query directly, following the kernel design of the RaBitQ reference library +//! (, Apache-2.0). +//! +//! Codes are stored in the *blocked* layout: dims are grouped into 64-dim +//! blocks (the last block zero-padded) and bit-interleaved within each block +//! so that the SIMD unpack emits codes in natural dim order: +//! +//! ```text +//! per 64-dim block (T = ex_bits - 1, the top bit; "run k" = dims 16k..16k+16): +//! 1 bit: [8B] bit i of the LE word = dim i +//! 2 bits: [16B] byte b = dims {b, b+16, b+32, b+48} at bit pairs 0/2/4/6 +//! 3 bits: [16B 2-bit plane as above][8B top-bit plane] +//! 4 bits: [32B] byte 8j+b = dim 16j+b (low nibble) | dim 16j+8+b (high nibble) +//! 5 bits: [32B 4-bit plane: byte b = dims b|b+16; byte 16+b = dims b+32|b+48] +//! [8B top-bit plane] +//! 6 bits: [48B] byte 16k+b = dim 16k+b (6 low bits) | bits 2k..2k+2 of +//! dim 48+b (2 high bits) +//! 7 bits: [48B as 6 bits][8B top-bit plane] +//! 8 bits: [64B] identity +//! top-bit plane: top bit of dim 16k+b at bit 8*(b%8) + 2k + b/8 of a LE u64 +//! ``` +//! +//! Because unpack order is natural, the kernels read the rotated query +//! directly; it only needs zero-padding ([`pad_query_into`]) when the rotated +//! dim is not a multiple of 64. Legacy indexes store ex codes sequentially +//! (LSB-first bit stream) and are repacked once at load time +//! ([`repack_sequential_row`]); for `ex_bits` ∈ {1, 8} the two layouts agree +//! (modulo trailing padding, which the kernels tolerate) and rows are used as +//! stored. + +use std::sync::LazyLock; + +/// Dims are packed in blocks of this size; the query is zero-padded to a +/// whole number of blocks when the rotated dim is not already a multiple. +pub const EX_DOT_BLOCK_DIMS: usize = 64; + +/// `f32` length of the query consumed by the kernels. +pub fn padded_query_len(dim: usize) -> usize { + dim.next_multiple_of(EX_DOT_BLOCK_DIMS) +} + +/// Whether the legacy sequential layout of a row already matches the blocked +/// layout (modulo trailing zero padding, which the kernels tolerate), so +/// legacy rows can be consumed without repacking. +pub fn sequential_matches_blocked(ex_bits: u8) -> bool { + matches!(ex_bits, 1 | 8) +} + +/// Bytes per row of the blocked ex-code layout. +pub fn blocked_ex_code_bytes(dim: usize, ex_bits: u8) -> usize { + debug_assert!((1..=8).contains(&ex_bits)); + padded_query_len(dim) * ex_bits as usize / 8 +} + +/// Dimensions per unpacking group for the given code width. +fn group_dims(ex_bits: u8) -> usize { + match ex_bits { + 1 | 4 | 8 => 16, + _ => EX_DOT_BLOCK_DIMS, + } +} + +fn group_bytes(ex_bits: u8) -> usize { + group_dims(ex_bits) * ex_bits as usize / 8 +} + +/// Extract the `ex_bits`-wide code of `dim_idx` from a sequentially bit-packed +/// row (LSB-first, codes may straddle byte boundaries). +#[inline] +pub fn packed_ex_code_value(row_codes: &[u8], dim_idx: usize, ex_bits: u8) -> u8 { + debug_assert!(ex_bits > 0); + let bit_offset = dim_idx * ex_bits as usize; + let byte_idx = bit_offset / u8::BITS as usize; + let bit_shift = bit_offset % u8::BITS as usize; + let bits = row_codes[byte_idx] as u16 + | row_codes + .get(byte_idx + 1) + .map(|byte| (*byte as u16) << u8::BITS) + .unwrap_or_default(); + let mask = (1u16 << ex_bits) - 1; + ((bits >> bit_shift) & mask) as u8 +} + +/// Zero-pad the rotated query to a whole number of 64-dim blocks. Only needed +/// when `dim` is not a multiple of [`EX_DOT_BLOCK_DIMS`]; aligned queries are +/// passed to the kernels as-is. +pub fn pad_query_into(rotated_query: &[f32], out: &mut [f32]) { + debug_assert_eq!(out.len(), padded_query_len(rotated_query.len())); + out[..rotated_query.len()].copy_from_slice(rotated_query); + out[rotated_query.len()..].fill(0.0); +} + +/// Pack the top bit of each of 64 codes into a `u64` so kernels can position +/// it with two shifts per 16-code run: the top bit of dim `16k + b` is stored +/// at bit `8 * (b % 8) + 2k + b / 8`. +fn pack_top_plane(block_values: &[u8; 64], top_bit: u8) -> u64 { + let mut plane = 0u64; + for k in 0..4 { + for b in 0..16 { + let bit = (block_values[16 * k + b] >> top_bit) & 1; + plane |= (bit as u64) << (8 * (b % 8) + 2 * k + b / 8); + } + } + plane +} + +/// Shift `plane` so that its bit `8j + from_bit` lands at bit `8j + to_bit`. +#[inline(always)] +fn shift_plane(plane: u64, from_bit: usize, to_bit: usize) -> u64 { + if from_bit >= to_bit { + plane >> (from_bit - to_bit) + } else { + plane << (to_bit - from_bit) + } +} + +/// Pack one block of 64 code values (natural dim order) into the blocked +/// layout described in the module docs. +fn pack_block(ex_bits: u8, block_values: &[u8; 64], out: &mut [u8]) { + let v = block_values; + match ex_bits { + 1 => { + for (b, byte) in out[..8].iter_mut().enumerate() { + *byte = (0..8).fold(0, |acc, t| acc | ((v[8 * b + t] & 1) << t)); + } + } + 2 | 3 => { + for b in 0..16 { + out[b] = (v[b] & 0b11) + | ((v[16 + b] & 0b11) << 2) + | ((v[32 + b] & 0b11) << 4) + | ((v[48 + b] & 0b11) << 6); + } + if ex_bits == 3 { + out[16..24].copy_from_slice(&pack_top_plane(v, 2).to_le_bytes()); + } + } + 4 => { + for unit in 0..4 { + for b in 0..8 { + out[8 * unit + b] = + (v[16 * unit + b] & 0x0f) | ((v[16 * unit + 8 + b] & 0x0f) << 4); + } + } + } + 5 => { + for b in 0..16 { + out[b] = (v[b] & 0x0f) | ((v[16 + b] & 0x0f) << 4); + out[16 + b] = (v[32 + b] & 0x0f) | ((v[48 + b] & 0x0f) << 4); + } + out[32..40].copy_from_slice(&pack_top_plane(v, 4).to_le_bytes()); + } + 6 | 7 => { + // Runs 0..3 keep their 6 low bits in place; the fourth run's dims + // are split into three 2-bit pieces stored in the runs' top bits. + for k in 0..3 { + for b in 0..16 { + out[16 * k + b] = + (v[16 * k + b] & 0x3f) | (((v[48 + b] >> (2 * k)) & 0b11) << 6); + } + } + if ex_bits == 7 { + out[48..56].copy_from_slice(&pack_top_plane(v, 6).to_le_bytes()); + } + } + 8 => out[..64].copy_from_slice(v), + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } +} + +/// Pack one row of unpacked code values (one `u8` per dim) into the blocked +/// layout; the writer path. `out` must have [`blocked_ex_code_bytes`] bytes. +pub fn pack_blocked_row(values: &[u8], ex_bits: u8, out: &mut [u8]) { + debug_assert_eq!(out.len(), blocked_ex_code_bytes(values.len(), ex_bits)); + let block_bytes = EX_DOT_BLOCK_DIMS * ex_bits as usize / 8; + let mut block_values = [0u8; 64]; + for (block, out) in out.chunks_exact_mut(block_bytes).enumerate() { + let base = block * EX_DOT_BLOCK_DIMS; + let count = EX_DOT_BLOCK_DIMS.min(values.len() - base); + block_values[..count].copy_from_slice(&values[base..base + count]); + block_values[count..].fill(0); + pack_block(ex_bits, &block_values, out); + } +} + +/// Repack one legacy sequentially bit-packed row into the blocked layout. +/// `out` must have [`blocked_ex_code_bytes`] bytes. +pub fn repack_sequential_row(seq_row: &[u8], dim: usize, ex_bits: u8, out: &mut [u8]) { + debug_assert_eq!(out.len(), blocked_ex_code_bytes(dim, ex_bits)); + let block_bytes = EX_DOT_BLOCK_DIMS * ex_bits as usize / 8; + let mut block_values = [0u8; 64]; + for (block, out) in out.chunks_exact_mut(block_bytes).enumerate() { + block_values.fill(0); + let base = block * EX_DOT_BLOCK_DIMS; + let count = EX_DOT_BLOCK_DIMS.min(dim.saturating_sub(base)); + for (i, value) in block_values[..count].iter_mut().enumerate() { + *value = packed_ex_code_value(seq_row, base + i, ex_bits); + } + pack_block(ex_bits, &block_values, out); + } +} + +/// Unpack one code group into per-dim values (natural dim order). Reference +/// implementation for the SIMD unpackers; also the scalar fallback. +fn unpack_group(ex_bits: u8, group_codes: &[u8], out: &mut [u8; 64]) { + debug_assert_eq!(group_codes.len(), group_bytes(ex_bits)); + match ex_bits { + 1 => { + for (i, value) in out[..16].iter_mut().enumerate() { + *value = (group_codes[i / 8] >> (i % 8)) & 1; + } + } + 2 => { + for k in 0..4 { + for b in 0..16 { + out[16 * k + b] = (group_codes[b] >> (2 * k)) & 0b11; + } + } + } + 3 => { + let plane = u64::from_le_bytes(group_codes[16..24].try_into().unwrap()); + for k in 0..4 { + for b in 0..16 { + let top = (plane >> (8 * (b % 8) + 2 * k + b / 8)) & 1; + out[16 * k + b] = ((group_codes[b] >> (2 * k)) & 0b11) | ((top as u8) << 2); + } + } + } + 4 => { + for b in 0..8 { + out[b] = group_codes[b] & 0x0f; + out[8 + b] = group_codes[b] >> 4; + } + } + 5 => { + let plane = u64::from_le_bytes(group_codes[32..40].try_into().unwrap()); + for k in 0..4 { + for b in 0..16 { + let nibble = (group_codes[16 * (k / 2) + b] >> (4 * (k % 2))) & 0x0f; + let top = (plane >> (8 * (b % 8) + 2 * k + b / 8)) & 1; + out[16 * k + b] = nibble | ((top as u8) << 4); + } + } + } + 6 | 7 => { + for k in 0..3 { + for b in 0..16 { + out[16 * k + b] = group_codes[16 * k + b] & 0x3f; + } + } + for b in 0..16 { + out[48 + b] = (group_codes[b] >> 6) + | ((group_codes[16 + b] >> 6) << 2) + | ((group_codes[32 + b] >> 6) << 4); + } + if ex_bits == 7 { + let plane = u64::from_le_bytes(group_codes[48..56].try_into().unwrap()); + for k in 0..4 { + for b in 0..16 { + let top = (plane >> (8 * (b % 8) + 2 * k + b / 8)) & 1; + out[16 * k + b] |= (top as u8) << 6; + } + } + } + } + 8 => out[..16].copy_from_slice(group_codes), + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } +} + +/// `sum_d query[d] * code[d]` for one row of blocked-layout codes. +/// +/// The query must cover a whole number of 64-dim blocks (the rotated query +/// as-is for aligned dims, otherwise zero-padded via [`pad_query_into`]); +/// `codes` is the blocked row slice. Rows shorter than the padded query +/// length are treated as zero-padded. +pub type ExDotFn = fn(&[f32], &[u8]) -> f32; + +/// Resolve the dot kernel for `ex_bits` once; the result can be cached by the +/// caller for per-candidate use. +pub fn ex_dot_kernel(ex_bits: u8) -> ExDotFn { + debug_assert!((1..=8).contains(&ex_bits)); + static KERNELS: LazyLock<[ExDotFn; 8]> = + LazyLock::new(|| std::array::from_fn(|i| select_ex_dot_kernel(i as u8 + 1))); + KERNELS[usize::from(ex_bits) - 1] +} + +fn select_ex_dot_kernel(ex_bits: u8) -> ExDotFn { + #[cfg(target_arch = "x86_64")] + { + if std::arch::is_x86_feature_detected!("avx512f") { + return x86::avx512_kernel(ex_bits); + } + if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") + { + return x86::avx2_kernel(ex_bits); + } + } + #[cfg(target_arch = "aarch64")] + { + // NEON is part of the aarch64 baseline. + return neon::kernel(ex_bits); + } + #[allow(unreachable_code)] + scalar_kernel(ex_bits) +} + +fn scalar_kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => ex_dot_scalar::<1>, + 2 => ex_dot_scalar::<2>, + 3 => ex_dot_scalar::<3>, + 4 => ex_dot_scalar::<4>, + 5 => ex_dot_scalar::<5>, + 6 => ex_dot_scalar::<6>, + 7 => ex_dot_scalar::<7>, + 8 => ex_dot_scalar::<8>, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } +} + +fn ex_dot_scalar(ex_query: &[f32], codes: &[u8]) -> f32 { + let group_dims = group_dims(EX_BITS); + let bytes_per_group = group_bytes(EX_BITS); + debug_assert_eq!(ex_query.len() % EX_DOT_BLOCK_DIMS, 0); + debug_assert!(codes.len() * u8::BITS as usize <= ex_query.len() * EX_BITS as usize); + + let mut sum = 0.0f32; + let mut unpacked = [0u8; 64]; + let mut padded = [0u8; 56]; + for (group, query) in ex_query.chunks_exact(group_dims).enumerate() { + let start = group * bytes_per_group; + if start >= codes.len() { + // The remaining query lanes are zero padding. + break; + } + let group_codes = if start + bytes_per_group <= codes.len() { + &codes[start..start + bytes_per_group] + } else { + let avail = codes.len() - start; + padded[..bytes_per_group].fill(0); + padded[..avail].copy_from_slice(&codes[start..]); + &padded[..bytes_per_group] + }; + unpack_group(EX_BITS, group_codes, &mut unpacked); + for (q, &code) in query.iter().zip(unpacked[..group_dims].iter()) { + sum += q * code as f32; + } + } + sum +} + +#[cfg(target_arch = "x86_64")] +mod x86 { + use super::ExDotFn; + use std::arch::x86_64::*; + + pub(super) fn avx2_kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => dot_u1_avx2_dispatch, + 2 => dot_u2_avx2_dispatch, + 3 => dot_u3_avx2_dispatch, + 4 => dot_u4_avx2_dispatch, + 5 => dot_u5_avx2_dispatch, + 6 => dot_u6_avx2_dispatch, + 7 => dot_u7_avx2_dispatch, + 8 => dot_u8_avx2_dispatch, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } + } + + pub(super) fn avx512_kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => dot_u1_avx512_dispatch, + 2 => dot_u2_avx512_dispatch, + 3 => dot_u3_avx512_dispatch, + 4 => dot_u4_avx512_dispatch, + 5 => dot_u5_avx512_dispatch, + 6 => dot_u6_avx512_dispatch, + 7 => dot_u7_avx512_dispatch, + 8 => dot_u8_avx512_dispatch, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } + } + + /// Broadcast a byte to the 8 bytes of a `u64`. + #[inline(always)] + fn splat_byte(byte: u8) -> u64 { + byte as u64 * 0x0101_0101_0101_0101 + } + + // Unpack helpers. They read exactly one group of code bytes and return + // runs of 16 codes matching the kernel-order query. Only SSE2 (baseline on + // x86_64) is required. + + /// 16 1-bit codes from 2 bytes: compare each replicated byte against + /// per-lane bit masks to turn set bits into 0/1 bytes. + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u1(ptr: *const u8) -> [__m128i; 1] { + let (b0, b1) = unsafe { (ptr.read(), ptr.add(1).read()) }; + let bytes = _mm_set_epi64x(splat_byte(b1) as i64, splat_byte(b0) as i64); + let bit_select = _mm_set1_epi64x(0x8040_2010_0804_0201u64 as i64); + let selected = _mm_cmpeq_epi8(_mm_and_si128(bytes, bit_select), bit_select); + [_mm_and_si128(selected, _mm_set1_epi8(1))] + } + + /// 64 2-bit codes from 16 bytes: byte b holds dims 4b..4b+3 at bit pairs. + /// The 16-bit shifts drag bits across byte boundaries, which the per-byte + /// mask removes. + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u2(ptr: *const u8) -> [__m128i; 4] { + let raw = unsafe { _mm_loadu_si128(ptr as *const __m128i) }; + let mask = _mm_set1_epi8(0b11); + [ + _mm_and_si128(raw, mask), + _mm_and_si128(_mm_srli_epi16::<2>(raw), mask), + _mm_and_si128(_mm_srli_epi16::<4>(raw), mask), + _mm_and_si128(_mm_srli_epi16::<6>(raw), mask), + ] + } + + /// Position the top-bit plane (see [`super::pack_top_plane`]) of run `k` + /// at `top_bit` within each byte. + #[inline] + #[target_feature(enable = "sse2")] + fn top_plane_run(plane: u64, k: usize, top_bit: usize) -> __m128i { + let lo = super::shift_plane(plane, 2 * k, top_bit); + let hi = super::shift_plane(plane, 2 * k + 1, top_bit); + _mm_and_si128( + _mm_set_epi64x(hi as i64, lo as i64), + _mm_set1_epi8(1 << top_bit), + ) + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u3(ptr: *const u8) -> [__m128i; 4] { + let mut runs = unsafe { unpack_u2(ptr) }; + let plane = unsafe { (ptr.add(16) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = _mm_or_si128(*run, top_plane_run(plane, k, 2)); + } + runs + } + + /// 16 4-bit codes from 8 bytes: low nibbles are the even dims, high + /// nibbles the odd dims. + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u4(ptr: *const u8) -> [__m128i; 1] { + let word = unsafe { (ptr as *const u64).read_unaligned() }; + let mask = 0x0f0f_0f0f_0f0f_0f0fu64; + [_mm_set_epi64x( + ((word >> 4) & mask) as i64, + (word & mask) as i64, + )] + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u5(ptr: *const u8) -> [__m128i; 4] { + let blk0 = unsafe { _mm_loadu_si128(ptr as *const __m128i) }; + let blk1 = unsafe { _mm_loadu_si128(ptr.add(16) as *const __m128i) }; + let plane = unsafe { (ptr.add(32) as *const u64).read_unaligned() }; + let mask = _mm_set1_epi8(0x0f); + let mut runs = [ + _mm_and_si128(blk0, mask), + _mm_and_si128(_mm_srli_epi16::<4>(blk0), mask), + _mm_and_si128(blk1, mask), + _mm_and_si128(_mm_srli_epi16::<4>(blk1), mask), + ]; + for (k, run) in runs.iter_mut().enumerate() { + *run = _mm_or_si128(*run, top_plane_run(plane, k, 4)); + } + runs + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u6(ptr: *const u8) -> [__m128i; 4] { + let blk0 = unsafe { _mm_loadu_si128(ptr as *const __m128i) }; + let blk1 = unsafe { _mm_loadu_si128(ptr.add(16) as *const __m128i) }; + let blk2 = unsafe { _mm_loadu_si128(ptr.add(32) as *const __m128i) }; + let mask6 = _mm_set1_epi8(0x3f); + let mask2 = _mm_set1_epi8(0b1100_0000u8 as i8); + let stolen = _mm_or_si128( + _mm_or_si128( + _mm_srli_epi16::<6>(_mm_and_si128(blk0, mask2)), + _mm_srli_epi16::<4>(_mm_and_si128(blk1, mask2)), + ), + _mm_srli_epi16::<2>(_mm_and_si128(blk2, mask2)), + ); + [ + _mm_and_si128(blk0, mask6), + _mm_and_si128(blk1, mask6), + _mm_and_si128(blk2, mask6), + stolen, + ] + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u7(ptr: *const u8) -> [__m128i; 4] { + let mut runs = unsafe { unpack_u6(ptr) }; + let plane = unsafe { (ptr.add(48) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = _mm_or_si128(*run, top_plane_run(plane, k, 6)); + } + runs + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u8x16(ptr: *const u8) -> [__m128i; 1] { + [unsafe { _mm_loadu_si128(ptr as *const __m128i) }] + } + + /// FMA 16 code bytes against 16 query floats (AVX2: two 8-float halves). + #[inline] + #[target_feature(enable = "avx2", enable = "fma")] + unsafe fn fma16_avx2(codes: __m128i, query: *const f32, acc: &mut [__m256; 2]) { + let lo = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(codes)); + acc[0] = _mm256_fmadd_ps(lo, unsafe { _mm256_loadu_ps(query) }, acc[0]); + let hi = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_srli_si128::<8>(codes))); + acc[1] = _mm256_fmadd_ps(hi, unsafe { _mm256_loadu_ps(query.add(8)) }, acc[1]); + } + + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn reduce_add_avx2(acc: [__m256; 2]) -> f32 { + let v = _mm256_add_ps(acc[0], acc[1]); + let halves = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps::<1>(v)); + let pairs = _mm_add_ps(halves, _mm_movehl_ps(halves, halves)); + let total = _mm_add_ss(pairs, _mm_shuffle_ps::<0b01>(pairs, pairs)); + _mm_cvtss_f32(total) + } + + /// FMA 16 code bytes against 16 query floats (AVX-512: one 16-float lane). + #[inline] + #[target_feature(enable = "avx512f")] + unsafe fn fma16_avx512(codes: __m128i, query: *const f32, acc: &mut __m512) { + let values = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(codes)); + *acc = _mm512_fmadd_ps(values, unsafe { _mm512_loadu_ps(query) }, *acc); + } + + macro_rules! x86_dot_kernel { + ($name:ident, $dispatch:ident, $unpack:ident, $ex_bits:expr, $runs:expr) => { + #[target_feature(enable = "avx2", enable = "fma")] + unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { + const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; + const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; + debug_assert_eq!(ex_query.len() % super::EX_DOT_BLOCK_DIMS, 0); + debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); + + let groups = ex_query.len() / GROUP_DIMS; + let full_groups = (codes.len() / GROUP_BYTES).min(groups); + // Two accumulators per run position break the FMA latency + // chain; they are summed once at the end. + let mut acc = [_mm256_setzero_ps(); 2]; + for group in 0..full_groups { + // SAFETY: `group < full_groups` keeps both the code group + // and the query run in bounds. + let runs = unsafe { $unpack(codes.as_ptr().add(group * GROUP_BYTES)) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx2( + codes16, + ex_query.as_ptr().add(group * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + let consumed = full_groups * GROUP_BYTES; + if consumed < codes.len() && full_groups < groups { + // Zero-pad the final partial code group on the stack. + let mut padded = [0u8; GROUP_BYTES]; + padded[..codes.len() - consumed].copy_from_slice(&codes[consumed..]); + let runs = unsafe { $unpack(padded.as_ptr()) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx2( + codes16, + ex_query.as_ptr().add(full_groups * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + unsafe { reduce_add_avx2(acc) } + } + + fn $dispatch(ex_query: &[f32], codes: &[u8]) -> f32 { + // SAFETY: only selected when AVX2 and FMA were detected. + unsafe { $name(ex_query, codes) } + } + }; + } + + macro_rules! x86_dot_kernel_avx512 { + ($name:ident, $dispatch:ident, $unpack:ident, $ex_bits:expr, $runs:expr) => { + #[target_feature(enable = "avx512f")] + unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { + const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; + const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; + debug_assert_eq!(ex_query.len() % super::EX_DOT_BLOCK_DIMS, 0); + debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); + + let groups = ex_query.len() / GROUP_DIMS; + let full_groups = (codes.len() / GROUP_BYTES).min(groups); + // Alternating by group as well as run keeps two independent + // FMA chains even for the single-run widths. + let mut acc = [_mm512_setzero_ps(); 2]; + for group in 0..full_groups { + // SAFETY: `group < full_groups` keeps both the code group + // and the query run in bounds. + let runs = unsafe { $unpack(codes.as_ptr().add(group * GROUP_BYTES)) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx512( + codes16, + ex_query.as_ptr().add(group * GROUP_DIMS + run * 16), + &mut acc[(group + run) % 2], + ) + }; + } + } + let consumed = full_groups * GROUP_BYTES; + if consumed < codes.len() && full_groups < groups { + let mut padded = [0u8; GROUP_BYTES]; + padded[..codes.len() - consumed].copy_from_slice(&codes[consumed..]); + let runs = unsafe { $unpack(padded.as_ptr()) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx512( + codes16, + ex_query.as_ptr().add(full_groups * GROUP_DIMS + run * 16), + &mut acc[(full_groups + run) % 2], + ) + }; + } + } + _mm512_reduce_add_ps(_mm512_add_ps(acc[0], acc[1])) + } + + fn $dispatch(ex_query: &[f32], codes: &[u8]) -> f32 { + // SAFETY: only selected when AVX-512F was detected. + unsafe { $name(ex_query, codes) } + } + }; + } + + x86_dot_kernel!(dot_u1_avx2, dot_u1_avx2_dispatch, unpack_u1, 1, 1); + x86_dot_kernel!(dot_u2_avx2, dot_u2_avx2_dispatch, unpack_u2, 2, 4); + x86_dot_kernel!(dot_u3_avx2, dot_u3_avx2_dispatch, unpack_u3, 3, 4); + x86_dot_kernel!(dot_u4_avx2, dot_u4_avx2_dispatch, unpack_u4, 4, 1); + x86_dot_kernel!(dot_u5_avx2, dot_u5_avx2_dispatch, unpack_u5, 5, 4); + x86_dot_kernel!(dot_u6_avx2, dot_u6_avx2_dispatch, unpack_u6, 6, 4); + x86_dot_kernel!(dot_u7_avx2, dot_u7_avx2_dispatch, unpack_u7, 7, 4); + x86_dot_kernel!(dot_u8_avx2, dot_u8_avx2_dispatch, unpack_u8x16, 8, 1); + + x86_dot_kernel_avx512!(dot_u1_avx512, dot_u1_avx512_dispatch, unpack_u1, 1, 1); + x86_dot_kernel_avx512!(dot_u2_avx512, dot_u2_avx512_dispatch, unpack_u2, 2, 4); + x86_dot_kernel_avx512!(dot_u3_avx512, dot_u3_avx512_dispatch, unpack_u3, 3, 4); + x86_dot_kernel_avx512!(dot_u4_avx512, dot_u4_avx512_dispatch, unpack_u4, 4, 1); + x86_dot_kernel_avx512!(dot_u5_avx512, dot_u5_avx512_dispatch, unpack_u5, 5, 4); + x86_dot_kernel_avx512!(dot_u6_avx512, dot_u6_avx512_dispatch, unpack_u6, 6, 4); + x86_dot_kernel_avx512!(dot_u7_avx512, dot_u7_avx512_dispatch, unpack_u7, 7, 4); + x86_dot_kernel_avx512!(dot_u8_avx512, dot_u8_avx512_dispatch, unpack_u8x16, 8, 1); +} + +#[cfg(target_arch = "aarch64")] +mod neon { + use super::ExDotFn; + use std::arch::aarch64::*; + + pub(super) fn kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => dot_u1_neon_dispatch, + 2 => dot_u2_neon_dispatch, + 3 => dot_u3_neon_dispatch, + 4 => dot_u4_neon_dispatch, + 5 => dot_u5_neon_dispatch, + 6 => dot_u6_neon_dispatch, + 7 => dot_u7_neon_dispatch, + 8 => dot_u8_neon_dispatch, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u1(ptr: *const u8) -> [uint8x16_t; 1] { + let (b0, b1) = unsafe { (ptr.read(), ptr.add(1).read()) }; + let bytes = vcombine_u8(vdup_n_u8(b0), vdup_n_u8(b1)); + let bit_select = vreinterpretq_u8_u64(vdupq_n_u64(0x8040_2010_0804_0201)); + [vandq_u8(vtstq_u8(bytes, bit_select), vdupq_n_u8(1))] + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u2(ptr: *const u8) -> [uint8x16_t; 4] { + let raw = unsafe { vld1q_u8(ptr) }; + let mask = vdupq_n_u8(0b11); + [ + vandq_u8(raw, mask), + vandq_u8(vshrq_n_u8::<2>(raw), mask), + vandq_u8(vshrq_n_u8::<4>(raw), mask), + vshrq_n_u8::<6>(raw), + ] + } + + #[inline] + #[target_feature(enable = "neon")] + fn top_plane_run(plane: u64, k: usize, top_bit: usize) -> uint8x16_t { + let lo = super::shift_plane(plane, 2 * k, top_bit); + let hi = super::shift_plane(plane, 2 * k + 1, top_bit); + vandq_u8( + vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(lo), vcreate_u64(hi))), + vdupq_n_u8(1 << top_bit), + ) + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u3(ptr: *const u8) -> [uint8x16_t; 4] { + let mut runs = unsafe { unpack_u2(ptr) }; + let plane = unsafe { (ptr.add(16) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = vorrq_u8(*run, top_plane_run(plane, k, 2)); + } + runs + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u4(ptr: *const u8) -> [uint8x16_t; 1] { + let word = unsafe { (ptr as *const u64).read_unaligned() }; + let mask = 0x0f0f_0f0f_0f0f_0f0fu64; + [vreinterpretq_u8_u64(vcombine_u64( + vcreate_u64(word & mask), + vcreate_u64((word >> 4) & mask), + ))] + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u5(ptr: *const u8) -> [uint8x16_t; 4] { + let blk0 = unsafe { vld1q_u8(ptr) }; + let blk1 = unsafe { vld1q_u8(ptr.add(16)) }; + let plane = unsafe { (ptr.add(32) as *const u64).read_unaligned() }; + let mask = vdupq_n_u8(0x0f); + let mut runs = [ + vandq_u8(blk0, mask), + vshrq_n_u8::<4>(blk0), + vandq_u8(blk1, mask), + vshrq_n_u8::<4>(blk1), + ]; + for (k, run) in runs.iter_mut().enumerate() { + *run = vorrq_u8(*run, top_plane_run(plane, k, 4)); + } + runs + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u6(ptr: *const u8) -> [uint8x16_t; 4] { + let blk0 = unsafe { vld1q_u8(ptr) }; + let blk1 = unsafe { vld1q_u8(ptr.add(16)) }; + let blk2 = unsafe { vld1q_u8(ptr.add(32)) }; + let mask6 = vdupq_n_u8(0x3f); + let stolen = vorrq_u8( + vorrq_u8( + vshrq_n_u8::<6>(blk0), + vshlq_n_u8::<2>(vshrq_n_u8::<6>(blk1)), + ), + vshlq_n_u8::<4>(vshrq_n_u8::<6>(blk2)), + ); + [ + vandq_u8(blk0, mask6), + vandq_u8(blk1, mask6), + vandq_u8(blk2, mask6), + stolen, + ] + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u7(ptr: *const u8) -> [uint8x16_t; 4] { + let mut runs = unsafe { unpack_u6(ptr) }; + let plane = unsafe { (ptr.add(48) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = vorrq_u8(*run, top_plane_run(plane, k, 6)); + } + runs + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u8x16(ptr: *const u8) -> [uint8x16_t; 1] { + [unsafe { vld1q_u8(ptr) }] + } + + /// FMA 16 code bytes against 16 query floats over four 4-float lanes. + #[inline] + #[target_feature(enable = "neon")] + unsafe fn fma16_neon(codes: uint8x16_t, query: *const f32, acc: &mut [float32x4_t; 4]) { + let lo = vmovl_u8(vget_low_u8(codes)); + let hi = vmovl_u8(vget_high_u8(codes)); + let c0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo))); + let c1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo))); + let c2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi))); + let c3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi))); + unsafe { + acc[0] = vfmaq_f32(acc[0], c0, vld1q_f32(query)); + acc[1] = vfmaq_f32(acc[1], c1, vld1q_f32(query.add(4))); + acc[2] = vfmaq_f32(acc[2], c2, vld1q_f32(query.add(8))); + acc[3] = vfmaq_f32(acc[3], c3, vld1q_f32(query.add(12))); + } + } + + macro_rules! neon_dot_kernel { + ($name:ident, $dispatch:ident, $unpack:ident, $ex_bits:expr, $runs:expr) => { + #[target_feature(enable = "neon")] + unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { + const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; + const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; + debug_assert_eq!(ex_query.len() % super::EX_DOT_BLOCK_DIMS, 0); + debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); + + let groups = ex_query.len() / GROUP_DIMS; + let full_groups = (codes.len() / GROUP_BYTES).min(groups); + let mut acc = [vdupq_n_f32(0.0); 4]; + for group in 0..full_groups { + // SAFETY: `group < full_groups` keeps both the code group + // and the query run in bounds. + let runs = unsafe { $unpack(codes.as_ptr().add(group * GROUP_BYTES)) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_neon( + codes16, + ex_query.as_ptr().add(group * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + let consumed = full_groups * GROUP_BYTES; + if consumed < codes.len() && full_groups < groups { + // Zero-pad the final partial code group on the stack. + let mut padded = [0u8; GROUP_BYTES]; + padded[..codes.len() - consumed].copy_from_slice(&codes[consumed..]); + let runs = unsafe { $unpack(padded.as_ptr()) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_neon( + codes16, + ex_query.as_ptr().add(full_groups * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + vaddvq_f32(vaddq_f32( + vaddq_f32(acc[0], acc[1]), + vaddq_f32(acc[2], acc[3]), + )) + } + + fn $dispatch(ex_query: &[f32], codes: &[u8]) -> f32 { + // SAFETY: NEON is part of the aarch64 baseline. + unsafe { $name(ex_query, codes) } + } + }; + } + + neon_dot_kernel!(dot_u1_neon, dot_u1_neon_dispatch, unpack_u1, 1, 1); + neon_dot_kernel!(dot_u2_neon, dot_u2_neon_dispatch, unpack_u2, 2, 4); + neon_dot_kernel!(dot_u3_neon, dot_u3_neon_dispatch, unpack_u3, 3, 4); + neon_dot_kernel!(dot_u4_neon, dot_u4_neon_dispatch, unpack_u4, 4, 1); + neon_dot_kernel!(dot_u5_neon, dot_u5_neon_dispatch, unpack_u5, 5, 4); + neon_dot_kernel!(dot_u6_neon, dot_u6_neon_dispatch, unpack_u6, 6, 4); + neon_dot_kernel!(dot_u7_neon, dot_u7_neon_dispatch, unpack_u7, 7, 4); + neon_dot_kernel!(dot_u8_neon, dot_u8_neon_dispatch, unpack_u8x16, 8, 1); +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use rstest::rstest; + + /// Bit-pack code values sequentially (LSB-first), the on-disk ex-code layout. + fn pack_sequential(values: &[u8], ex_bits: u8) -> Vec { + let mut out = vec![0u8; (values.len() * ex_bits as usize).div_ceil(8)]; + for (dim, &value) in values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + out[bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + out[bit_offset / 8 + 1] |= (bits >> 8) as u8; + } + } + out + } + + fn kernel_codes(values: &[u8], dim: usize, ex_bits: u8) -> Vec { + debug_assert_eq!(values.len(), dim); + let mut out = vec![0u8; blocked_ex_code_bytes(dim, ex_bits)]; + pack_blocked_row(values, ex_bits, &mut out); + out + } + + fn available_kernels(ex_bits: u8) -> Vec<(&'static str, ExDotFn)> { + // `mut` is only exercised on x86_64 where extra kernels may be pushed. + #[allow(unused_mut)] + let mut kernels = vec![ + ("scalar", scalar_kernel(ex_bits)), + ("dispatched", ex_dot_kernel(ex_bits)), + ]; + #[cfg(target_arch = "x86_64")] + { + if std::arch::is_x86_feature_detected!("avx2") + && std::arch::is_x86_feature_detected!("fma") + { + kernels.push(("avx2", x86::avx2_kernel(ex_bits))); + } + if std::arch::is_x86_feature_detected!("avx512f") { + kernels.push(("avx512", x86::avx512_kernel(ex_bits))); + } + } + kernels + } + + #[rstest] + fn test_ex_dot_matches_reference( + #[values(1, 2, 3, 4, 5, 6, 7, 8)] ex_bits: u8, + #[values(7, 16, 60, 64, 100, 128, 1024, 1536, 2048)] dim: usize, + ) { + let mut rng = SmallRng::seed_from_u64(42 + ex_bits as u64 * 1000 + dim as u64); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let query = (0..dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + + let expected = query + .iter() + .zip(values.iter()) + .map(|(q, &c)| *q as f64 * c as f64) + .sum::(); + + let codes = kernel_codes(&values, dim, ex_bits); + let mut ex_query = vec![0.0; padded_query_len(dim)]; + pad_query_into(&query, &mut ex_query); + + let tolerance = 1e-3 * expected.abs().max(1.0); + for (name, kernel) in available_kernels(ex_bits) { + let actual = kernel(&ex_query, &codes) as f64; + assert!( + (actual - expected).abs() <= tolerance, + "ex_bits={ex_bits} dim={dim} kernel={name}: {actual} != {expected}" + ); + } + } + + #[rstest] + fn test_unpack_group_roundtrip(#[values(1, 2, 3, 4, 5, 6, 7, 8)] ex_bits: u8) { + let mut rng = SmallRng::seed_from_u64(7 + ex_bits as u64); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..EX_DOT_BLOCK_DIMS) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let codes = kernel_codes(&values, EX_DOT_BLOCK_DIMS, ex_bits); + + // Unpacking each kernel group must reproduce the values in natural + // dim order. + let dims = group_dims(ex_bits); + let bytes = group_bytes(ex_bits); + let mut unpacked = [0u8; 64]; + for group in 0..EX_DOT_BLOCK_DIMS / dims { + unpack_group( + ex_bits, + &codes[group * bytes..(group + 1) * bytes], + &mut unpacked, + ); + assert_eq!( + &unpacked[..dims], + &values[group * dims..(group + 1) * dims], + "ex_bits={ex_bits} group={group}" + ); + } + } + + /// The legacy sequential rows must repack into exactly what the writer + /// produces from the unpacked values. + #[rstest] + fn test_repack_sequential_matches_blocked( + #[values(1, 2, 3, 4, 5, 6, 7, 8)] ex_bits: u8, + #[values(7, 64, 100, 1536)] dim: usize, + ) { + let mut rng = SmallRng::seed_from_u64(11 + ex_bits as u64 * 100 + dim as u64); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let seq = pack_sequential(&values, ex_bits); + + let mut repacked = vec![0u8; blocked_ex_code_bytes(dim, ex_bits)]; + repack_sequential_row(&seq, dim, ex_bits, &mut repacked); + assert_eq!(repacked, kernel_codes(&values, dim, ex_bits)); + + // For the widths where the sequential layout is already blocked + // (modulo trailing padding), the raw row must be a prefix. + if sequential_matches_blocked(ex_bits) { + assert_eq!(&repacked[..seq.len()], &seq); + assert!(repacked[seq.len()..].iter().all(|&byte| byte == 0)); + } + } + + /// Dense dim sweep for the bit-plane widths: every tail shape within the + /// 64-dim kernel group, plus multi-group sizes. + #[rstest] + fn test_ex_dot_plane_widths_dense_dims(#[values(3, 5)] ex_bits: u8) { + let mut rng = SmallRng::seed_from_u64(97 + ex_bits as u64); + let max_code = ((1u16 << ex_bits) - 1) as u8; + for dim in (1..=160).chain([255, 256, 1000, 1536, 2048]) { + let values = (0..dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let query = (0..dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let expected = query + .iter() + .zip(values.iter()) + .map(|(q, &c)| *q as f64 * c as f64) + .sum::(); + + let codes = kernel_codes(&values, dim, ex_bits); + let mut ex_query = vec![0.0; padded_query_len(dim)]; + pad_query_into(&query, &mut ex_query); + let tolerance = 1e-3 * expected.abs().max(1.0); + for (name, kernel) in available_kernels(ex_bits) { + let actual = kernel(&ex_query, &codes) as f64; + assert!( + (actual - expected).abs() <= tolerance, + "ex_bits={ex_bits} dim={dim} kernel={name}: {actual} != {expected}" + ); + } + } + } + + #[test] + fn test_pad_query_pads_with_zeros() { + let query = vec![1.0f32; 100]; + let mut padded = vec![f32::NAN; padded_query_len(query.len())]; + pad_query_into(&query, &mut padded); + assert_eq!(padded.len(), 128); + assert_eq!(&padded[..100], &query[..]); + assert!(padded[100..].iter().all(|&value| value == 0.0)); + } +} diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index bd70f176c5d..36e56986921 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -41,6 +41,10 @@ use serde::{Deserialize, Serialize}; use crate::frag_reuse::FragReuseIndex; use crate::pb; use crate::vector::ApproxMode; +use crate::vector::bq::ex_dot::{ + EX_DOT_BLOCK_DIMS, ExDotFn, blocked_ex_code_bytes, ex_dot_kernel, pad_query_into, + padded_query_len, repack_sequential_row, sequential_matches_blocked, +}; use crate::vector::bq::rotation::{apply_fast_rotation, apply_fast_rotation_in_place}; use crate::vector::bq::transform::{ ADD_FACTORS_COLUMN, ERROR_FACTORS_COLUMN, EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN, @@ -59,7 +63,14 @@ use crate::vector::storage::{ pub const RABIT_METADATA_KEY: &str = "lance:rabit"; pub const RABIT_CODE_COLUMN: &str = "_rabit_codes"; +/// Legacy ex-code column: sequential LSB-first bit stream per row. Read-only; +/// rows are repacked into the blocked layout at load time. pub const RABIT_EX_CODE_COLUMN: &str = "__ex_codes"; +/// Ex-code column in the blocked layout consumed by the ex-dot kernels (see +/// `ex_dot` module docs). Indexes written with this column cannot be read by +/// older versions, which fail with a missing-column error instead of +/// misinterpreting the bytes. +pub const RABIT_BLOCKED_EX_CODE_COLUMN: &str = "__blocked_ex_codes"; pub const SEGMENT_LENGTH: usize = 4; pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH; const RABIT_PRUNE_STATS_ENV: &str = "LANCE_RQ_PRUNE_STATS"; @@ -210,10 +221,10 @@ pub fn rabit_ex_code_field(rotated_dim: usize, num_bits: u8) -> Result(&rotated_query, &mut dist_table); - let mut ex_dist_table = vec![0.0; ex_dist_table_len]; - build_ex_dist_table_direct_into(&rotated_query, ex_bits, &mut ex_dist_table); + // The kernels consume the rotated query directly; a zero-padded copy + // is only needed when the rotated dim is not block-aligned. + let mut ex_query = Vec::new(); + if ex_bits > 0 && !code_dim.is_multiple_of(EX_DOT_BLOCK_DIMS) { + ex_query.resize(padded_query_len(code_dim), 0.0); + pad_query_into(&rotated_query, &mut ex_query); + } let sum_q = rotated_query.iter().copied().sum(); Ok(RabitRawQueryContext { @@ -370,7 +381,7 @@ impl RabitQuantizationMetadata { ex_bits, rotated_query, dist_table, - ex_dist_table, + ex_query, sum_q, }) } @@ -462,6 +473,10 @@ pub struct RabitQuantizationStorage { add_factors: Float32Array, scale_factors: Float32Array, error_factors: Option, + // ex codes in the blocked kernel layout; always aliases the batch column + // (legacy sequential batches are normalized at load, replacing the + // sequential column with the repacked one, so rewrites emit the blocked + // format). ex_codes: Option, packed_ex_codes: Option, ex_add_factors: Option, @@ -560,12 +575,17 @@ impl RabitQuantizationStorage { let RabitDistCalculatorParts { dim, dist_table, - ex_dist_table, + ex_query, sum_q, query_factor, query_error, approx_mode, } = parts; + let ex_code_len = self + .ex_codes + .as_ref() + .map(|codes| codes.value_length() as usize) + .unwrap_or_default(); let ex_codes = self .ex_codes .as_ref() @@ -579,10 +599,11 @@ impl RabitQuantizationStorage { self.metadata.num_bits, self.metadata.query_estimator, dist_table, - ex_dist_table, + ex_query, sum_q, self.codes.values().as_primitive::().values(), ex_codes, + ex_code_len, self.add_factors.values(), self.scale_factors.values(), self.error_factors @@ -767,25 +788,42 @@ fn copy_subtract_f32(lhs: &[f32], rhs: &[f32], output: &mut [f32]) { struct RabitDistCalculatorParts<'a> { dim: usize, dist_table: Cow<'a, [f32]>, - ex_dist_table: Cow<'a, [f32]>, + ex_query: Cow<'a, [f32]>, sum_q: f32, query_factor: f32, query_error: f32, approx_mode: ApproxMode, } +/// Pick the query slice the ex-dot kernels consume: the rotated query itself +/// when the dim is block-aligned, otherwise a zero-padded copy. +fn kernel_query<'a>(rotated_query: &'a [f32], padded: &'a [f32]) -> &'a [f32] { + if rotated_query.len().is_multiple_of(EX_DOT_BLOCK_DIMS) { + rotated_query + } else { + padded + } +} + pub struct RabitDistCalculator<'a> { dim: usize, num_bits: u8, query_estimator: RabitQueryEstimator, // n * d / 8 binary-code bytes codes: &'a [u8], + // per-row ex codes in the blocked kernel layout ex_codes: Option<&'a [u8]>, + // bytes per ex-code row; legacy rows for layout-compatible widths may be + // shorter than the blocked size, which the kernels treat as zero padding + ex_code_len: usize, // this is a flattened 2D array of size d/4 * 16, // we split the query codes into d/4 chunks, each chunk is with 4 elements, // then dist_table[i][j] is the distance between the i-th query code and the code j dist_table: Cow<'a, [f32]>, - ex_dist_table: Cow<'a, [f32]>, + // the rotated query, zero-padded to a 64-dim multiple when needed; also + // the source for the FastScan ex LUT on the legacy bypass path + ex_query: Cow<'a, [f32]>, + ex_dot: Option, add_factors: &'a [f32], scale_factors: &'a [f32], error_factors: Option<&'a [f32]>, @@ -807,10 +845,11 @@ impl<'a> RabitDistCalculator<'a> { num_bits: u8, query_estimator: RabitQueryEstimator, dist_table: Cow<'a, [f32]>, - ex_dist_table: Cow<'a, [f32]>, + ex_query: Cow<'a, [f32]>, sum_q: f32, codes: &'a [u8], ex_codes: Option<&'a [u8]>, + ex_code_len: usize, add_factors: &'a [f32], scale_factors: &'a [f32], error_factors: Option<&'a [f32]>, @@ -821,14 +860,17 @@ impl<'a> RabitDistCalculator<'a> { query_error: f32, approx_mode: ApproxMode, ) -> Self { + let ex_dot = (num_bits > 1).then(|| ex_dot_kernel(num_bits - 1)); Self { dim, num_bits, query_estimator, codes, ex_codes, + ex_code_len, dist_table, - ex_dist_table, + ex_query, + ex_dot, add_factors, scale_factors, error_factors, @@ -843,6 +885,18 @@ impl<'a> RabitDistCalculator<'a> { } } + /// `sum_d query[d] * ex_code[d]` for the candidate's packed ex codes. + #[inline] + fn ex_code_dot(&self, ex_codes: &[u8], id: usize) -> f32 { + let ex_dot = self + .ex_dot + .expect("raw-query multi-bit RQ requires an ex-dot kernel"); + ex_dot( + self.ex_query.as_ref(), + &ex_codes[id * self.ex_code_len..(id + 1) * self.ex_code_len], + ) + } + #[allow(clippy::uninit_vec)] fn binary_distances_with_scratch( &self, @@ -1030,8 +1084,6 @@ impl<'a> RabitDistCalculator<'a> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = - rabit_ex_code_bytes(self.dim, ex_bits).expect("RabitQ num_bits should be validated"); let code_scale = (1u32 << ex_bits) as f32; let code_bias = -(code_scale - 0.5); @@ -1039,12 +1091,11 @@ impl<'a> RabitDistCalculator<'a> { self.packed_ex_codes .map(|packed_ex_codes| { let fastscan_len = simd_len; - let fastscan_code_len = ex_fastscan_code_len(self.dim, ex_bits) - .expect("RabitQ num_bits should be validated"); + let fastscan_code_len = self.ex_code_len; let (qmin, qmax, quantization_max) = quantize_ex_fastscan_dist_table_into( - self.dim, ex_bits, - &self.ex_dist_table, + self.ex_code_len, + self.ex_query.as_ref(), quantized_dists_table, ); quantized_dists.clear(); @@ -1088,14 +1139,7 @@ impl<'a> RabitDistCalculator<'a> { .enumerate() .skip(fastscan_len) .for_each(|(id, dist)| { - let ex_dist = compute_single_rq_ex_distance( - ex_codes, - id, - ex_code_len, - ex_bits, - self.dim, - &self.ex_dist_table, - ); + let ex_dist = self.ex_code_dot(ex_codes, id); let full_dot = code_scale * *dist + ex_dist + code_bias * self.sum_q; *dist = full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor; }); @@ -1121,19 +1165,11 @@ impl<'a> RabitDistCalculator<'a> { id: usize, binary_ip: f32, ex_bits: u8, - ex_code_len: usize, ex_codes: &[u8], ex_add_factors: &[f32], ex_scale_factors: &[f32], ) -> f32 { - let ex_dist = compute_single_rq_ex_distance( - ex_codes, - id, - ex_code_len, - ex_bits, - self.dim, - &self.ex_dist_table, - ); + let ex_dist = self.ex_code_dot(ex_codes, id); let code_bias = -((1u32 << ex_bits) as f32 - 0.5); let full_dot = (1u32 << ex_bits) as f32 * binary_ip + ex_dist + code_bias * self.sum_q; full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor @@ -1180,8 +1216,6 @@ impl<'a> RabitDistCalculator<'a> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = - rabit_ex_code_bytes(self.dim, ex_bits).expect("RabitQ num_bits should be validated"); let query_lower_bound = lower_bound.unwrap_or(f32::MIN); let query_upper_bound = upper_bound.unwrap_or(f32::MAX); let mut max_dist = res.peek().map(|node| node.dist); @@ -1213,7 +1247,6 @@ impl<'a> RabitDistCalculator<'a> { id, binary_ip, ex_bits, - ex_code_len, ex_codes, ex_add_factors, ex_scale_factors, @@ -1276,33 +1309,6 @@ where dist_table } -fn build_ex_dist_table_direct(rotated_query: &[f32], ex_bits: u8) -> Vec { - if ex_bits == 0 { - return Vec::new(); - } - let entries_per_dim = 1usize << ex_bits; - let mut dist_table = vec![0.0; rotated_query.len() * entries_per_dim]; - build_ex_dist_table_direct_into(rotated_query, ex_bits, &mut dist_table); - dist_table -} - -fn build_ex_dist_table_direct_into(rotated_query: &[f32], ex_bits: u8, dist_table: &mut [f32]) { - if ex_bits == 0 { - debug_assert!(dist_table.is_empty()); - return; - } - let entries_per_dim = 1usize << ex_bits; - debug_assert_eq!(dist_table.len(), rotated_query.len() * entries_per_dim); - for (query_value, table) in rotated_query - .iter() - .zip(dist_table.chunks_exact_mut(entries_per_dim)) - { - for (code, value) in table.iter_mut().enumerate() { - *value = *query_value * code as f32; - } - } -} - fn build_dist_table_direct_into(qc: &[T::Native], dist_table: &mut [f32]) where T::Native: AsPrimitive, @@ -1401,33 +1407,20 @@ fn quantize_dist_table_u16_into( (qmin, qmax) } -#[inline] -fn packed_ex_code_value(row_codes: &[u8], dim_idx: usize, ex_bits: u8) -> u8 { - debug_assert!(ex_bits > 0); - let bit_offset = dim_idx * ex_bits as usize; - let byte_idx = bit_offset / u8::BITS as usize; - let bit_shift = bit_offset % u8::BITS as usize; - let bits = row_codes[byte_idx] as u16 - | row_codes - .get(byte_idx + 1) - .map(|byte| (*byte as u16) << u8::BITS) - .unwrap_or_default(); - let mask = (1u16 << ex_bits) - 1; - ((bits >> bit_shift) & mask) as u8 -} - +/// Build the u8 FastScan LUT for the ex codes directly from the rotated +/// query (`ex_query`, natural dim order, padding dims zero): the underlying +/// per-dim table is the pure multiplication `q[d] * code`, so no intermediate +/// `dim * 2^ex_bits` table is materialized. fn quantize_ex_fastscan_dist_table_into( - dim: usize, ex_bits: u8, - ex_dist_table: &[f32], + ex_code_len: usize, + ex_query: &[f32], quantized_dist_table: &mut Vec, ) -> (f32, f32, f32) { debug_assert!(supports_ex_fastscan(ex_bits)); - let entries_per_dim = 1usize << ex_bits; - debug_assert_eq!(ex_dist_table.len(), dim * entries_per_dim); - let num_split_tables = - ex_fastscan_code_len(dim, ex_bits).expect("RabitQ num_bits should be validated") * 2; + // One split table per code nibble of the row. + let num_split_tables = ex_code_len * 2; let quantization_max = (u16::MAX as usize / num_split_tables) .min(u8::MAX as usize) .max(1) as f32; @@ -1436,7 +1429,7 @@ fn quantize_ex_fastscan_dist_table_into( let mut qmax = f32::NEG_INFINITY; for table_idx in 0..num_split_tables { for code in 0..SEGMENT_NUM_CODES { - let value = ex_fastscan_dist_table_value(dim, ex_bits, ex_dist_table, table_idx, code); + let value = ex_fastscan_dist_table_value(ex_query, ex_bits, table_idx, code); qmin = qmin.min(value); qmax = qmax.max(value); } @@ -1452,7 +1445,7 @@ fn quantize_ex_fastscan_dist_table_into( let factor = quantization_max / (qmax - qmin); for table_idx in 0..num_split_tables { for code in 0..SEGMENT_NUM_CODES { - let value = ex_fastscan_dist_table_value(dim, ex_bits, ex_dist_table, table_idx, code); + let value = ex_fastscan_dist_table_value(ex_query, ex_bits, table_idx, code); quantized_dist_table.push(((value - qmin) * factor).round() as u8); } } @@ -1465,91 +1458,153 @@ fn supports_ex_fastscan(ex_bits: u8) -> bool { matches!(ex_bits, 2 | 4 | 8) } -#[inline] -fn ex_fastscan_code_len(dim: usize, ex_bits: u8) -> Option { - match ex_bits { - 2 | 4 | 8 => rabit_ex_code_bytes(dim, ex_bits).ok(), - _ => None, - } -} - +/// The FastScan LUT value for one nibble of a blocked-layout code byte: +/// `table_idx / 2` is the byte position within a row and `table_idx % 2` +/// selects its low/high nibble (see the `ex_dot` module docs for the +/// byte-to-dim mapping per width). Dims beyond the query length (block +/// padding) contribute zero. #[inline] fn ex_fastscan_dist_table_value( - dim: usize, + ex_query: &[f32], ex_bits: u8, - ex_dist_table: &[f32], table_idx: usize, code: usize, ) -> f32 { + let query = |dim_idx: usize| ex_query.get(dim_idx).copied().unwrap_or(0.0); + let byte_idx = table_idx / 2; + let high_nibble = table_idx % 2 == 1; match ex_bits { 2 => { - let dim_idx = table_idx * 2; - let low = code & 0b11; - let high = (code >> 2) & 0b11; - ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, low) - + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx + 1, high) + // byte 16g+b = dims {64g+b, +16, +32, +48} at bit pairs; the low + // nibble covers the first two dims, the high nibble the last two. + let dim_idx = 64 * (byte_idx / 16) + byte_idx % 16 + 32 * usize::from(high_nibble); + let low = (code & 0b11) as f32; + let high = ((code >> 2) & 0b11) as f32; + query(dim_idx) * low + query(dim_idx + 16) * high + } + 4 => { + // byte 32g+8j+b = dim 64g+16j+b (low nibble) | dim +8 (high). + let in_block = byte_idx % 32; + let dim_idx = 64 * (byte_idx / 32) + + 16 * (in_block / 8) + + in_block % 8 + + 8 * usize::from(high_nibble); + query(dim_idx) * code as f32 } - 4 => ex_dist_table_value(ex_dist_table, dim, ex_bits, table_idx, code), 8 => { - let dim_idx = table_idx / 2; - if table_idx.is_multiple_of(2) { - ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code) + // byte = dim identity; the high nibble carries code bits 4..8. + let code = if high_nibble { + code << SEGMENT_LENGTH } else { - ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code << SEGMENT_LENGTH) - } + code + }; + query(byte_idx) * code as f32 } _ => unreachable!("unsupported RabitQ ex_bits={ex_bits} for FastScan"), } } -#[inline] -fn ex_dist_table_value( - ex_dist_table: &[f32], - dim: usize, - ex_bits: u8, - dim_idx: usize, - code: usize, -) -> f32 { - if dim_idx >= dim { - return 0.0; - } - let entries_per_dim = 1usize << ex_bits; - ex_dist_table[dim_idx * entries_per_dim + code] -} - -#[inline] -fn compute_single_rq_ex_distance( - ex_codes: &[u8], - id: usize, - ex_code_len: usize, - ex_bits: u8, - dim: usize, - ex_dist_table: &[f32], -) -> f32 { - if ex_bits == 0 { - return 0.0; - } - let entries_per_dim = 1usize << ex_bits; - let row_codes = &ex_codes[id * ex_code_len..(id + 1) * ex_code_len]; - (0..dim) - .map(|dim_idx| { - let code = packed_ex_code_value(row_codes, dim_idx, ex_bits) as usize; - ex_dist_table[dim_idx * entries_per_dim + code] - }) - .sum() -} - +/// Transpose ex codes for the FastScan bulk path. That path is only reachable +/// when lower-bound gating is disabled, i.e. for legacy indexes without error +/// factors; gated indexes rerank per candidate with the ex-dot kernels and +/// never touch this copy, so skip the transpose (and its resident memory). fn maybe_pack_ex_codes( ex_codes: Option<&FixedSizeListArray>, ex_bits: u8, + error_factors: Option<&Float32Array>, ) -> Option { let ex_codes = ex_codes?; + if error_factors.is_some() { + return None; + } match ex_bits { 2 | 4 | 8 => Some(pack_codes(ex_codes)), _ => None, } } +/// Bring legacy sequential ex codes into the blocked kernel layout: rows are +/// repacked, except for the widths whose layouts agree byte-for-byte (then +/// the column is used as stored). +fn blocked_ex_codes_from_sequential( + seq_codes: &FixedSizeListArray, + dim: usize, + ex_bits: u8, +) -> Result { + if sequential_matches_blocked(ex_bits) + && seq_codes.value_length() as usize == blocked_ex_code_bytes(dim, ex_bits) + { + return Ok(seq_codes.clone()); + } + let seq_code_len = seq_codes.value_length() as usize; + let seq_values = seq_codes.values().as_primitive::().values(); + let blocked_code_len = blocked_ex_code_bytes(dim, ex_bits); + let mut blocked_values = vec![0u8; seq_codes.len() * blocked_code_len]; + for (seq_row, blocked_row) in seq_values + .chunks_exact(seq_code_len) + .zip(blocked_values.chunks_exact_mut(blocked_code_len)) + { + repack_sequential_row(seq_row, dim, ex_bits, blocked_row); + } + Ok(FixedSizeListArray::try_new_from_values( + UInt8Array::from(blocked_values), + blocked_code_len as i32, + )?) +} + +/// Load the ex-code column of an index batch into the blocked kernel layout, +/// accepting both the blocked format and the legacy sequential format. Legacy +/// batches are normalized in place (the sequential column is replaced by the +/// blocked one), so rewrites — remap, optimize merges — always emit the +/// blocked format and legacy indexes upgrade on their next rewrite. +pub(crate) fn load_blocked_ex_codes( + batch: RecordBatch, + rotated_dim: usize, + num_bits: u8, +) -> Result<(RecordBatch, FixedSizeListArray)> { + let ex_bits = rabit_ex_bits(num_bits)?; + if let Some(column) = batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN) { + let codes = column.as_fixed_size_list().clone(); + let expected_bytes = blocked_ex_code_bytes(rotated_dim, ex_bits); + if codes.value_length() as usize != expected_bytes { + return Err(Error::invalid_input(format!( + "RabitQ ex-code byte width mismatch: column {} has {} bytes, metadata rotated_dim={} ex_bits={} requires {} bytes", + RABIT_BLOCKED_EX_CODE_COLUMN, + codes.value_length(), + rotated_dim, + ex_bits, + expected_bytes + ))); + } + return Ok((batch, codes)); + } + let column = batch.column_by_name(RABIT_EX_CODE_COLUMN).ok_or_else(|| { + Error::invalid_input(format!( + "RabitQ num_bits={} requires {} column", + num_bits, RABIT_BLOCKED_EX_CODE_COLUMN + )) + })?; + let codes = column.as_fixed_size_list().clone(); + let expected_bytes = rabit_ex_code_bytes(rotated_dim, ex_bits)?; + if codes.value_length() as usize != expected_bytes { + return Err(Error::invalid_input(format!( + "RabitQ ex-code byte width mismatch: column {} has {} bytes, metadata rotated_dim={} ex_bits={} requires {} bytes", + RABIT_EX_CODE_COLUMN, + codes.value_length(), + rotated_dim, + ex_bits, + expected_bytes + ))); + } + let blocked = blocked_ex_codes_from_sequential(&codes, rotated_dim, ex_bits)?; + let ex_code_field = rabit_ex_code_field(rotated_dim, num_bits)? + .expect("multi-bit RabitQ always has an ex-code field"); + let batch = batch + .drop_column(RABIT_EX_CODE_COLUMN)? + .try_with_column(ex_code_field, Arc::new(blocked.clone()))?; + Ok((batch, blocked)) +} + impl DistCalculator for RabitDistCalculator<'_> { #[inline(always)] fn distance(&self, id: u32) -> f32 { @@ -1580,13 +1635,10 @@ impl DistCalculator for RabitDistCalculator<'_> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = rabit_ex_code_bytes(self.dim, ex_bits) - .expect("RabitQ num_bits should be validated"); self.raw_query_multi_bit_exact_distance( id, dist, ex_bits, - ex_code_len, ex_codes, ex_add_factors, ex_scale_factors, @@ -1865,8 +1917,6 @@ impl VectorStore for RabitQuantizationStorage { let code_dim = self.code_dim(); let rotated_qr = self.rotate_query_vector(code_dim, &qr); let dist_table = build_dist_table_direct::(&rotated_qr); - let ex_bits = self.metadata.num_bits - 1; - let ex_dist_table = build_ex_dist_table_direct(&rotated_qr, ex_bits); let query_factor = match self.metadata.query_estimator { RabitQueryEstimator::ResidualQuery => self.residual_query_factor(dist_q_c), RabitQueryEstimator::RawQuery => self.raw_query_factor(dist_q_c, &rotated_qr, None), @@ -1877,12 +1927,21 @@ impl VectorStore for RabitQuantizationStorage { self.raw_query_error_for_gating(dist_q_c, &rotated_qr, None) } }; - let sum_q = rotated_qr.into_iter().sum(); + let sum_q = rotated_qr.iter().copied().sum(); + // The kernels read the rotated query directly; only unaligned dims + // need a zero-padded copy. + let ex_query = if code_dim.is_multiple_of(EX_DOT_BLOCK_DIMS) { + rotated_qr + } else { + let mut padded = vec![0.0; padded_query_len(code_dim)]; + pad_query_into(&rotated_qr, &mut padded); + padded + }; self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, dist_table: Cow::Owned(dist_table), - ex_dist_table: Cow::Owned(ex_dist_table), + ex_query: Cow::Owned(ex_query), sum_q, query_factor, query_error, @@ -1921,7 +1980,10 @@ impl VectorStore for RabitQuantizationStorage { return self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, dist_table: Cow::Borrowed(&raw_query.dist_table), - ex_dist_table: Cow::Borrowed(&raw_query.ex_dist_table), + ex_query: Cow::Borrowed(kernel_query( + &raw_query.rotated_query, + &raw_query.ex_query, + )), sum_q: raw_query.sum_q, query_factor, query_error, @@ -1931,18 +1993,20 @@ impl VectorStore for RabitQuantizationStorage { let dist_table_len = code_dim * 4; let ex_bits = self.metadata.num_bits - 1; - let ex_dist_table_len = if ex_bits == 0 { + // The kernels read the rotated query in place; a zero-padded copy is + // only needed when the rotated dim is not block-aligned. + let ex_query_table_len = if ex_bits == 0 || code_dim.is_multiple_of(EX_DOT_BLOCK_DIMS) { 0 } else { - code_dim * (1usize << ex_bits) + padded_query_len(code_dim) }; - f32_scratch.resize(code_dim + dist_table_len + ex_dist_table_len, 0.0); + f32_scratch.resize(code_dim + dist_table_len + ex_query_table_len, 0.0); let query_factor; let query_error; let sum_q = { let (rotated_qr, remaining) = f32_scratch.split_at_mut(code_dim); - let (dist_table, ex_dist_table) = remaining.split_at_mut(dist_table_len); + let (dist_table, ex_query) = remaining.split_at_mut(dist_table_len); match residual { Some(QueryResidual::Centroid(residual_centroid)) => { self.rotate_query_vector_into( @@ -1981,17 +2045,20 @@ impl VectorStore for RabitQuantizationStorage { } }; build_dist_table_direct_into::(rotated_qr, dist_table); - build_ex_dist_table_direct_into(rotated_qr, ex_bits, ex_dist_table); + if ex_query_table_len > 0 { + pad_query_into(rotated_qr, ex_query); + } rotated_qr.iter().copied().sum() }; + let ex_query_start = code_dim + dist_table_len; self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, - dist_table: Cow::Borrowed(&f32_scratch[code_dim..code_dim + dist_table_len]), - ex_dist_table: Cow::Borrowed( - &f32_scratch - [code_dim + dist_table_len..code_dim + dist_table_len + ex_dist_table_len], - ), + dist_table: Cow::Borrowed(&f32_scratch[code_dim..ex_query_start]), + ex_query: Cow::Borrowed(kernel_query( + &f32_scratch[..code_dim], + &f32_scratch[ex_query_start..ex_query_start + ex_query_table_len], + )), sum_q, query_factor, query_error, @@ -2192,31 +2259,14 @@ impl QuantizerStorage for RabitQuantizationStorage { .column_by_name(ERROR_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); let ex_bits = rabit_ex_bits(metadata.num_bits)?; + let mut batch = batch; let mut ex_codes = None; let mut ex_add_factors = None; let mut ex_scale_factors = None; if ex_bits != 0 { - let codes = batch - .column_by_name(RABIT_EX_CODE_COLUMN) - .ok_or_else(|| { - Error::invalid_input(format!( - "RabitQ num_bits={} requires {} column", - metadata.num_bits, RABIT_EX_CODE_COLUMN - )) - })? - .as_fixed_size_list() - .clone(); - let expected_ex_code_bytes = rabit_ex_code_bytes(metadata.rotated_dim(), ex_bits)?; - if codes.value_length() as usize != expected_ex_code_bytes { - return Err(Error::invalid_input(format!( - "RabitQ ex-code byte width mismatch: column {} has {} bytes, metadata rotated_dim={} ex_bits={} requires {} bytes", - RABIT_EX_CODE_COLUMN, - codes.value_length(), - metadata.rotated_dim(), - ex_bits, - expected_ex_code_bytes - ))); - } + let (normalized_batch, codes) = + load_blocked_ex_codes(batch, metadata.rotated_dim(), metadata.num_bits)?; + batch = normalized_batch; ex_codes = Some(codes); ex_add_factors = Some( batch @@ -2246,16 +2296,19 @@ impl QuantizerStorage for RabitQuantizationStorage { if batch.column_by_name(EX_ADD_FACTORS_COLUMN).is_some() || batch.column_by_name(EX_SCALE_FACTORS_COLUMN).is_some() || batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() + || batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN).is_some() { return Err(Error::invalid_input( "RabitQ num_bits=1 raw-query indexes must not contain ex-code columns" .to_string(), )); } - } else if batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() { + } else if batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() + || batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN).is_some() + { return Err(Error::invalid_input(format!( - "RabitQ num_bits={} does not support {} column", - metadata.num_bits, RABIT_EX_CODE_COLUMN + "RabitQ num_bits={} does not support ex-code columns", + metadata.num_bits ))); } @@ -2270,7 +2323,8 @@ impl QuantizerStorage for RabitQuantizationStorage { let mut metadata = metadata.clone(); metadata.packed = true; - let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); + let packed_ex_codes = + maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits, error_factors.as_ref()); Ok(Self { metadata, @@ -2353,11 +2407,18 @@ impl QuantizerStorage for RabitQuantizationStorage { let error_factors = batch .column_by_name(ERROR_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); - let ex_codes = batch - .column_by_name(RABIT_EX_CODE_COLUMN) - .map(|codes| codes.as_fixed_size_list().clone()); + let ex_bits = rabit_ex_bits(self.metadata.num_bits)?; + let (batch, ex_codes) = if ex_bits == 0 { + (batch, None) + } else { + // `self.batch` is already normalized at load, so this is a + // zero-copy column lookup. + let (batch, codes) = + load_blocked_ex_codes(batch, self.metadata.rotated_dim(), self.metadata.num_bits)?; + (batch, Some(codes)) + }; let packed_ex_codes = - maybe_pack_ex_codes(ex_codes.as_ref(), rabit_ex_bits(self.metadata.num_bits)?); + maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits, error_factors.as_ref()); let ex_add_factors = batch .column_by_name(EX_ADD_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); @@ -2695,7 +2756,7 @@ mod tests { assert!(rabit_ex_code_field(128, 1).unwrap().is_none()); let ex_field = rabit_ex_code_field(128, 9).unwrap().unwrap(); - assert_eq!(ex_field.name(), RABIT_EX_CODE_COLUMN); + assert_eq!(ex_field.name(), RABIT_BLOCKED_EX_CODE_COLUMN); let DataType::FixedSizeList(_, ex_code_bytes) = ex_field.data_type() else { panic!("ex-code field should be FixedSizeList"); }; @@ -2898,6 +2959,229 @@ mod tests { assert_eq!(distances, vec![104.0, 22.0]); } + /// Exercise the ex-dot kernel through the storage API for every ex width, + /// including the widths without FastScan support ({1, 3, 5, 6, 7}), and a + /// dim that is not a multiple of the 64-dim kernel group. + /// + /// The dim must be a multiple of 8: the binary distance stage consumes + /// two 4-dim segments per code byte and ignores trailing dims otherwise. + #[test] + fn test_raw_query_multi_bit_distance_matches_reference_for_all_ex_widths() { + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + // 72 exercises the kernels' padded-tail path; 1536 is a production + // embedding dim exercising the full-group path. Both the blocked + // format and the legacy sequential format must produce the same + // distances. + for (code_dim, num_rows) in [(72usize, 33usize), (1536, 33)] { + for num_bits in 2..=9u8 { + for legacy_format in [false, true] { + let ex_bits = num_bits - 1; + let mut rng = SmallRng::seed_from_u64(num_bits as u64); + + let sign_bits = (0..num_rows * code_dim) + .map(|_| rng.random_bool(0.5)) + .collect::>(); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let ex_values = (0..num_rows * code_dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + + let code_len = rabit_binary_code_bytes(code_dim); + let mut code_bytes = vec![0u8; num_rows * code_len]; + for (row, bits) in sign_bits.chunks_exact(code_dim).enumerate() { + for (dim, &bit) in bits.iter().enumerate() { + code_bytes[row * code_len + dim / 8] |= (bit as u8) << (dim % 8); + } + } + let (ex_code_column, ex_code_len, ex_code_bytes) = if legacy_format { + let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_values.chunks_exact(code_dim).enumerate() { + for (dim, &value) in values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + ex_code_bytes[row * ex_code_len + bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + ex_code_bytes[row * ex_code_len + bit_offset / 8 + 1] |= + (bits >> 8) as u8; + } + } + } + (RABIT_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + } else { + let ex_code_len = blocked_ex_code_bytes(code_dim, ex_bits); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_code_bytes + .chunks_exact_mut(ex_code_len) + .zip(ex_values.chunks_exact(code_dim)) + { + crate::vector::bq::ex_dot::pack_blocked_row(values, ex_bits, row); + } + (RABIT_BLOCKED_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + }; + + let identity = Float32Array::from_iter_values((0..code_dim).flat_map(|row| { + (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 }) + })); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from(code_bytes), + code_len as i32, + ) + .unwrap(); + let ex_codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from(ex_code_bytes), + ex_code_len as i32, + ) + .unwrap(); + let ex_add_factors = (0..num_rows) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let ex_scale_factors = (0..num_rows) + .map(|_| rng.random_range(0.1f32..1.0)) + .collect::>(); + let batch = RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef, + ), + (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + (ex_code_column, Arc::new(ex_codes) as ArrayRef), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(ex_add_factors.clone())) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(ex_scale_factors.clone())) as ArrayRef, + ), + ]) + .unwrap(); + let storage = RabitQuantizationStorage::try_from_batch( + batch, + &metadata, + DistanceType::L2, + None, + ) + .unwrap(); + + let query = (0..code_dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let sum_q = query.iter().sum::(); + let calc = storage.dist_calculator( + Arc::new(Float32Array::from(query.clone())) as ArrayRef, + 0.0, + ); + + let code_scale = (1u32 << ex_bits) as f32; + let code_bias = -(code_scale - 0.5); + let expected = (0..num_rows) + .map(|row| { + let binary_ip = (0..code_dim) + .map(|dim| { + query[dim] * sign_bits[row * code_dim + dim] as u8 as f32 + }) + .sum::(); + let ex_dist = (0..code_dim) + .map(|dim| query[dim] * ex_values[row * code_dim + dim] as f32) + .sum::(); + let full_dot = code_scale * binary_ip + ex_dist + code_bias * sum_q; + full_dot * ex_scale_factors[row] + ex_add_factors[row] + }) + .collect::>(); + + for (row, &want) in expected.iter().enumerate() { + let got = calc.distance(row as u32); + assert!( + (got - want).abs() <= 1e-3 * want.abs().max(1.0), + "num_bits={num_bits} row={row}: {got} != {want}" + ); + } + + let mut distances = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + let mut u32_scratch = Vec::new(); + calc.distance_all_with_scratch( + 0, + &mut distances, + &mut u16_scratch, + &mut u8_scratch, + &mut u32_scratch, + ); + assert_eq!(distances.len(), num_rows); + // The bulk path quantizes the binary LUT to u8, and that error is + // amplified by 2^ex_bits in the multi-bit estimate, so the value + // assertions need a quantization-aware bound. The FastScan ex + // widths additionally quantize the ex LUT and are covered by + // `test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes`. + if !matches!(ex_bits, 2 | 4 | 8) { + // Worst-case |error| of one u8-quantized binary LUT lookup is + // (table range) / 255 / 2, accumulated over one lookup per + // 8-dim pair of segments. + let num_tables = code_dim.div_ceil(4); + let mut table_min = f32::INFINITY; + let mut table_max = f32::NEG_INFINITY; + for segment in query.chunks(4) { + for subset in 0..16usize { + let value = segment + .iter() + .enumerate() + .filter(|(idx, _)| subset & (1 << idx) != 0) + .map(|(_, q)| *q) + .sum::(); + table_min = table_min.min(value); + table_max = table_max.max(value); + } + } + let binary_bound = + code_scale * num_tables as f32 * (table_max - table_min) / 255.0 / 2.0 + * ex_scale_factors.iter().fold(0.0f32, |max, &s| max.max(s)); + for (row, (&got, &want)) in + distances.iter().zip(expected.iter()).enumerate() + { + assert!( + (got - want).abs() <= binary_bound + 1e-3, + "num_bits={num_bits} row={row} (distance_all): {got} != {want} (bound {binary_bound})" + ); + } + // Rows past the SIMD batch use the exact binary path, so the + // final remainder row must match the per-candidate distance. + let remainder_row = num_rows - 1; + let got = distances[remainder_row]; + let want = calc.distance(remainder_row as u32); + assert!( + (got - want).abs() <= 1e-3 * want.abs().max(1.0), + "num_bits={num_bits} remainder row (distance_all): {got} != {want}" + ); + } + } + } + } + } + #[test] fn test_fast_approx_mode_uses_one_bit_scores_for_multi_bit_raw_query() { let code_dim = 8usize; @@ -3061,10 +3345,17 @@ mod tests { assert_eq!(hacc_accum_len, num_rows); } - fn assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits: u8) { - let code_dim = 8usize; + fn assert_raw_query_multi_bit_distance_all_uses_fastscan( + num_bits: u8, + legacy_format: bool, + with_error_factors: bool, + ) { + // Not a multiple of 64, so the padded-tail LUT entries are exercised; + // a multiple of 8 as the binary stage requires. + let code_dim = 72usize; let num_rows = BATCH_SIZE + 1; let ex_bits = rabit_ex_bits(num_bits).unwrap(); + let max_code = ((1u16 << ex_bits) - 1) as u8; let identity = Float32Array::from_iter_values( (0..code_dim) .flat_map(|row| (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 })), @@ -3081,16 +3372,42 @@ mod tests { packed: false, query_estimator: RabitQueryEstimator::RawQuery, }; + let code_len = rabit_binary_code_bytes(code_dim); let codes = FixedSizeListArray::try_new_from_values( - UInt8Array::from_iter_values((0..num_rows).map(|idx| (idx * 13) as u8)), - 1, + UInt8Array::from_iter_values((0..num_rows * code_len).map(|idx| (idx * 13) as u8)), + code_len as i32, ) .unwrap(); - let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let ex_values = (0..num_rows * code_dim) + .map(|idx| ((idx * 37) % (max_code as usize + 1)) as u8) + .collect::>(); + let (ex_code_column, ex_code_len, ex_code_bytes) = if legacy_format { + let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_values.chunks_exact(code_dim).enumerate() { + for (dim, &value) in values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + ex_code_bytes[row * ex_code_len + bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + ex_code_bytes[row * ex_code_len + bit_offset / 8 + 1] |= (bits >> 8) as u8; + } + } + } + (RABIT_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + } else { + let ex_code_len = blocked_ex_code_bytes(code_dim, ex_bits); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_code_bytes + .chunks_exact_mut(ex_code_len) + .zip(ex_values.chunks_exact(code_dim)) + { + crate::vector::bq::ex_dot::pack_blocked_row(values, ex_bits, row); + } + (RABIT_BLOCKED_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + }; let ex_codes = FixedSizeListArray::try_new_from_values( - UInt8Array::from_iter_values( - (0..num_rows * ex_code_len).map(|idx| (idx * 37 % 251) as u8), - ), + UInt8Array::from(ex_code_bytes), ex_code_len as i32, ) .unwrap(); @@ -3108,7 +3425,7 @@ mod tests { SCALE_FACTORS_COLUMN, Arc::new(Float32Array::from(vec![1.0; num_rows])) as ArrayRef, ), - (RABIT_EX_CODE_COLUMN, Arc::new(ex_codes) as ArrayRef), + (ex_code_column, Arc::new(ex_codes) as ArrayRef), ( EX_ADD_FACTORS_COLUMN, Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, @@ -3119,12 +3436,30 @@ mod tests { ), ]) .unwrap(); + let batch = if with_error_factors { + batch + .try_with_column( + crate::vector::bq::transform::ERROR_FACTORS_FIELD.clone(), + Arc::new(Float32Array::from(vec![1000.0; num_rows])) as ArrayRef, + ) + .unwrap() + } else { + batch + }; let storage = RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) .unwrap(); - assert!(storage.packed_ex_codes.is_some()); + // The FastScan transpose only exists for indexes that can reach the + // bulk bypass path (no error factors); gated indexes fall through to + // the exact per-row kernels in `distance_all`. + assert_eq!(storage.packed_ex_codes.is_some(), !with_error_factors); - let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + // A per-dim varying query so that any dim-mapping error in the + // FastScan LUT shows up as a value mismatch. + let query_values = (0..code_dim) + .map(|dim| (dim % 11) as f32 * 0.3 - 1.5) + .collect::>(); + let query = Arc::new(Float32Array::from(query_values.clone())) as ArrayRef; let calc = storage.dist_calculator(query, 0.0); let mut distances = Vec::new(); let mut u16_scratch = Vec::new(); @@ -3140,15 +3475,57 @@ mod tests { assert_eq!(distances.len(), num_rows); assert_eq!(u16_scratch.len(), BATCH_SIZE); - assert_eq!( - u8_scratch.len(), - ex_fastscan_code_len(code_dim, ex_bits).unwrap() * 2 * SEGMENT_NUM_CODES + let loaded_ex_code_len = storage.ex_codes.as_ref().unwrap().value_length() as usize; + if with_error_factors { + // The gated path never builds the ex LUT; the scratch holds the + // binary LUT only. + assert_eq!(u8_scratch.len(), code_dim * 4); + } else { + assert_eq!(u8_scratch.len(), loaded_ex_code_len * 2 * SEGMENT_NUM_CODES); + } + + // The fastscan estimate differs from the exact path only by the u8 + // quantization of the binary LUT (amplified by 2^ex_bits) and of the + // ex LUT, so bound the comparison by those quantization errors. + let mut table_min = f32::INFINITY; + let mut table_max = f32::NEG_INFINITY; + for segment in query_values.chunks(4) { + for subset in 0..SEGMENT_NUM_CODES { + let value = segment + .iter() + .enumerate() + .filter(|(idx, _)| subset & (1 << idx) != 0) + .map(|(_, q)| *q) + .sum::(); + table_min = table_min.min(value); + table_max = table_max.max(value); + } + } + let code_scale = (1u32 << ex_bits) as f32; + let binary_bound = + code_scale * code_dim.div_ceil(4) as f32 * (table_max - table_min) / 510.0; + let mut padded_query = vec![0.0f32; crate::vector::bq::ex_dot::padded_query_len(code_dim)]; + crate::vector::bq::ex_dot::pad_query_into(&query_values, &mut padded_query); + let mut quantized_table = Vec::new(); + let (ex_qmin, ex_qmax, ex_qcap) = quantize_ex_fastscan_dist_table_into( + ex_bits, + loaded_ex_code_len, + &padded_query, + &mut quantized_table, ); + // Without the FastScan transpose the ex stage is exact, so only the + // binary LUT quantization remains. + let ex_bound = if with_error_factors { + 0.0 + } else { + (loaded_ex_code_len * 2) as f32 * (ex_qmax - ex_qmin) / ex_qcap / 2.0 + }; + let bound = (binary_bound + ex_bound) * 1.5 + 1e-3; for (id, distance) in distances.iter().take(BATCH_SIZE).enumerate() { let exact = calc.distance(id as u32); assert!( - (*distance - exact).abs() < 10.0, - "distance_all fastscan mismatch for id {id}: actual={distance}, exact={exact}" + (*distance - exact).abs() <= bound, + "distance_all fastscan mismatch for id {id} (num_bits={num_bits} legacy={legacy_format}): actual={distance}, exact={exact}, bound={bound}" ); } assert_eq!(distances[BATCH_SIZE], calc.distance(BATCH_SIZE as u32)); @@ -3156,8 +3533,17 @@ mod tests { #[test] fn test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes() { - for num_bits in [3, 9] { - assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits); + for num_bits in [3, 5, 9] { + for legacy_format in [false, true] { + assert_raw_query_multi_bit_distance_all_uses_fastscan( + num_bits, + legacy_format, + false, + ); + } + // Gated indexes (with error factors) skip the FastScan artifacts + // and score the bulk path with the exact kernels. + assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits, false, true); } } @@ -3239,7 +3625,6 @@ mod tests { id, binary_ip, ex_bits, - ex_code_len, ex_codes, ex_add_factors, ex_scale_factors, @@ -3457,7 +3842,8 @@ mod tests { ) .unwrap_err(); assert!( - err.to_string().contains("requires __ex_codes column"), + err.to_string() + .contains("requires __blocked_ex_codes column"), "{}", err ); @@ -3501,9 +3887,11 @@ mod tests { .unwrap(); assert!(storage.metadata().packed); + // Legacy batches are normalized to the blocked column at load. let stored_batch = storage.to_batches().unwrap().next().unwrap(); + assert!(stored_batch.column_by_name(RABIT_EX_CODE_COLUMN).is_none()); assert_eq!( - stored_batch[RABIT_EX_CODE_COLUMN] + stored_batch[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .value_length(), 64 @@ -3571,11 +3959,19 @@ mod tests { #[test] fn test_remap_preserves_multi_bit_rq_split_columns() { + // num_bits=9 keeps sequential ex codes; num_bits 4/6/8 (ex_bits + // 3/5/7) also exercise the bit-plane repack rebuild in `remap`. + for num_bits in [4, 6, 8, 9u8] { + test_remap_preserves_multi_bit_rq_split_columns_impl(num_bits); + } + } + + fn test_remap_preserves_multi_bit_rq_split_columns_impl(num_bits: u8) { let original_codes = make_test_codes(50, 64); let code_dim = original_codes.value_length() as usize * 8; - let ex_codes = make_test_ex_codes(original_codes.len(), code_dim, 9); + let ex_codes = make_test_ex_codes(original_codes.len(), code_dim, num_bits); let mut metadata = make_test_metadata(code_dim); - metadata.num_bits = 9; + metadata.num_bits = num_bits; let storage = RabitQuantizationStorage::try_from_batch( make_test_batch_with_ex(original_codes.clone(), ex_codes), &metadata, @@ -3599,11 +3995,14 @@ mod tests { ); assert_eq!(remapped_row_ids, expected_row_ids.values()); + // Legacy batches are normalized to the blocked format at load, so the + // remapped batch carries the blocked column. + let ex_code_len = blocked_ex_code_bytes(code_dim, rabit_ex_bits(num_bits).unwrap()); assert_eq!( - remapped_batch[RABIT_EX_CODE_COLUMN] + remapped_batch[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .value_length(), - 64 + ex_code_len as i32 ); assert_eq!( &remapped_batch[EX_ADD_FACTORS_COLUMN] @@ -3623,5 +4022,20 @@ mod tests { .values()[..5], &[0.25, 1.25, 2.25, 4.25, 5.25] ); + + // The remapped storage must hold the same kernel-layout ex codes as a + // storage freshly loaded from the remapped batch. + let reloaded = RabitQuantizationStorage::try_from_batch( + remapped_batch, + &remapped.metadata, + DistanceType::L2, + None, + ) + .unwrap(); + assert_eq!(remapped.ex_codes, reloaded.ex_codes); + assert_eq!( + remapped.ex_codes.as_ref().unwrap().value_length() as usize, + blocked_ex_code_bytes(code_dim, rabit_ex_bits(num_bits).unwrap()) + ); } } diff --git a/rust/lance-index/src/vector/bq/transform.rs b/rust/lance-index/src/vector/bq/transform.rs index c2fc0608102..c87695e14cd 100644 --- a/rust/lance-index/src/vector/bq/transform.rs +++ b/rust/lance-index/src/vector/bq/transform.rs @@ -17,7 +17,9 @@ use tracing::instrument; use crate::vector::bq::builder::RabitQuantizer; use crate::vector::bq::rabit_ex_bits; -use crate::vector::bq::storage::{RABIT_CODE_COLUMN, RABIT_EX_CODE_COLUMN, RabitQueryEstimator}; +use crate::vector::bq::storage::{ + RABIT_BLOCKED_EX_CODE_COLUMN, RABIT_CODE_COLUMN, RabitQueryEstimator, +}; use crate::vector::quantizer::Quantization; use crate::vector::transform::Transformer; use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN}; @@ -281,7 +283,7 @@ impl Transformer for RQTransformer { #[instrument(name = "RQTransformer::transform", level = "debug", skip_all)] fn transform(&self, batch: &RecordBatch) -> Result { let has_split_codes = self.rq.num_bits() == 1 - || (batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() + || (batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN).is_some() && batch.column_by_name(EX_ADD_FACTORS_COLUMN).is_some() && batch.column_by_name(EX_SCALE_FACTORS_COLUMN).is_some()); if batch.column_by_name(RABIT_CODE_COLUMN).is_some() && has_split_codes { @@ -494,7 +496,8 @@ mod tests { use crate::vector::bq::RQRotationType; use crate::vector::bq::builder::RabitQuantizer; - use crate::vector::bq::storage::RABIT_EX_CODE_COLUMN; + use crate::vector::bq::ex_dot::blocked_ex_code_bytes; + use crate::vector::bq::storage::RABIT_BLOCKED_EX_CODE_COLUMN; use crate::vector::transform::Transformer; use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN}; @@ -535,15 +538,19 @@ mod tests { .unwrap(); let transformed = transformer.transform(&batch).unwrap(); - assert!(transformed.column_by_name(RABIT_EX_CODE_COLUMN).is_some()); + assert!( + transformed + .column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN) + .is_some() + ); assert_eq!( - transformed[RABIT_EX_CODE_COLUMN] + transformed[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .value_length(), - 3 + blocked_ex_code_bytes(8, 3) as i32 ); assert!( - transformed[RABIT_EX_CODE_COLUMN] + transformed[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .values() .as_primitive::() diff --git a/rust/lance-index/src/vector/distributed/index_merger.rs b/rust/lance-index/src/vector/distributed/index_merger.rs index 5f59985673e..70371ad4794 100755 --- a/rust/lance-index/src/vector/distributed/index_merger.rs +++ b/rust/lance-index/src/vector/distributed/index_merger.rs @@ -1440,6 +1440,25 @@ pub async fn merge_partial_vector_auxiliary_files( ))); } + // Shards written by older lance versions carry sequential ex + // codes; normalize every batch to the blocked layout before + // concatenation so mixed-version shards merge correctly + // (concat_batches combines columns by position and would + // otherwise mix the two layouts silently). + let batches = match rq_meta.as_ref() { + Some(meta) if meta.num_bits > 1 => batches + .into_iter() + .map(|batch| { + crate::vector::bq::storage::load_blocked_ex_codes( + batch, + meta.rotated_dim(), + meta.num_bits, + ) + .map(|(batch, _)| batch) + }) + .collect::>>()?, + _ => batches, + }; let schema = batches[0].schema(); let partition_batch = concat_batches(&schema, batches.iter())?; if let Some(w) = v2w_opt.as_mut() { @@ -1527,7 +1546,7 @@ mod tests { use prost::Message; use crate::vector::bq::RQRotationType; - use crate::vector::bq::storage::{RABIT_EX_CODE_COLUMN, RabitQueryEstimator}; + use crate::vector::bq::storage::{RABIT_BLOCKED_EX_CODE_COLUMN, RabitQueryEstimator}; use crate::vector::bq::transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}; lance_testing::define_stage_event_progress!( RecordingProgress, @@ -2529,11 +2548,14 @@ mod tests { let batch = batch.unwrap(); if !checked_split_columns { let schema = batch.schema(); - let ex_code_field = schema.field_with_name(RABIT_EX_CODE_COLUMN).unwrap(); + let ex_code_field = schema + .field_with_name(RABIT_BLOCKED_EX_CODE_COLUMN) + .unwrap(); let DataType::FixedSizeList(_, ex_code_bytes) = ex_code_field.data_type() else { panic!("RQ ex-code field should be FixedSizeList"); }; - assert_eq!(*ex_code_bytes, 6); + // code_dim=16 padded to one 64-dim block at ex_bits=3. + assert_eq!(*ex_code_bytes, 24); assert!(schema.field_with_name(ERROR_FACTORS_FIELD.name()).is_ok()); assert!(schema.field_with_name(EX_ADD_FACTORS_COLUMN).is_ok()); assert!(schema.field_with_name(EX_SCALE_FACTORS_COLUMN).is_ok()); diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index b036e187b77..4cb61dc9e49 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -249,7 +249,10 @@ pub struct RabitRawQueryContext { pub ex_bits: u8, pub rotated_query: Vec, pub dist_table: Vec, - pub ex_dist_table: Vec, + /// The rotated query zero-padded to a 64-dim multiple for the ex-dot + /// kernels; empty when `code_dim` is already aligned (the kernels then + /// read `rotated_query` directly). + pub ex_query: Vec, pub sum_q: f32, } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 4ea076ed420..d169cf08190 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -38,8 +38,9 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::{LocalMetricsCollector, MetricsCollector, NoOpMetricsCollector}; use lance_index::vector::VectorIndexCacheEntry; use lance_index::vector::bq::builder::RabitQuantizer; +use lance_index::vector::bq::ex_dot::{blocked_ex_code_bytes, padded_query_len}; +use lance_index::vector::bq::rabit_ex_bits; use lance_index::vector::bq::storage::{RabitQueryEstimator, SEGMENT_NUM_CODES}; -use lance_index::vector::bq::{rabit_ex_bits, rabit_ex_code_bytes}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::graph::OrderedNode; use lance_index::vector::hnsw::HNSW; @@ -152,16 +153,18 @@ fn rotated_partition_centroid_slice( cache.rotated_centroids.get(start..end) } -fn rabit_ex_dist_table_len(dim: usize, num_bits: u8) -> usize { - rabit_ex_bits(num_bits) - .map(|ex_bits| { - if ex_bits == 0 { - 0 - } else { - dim * (1usize << usize::from(ex_bits)) - } - }) - .unwrap_or(dim * 256) +/// `f32` scratch needed for the ex-bit query state: a zero-padded query copy +/// when the rotated dim is not a multiple of the 64-dim kernel block (the +/// FastScan ex LUT is built directly from the query, with no f32 table). +fn rabit_ex_scratch_len(dim: usize, num_bits: u8) -> usize { + let multi_bit = rabit_ex_bits(num_bits) + .map(|ex_bits| ex_bits > 0) + .unwrap_or(true); + if !multi_bit || dim.is_multiple_of(64) { + 0 + } else { + padded_query_len(dim) + } } fn rabit_u8_scratch_len(dim: usize, num_bits: u8) -> usize { @@ -169,7 +172,7 @@ fn rabit_u8_scratch_len(dim: usize, num_bits: u8) -> usize { let ex_dist_table_len = rabit_ex_bits(num_bits) .ok() .and_then(|ex_bits| match ex_bits { - 2 | 4 | 8 => rabit_ex_code_bytes(dim, ex_bits).ok(), + 2 | 4 | 8 => Some(blocked_ex_code_bytes(dim, ex_bits)), _ => None, }) .map(|ex_code_len| ex_code_len * 2 * SEGMENT_NUM_CODES) @@ -183,12 +186,12 @@ fn rabit_query_scratch_capacity( num_bits: u8, ) -> QueryScratchCapacity { let dist_table_len = dim * 4; - let ex_dist_table_len = rabit_ex_dist_table_len(dim, num_bits); + let ex_scratch_len = rabit_ex_scratch_len(dim, num_bits); let u8_scratch_len = rabit_u8_scratch_len(dim, num_bits); QueryScratchCapacity::new( max_partition_len, - dim + dist_table_len + ex_dist_table_len, + dim + dist_table_len + ex_scratch_len, max_partition_len.max(dist_table_len), u8_scratch_len, ) @@ -1908,7 +1911,8 @@ mod tests { use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::bq::{ RQBuildParams, RQRotationType, - storage::{RABIT_EX_CODE_COLUMN, RabitQuantizationMetadata, RabitQueryEstimator}, + ex_dot::{blocked_ex_code_bytes, padded_query_len}, + storage::{RABIT_BLOCKED_EX_CODE_COLUMN, RabitQuantizationMetadata, RabitQueryEstimator}, transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}, }; use lance_index::vector::storage::VectorStore; @@ -1983,14 +1987,17 @@ mod tests { } #[test] - fn test_rabit_ex_dist_table_len_uses_num_bits() { + fn test_rabit_ex_scratch_len_uses_num_bits() { + // Block-aligned dims read the rotated query in place. let dim = 960; + for num_bits in [1, 3, 5, 7, 9] { + assert_eq!(super::rabit_ex_scratch_len(dim, num_bits), 0); + } - assert_eq!(super::rabit_ex_dist_table_len(dim, 1), 0); - assert_eq!(super::rabit_ex_dist_table_len(dim, 3), dim * 4); - assert_eq!(super::rabit_ex_dist_table_len(dim, 5), dim * 16); - assert_eq!(super::rabit_ex_dist_table_len(dim, 7), dim * 64); - assert_eq!(super::rabit_ex_dist_table_len(dim, 9), dim * 256); + // Unaligned multi-bit queries add one padded query copy. + let dim = 968; + assert_eq!(super::rabit_ex_scratch_len(dim, 1), 0); + assert_eq!(super::rabit_ex_scratch_len(dim, 7), padded_query_len(dim)); } #[test] @@ -2012,7 +2019,7 @@ mod tests { let capacity = super::rabit_query_scratch_capacity(dim, max_partition_len, 5); assert_eq!(capacity.distances, max_partition_len); - assert_eq!(capacity.query_f32, dim + dim * 4 + dim * 16); + assert_eq!(capacity.query_f32, dim + dim * 4); assert_eq!(capacity.u16, max_partition_len); assert_eq!(capacity.u8, dim * 16); assert_eq!(capacity.u32, 0); @@ -4403,18 +4410,24 @@ mod tests { } #[rstest] - #[case::l2(DistanceType::L2)] - #[case::cosine(DistanceType::Cosine)] + #[case::l2(DistanceType::L2, 9)] + #[case::cosine(DistanceType::Cosine, 9)] + // ex_bits=3 and ex_bits=5 have no FastScan support and use the bit-plane + // repack, so these searches go through the exact ex-dot rerank kernels + // end to end. + #[case::l2_plane_repack_3bit(DistanceType::L2, 4)] + #[case::l2_plane_repack_5bit(DistanceType::L2, 6)] #[tokio::test] async fn test_build_ivf_rq_multi_bit_persists_split_codes_and_searches( #[case] distance_type: DistanceType, + #[case] num_bits: u8, ) { let test_dir = TempStrDir::default(); let test_uri = test_dir.as_str(); let (mut dataset, vectors) = generate_test_dataset::(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::new(4); - let rq_params = RQBuildParams::with_rotation_type(9, RQRotationType::Fast); + let rq_params = RQBuildParams::with_rotation_type(num_bits, RQRotationType::Fast); let params = VectorIndexParams::with_ivf_rq_params(distance_type, ivf_params, rq_params); dataset .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) @@ -4427,16 +4440,18 @@ mod tests { let scheduler = ScanScheduler::new(obj_store, SchedulerConfig::default_for_testing()); let index_uuid = indices[0].uuid.to_string(); let rq_meta = get_rq_metadata(&dataset, scheduler.clone(), &index_uuid).await; - assert_eq!(rq_meta.num_bits, 9); + assert_eq!(rq_meta.num_bits, num_bits); assert_eq!(rq_meta.query_estimator, RabitQueryEstimator::RawQuery); let reader = open_rq_aux_reader(&dataset, scheduler, &index_uuid).await; let schema = reader.schema(); - let ex_field = schema.field(RABIT_EX_CODE_COLUMN).unwrap(); + let ex_field = schema.field(RABIT_BLOCKED_EX_CODE_COLUMN).unwrap(); let DataType::FixedSizeList(_, ex_code_bytes) = ex_field.data_type() else { panic!("RQ ex-code field should be FixedSizeList"); }; - assert_eq!(ex_code_bytes, 32); + let expected_ex_code_bytes = + blocked_ex_code_bytes(rq_meta.rotated_dim(), num_bits - 1) as i32; + assert_eq!(ex_code_bytes, expected_ex_code_bytes); assert!(schema.field(EX_ADD_FACTORS_COLUMN).is_some()); assert!(schema.field(EX_SCALE_FACTORS_COLUMN).is_some());