From 55ebb4e23f1b82e6a908de3170fe0597cdd27fdf Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 10 Jun 2026 22:02:16 +0800 Subject: [PATCH 1/3] perf(vector)!: add dedicated SIMD kernels for RaBitQ ex-code reranking Replace the per-candidate table-gather ex distance (dim * 2^ex_bits f32 LUT) with direct f32-query x packed-code FMA kernels for all ex_bits 1..=8, with scalar/AVX2/AVX-512/NEON variants. ex_bits {1,2,4,8} consume the sequential packed rows as stored; {3,5,6,7} are repacked once at load into bit-plane layouts. On-disk index format is unchanged. Co-Authored-By: Claude Fable 5 --- rust/lance-index/benches/rq.rs | 108 ++- rust/lance-index/src/vector/bq.rs | 1 + rust/lance-index/src/vector/bq/ex_dot.rs | 1047 +++++++++++++++++++++ rust/lance-index/src/vector/bq/storage.rs | 442 +++++++-- rust/lance-index/src/vector/storage.rs | 5 + rust/lance/src/index/vector/ivf/v2.rs | 60 +- 6 files changed, 1556 insertions(+), 107 deletions(-) create mode 100644 rust/lance-index/src/vector/bq/ex_dot.rs diff --git a/rust/lance-index/benches/rq.rs b/rust/lance-index/benches/rq.rs index 4a7364d1313..d51db3c98f0 100644 --- a/rust/lance-index/benches/rq.rs +++ b/rust/lance-index/benches/rq.rs @@ -17,11 +17,17 @@ use lance_datagen::array::rand_type; use lance_datagen::{BatchGeneratorBuilder, RowCount}; use lance_index::vector::bq::RQRotationType; use lance_index::vector::bq::builder::RabitQuantizer; +use lance_index::vector::bq::ex_dot::{ + build_ex_query, ex_dot_code_bytes, ex_dot_kernel, needs_plane_repack, packed_ex_code_value, + plane_pack_row, +}; use lance_index::vector::bq::storage::*; use lance_index::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN}; use lance_index::vector::quantizer::{Quantization, QuantizerStorage}; use lance_index::vector::storage::{DistCalculator, VectorStore}; use lance_linalg::distance::DistanceType; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; const DIM: usize = 128; const TOTAL: usize = 16 * 1000; @@ -119,16 +125,104 @@ fn compute_distances(c: &mut Criterion) { } } -#[cfg(target_os = "linux")] -criterion_group!( - name=benches; - config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = construct_dist_table, compute_distances); +/// The table-gather ex distance used before the dedicated ex-dot kernels, +/// kept here as the baseline: per dim, extract the packed code and gather +/// `query[d] * code` from a `dim * 2^ex_bits` table. +fn gather_ex_distance(row_codes: &[u8], dim: usize, ex_bits: u8, ex_dist_table: &[f32]) -> f32 { + let entries_per_dim = 1usize << ex_bits; + (0..dim) + .map(|dim_idx| { + let code = packed_ex_code_value(row_codes, dim_idx, ex_bits) as usize; + ex_dist_table[dim_idx * entries_per_dim + code] + }) + .sum() +} + +fn ex_dot_kernels(c: &mut Criterion) { + for ex_dim in [1536usize, 2048] { + ex_dot_kernels_for_dim(c, ex_dim); + } +} + +fn ex_dot_kernels_for_dim(c: &mut Criterion, ex_dim: usize) { + const NUM_ROWS: usize = 1024; + + let mut rng = SmallRng::seed_from_u64(42); + let query = (0..ex_dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + + for ex_bits in 1..=8u8 { + let max_code = ((1u16 << ex_bits) - 1) as u8; + let seq_code_len = (ex_dim * ex_bits as usize).div_ceil(8); + let mut seq_codes = vec![0u8; NUM_ROWS * seq_code_len]; + for row in seq_codes.chunks_exact_mut(seq_code_len) { + for dim in 0..ex_dim { + let value = rng.random_range(0..=max_code); + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + row[bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + row[bit_offset / 8 + 1] |= (bits >> 8) as u8; + } + } + } + + let kernel_code_len = ex_dot_code_bytes(ex_dim, ex_bits); + let kernel_codes = if needs_plane_repack(ex_bits) { + let mut out = vec![0u8; NUM_ROWS * kernel_code_len]; + for (seq_row, plane_row) in seq_codes + .chunks_exact(seq_code_len) + .zip(out.chunks_exact_mut(kernel_code_len)) + { + plane_pack_row(seq_row, ex_dim, ex_bits, plane_row); + } + out + } else { + seq_codes.clone() + }; + + let ex_query = build_ex_query(&query, ex_bits); + let kernel = ex_dot_kernel(ex_bits); + c.bench_function( + format!("RQ ex_dot kernel: ex_bits={ex_bits}, DIM={ex_dim}, rows={NUM_ROWS}").as_str(), + |b| { + b.iter(|| { + let mut sum = 0.0f32; + for row in kernel_codes.chunks_exact(kernel_code_len) { + sum += kernel(&ex_query, row); + } + black_box(sum) + }) + }, + ); + + let entries_per_dim = 1usize << ex_bits; + let mut ex_dist_table = vec![0.0f32; ex_dim * entries_per_dim]; + for (dim, table) in ex_dist_table.chunks_exact_mut(entries_per_dim).enumerate() { + for (code, value) in table.iter_mut().enumerate() { + *value = query[dim] * code as f32; + } + } + c.bench_function( + format!("RQ ex_dot table-gather: ex_bits={ex_bits}, DIM={ex_dim}, rows={NUM_ROWS}") + .as_str(), + |b| { + b.iter(|| { + let mut sum = 0.0f32; + for row in seq_codes.chunks_exact(seq_code_len) { + sum += gather_ex_distance(row, ex_dim, ex_bits, &ex_dist_table); + } + black_box(sum) + }) + }, + ); + } +} -#[cfg(not(target_os = "linux"))] criterion_group!( name=benches; config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = construct_dist_table, compute_distances); + targets = construct_dist_table, compute_distances, ex_dot_kernels); criterion_main!(benches); diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index 0fdd918edab..71c4eed7fd8 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -18,6 +18,7 @@ use crate::vector::bq::storage::RabitQuantizationMetadata; use crate::vector::quantizer::QuantizerBuildParams; pub mod builder; +pub mod ex_dot; pub mod rotation; pub mod storage; pub mod transform; diff --git a/rust/lance-index/src/vector/bq/ex_dot.rs b/rust/lance-index/src/vector/bq/ex_dot.rs new file mode 100644 index 00000000000..afe54e6ee37 --- /dev/null +++ b/rust/lance-index/src/vector/bq/ex_dot.rs @@ -0,0 +1,1047 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Inner-product kernels between an `f32` query and bit-packed RaBitQ ex-codes. +//! +//! Multi-bit RaBitQ reranking reduces to `sum_d query[d] * ex_code[d]`, where +//! `ex_code[d]` is an unsigned `ex_bits`-wide integer. Materializing a +//! `dim * 2^ex_bits` lookup table and gathering one entry per dimension is +//! cache-hostile (the table is 1MiB for `ex_bits=8`, `dim=1024`); these kernels +//! instead unpack the codes with shifts and masks and FMA them against the +//! query directly, following the kernel design of the RaBitQ reference library +//! (, Apache-2.0). +//! +//! Two code layouts are consumed: +//! +//! - `ex_bits` ∈ {1, 2, 4, 8}: the sequential LSB-first layout written by the +//! index builder is already byte-aligned, so rows are used as stored. +//! - `ex_bits` ∈ {3, 5, 6, 7}: codes straddle byte boundaries in the sequential +//! layout, so rows are repacked once at load time ([`plane_pack_row`]) into +//! bit-planes that unpack with byte-wise shifts: +//! +//! ```text +//! per 64-dim group (T = ex_bits - 1, the top bit): +//! 3 bits: [16B 2-bit plane][8B top-bit plane] +//! 5 bits: [32B 4-bit plane][8B top-bit plane] +//! 6 bits: [48B: 3 blocks of "6 low bits | 2 stolen bits"] +//! 7 bits: [48B: same as 6 bits][8B top-bit plane] +//! ``` +//! +//! Kernels unpack each group into runs of 16 code bytes whose dimension order +//! differs from the natural order, so the query is permuted once per search +//! with [`build_ex_query_into`] and the kernels then read both sides +//! sequentially. The permuted query is zero-padded to a multiple of 64 so that +//! padded lanes contribute nothing. + +use std::sync::LazyLock; + +/// Dimensions are processed in groups; the permuted query is padded to this +/// multiple so every kernel sees whole groups. +pub const EX_DOT_GROUP_DIMS: usize = 64; + +/// `f32` length of the kernel-order query built by [`build_ex_query_into`]. +pub fn ex_query_len(dim: usize) -> usize { + dim.next_multiple_of(EX_DOT_GROUP_DIMS) +} + +/// Whether the sequential ex-code layout must be repacked into bit-planes for +/// the dot kernels. For the remaining widths codes are byte-aligned already. +pub fn needs_plane_repack(ex_bits: u8) -> bool { + matches!(ex_bits, 3 | 5 | 6 | 7) +} + +/// Bytes per row of the code layout consumed by the dot kernels. +pub fn ex_dot_code_bytes(dim: usize, ex_bits: u8) -> usize { + debug_assert!((1..=8).contains(&ex_bits)); + if needs_plane_repack(ex_bits) { + ex_query_len(dim) * ex_bits as usize / 8 + } else { + (dim * ex_bits as usize).div_ceil(u8::BITS as usize) + } +} + +/// Dimensions per unpacking group for the given code width. +fn group_dims(ex_bits: u8) -> usize { + match ex_bits { + 1 | 4 | 8 => 16, + _ => EX_DOT_GROUP_DIMS, + } +} + +fn group_bytes(ex_bits: u8) -> usize { + group_dims(ex_bits) * ex_bits as usize / 8 +} + +/// Extract the `ex_bits`-wide code of `dim_idx` from a sequentially bit-packed +/// row (LSB-first, codes may straddle byte boundaries). +#[inline] +pub fn packed_ex_code_value(row_codes: &[u8], dim_idx: usize, ex_bits: u8) -> u8 { + debug_assert!(ex_bits > 0); + let bit_offset = dim_idx * ex_bits as usize; + let byte_idx = bit_offset / u8::BITS as usize; + let bit_shift = bit_offset % u8::BITS as usize; + let bits = row_codes[byte_idx] as u16 + | row_codes + .get(byte_idx + 1) + .map(|byte| (*byte as u16) << u8::BITS) + .unwrap_or_default(); + let mask = (1u16 << ex_bits) - 1; + ((bits >> bit_shift) & mask) as u8 +} + +/// Kernel-order position of `dim` within its group (see [`build_ex_query_into`]). +fn kernel_position(dim: usize, ex_bits: u8) -> usize { + match ex_bits { + 1 | 8 => dim, + // 16-dim groups unpack the low nibbles (even dims) before the high + // nibbles (odd dims). + 4 => { + let group = dim / 16; + let r = dim % 16; + group * 16 + r / 2 + (r % 2) * 8 + } + // 64-dim groups unpack four 16-byte runs holding dims k, k+4, k+8, ... + 2 | 3 | 5 | 6 | 7 => { + let group = dim / 64; + let r = dim % 64; + group * 64 + (r % 4) * 16 + r / 4 + } + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } +} + +/// Permute the rotated query into the order the dot kernels unpack codes in, +/// zero-padding to a multiple of [`EX_DOT_GROUP_DIMS`]. +pub fn build_ex_query_into(rotated_query: &[f32], ex_bits: u8, out: &mut [f32]) { + debug_assert_eq!(out.len(), ex_query_len(rotated_query.len())); + out.fill(0.0); + for (dim, &value) in rotated_query.iter().enumerate() { + out[kernel_position(dim, ex_bits)] = value; + } +} + +pub fn build_ex_query(rotated_query: &[f32], ex_bits: u8) -> Vec { + let mut out = vec![0.0; ex_query_len(rotated_query.len())]; + build_ex_query_into(rotated_query, ex_bits, &mut out); + out +} + +/// Pack the top bit of each of 64 codes into a `u64` so kernels can position +/// it with two shifts per 16-code run: the top bit of dim `4j + k` is stored +/// at bit `8 * (j % 8) + 2k + j / 8`. +fn pack_top_plane(group_values: &[u8; 64], top_bit: u8) -> u64 { + let mut plane = 0u64; + for j in 0..16 { + for k in 0..4 { + let bit = (group_values[4 * j + k] >> top_bit) & 1; + plane |= (bit as u64) << (8 * (j % 8) + 2 * k + j / 8); + } + } + plane +} + +/// Shift `plane` so that its bit `8j + from_bit` lands at bit `8j + to_bit`. +#[inline(always)] +fn shift_plane(plane: u64, from_bit: usize, to_bit: usize) -> u64 { + if from_bit >= to_bit { + plane >> (from_bit - to_bit) + } else { + plane << (to_bit - from_bit) + } +} + +/// Pack one group of 64 code values (natural dim order) into the bit-plane +/// layout described in the module docs. +fn plane_pack_group(ex_bits: u8, group_values: &[u8; 64], out: &mut [u8]) { + let v = group_values; + match ex_bits { + 3 => { + for b in 0..16 { + out[b] = (v[4 * b] & 0b11) + | ((v[4 * b + 1] & 0b11) << 2) + | ((v[4 * b + 2] & 0b11) << 4) + | ((v[4 * b + 3] & 0b11) << 6); + } + out[16..24].copy_from_slice(&pack_top_plane(v, 2).to_le_bytes()); + } + 5 => { + for b in 0..16 { + out[b] = (v[4 * b] & 0x0f) | ((v[4 * b + 1] & 0x0f) << 4); + out[16 + b] = (v[4 * b + 2] & 0x0f) | ((v[4 * b + 3] & 0x0f) << 4); + } + out[32..40].copy_from_slice(&pack_top_plane(v, 4).to_le_bytes()); + } + 6 | 7 => { + // Dims k, k+4, ... (k < 3) keep their 6 low bits in block k; the + // fourth dim of each quad is split into three 2-bit pieces stored + // in the blocks' top bits. + for k in 0..3 { + for b in 0..16 { + out[16 * k + b] = + (v[4 * b + k] & 0x3f) | (((v[4 * b + 3] >> (2 * k)) & 0b11) << 6); + } + } + if ex_bits == 7 { + out[48..56].copy_from_slice(&pack_top_plane(v, 6).to_le_bytes()); + } + } + _ => unreachable!("plane packing is only used for ex_bits 3, 5, 6, 7"), + } +} + +/// Repack one sequentially bit-packed row into the kernel bit-plane layout. +/// `out` must have [`ex_dot_code_bytes`] bytes. +pub fn plane_pack_row(seq_row: &[u8], dim: usize, ex_bits: u8, out: &mut [u8]) { + debug_assert!(needs_plane_repack(ex_bits)); + debug_assert_eq!(out.len(), ex_dot_code_bytes(dim, ex_bits)); + let bytes_per_group = group_bytes(ex_bits); + let mut group_values = [0u8; 64]; + for (group, out) in out.chunks_exact_mut(bytes_per_group).enumerate() { + group_values.fill(0); + let base = group * 64; + let count = 64.min(dim.saturating_sub(base)); + for (i, value) in group_values[..count].iter_mut().enumerate() { + *value = packed_ex_code_value(seq_row, base + i, ex_bits); + } + plane_pack_group(ex_bits, &group_values, out); + } +} + +/// Unpack one code group into per-dim values in kernel order (the order +/// [`build_ex_query_into`] permutes the query into). Reference implementation +/// for the SIMD unpackers; also the scalar fallback. +fn unpack_group(ex_bits: u8, group_codes: &[u8], out: &mut [u8; 64]) { + debug_assert_eq!(group_codes.len(), group_bytes(ex_bits)); + match ex_bits { + 1 => { + for (i, value) in out[..16].iter_mut().enumerate() { + *value = (group_codes[i / 8] >> (i % 8)) & 1; + } + } + 2 => { + for k in 0..4 { + for b in 0..16 { + out[16 * k + b] = (group_codes[b] >> (2 * k)) & 0b11; + } + } + } + 3 => { + let plane = u64::from_le_bytes(group_codes[16..24].try_into().unwrap()); + for k in 0..4 { + for b in 0..16 { + let top = (plane >> (8 * (b % 8) + 2 * k + b / 8)) & 1; + out[16 * k + b] = ((group_codes[b] >> (2 * k)) & 0b11) | ((top as u8) << 2); + } + } + } + 4 => { + for b in 0..8 { + out[b] = group_codes[b] & 0x0f; + out[8 + b] = group_codes[b] >> 4; + } + } + 5 => { + let plane = u64::from_le_bytes(group_codes[32..40].try_into().unwrap()); + for k in 0..4 { + for b in 0..16 { + let nibble = (group_codes[16 * (k / 2) + b] >> (4 * (k % 2))) & 0x0f; + let top = (plane >> (8 * (b % 8) + 2 * k + b / 8)) & 1; + out[16 * k + b] = nibble | ((top as u8) << 4); + } + } + } + 6 | 7 => { + for k in 0..3 { + for b in 0..16 { + out[16 * k + b] = group_codes[16 * k + b] & 0x3f; + } + } + for b in 0..16 { + out[48 + b] = (group_codes[b] >> 6) + | ((group_codes[16 + b] >> 6) << 2) + | ((group_codes[32 + b] >> 6) << 4); + } + if ex_bits == 7 { + let plane = u64::from_le_bytes(group_codes[48..56].try_into().unwrap()); + for k in 0..4 { + for b in 0..16 { + let top = (plane >> (8 * (b % 8) + 2 * k + b / 8)) & 1; + out[16 * k + b] |= (top as u8) << 6; + } + } + } + } + 8 => out[..16].copy_from_slice(group_codes), + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } +} + +/// `sum_d ex_query[d] * code[d]` for one row of kernel-layout codes. +/// +/// `ex_query` must be the kernel-order query from [`build_ex_query_into`]; +/// `codes` is the row slice (sequential layout for `ex_bits` ∈ {1, 2, 4, 8}, +/// bit-plane layout otherwise). Rows shorter than the padded query length are +/// treated as zero-padded. +pub type ExDotFn = fn(&[f32], &[u8]) -> f32; + +/// Resolve the dot kernel for `ex_bits` once; the result can be cached by the +/// caller for per-candidate use. +pub fn ex_dot_kernel(ex_bits: u8) -> ExDotFn { + debug_assert!((1..=8).contains(&ex_bits)); + static KERNELS: LazyLock<[ExDotFn; 8]> = + LazyLock::new(|| std::array::from_fn(|i| select_ex_dot_kernel(i as u8 + 1))); + KERNELS[usize::from(ex_bits) - 1] +} + +fn select_ex_dot_kernel(ex_bits: u8) -> ExDotFn { + #[cfg(target_arch = "x86_64")] + { + if std::arch::is_x86_feature_detected!("avx512f") { + return x86::avx512_kernel(ex_bits); + } + if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") + { + return x86::avx2_kernel(ex_bits); + } + } + #[cfg(target_arch = "aarch64")] + { + // NEON is part of the aarch64 baseline. + return neon::kernel(ex_bits); + } + #[allow(unreachable_code)] + scalar_kernel(ex_bits) +} + +fn scalar_kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => ex_dot_scalar::<1>, + 2 => ex_dot_scalar::<2>, + 3 => ex_dot_scalar::<3>, + 4 => ex_dot_scalar::<4>, + 5 => ex_dot_scalar::<5>, + 6 => ex_dot_scalar::<6>, + 7 => ex_dot_scalar::<7>, + 8 => ex_dot_scalar::<8>, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } +} + +fn ex_dot_scalar(ex_query: &[f32], codes: &[u8]) -> f32 { + let group_dims = group_dims(EX_BITS); + let bytes_per_group = group_bytes(EX_BITS); + debug_assert_eq!(ex_query.len() % EX_DOT_GROUP_DIMS, 0); + debug_assert!(codes.len() * u8::BITS as usize <= ex_query.len() * EX_BITS as usize); + + let mut sum = 0.0f32; + let mut unpacked = [0u8; 64]; + let mut padded = [0u8; 56]; + for (group, query) in ex_query.chunks_exact(group_dims).enumerate() { + let start = group * bytes_per_group; + if start >= codes.len() { + // The remaining query lanes are zero padding. + break; + } + let group_codes = if start + bytes_per_group <= codes.len() { + &codes[start..start + bytes_per_group] + } else { + let avail = codes.len() - start; + padded[..bytes_per_group].fill(0); + padded[..avail].copy_from_slice(&codes[start..]); + &padded[..bytes_per_group] + }; + unpack_group(EX_BITS, group_codes, &mut unpacked); + for (q, &code) in query.iter().zip(unpacked[..group_dims].iter()) { + sum += q * code as f32; + } + } + sum +} + +#[cfg(target_arch = "x86_64")] +mod x86 { + use super::ExDotFn; + use std::arch::x86_64::*; + + pub(super) fn avx2_kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => dot_u1_avx2_dispatch, + 2 => dot_u2_avx2_dispatch, + 3 => dot_u3_avx2_dispatch, + 4 => dot_u4_avx2_dispatch, + 5 => dot_u5_avx2_dispatch, + 6 => dot_u6_avx2_dispatch, + 7 => dot_u7_avx2_dispatch, + 8 => dot_u8_avx2_dispatch, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } + } + + pub(super) fn avx512_kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => dot_u1_avx512_dispatch, + 2 => dot_u2_avx512_dispatch, + 3 => dot_u3_avx512_dispatch, + 4 => dot_u4_avx512_dispatch, + 5 => dot_u5_avx512_dispatch, + 6 => dot_u6_avx512_dispatch, + 7 => dot_u7_avx512_dispatch, + 8 => dot_u8_avx512_dispatch, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } + } + + /// Broadcast a byte to the 8 bytes of a `u64`. + #[inline(always)] + fn splat_byte(byte: u8) -> u64 { + byte as u64 * 0x0101_0101_0101_0101 + } + + // Unpack helpers. They read exactly one group of code bytes and return + // runs of 16 codes matching the kernel-order query. Only SSE2 (baseline on + // x86_64) is required. + + /// 16 1-bit codes from 2 bytes: compare each replicated byte against + /// per-lane bit masks to turn set bits into 0/1 bytes. + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u1(ptr: *const u8) -> [__m128i; 1] { + let (b0, b1) = unsafe { (ptr.read(), ptr.add(1).read()) }; + let bytes = _mm_set_epi64x(splat_byte(b1) as i64, splat_byte(b0) as i64); + let bit_select = _mm_set1_epi64x(0x8040_2010_0804_0201u64 as i64); + let selected = _mm_cmpeq_epi8(_mm_and_si128(bytes, bit_select), bit_select); + [_mm_and_si128(selected, _mm_set1_epi8(1))] + } + + /// 64 2-bit codes from 16 bytes: byte b holds dims 4b..4b+3 at bit pairs. + /// The 16-bit shifts drag bits across byte boundaries, which the per-byte + /// mask removes. + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u2(ptr: *const u8) -> [__m128i; 4] { + let raw = unsafe { _mm_loadu_si128(ptr as *const __m128i) }; + let mask = _mm_set1_epi8(0b11); + [ + _mm_and_si128(raw, mask), + _mm_and_si128(_mm_srli_epi16::<2>(raw), mask), + _mm_and_si128(_mm_srli_epi16::<4>(raw), mask), + _mm_and_si128(_mm_srli_epi16::<6>(raw), mask), + ] + } + + /// Position the top-bit plane (see [`super::pack_top_plane`]) of run `k` + /// at `top_bit` within each byte. + #[inline] + #[target_feature(enable = "sse2")] + fn top_plane_run(plane: u64, k: usize, top_bit: usize) -> __m128i { + let lo = super::shift_plane(plane, 2 * k, top_bit); + let hi = super::shift_plane(plane, 2 * k + 1, top_bit); + _mm_and_si128( + _mm_set_epi64x(hi as i64, lo as i64), + _mm_set1_epi8(1 << top_bit), + ) + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u3(ptr: *const u8) -> [__m128i; 4] { + let mut runs = unsafe { unpack_u2(ptr) }; + let plane = unsafe { (ptr.add(16) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = _mm_or_si128(*run, top_plane_run(plane, k, 2)); + } + runs + } + + /// 16 4-bit codes from 8 bytes: low nibbles are the even dims, high + /// nibbles the odd dims. + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u4(ptr: *const u8) -> [__m128i; 1] { + let word = unsafe { (ptr as *const u64).read_unaligned() }; + let mask = 0x0f0f_0f0f_0f0f_0f0fu64; + [_mm_set_epi64x( + ((word >> 4) & mask) as i64, + (word & mask) as i64, + )] + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u5(ptr: *const u8) -> [__m128i; 4] { + let blk0 = unsafe { _mm_loadu_si128(ptr as *const __m128i) }; + let blk1 = unsafe { _mm_loadu_si128(ptr.add(16) as *const __m128i) }; + let plane = unsafe { (ptr.add(32) as *const u64).read_unaligned() }; + let mask = _mm_set1_epi8(0x0f); + let mut runs = [ + _mm_and_si128(blk0, mask), + _mm_and_si128(_mm_srli_epi16::<4>(blk0), mask), + _mm_and_si128(blk1, mask), + _mm_and_si128(_mm_srli_epi16::<4>(blk1), mask), + ]; + for (k, run) in runs.iter_mut().enumerate() { + *run = _mm_or_si128(*run, top_plane_run(plane, k, 4)); + } + runs + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u6(ptr: *const u8) -> [__m128i; 4] { + let blk0 = unsafe { _mm_loadu_si128(ptr as *const __m128i) }; + let blk1 = unsafe { _mm_loadu_si128(ptr.add(16) as *const __m128i) }; + let blk2 = unsafe { _mm_loadu_si128(ptr.add(32) as *const __m128i) }; + let mask6 = _mm_set1_epi8(0x3f); + let mask2 = _mm_set1_epi8(0b1100_0000u8 as i8); + let stolen = _mm_or_si128( + _mm_or_si128( + _mm_srli_epi16::<6>(_mm_and_si128(blk0, mask2)), + _mm_srli_epi16::<4>(_mm_and_si128(blk1, mask2)), + ), + _mm_srli_epi16::<2>(_mm_and_si128(blk2, mask2)), + ); + [ + _mm_and_si128(blk0, mask6), + _mm_and_si128(blk1, mask6), + _mm_and_si128(blk2, mask6), + stolen, + ] + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u7(ptr: *const u8) -> [__m128i; 4] { + let mut runs = unsafe { unpack_u6(ptr) }; + let plane = unsafe { (ptr.add(48) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = _mm_or_si128(*run, top_plane_run(plane, k, 6)); + } + runs + } + + #[inline] + #[target_feature(enable = "sse2")] + unsafe fn unpack_u8x16(ptr: *const u8) -> [__m128i; 1] { + [unsafe { _mm_loadu_si128(ptr as *const __m128i) }] + } + + /// FMA 16 code bytes against 16 query floats (AVX2: two 8-float halves). + #[inline] + #[target_feature(enable = "avx2", enable = "fma")] + unsafe fn fma16_avx2(codes: __m128i, query: *const f32, acc: &mut [__m256; 2]) { + let lo = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(codes)); + acc[0] = _mm256_fmadd_ps(lo, unsafe { _mm256_loadu_ps(query) }, acc[0]); + let hi = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_srli_si128::<8>(codes))); + acc[1] = _mm256_fmadd_ps(hi, unsafe { _mm256_loadu_ps(query.add(8)) }, acc[1]); + } + + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn reduce_add_avx2(acc: [__m256; 2]) -> f32 { + let v = _mm256_add_ps(acc[0], acc[1]); + let halves = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps::<1>(v)); + let pairs = _mm_add_ps(halves, _mm_movehl_ps(halves, halves)); + let total = _mm_add_ss(pairs, _mm_shuffle_ps::<0b01>(pairs, pairs)); + _mm_cvtss_f32(total) + } + + /// FMA 16 code bytes against 16 query floats (AVX-512: one 16-float lane). + #[inline] + #[target_feature(enable = "avx512f")] + unsafe fn fma16_avx512(codes: __m128i, query: *const f32, acc: &mut __m512) { + let values = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(codes)); + *acc = _mm512_fmadd_ps(values, unsafe { _mm512_loadu_ps(query) }, *acc); + } + + macro_rules! x86_dot_kernel { + ($name:ident, $dispatch:ident, $unpack:ident, $ex_bits:expr, $runs:expr) => { + #[target_feature(enable = "avx2", enable = "fma")] + unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { + const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; + const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; + debug_assert_eq!(ex_query.len() % super::EX_DOT_GROUP_DIMS, 0); + debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); + + let groups = ex_query.len() / GROUP_DIMS; + let full_groups = (codes.len() / GROUP_BYTES).min(groups); + // Two accumulators per run position break the FMA latency + // chain; they are summed once at the end. + let mut acc = [_mm256_setzero_ps(); 2]; + for group in 0..full_groups { + // SAFETY: `group < full_groups` keeps both the code group + // and the query run in bounds. + let runs = unsafe { $unpack(codes.as_ptr().add(group * GROUP_BYTES)) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx2( + codes16, + ex_query.as_ptr().add(group * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + let consumed = full_groups * GROUP_BYTES; + if consumed < codes.len() && full_groups < groups { + // Zero-pad the final partial code group on the stack. + let mut padded = [0u8; GROUP_BYTES]; + padded[..codes.len() - consumed].copy_from_slice(&codes[consumed..]); + let runs = unsafe { $unpack(padded.as_ptr()) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx2( + codes16, + ex_query.as_ptr().add(full_groups * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + unsafe { reduce_add_avx2(acc) } + } + + fn $dispatch(ex_query: &[f32], codes: &[u8]) -> f32 { + // SAFETY: only selected when AVX2 and FMA were detected. + unsafe { $name(ex_query, codes) } + } + }; + } + + macro_rules! x86_dot_kernel_avx512 { + ($name:ident, $dispatch:ident, $unpack:ident, $ex_bits:expr, $runs:expr) => { + #[target_feature(enable = "avx512f")] + unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { + const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; + const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; + debug_assert_eq!(ex_query.len() % super::EX_DOT_GROUP_DIMS, 0); + debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); + + let groups = ex_query.len() / GROUP_DIMS; + let full_groups = (codes.len() / GROUP_BYTES).min(groups); + // Alternating by group as well as run keeps two independent + // FMA chains even for the single-run widths. + let mut acc = [_mm512_setzero_ps(); 2]; + for group in 0..full_groups { + // SAFETY: `group < full_groups` keeps both the code group + // and the query run in bounds. + let runs = unsafe { $unpack(codes.as_ptr().add(group * GROUP_BYTES)) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx512( + codes16, + ex_query.as_ptr().add(group * GROUP_DIMS + run * 16), + &mut acc[(group + run) % 2], + ) + }; + } + } + let consumed = full_groups * GROUP_BYTES; + if consumed < codes.len() && full_groups < groups { + let mut padded = [0u8; GROUP_BYTES]; + padded[..codes.len() - consumed].copy_from_slice(&codes[consumed..]); + let runs = unsafe { $unpack(padded.as_ptr()) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_avx512( + codes16, + ex_query.as_ptr().add(full_groups * GROUP_DIMS + run * 16), + &mut acc[(full_groups + run) % 2], + ) + }; + } + } + _mm512_reduce_add_ps(_mm512_add_ps(acc[0], acc[1])) + } + + fn $dispatch(ex_query: &[f32], codes: &[u8]) -> f32 { + // SAFETY: only selected when AVX-512F was detected. + unsafe { $name(ex_query, codes) } + } + }; + } + + x86_dot_kernel!(dot_u1_avx2, dot_u1_avx2_dispatch, unpack_u1, 1, 1); + x86_dot_kernel!(dot_u2_avx2, dot_u2_avx2_dispatch, unpack_u2, 2, 4); + x86_dot_kernel!(dot_u3_avx2, dot_u3_avx2_dispatch, unpack_u3, 3, 4); + x86_dot_kernel!(dot_u4_avx2, dot_u4_avx2_dispatch, unpack_u4, 4, 1); + x86_dot_kernel!(dot_u5_avx2, dot_u5_avx2_dispatch, unpack_u5, 5, 4); + x86_dot_kernel!(dot_u6_avx2, dot_u6_avx2_dispatch, unpack_u6, 6, 4); + x86_dot_kernel!(dot_u7_avx2, dot_u7_avx2_dispatch, unpack_u7, 7, 4); + x86_dot_kernel!(dot_u8_avx2, dot_u8_avx2_dispatch, unpack_u8x16, 8, 1); + + x86_dot_kernel_avx512!(dot_u1_avx512, dot_u1_avx512_dispatch, unpack_u1, 1, 1); + x86_dot_kernel_avx512!(dot_u2_avx512, dot_u2_avx512_dispatch, unpack_u2, 2, 4); + x86_dot_kernel_avx512!(dot_u3_avx512, dot_u3_avx512_dispatch, unpack_u3, 3, 4); + x86_dot_kernel_avx512!(dot_u4_avx512, dot_u4_avx512_dispatch, unpack_u4, 4, 1); + x86_dot_kernel_avx512!(dot_u5_avx512, dot_u5_avx512_dispatch, unpack_u5, 5, 4); + x86_dot_kernel_avx512!(dot_u6_avx512, dot_u6_avx512_dispatch, unpack_u6, 6, 4); + x86_dot_kernel_avx512!(dot_u7_avx512, dot_u7_avx512_dispatch, unpack_u7, 7, 4); + x86_dot_kernel_avx512!(dot_u8_avx512, dot_u8_avx512_dispatch, unpack_u8x16, 8, 1); +} + +#[cfg(target_arch = "aarch64")] +mod neon { + use super::ExDotFn; + use std::arch::aarch64::*; + + pub(super) fn kernel(ex_bits: u8) -> ExDotFn { + match ex_bits { + 1 => dot_u1_neon_dispatch, + 2 => dot_u2_neon_dispatch, + 3 => dot_u3_neon_dispatch, + 4 => dot_u4_neon_dispatch, + 5 => dot_u5_neon_dispatch, + 6 => dot_u6_neon_dispatch, + 7 => dot_u7_neon_dispatch, + 8 => dot_u8_neon_dispatch, + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u1(ptr: *const u8) -> [uint8x16_t; 1] { + let (b0, b1) = unsafe { (ptr.read(), ptr.add(1).read()) }; + let bytes = vcombine_u8(vdup_n_u8(b0), vdup_n_u8(b1)); + let bit_select = vreinterpretq_u8_u64(vdupq_n_u64(0x8040_2010_0804_0201)); + [vandq_u8(vtstq_u8(bytes, bit_select), vdupq_n_u8(1))] + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u2(ptr: *const u8) -> [uint8x16_t; 4] { + let raw = unsafe { vld1q_u8(ptr) }; + let mask = vdupq_n_u8(0b11); + [ + vandq_u8(raw, mask), + vandq_u8(vshrq_n_u8::<2>(raw), mask), + vandq_u8(vshrq_n_u8::<4>(raw), mask), + vshrq_n_u8::<6>(raw), + ] + } + + #[inline] + #[target_feature(enable = "neon")] + fn top_plane_run(plane: u64, k: usize, top_bit: usize) -> uint8x16_t { + let lo = super::shift_plane(plane, 2 * k, top_bit); + let hi = super::shift_plane(plane, 2 * k + 1, top_bit); + vandq_u8( + vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(lo), vcreate_u64(hi))), + vdupq_n_u8(1 << top_bit), + ) + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u3(ptr: *const u8) -> [uint8x16_t; 4] { + let mut runs = unsafe { unpack_u2(ptr) }; + let plane = unsafe { (ptr.add(16) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = vorrq_u8(*run, top_plane_run(plane, k, 2)); + } + runs + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u4(ptr: *const u8) -> [uint8x16_t; 1] { + let word = unsafe { (ptr as *const u64).read_unaligned() }; + let mask = 0x0f0f_0f0f_0f0f_0f0fu64; + [vreinterpretq_u8_u64(vcombine_u64( + vcreate_u64(word & mask), + vcreate_u64((word >> 4) & mask), + ))] + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u5(ptr: *const u8) -> [uint8x16_t; 4] { + let blk0 = unsafe { vld1q_u8(ptr) }; + let blk1 = unsafe { vld1q_u8(ptr.add(16)) }; + let plane = unsafe { (ptr.add(32) as *const u64).read_unaligned() }; + let mask = vdupq_n_u8(0x0f); + let mut runs = [ + vandq_u8(blk0, mask), + vshrq_n_u8::<4>(blk0), + vandq_u8(blk1, mask), + vshrq_n_u8::<4>(blk1), + ]; + for (k, run) in runs.iter_mut().enumerate() { + *run = vorrq_u8(*run, top_plane_run(plane, k, 4)); + } + runs + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u6(ptr: *const u8) -> [uint8x16_t; 4] { + let blk0 = unsafe { vld1q_u8(ptr) }; + let blk1 = unsafe { vld1q_u8(ptr.add(16)) }; + let blk2 = unsafe { vld1q_u8(ptr.add(32)) }; + let mask6 = vdupq_n_u8(0x3f); + let stolen = vorrq_u8( + vorrq_u8( + vshrq_n_u8::<6>(blk0), + vshlq_n_u8::<2>(vshrq_n_u8::<6>(blk1)), + ), + vshlq_n_u8::<4>(vshrq_n_u8::<6>(blk2)), + ); + [ + vandq_u8(blk0, mask6), + vandq_u8(blk1, mask6), + vandq_u8(blk2, mask6), + stolen, + ] + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u7(ptr: *const u8) -> [uint8x16_t; 4] { + let mut runs = unsafe { unpack_u6(ptr) }; + let plane = unsafe { (ptr.add(48) as *const u64).read_unaligned() }; + for (k, run) in runs.iter_mut().enumerate() { + *run = vorrq_u8(*run, top_plane_run(plane, k, 6)); + } + runs + } + + #[inline] + #[target_feature(enable = "neon")] + unsafe fn unpack_u8x16(ptr: *const u8) -> [uint8x16_t; 1] { + [unsafe { vld1q_u8(ptr) }] + } + + /// FMA 16 code bytes against 16 query floats over four 4-float lanes. + #[inline] + #[target_feature(enable = "neon")] + unsafe fn fma16_neon(codes: uint8x16_t, query: *const f32, acc: &mut [float32x4_t; 4]) { + let lo = vmovl_u8(vget_low_u8(codes)); + let hi = vmovl_u8(vget_high_u8(codes)); + let c0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo))); + let c1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo))); + let c2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi))); + let c3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi))); + unsafe { + acc[0] = vfmaq_f32(acc[0], c0, vld1q_f32(query)); + acc[1] = vfmaq_f32(acc[1], c1, vld1q_f32(query.add(4))); + acc[2] = vfmaq_f32(acc[2], c2, vld1q_f32(query.add(8))); + acc[3] = vfmaq_f32(acc[3], c3, vld1q_f32(query.add(12))); + } + } + + macro_rules! neon_dot_kernel { + ($name:ident, $dispatch:ident, $unpack:ident, $ex_bits:expr, $runs:expr) => { + #[target_feature(enable = "neon")] + unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { + const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; + const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; + debug_assert_eq!(ex_query.len() % super::EX_DOT_GROUP_DIMS, 0); + debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); + + let groups = ex_query.len() / GROUP_DIMS; + let full_groups = (codes.len() / GROUP_BYTES).min(groups); + let mut acc = [vdupq_n_f32(0.0); 4]; + for group in 0..full_groups { + // SAFETY: `group < full_groups` keeps both the code group + // and the query run in bounds. + let runs = unsafe { $unpack(codes.as_ptr().add(group * GROUP_BYTES)) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_neon( + codes16, + ex_query.as_ptr().add(group * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + let consumed = full_groups * GROUP_BYTES; + if consumed < codes.len() && full_groups < groups { + // Zero-pad the final partial code group on the stack. + let mut padded = [0u8; GROUP_BYTES]; + padded[..codes.len() - consumed].copy_from_slice(&codes[consumed..]); + let runs = unsafe { $unpack(padded.as_ptr()) }; + for (run, codes16) in runs.into_iter().enumerate() { + unsafe { + fma16_neon( + codes16, + ex_query.as_ptr().add(full_groups * GROUP_DIMS + run * 16), + &mut acc, + ) + }; + } + } + vaddvq_f32(vaddq_f32( + vaddq_f32(acc[0], acc[1]), + vaddq_f32(acc[2], acc[3]), + )) + } + + fn $dispatch(ex_query: &[f32], codes: &[u8]) -> f32 { + // SAFETY: NEON is part of the aarch64 baseline. + unsafe { $name(ex_query, codes) } + } + }; + } + + neon_dot_kernel!(dot_u1_neon, dot_u1_neon_dispatch, unpack_u1, 1, 1); + neon_dot_kernel!(dot_u2_neon, dot_u2_neon_dispatch, unpack_u2, 2, 4); + neon_dot_kernel!(dot_u3_neon, dot_u3_neon_dispatch, unpack_u3, 3, 4); + neon_dot_kernel!(dot_u4_neon, dot_u4_neon_dispatch, unpack_u4, 4, 1); + neon_dot_kernel!(dot_u5_neon, dot_u5_neon_dispatch, unpack_u5, 5, 4); + neon_dot_kernel!(dot_u6_neon, dot_u6_neon_dispatch, unpack_u6, 6, 4); + neon_dot_kernel!(dot_u7_neon, dot_u7_neon_dispatch, unpack_u7, 7, 4); + neon_dot_kernel!(dot_u8_neon, dot_u8_neon_dispatch, unpack_u8x16, 8, 1); +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use rstest::rstest; + + /// Bit-pack code values sequentially (LSB-first), the on-disk ex-code layout. + fn pack_sequential(values: &[u8], ex_bits: u8) -> Vec { + let mut out = vec![0u8; (values.len() * ex_bits as usize).div_ceil(8)]; + for (dim, &value) in values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + out[bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + out[bit_offset / 8 + 1] |= (bits >> 8) as u8; + } + } + out + } + + fn kernel_codes(values: &[u8], dim: usize, ex_bits: u8) -> Vec { + let seq = pack_sequential(values, ex_bits); + if needs_plane_repack(ex_bits) { + let mut out = vec![0u8; ex_dot_code_bytes(dim, ex_bits)]; + plane_pack_row(&seq, dim, ex_bits, &mut out); + out + } else { + seq + } + } + + fn available_kernels(ex_bits: u8) -> Vec<(&'static str, ExDotFn)> { + // `mut` is only exercised on x86_64 where extra kernels may be pushed. + #[allow(unused_mut)] + let mut kernels = vec![ + ("scalar", scalar_kernel(ex_bits)), + ("dispatched", ex_dot_kernel(ex_bits)), + ]; + #[cfg(target_arch = "x86_64")] + { + if std::arch::is_x86_feature_detected!("avx2") + && std::arch::is_x86_feature_detected!("fma") + { + kernels.push(("avx2", x86::avx2_kernel(ex_bits))); + } + if std::arch::is_x86_feature_detected!("avx512f") { + kernels.push(("avx512", x86::avx512_kernel(ex_bits))); + } + } + kernels + } + + #[rstest] + fn test_ex_dot_matches_reference( + #[values(1, 2, 3, 4, 5, 6, 7, 8)] ex_bits: u8, + #[values(7, 16, 60, 64, 100, 128, 1024, 1536, 2048)] dim: usize, + ) { + let mut rng = SmallRng::seed_from_u64(42 + ex_bits as u64 * 1000 + dim as u64); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let query = (0..dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + + let expected = query + .iter() + .zip(values.iter()) + .map(|(q, &c)| *q as f64 * c as f64) + .sum::(); + + let codes = kernel_codes(&values, dim, ex_bits); + let ex_query = build_ex_query(&query, ex_bits); + assert_eq!(ex_query.len() % EX_DOT_GROUP_DIMS, 0); + + let tolerance = 1e-3 * expected.abs().max(1.0); + for (name, kernel) in available_kernels(ex_bits) { + let actual = kernel(&ex_query, &codes) as f64; + assert!( + (actual - expected).abs() <= tolerance, + "ex_bits={ex_bits} dim={dim} kernel={name}: {actual} != {expected}" + ); + } + } + + #[rstest] + fn test_unpack_group_roundtrip(#[values(1, 2, 3, 4, 5, 6, 7, 8)] ex_bits: u8) { + let mut rng = SmallRng::seed_from_u64(7 + ex_bits as u64); + let dims = group_dims(ex_bits); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..dims) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let codes = kernel_codes(&values, dims, ex_bits); + + let mut unpacked = [0u8; 64]; + unpack_group(ex_bits, &codes, &mut unpacked); + for dim in 0..dims { + assert_eq!( + unpacked[kernel_position(dim, ex_bits)], + values[dim], + "ex_bits={ex_bits} dim={dim}" + ); + } + } + + /// Dense dim sweep for the bit-plane widths: every tail shape within the + /// 64-dim kernel group, plus multi-group sizes. + #[rstest] + fn test_ex_dot_plane_widths_dense_dims(#[values(3, 5)] ex_bits: u8) { + let mut rng = SmallRng::seed_from_u64(97 + ex_bits as u64); + let max_code = ((1u16 << ex_bits) - 1) as u8; + for dim in (1..=160).chain([255, 256, 1000, 1536, 2048]) { + let values = (0..dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let query = (0..dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let expected = query + .iter() + .zip(values.iter()) + .map(|(q, &c)| *q as f64 * c as f64) + .sum::(); + + let codes = kernel_codes(&values, dim, ex_bits); + let ex_query = build_ex_query(&query, ex_bits); + let tolerance = 1e-3 * expected.abs().max(1.0); + for (name, kernel) in available_kernels(ex_bits) { + let actual = kernel(&ex_query, &codes) as f64; + assert!( + (actual - expected).abs() <= tolerance, + "ex_bits={ex_bits} dim={dim} kernel={name}: {actual} != {expected}" + ); + } + } + } + + #[test] + fn test_build_ex_query_pads_with_zeros() { + let query = vec![1.0f32; 100]; + for ex_bits in 1..=8u8 { + let ex_query = build_ex_query(&query, ex_bits); + assert_eq!(ex_query.len(), 128); + let sum = ex_query.iter().sum::(); + assert_eq!(sum, 100.0, "ex_bits={ex_bits}"); + } + } +} diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index bd70f176c5d..b5b38971fed 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -41,6 +41,10 @@ use serde::{Deserialize, Serialize}; use crate::frag_reuse::FragReuseIndex; use crate::pb; use crate::vector::ApproxMode; +use crate::vector::bq::ex_dot::{ + ExDotFn, build_ex_query, build_ex_query_into, ex_dot_code_bytes, ex_dot_kernel, ex_query_len, + needs_plane_repack, plane_pack_row, +}; use crate::vector::bq::rotation::{apply_fast_rotation, apply_fast_rotation_in_place}; use crate::vector::bq::transform::{ ADD_FACTORS_COLUMN, ERROR_FACTORS_COLUMN, EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN, @@ -349,10 +353,13 @@ impl RabitQuantizationMetadata { let code_dim = self.code_dim(); let ex_bits = rabit_ex_bits(self.num_bits)?; let dist_table_len = code_dim * 4; - let ex_dist_table_len = if ex_bits == 0 { - 0 - } else { + // The quantized ex dist table is only consumed by the FastScan bulk + // path; the exact rerank path multiplies the query against the packed + // codes directly (see `ex_dot`). + let ex_dist_table_len = if supports_ex_fastscan(ex_bits) { code_dim * (1usize << ex_bits) + } else { + 0 }; let mut rotated_query = vec![0.0; code_dim]; @@ -364,6 +371,12 @@ impl RabitQuantizationMetadata { let mut ex_dist_table = vec![0.0; ex_dist_table_len]; build_ex_dist_table_direct_into(&rotated_query, ex_bits, &mut ex_dist_table); + let mut ex_query = Vec::new(); + if ex_bits > 0 { + ex_query.resize(ex_query_len(code_dim), 0.0); + build_ex_query_into(&rotated_query, ex_bits, &mut ex_query); + } + let sum_q = rotated_query.iter().copied().sum(); Ok(RabitRawQueryContext { code_dim, @@ -371,6 +384,7 @@ impl RabitQuantizationMetadata { rotated_query, dist_table, ex_dist_table, + ex_query, sum_q, }) } @@ -464,6 +478,12 @@ pub struct RabitQuantizationStorage { error_factors: Option, ex_codes: Option, packed_ex_codes: Option, + // ex codes repacked into the bit-plane layout consumed by the ex-dot + // kernels; only present for the widths whose sequential layout is not + // byte-aligned (see `ex_dot::needs_plane_repack`). This keeps a second + // resident copy of the ex codes, mirroring `packed_ex_codes` for the + // FastScan widths. + plane_ex_codes: Option, ex_add_factors: Option, ex_scale_factors: Option, } @@ -477,6 +497,11 @@ impl DeepSizeOf for RabitQuantizationStorage { .as_ref() .map(|codes| (codes as &dyn Array).deep_size_of_children(context)) .unwrap_or_default() + + self + .plane_ex_codes + .as_ref() + .map(|codes| (codes as &dyn Array).deep_size_of_children(context)) + .unwrap_or_default() } } @@ -561,14 +586,18 @@ impl RabitQuantizationStorage { dim, dist_table, ex_dist_table, + ex_query, sum_q, query_factor, query_error, approx_mode, } = parts; + // The ex-dot kernels consume the bit-plane repack where one exists and + // the sequential (byte-aligned) layout otherwise. let ex_codes = self - .ex_codes + .plane_ex_codes .as_ref() + .or(self.ex_codes.as_ref()) .map(|codes| codes.values().as_primitive::().values().as_ref()); let packed_ex_codes = self .packed_ex_codes @@ -580,6 +609,7 @@ impl RabitQuantizationStorage { self.metadata.query_estimator, dist_table, ex_dist_table, + ex_query, sum_q, self.codes.values().as_primitive::().values(), ex_codes, @@ -768,6 +798,7 @@ struct RabitDistCalculatorParts<'a> { dim: usize, dist_table: Cow<'a, [f32]>, ex_dist_table: Cow<'a, [f32]>, + ex_query: Cow<'a, [f32]>, sum_q: f32, query_factor: f32, query_error: f32, @@ -780,12 +811,19 @@ pub struct RabitDistCalculator<'a> { query_estimator: RabitQueryEstimator, // n * d / 8 binary-code bytes codes: &'a [u8], + // per-row ex codes in the layout consumed by the ex-dot kernels + // (`ex_dot::ex_dot_code_bytes` bytes per row) ex_codes: Option<&'a [u8]>, // this is a flattened 2D array of size d/4 * 16, // we split the query codes into d/4 chunks, each chunk is with 4 elements, // then dist_table[i][j] is the distance between the i-th query code and the code j dist_table: Cow<'a, [f32]>, + // only built for the ex widths supported by FastScan; the exact rerank + // path uses `ex_query` + `ex_dot` instead ex_dist_table: Cow<'a, [f32]>, + // rotated query permuted into kernel order (see `ex_dot::build_ex_query_into`) + ex_query: Cow<'a, [f32]>, + ex_dot: Option, add_factors: &'a [f32], scale_factors: &'a [f32], error_factors: Option<&'a [f32]>, @@ -808,6 +846,7 @@ impl<'a> RabitDistCalculator<'a> { query_estimator: RabitQueryEstimator, dist_table: Cow<'a, [f32]>, ex_dist_table: Cow<'a, [f32]>, + ex_query: Cow<'a, [f32]>, sum_q: f32, codes: &'a [u8], ex_codes: Option<&'a [u8]>, @@ -821,6 +860,7 @@ impl<'a> RabitDistCalculator<'a> { query_error: f32, approx_mode: ApproxMode, ) -> Self { + let ex_dot = (num_bits > 1).then(|| ex_dot_kernel(num_bits - 1)); Self { dim, num_bits, @@ -829,6 +869,8 @@ impl<'a> RabitDistCalculator<'a> { ex_codes, dist_table, ex_dist_table, + ex_query, + ex_dot, add_factors, scale_factors, error_factors, @@ -843,6 +885,18 @@ impl<'a> RabitDistCalculator<'a> { } } + /// `sum_d query[d] * ex_code[d]` for the candidate's packed ex codes. + #[inline] + fn ex_code_dot(&self, ex_codes: &[u8], id: usize, ex_code_len: usize) -> f32 { + let ex_dot = self + .ex_dot + .expect("raw-query multi-bit RQ requires an ex-dot kernel"); + ex_dot( + self.ex_query.as_ref(), + &ex_codes[id * ex_code_len..(id + 1) * ex_code_len], + ) + } + #[allow(clippy::uninit_vec)] fn binary_distances_with_scratch( &self, @@ -1030,8 +1084,7 @@ impl<'a> RabitDistCalculator<'a> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = - rabit_ex_code_bytes(self.dim, ex_bits).expect("RabitQ num_bits should be validated"); + let ex_code_len = ex_dot_code_bytes(self.dim, ex_bits); let code_scale = (1u32 << ex_bits) as f32; let code_bias = -(code_scale - 0.5); @@ -1088,14 +1141,7 @@ impl<'a> RabitDistCalculator<'a> { .enumerate() .skip(fastscan_len) .for_each(|(id, dist)| { - let ex_dist = compute_single_rq_ex_distance( - ex_codes, - id, - ex_code_len, - ex_bits, - self.dim, - &self.ex_dist_table, - ); + let ex_dist = self.ex_code_dot(ex_codes, id, ex_code_len); let full_dot = code_scale * *dist + ex_dist + code_bias * self.sum_q; *dist = full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor; }); @@ -1126,14 +1172,7 @@ impl<'a> RabitDistCalculator<'a> { ex_add_factors: &[f32], ex_scale_factors: &[f32], ) -> f32 { - let ex_dist = compute_single_rq_ex_distance( - ex_codes, - id, - ex_code_len, - ex_bits, - self.dim, - &self.ex_dist_table, - ); + let ex_dist = self.ex_code_dot(ex_codes, id, ex_code_len); let code_bias = -((1u32 << ex_bits) as f32 - 0.5); let full_dot = (1u32 << ex_bits) as f32 * binary_ip + ex_dist + code_bias * self.sum_q; full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor @@ -1180,8 +1219,7 @@ impl<'a> RabitDistCalculator<'a> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = - rabit_ex_code_bytes(self.dim, ex_bits).expect("RabitQ num_bits should be validated"); + let ex_code_len = ex_dot_code_bytes(self.dim, ex_bits); let query_lower_bound = lower_bound.unwrap_or(f32::MIN); let query_upper_bound = upper_bound.unwrap_or(f32::MAX); let mut max_dist = res.peek().map(|node| node.dist); @@ -1287,7 +1325,9 @@ fn build_ex_dist_table_direct(rotated_query: &[f32], ex_bits: u8) -> Vec { } fn build_ex_dist_table_direct_into(rotated_query: &[f32], ex_bits: u8, dist_table: &mut [f32]) { - if ex_bits == 0 { + // The table may legitimately be empty for multi-bit widths without + // FastScan support; the exact path uses the ex-dot kernels instead. + if ex_bits == 0 || dist_table.is_empty() { debug_assert!(dist_table.is_empty()); return; } @@ -1401,21 +1441,6 @@ fn quantize_dist_table_u16_into( (qmin, qmax) } -#[inline] -fn packed_ex_code_value(row_codes: &[u8], dim_idx: usize, ex_bits: u8) -> u8 { - debug_assert!(ex_bits > 0); - let bit_offset = dim_idx * ex_bits as usize; - let byte_idx = bit_offset / u8::BITS as usize; - let bit_shift = bit_offset % u8::BITS as usize; - let bits = row_codes[byte_idx] as u16 - | row_codes - .get(byte_idx + 1) - .map(|byte| (*byte as u16) << u8::BITS) - .unwrap_or_default(); - let mask = (1u16 << ex_bits) - 1; - ((bits >> bit_shift) & mask) as u8 -} - fn quantize_ex_fastscan_dist_table_into( dim: usize, ex_bits: u8, @@ -1517,28 +1542,6 @@ fn ex_dist_table_value( ex_dist_table[dim_idx * entries_per_dim + code] } -#[inline] -fn compute_single_rq_ex_distance( - ex_codes: &[u8], - id: usize, - ex_code_len: usize, - ex_bits: u8, - dim: usize, - ex_dist_table: &[f32], -) -> f32 { - if ex_bits == 0 { - return 0.0; - } - let entries_per_dim = 1usize << ex_bits; - let row_codes = &ex_codes[id * ex_code_len..(id + 1) * ex_code_len]; - (0..dim) - .map(|dim_idx| { - let code = packed_ex_code_value(row_codes, dim_idx, ex_bits) as usize; - ex_dist_table[dim_idx * entries_per_dim + code] - }) - .sum() -} - fn maybe_pack_ex_codes( ex_codes: Option<&FixedSizeListArray>, ex_bits: u8, @@ -1550,6 +1553,33 @@ fn maybe_pack_ex_codes( } } +/// Repack sequential ex codes into the bit-plane layout the ex-dot kernels +/// consume, for the widths whose sequential layout is not byte-aligned. +fn maybe_plane_pack_ex_codes( + ex_codes: Option<&FixedSizeListArray>, + dim: usize, + ex_bits: u8, +) -> Result> { + let ex_codes = match ex_codes { + Some(ex_codes) if needs_plane_repack(ex_bits) => ex_codes, + _ => return Ok(None), + }; + let seq_code_len = ex_codes.value_length() as usize; + let seq_values = ex_codes.values().as_primitive::().values(); + let plane_code_len = ex_dot_code_bytes(dim, ex_bits); + let mut plane_values = vec![0u8; ex_codes.len() * plane_code_len]; + for (seq_row, plane_row) in seq_values + .chunks_exact(seq_code_len) + .zip(plane_values.chunks_exact_mut(plane_code_len)) + { + plane_pack_row(seq_row, dim, ex_bits, plane_row); + } + Ok(Some(FixedSizeListArray::try_new_from_values( + UInt8Array::from(plane_values), + plane_code_len as i32, + )?)) +} + impl DistCalculator for RabitDistCalculator<'_> { #[inline(always)] fn distance(&self, id: u32) -> f32 { @@ -1580,8 +1610,7 @@ impl DistCalculator for RabitDistCalculator<'_> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = rabit_ex_code_bytes(self.dim, ex_bits) - .expect("RabitQ num_bits should be validated"); + let ex_code_len = ex_dot_code_bytes(self.dim, ex_bits); self.raw_query_multi_bit_exact_distance( id, dist, @@ -1866,7 +1895,16 @@ impl VectorStore for RabitQuantizationStorage { let rotated_qr = self.rotate_query_vector(code_dim, &qr); let dist_table = build_dist_table_direct::(&rotated_qr); let ex_bits = self.metadata.num_bits - 1; - let ex_dist_table = build_ex_dist_table_direct(&rotated_qr, ex_bits); + let ex_dist_table = if supports_ex_fastscan(ex_bits) { + build_ex_dist_table_direct(&rotated_qr, ex_bits) + } else { + Vec::new() + }; + let ex_query = if ex_bits > 0 { + build_ex_query(&rotated_qr, ex_bits) + } else { + Vec::new() + }; let query_factor = match self.metadata.query_estimator { RabitQueryEstimator::ResidualQuery => self.residual_query_factor(dist_q_c), RabitQueryEstimator::RawQuery => self.raw_query_factor(dist_q_c, &rotated_qr, None), @@ -1883,6 +1921,7 @@ impl VectorStore for RabitQuantizationStorage { dim: code_dim, dist_table: Cow::Owned(dist_table), ex_dist_table: Cow::Owned(ex_dist_table), + ex_query: Cow::Owned(ex_query), sum_q, query_factor, query_error, @@ -1922,6 +1961,7 @@ impl VectorStore for RabitQuantizationStorage { dim: code_dim, dist_table: Cow::Borrowed(&raw_query.dist_table), ex_dist_table: Cow::Borrowed(&raw_query.ex_dist_table), + ex_query: Cow::Borrowed(&raw_query.ex_query), sum_q: raw_query.sum_q, query_factor, query_error, @@ -1931,18 +1971,29 @@ impl VectorStore for RabitQuantizationStorage { let dist_table_len = code_dim * 4; let ex_bits = self.metadata.num_bits - 1; - let ex_dist_table_len = if ex_bits == 0 { + // Only the FastScan bulk path consumes the quantized ex dist table; + // the exact rerank path uses the kernel-order query instead. + let ex_dist_table_len = if supports_ex_fastscan(ex_bits) { + code_dim * (1usize << ex_bits) + } else { + 0 + }; + let ex_query_table_len = if ex_bits == 0 { 0 } else { - code_dim * (1usize << ex_bits) + ex_query_len(code_dim) }; - f32_scratch.resize(code_dim + dist_table_len + ex_dist_table_len, 0.0); + f32_scratch.resize( + code_dim + dist_table_len + ex_dist_table_len + ex_query_table_len, + 0.0, + ); let query_factor; let query_error; let sum_q = { let (rotated_qr, remaining) = f32_scratch.split_at_mut(code_dim); - let (dist_table, ex_dist_table) = remaining.split_at_mut(dist_table_len); + let (dist_table, remaining) = remaining.split_at_mut(dist_table_len); + let (ex_dist_table, ex_query) = remaining.split_at_mut(ex_dist_table_len); match residual { Some(QueryResidual::Centroid(residual_centroid)) => { self.rotate_query_vector_into( @@ -1982,15 +2033,20 @@ impl VectorStore for RabitQuantizationStorage { }; build_dist_table_direct_into::(rotated_qr, dist_table); build_ex_dist_table_direct_into(rotated_qr, ex_bits, ex_dist_table); + if ex_bits > 0 { + build_ex_query_into(rotated_qr, ex_bits, ex_query); + } rotated_qr.iter().copied().sum() }; + let ex_dist_table_start = code_dim + dist_table_len; + let ex_query_start = ex_dist_table_start + ex_dist_table_len; self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, - dist_table: Cow::Borrowed(&f32_scratch[code_dim..code_dim + dist_table_len]), - ex_dist_table: Cow::Borrowed( - &f32_scratch - [code_dim + dist_table_len..code_dim + dist_table_len + ex_dist_table_len], + dist_table: Cow::Borrowed(&f32_scratch[code_dim..ex_dist_table_start]), + ex_dist_table: Cow::Borrowed(&f32_scratch[ex_dist_table_start..ex_query_start]), + ex_query: Cow::Borrowed( + &f32_scratch[ex_query_start..ex_query_start + ex_query_table_len], ), sum_q, query_factor, @@ -2271,6 +2327,8 @@ impl QuantizerStorage for RabitQuantizationStorage { let mut metadata = metadata.clone(); metadata.packed = true; let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); + let plane_ex_codes = + maybe_plane_pack_ex_codes(ex_codes.as_ref(), metadata.rotated_dim(), ex_bits)?; Ok(Self { metadata, @@ -2283,6 +2341,7 @@ impl QuantizerStorage for RabitQuantizationStorage { error_factors, ex_codes, packed_ex_codes, + plane_ex_codes, ex_add_factors, ex_scale_factors, }) @@ -2356,8 +2415,10 @@ impl QuantizerStorage for RabitQuantizationStorage { let ex_codes = batch .column_by_name(RABIT_EX_CODE_COLUMN) .map(|codes| codes.as_fixed_size_list().clone()); - let packed_ex_codes = - maybe_pack_ex_codes(ex_codes.as_ref(), rabit_ex_bits(self.metadata.num_bits)?); + let ex_bits = rabit_ex_bits(self.metadata.num_bits)?; + let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); + let plane_ex_codes = + maybe_plane_pack_ex_codes(ex_codes.as_ref(), self.metadata.rotated_dim(), ex_bits)?; let ex_add_factors = batch .column_by_name(EX_ADD_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); @@ -2375,6 +2436,7 @@ impl QuantizerStorage for RabitQuantizationStorage { error_factors, ex_codes, packed_ex_codes, + plane_ex_codes, ex_add_factors, ex_scale_factors, row_ids: new_row_ids, @@ -2898,6 +2960,206 @@ mod tests { assert_eq!(distances, vec![104.0, 22.0]); } + /// Exercise the ex-dot kernel through the storage API for every ex width, + /// including the widths without FastScan support ({1, 3, 5, 6, 7}), and a + /// dim that is not a multiple of the 64-dim kernel group. + /// + /// The dim must be a multiple of 8: the binary distance stage consumes + /// two 4-dim segments per code byte and ignores trailing dims otherwise. + #[test] + fn test_raw_query_multi_bit_distance_matches_reference_for_all_ex_widths() { + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + // 72 exercises the kernels' padded-tail path; 1536 is a production + // embedding dim exercising the full-group path. + for (code_dim, num_rows) in [(72usize, 33usize), (1536, 33)] { + for num_bits in 2..=9u8 { + let ex_bits = num_bits - 1; + let mut rng = SmallRng::seed_from_u64(num_bits as u64); + + let sign_bits = (0..num_rows * code_dim) + .map(|_| rng.random_bool(0.5)) + .collect::>(); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let ex_values = (0..num_rows * code_dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + + let code_len = rabit_binary_code_bytes(code_dim); + let mut code_bytes = vec![0u8; num_rows * code_len]; + for (row, bits) in sign_bits.chunks_exact(code_dim).enumerate() { + for (dim, &bit) in bits.iter().enumerate() { + code_bytes[row * code_len + dim / 8] |= (bit as u8) << (dim % 8); + } + } + let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_values.chunks_exact(code_dim).enumerate() { + for (dim, &value) in values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + ex_code_bytes[row * ex_code_len + bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + ex_code_bytes[row * ex_code_len + bit_offset / 8 + 1] |= + (bits >> 8) as u8; + } + } + } + + let identity = Float32Array::from_iter_values((0..code_dim).flat_map(|row| { + (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 }) + })); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from(code_bytes), + code_len as i32, + ) + .unwrap(); + let ex_codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from(ex_code_bytes), + ex_code_len as i32, + ) + .unwrap(); + let ex_add_factors = (0..num_rows) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let ex_scale_factors = (0..num_rows) + .map(|_| rng.random_range(0.1f32..1.0)) + .collect::>(); + let batch = RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef, + ), + (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + (RABIT_EX_CODE_COLUMN, Arc::new(ex_codes) as ArrayRef), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(ex_add_factors.clone())) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(ex_scale_factors.clone())) as ArrayRef, + ), + ]) + .unwrap(); + let storage = RabitQuantizationStorage::try_from_batch( + batch, + &metadata, + DistanceType::L2, + None, + ) + .unwrap(); + + let query = (0..code_dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let sum_q = query.iter().sum::(); + let calc = storage + .dist_calculator(Arc::new(Float32Array::from(query.clone())) as ArrayRef, 0.0); + + let code_scale = (1u32 << ex_bits) as f32; + let code_bias = -(code_scale - 0.5); + let expected = (0..num_rows) + .map(|row| { + let binary_ip = (0..code_dim) + .map(|dim| query[dim] * sign_bits[row * code_dim + dim] as u8 as f32) + .sum::(); + let ex_dist = (0..code_dim) + .map(|dim| query[dim] * ex_values[row * code_dim + dim] as f32) + .sum::(); + let full_dot = code_scale * binary_ip + ex_dist + code_bias * sum_q; + full_dot * ex_scale_factors[row] + ex_add_factors[row] + }) + .collect::>(); + + for (row, &want) in expected.iter().enumerate() { + let got = calc.distance(row as u32); + assert!( + (got - want).abs() <= 1e-3 * want.abs().max(1.0), + "num_bits={num_bits} row={row}: {got} != {want}" + ); + } + + let mut distances = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + let mut u32_scratch = Vec::new(); + calc.distance_all_with_scratch( + 0, + &mut distances, + &mut u16_scratch, + &mut u8_scratch, + &mut u32_scratch, + ); + assert_eq!(distances.len(), num_rows); + // The bulk path quantizes the binary LUT to u8, and that error is + // amplified by 2^ex_bits in the multi-bit estimate, so the value + // assertions need a quantization-aware bound. The FastScan ex + // widths additionally quantize the ex LUT and are covered by + // `test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes`. + if !matches!(ex_bits, 2 | 4 | 8) { + // Worst-case |error| of one u8-quantized binary LUT lookup is + // (table range) / 255 / 2, accumulated over one lookup per + // 8-dim pair of segments. + let num_tables = code_dim.div_ceil(4); + let mut table_min = f32::INFINITY; + let mut table_max = f32::NEG_INFINITY; + for segment in query.chunks(4) { + for subset in 0..16usize { + let value = segment + .iter() + .enumerate() + .filter(|(idx, _)| subset & (1 << idx) != 0) + .map(|(_, q)| *q) + .sum::(); + table_min = table_min.min(value); + table_max = table_max.max(value); + } + } + let binary_bound = + code_scale * num_tables as f32 * (table_max - table_min) / 255.0 / 2.0 + * ex_scale_factors.iter().fold(0.0f32, |max, &s| max.max(s)); + for (row, (&got, &want)) in distances.iter().zip(expected.iter()).enumerate() { + assert!( + (got - want).abs() <= binary_bound + 1e-3, + "num_bits={num_bits} row={row} (distance_all): {got} != {want} (bound {binary_bound})" + ); + } + // Rows past the SIMD batch use the exact binary path, so the + // final remainder row must match the per-candidate distance. + let remainder_row = num_rows - 1; + let got = distances[remainder_row]; + let want = calc.distance(remainder_row as u32); + assert!( + (got - want).abs() <= 1e-3 * want.abs().max(1.0), + "num_bits={num_bits} remainder row (distance_all): {got} != {want}" + ); + } + } + } + } + #[test] fn test_fast_approx_mode_uses_one_bit_scores_for_multi_bit_raw_query() { let code_dim = 8usize; @@ -3571,11 +3833,19 @@ mod tests { #[test] fn test_remap_preserves_multi_bit_rq_split_columns() { + // num_bits=9 keeps sequential ex codes; num_bits 4/6/8 (ex_bits + // 3/5/7) also exercise the bit-plane repack rebuild in `remap`. + for num_bits in [4, 6, 8, 9u8] { + test_remap_preserves_multi_bit_rq_split_columns_impl(num_bits); + } + } + + fn test_remap_preserves_multi_bit_rq_split_columns_impl(num_bits: u8) { let original_codes = make_test_codes(50, 64); let code_dim = original_codes.value_length() as usize * 8; - let ex_codes = make_test_ex_codes(original_codes.len(), code_dim, 9); + let ex_codes = make_test_ex_codes(original_codes.len(), code_dim, num_bits); let mut metadata = make_test_metadata(code_dim); - metadata.num_bits = 9; + metadata.num_bits = num_bits; let storage = RabitQuantizationStorage::try_from_batch( make_test_batch_with_ex(original_codes.clone(), ex_codes), &metadata, @@ -3599,11 +3869,12 @@ mod tests { ); assert_eq!(remapped_row_ids, expected_row_ids.values()); + let ex_code_len = rabit_ex_code_bytes(code_dim, rabit_ex_bits(num_bits).unwrap()).unwrap(); assert_eq!( remapped_batch[RABIT_EX_CODE_COLUMN] .as_fixed_size_list() .value_length(), - 64 + ex_code_len as i32 ); assert_eq!( &remapped_batch[EX_ADD_FACTORS_COLUMN] @@ -3623,5 +3894,20 @@ mod tests { .values()[..5], &[0.25, 1.25, 2.25, 4.25, 5.25] ); + + // The remapped storage must hold the same kernel-layout ex codes as a + // storage freshly loaded from the remapped batch. + let reloaded = RabitQuantizationStorage::try_from_batch( + remapped_batch, + &remapped.metadata, + DistanceType::L2, + None, + ) + .unwrap(); + assert_eq!(remapped.plane_ex_codes, reloaded.plane_ex_codes); + assert_eq!( + remapped.plane_ex_codes.is_some(), + needs_plane_repack(rabit_ex_bits(num_bits).unwrap()) + ); } } diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index b036e187b77..2b7382e1c65 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -249,7 +249,12 @@ pub struct RabitRawQueryContext { pub ex_bits: u8, pub rotated_query: Vec, pub dist_table: Vec, + /// Quantized-table input for the FastScan ex path; empty for ex widths + /// without FastScan support. pub ex_dist_table: Vec, + /// Rotated query permuted into ex-dot kernel order (see + /// `lance_index::vector::bq::ex_dot::build_ex_query_into`). + pub ex_query: Vec, pub sum_q: f32, } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 4ea076ed420..4c806355e2b 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -38,6 +38,7 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::{LocalMetricsCollector, MetricsCollector, NoOpMetricsCollector}; use lance_index::vector::VectorIndexCacheEntry; use lance_index::vector::bq::builder::RabitQuantizer; +use lance_index::vector::bq::ex_dot::ex_query_len; use lance_index::vector::bq::storage::{RabitQueryEstimator, SEGMENT_NUM_CODES}; use lance_index::vector::bq::{rabit_ex_bits, rabit_ex_code_bytes}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; @@ -152,16 +153,17 @@ fn rotated_partition_centroid_slice( cache.rotated_centroids.get(start..end) } -fn rabit_ex_dist_table_len(dim: usize, num_bits: u8) -> usize { +/// `f32` scratch needed for the ex-bit query state: the quantized-table input +/// for FastScan-supported widths plus the kernel-order query for the exact +/// rerank path. +fn rabit_ex_scratch_len(dim: usize, num_bits: u8) -> usize { rabit_ex_bits(num_bits) - .map(|ex_bits| { - if ex_bits == 0 { - 0 - } else { - dim * (1usize << usize::from(ex_bits)) - } + .map(|ex_bits| match ex_bits { + 0 => 0, + 2 | 4 | 8 => dim * (1usize << usize::from(ex_bits)) + ex_query_len(dim), + _ => ex_query_len(dim), }) - .unwrap_or(dim * 256) + .unwrap_or(dim * 256 + ex_query_len(dim)) } fn rabit_u8_scratch_len(dim: usize, num_bits: u8) -> usize { @@ -183,12 +185,12 @@ fn rabit_query_scratch_capacity( num_bits: u8, ) -> QueryScratchCapacity { let dist_table_len = dim * 4; - let ex_dist_table_len = rabit_ex_dist_table_len(dim, num_bits); + let ex_scratch_len = rabit_ex_scratch_len(dim, num_bits); let u8_scratch_len = rabit_u8_scratch_len(dim, num_bits); QueryScratchCapacity::new( max_partition_len, - dim + dist_table_len + ex_dist_table_len, + dim + dist_table_len + ex_scratch_len, max_partition_len.max(dist_table_len), u8_scratch_len, ) @@ -1908,6 +1910,8 @@ mod tests { use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::bq::{ RQBuildParams, RQRotationType, + ex_dot::ex_query_len, + rabit_ex_code_bytes, storage::{RABIT_EX_CODE_COLUMN, RabitQuantizationMetadata, RabitQueryEstimator}, transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}, }; @@ -1983,14 +1987,15 @@ mod tests { } #[test] - fn test_rabit_ex_dist_table_len_uses_num_bits() { + fn test_rabit_ex_scratch_len_uses_num_bits() { let dim = 960; + let ex_query = ex_query_len(dim); - assert_eq!(super::rabit_ex_dist_table_len(dim, 1), 0); - assert_eq!(super::rabit_ex_dist_table_len(dim, 3), dim * 4); - assert_eq!(super::rabit_ex_dist_table_len(dim, 5), dim * 16); - assert_eq!(super::rabit_ex_dist_table_len(dim, 7), dim * 64); - assert_eq!(super::rabit_ex_dist_table_len(dim, 9), dim * 256); + assert_eq!(super::rabit_ex_scratch_len(dim, 1), 0); + assert_eq!(super::rabit_ex_scratch_len(dim, 3), dim * 4 + ex_query); + assert_eq!(super::rabit_ex_scratch_len(dim, 5), dim * 16 + ex_query); + assert_eq!(super::rabit_ex_scratch_len(dim, 7), ex_query); + assert_eq!(super::rabit_ex_scratch_len(dim, 9), dim * 256 + ex_query); } #[test] @@ -2012,7 +2017,10 @@ mod tests { let capacity = super::rabit_query_scratch_capacity(dim, max_partition_len, 5); assert_eq!(capacity.distances, max_partition_len); - assert_eq!(capacity.query_f32, dim + dim * 4 + dim * 16); + assert_eq!( + capacity.query_f32, + dim + dim * 4 + dim * 16 + ex_query_len(dim) + ); assert_eq!(capacity.u16, max_partition_len); assert_eq!(capacity.u8, dim * 16); assert_eq!(capacity.u32, 0); @@ -4403,18 +4411,24 @@ mod tests { } #[rstest] - #[case::l2(DistanceType::L2)] - #[case::cosine(DistanceType::Cosine)] + #[case::l2(DistanceType::L2, 9)] + #[case::cosine(DistanceType::Cosine, 9)] + // ex_bits=3 and ex_bits=5 have no FastScan support and use the bit-plane + // repack, so these searches go through the exact ex-dot rerank kernels + // end to end. + #[case::l2_plane_repack_3bit(DistanceType::L2, 4)] + #[case::l2_plane_repack_5bit(DistanceType::L2, 6)] #[tokio::test] async fn test_build_ivf_rq_multi_bit_persists_split_codes_and_searches( #[case] distance_type: DistanceType, + #[case] num_bits: u8, ) { let test_dir = TempStrDir::default(); let test_uri = test_dir.as_str(); let (mut dataset, vectors) = generate_test_dataset::(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::new(4); - let rq_params = RQBuildParams::with_rotation_type(9, RQRotationType::Fast); + let rq_params = RQBuildParams::with_rotation_type(num_bits, RQRotationType::Fast); let params = VectorIndexParams::with_ivf_rq_params(distance_type, ivf_params, rq_params); dataset .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) @@ -4427,7 +4441,7 @@ mod tests { let scheduler = ScanScheduler::new(obj_store, SchedulerConfig::default_for_testing()); let index_uuid = indices[0].uuid.to_string(); let rq_meta = get_rq_metadata(&dataset, scheduler.clone(), &index_uuid).await; - assert_eq!(rq_meta.num_bits, 9); + assert_eq!(rq_meta.num_bits, num_bits); assert_eq!(rq_meta.query_estimator, RabitQueryEstimator::RawQuery); let reader = open_rq_aux_reader(&dataset, scheduler, &index_uuid).await; @@ -4436,7 +4450,9 @@ mod tests { let DataType::FixedSizeList(_, ex_code_bytes) = ex_field.data_type() else { panic!("RQ ex-code field should be FixedSizeList"); }; - assert_eq!(ex_code_bytes, 32); + let expected_ex_code_bytes = + rabit_ex_code_bytes(rq_meta.rotated_dim(), num_bits - 1).unwrap() as i32; + assert_eq!(ex_code_bytes, expected_ex_code_bytes); assert!(schema.field(EX_ADD_FACTORS_COLUMN).is_some()); assert!(schema.field(EX_SCALE_FACTORS_COLUMN).is_some()); From 33c096804610c2181e2d7395ffd0d8eacfcb7f8a Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 11 Jun 2026 14:24:18 +0800 Subject: [PATCH 2/3] perf(vector)!: store RaBitQ ex codes in a blocked SIMD layout Replace the sequential LSB-first ex-code layout with a blocked layout (64-dim blocks, bit-interleaved so the SIMD unpack emits natural dim order) written under a new column (__blocked_ex_codes): - The ex-dot kernels read the rotated query directly (no per-query permutation) and consume rows as stored (no load-time repack, no second resident copy). - Legacy indexes remain readable: sequential rows are repacked at load and the batch is normalized, so rewrites (remap, optimize merges) emit the blocked format. Older lance versions fail loudly on new multi-bit indexes (missing-column error) instead of misreading them. - num_bits=1 indexes carry no ex codes and are unaffected. Co-Authored-By: Claude Fable 5 --- rust/lance-index/benches/rq.rs | 177 +++- rust/lance-index/src/vector/bq/builder.rs | 34 +- rust/lance-index/src/vector/bq/ex_dot.rs | 323 +++---- rust/lance-index/src/vector/bq/storage.rs | 808 +++++++++++------- rust/lance-index/src/vector/bq/transform.rs | 21 +- .../src/vector/distributed/index_merger.rs | 28 +- rust/lance-index/src/vector/storage.rs | 5 +- rust/lance/src/index/vector/ivf/v2.rs | 50 +- 8 files changed, 904 insertions(+), 542 deletions(-) diff --git a/rust/lance-index/benches/rq.rs b/rust/lance-index/benches/rq.rs index d51db3c98f0..088927a54da 100644 --- a/rust/lance-index/benches/rq.rs +++ b/rust/lance-index/benches/rq.rs @@ -18,8 +18,7 @@ use lance_datagen::{BatchGeneratorBuilder, RowCount}; use lance_index::vector::bq::RQRotationType; use lance_index::vector::bq::builder::RabitQuantizer; use lance_index::vector::bq::ex_dot::{ - build_ex_query, ex_dot_code_bytes, ex_dot_kernel, needs_plane_repack, packed_ex_code_value, - plane_pack_row, + blocked_ex_code_bytes, ex_dot_kernel, pack_blocked_row, packed_ex_code_value, }; use lance_index::vector::bq::storage::*; use lance_index::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN}; @@ -154,11 +153,19 @@ fn ex_dot_kernels_for_dim(c: &mut Criterion, ex_dim: usize) { for ex_bits in 1..=8u8 { let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..NUM_ROWS * ex_dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + + // The gather baseline reads the legacy sequential layout it shipped + // with; the kernel reads the blocked layout. let seq_code_len = (ex_dim * ex_bits as usize).div_ceil(8); let mut seq_codes = vec![0u8; NUM_ROWS * seq_code_len]; - for row in seq_codes.chunks_exact_mut(seq_code_len) { - for dim in 0..ex_dim { - let value = rng.random_range(0..=max_code); + for (row, row_values) in seq_codes + .chunks_exact_mut(seq_code_len) + .zip(values.chunks_exact(ex_dim)) + { + for (dim, &value) in row_values.iter().enumerate() { let bit_offset = dim * ex_bits as usize; let bits = (value as u16) << (bit_offset % 8); row[bit_offset / 8] |= bits as u8; @@ -168,21 +175,17 @@ fn ex_dot_kernels_for_dim(c: &mut Criterion, ex_dim: usize) { } } - let kernel_code_len = ex_dot_code_bytes(ex_dim, ex_bits); - let kernel_codes = if needs_plane_repack(ex_bits) { - let mut out = vec![0u8; NUM_ROWS * kernel_code_len]; - for (seq_row, plane_row) in seq_codes - .chunks_exact(seq_code_len) - .zip(out.chunks_exact_mut(kernel_code_len)) - { - plane_pack_row(seq_row, ex_dim, ex_bits, plane_row); - } - out - } else { - seq_codes.clone() - }; + let kernel_code_len = blocked_ex_code_bytes(ex_dim, ex_bits); + let mut kernel_codes = vec![0u8; NUM_ROWS * kernel_code_len]; + for (row, row_values) in kernel_codes + .chunks_exact_mut(kernel_code_len) + .zip(values.chunks_exact(ex_dim)) + { + pack_blocked_row(row_values, ex_bits, row); + } - let ex_query = build_ex_query(&query, ex_bits); + // ex_dim is block-aligned here, so the kernels read the query as-is. + let ex_query = &query; let kernel = ex_dot_kernel(ex_bits); c.bench_function( format!("RQ ex_dot kernel: ex_bits={ex_bits}, DIM={ex_dim}, rows={NUM_ROWS}").as_str(), @@ -190,7 +193,7 @@ fn ex_dot_kernels_for_dim(c: &mut Criterion, ex_dim: usize) { b.iter(|| { let mut sum = 0.0f32; for row in kernel_codes.chunks_exact(kernel_code_len) { - sum += kernel(&ex_query, row); + sum += kernel(ex_query, row); } black_box(sum) }) @@ -220,9 +223,141 @@ fn ex_dot_kernels_for_dim(c: &mut Criterion, ex_dim: usize) { } } +/// Storage load cost per format: blocked-format ex codes are aliased as-is, +/// legacy sequential ex codes are repacked row by row. +fn ex_code_storage_load(c: &mut Criterion) { + use arrow_array::{ArrayRef, FixedSizeListArray, Float32Array, UInt8Array, UInt64Array}; + use lance_arrow::FixedSizeListArrayExt; + use lance_index::vector::bq::ex_dot::repack_sequential_row; + use lance_index::vector::bq::rabit_ex_code_bytes; + use lance_index::vector::bq::transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}; + use std::sync::Arc; + + const LOAD_DIM: usize = 1536; + const LOAD_ROWS: usize = 8192; + const NUM_BITS: u8 = 4; // ex_bits=3, a bit-plane width + + let ex_bits = NUM_BITS - 1; + let mut rng = SmallRng::seed_from_u64(7); + let metadata = RabitQuantizationMetadata { + rotate_mat: None, + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Fast, + code_dim: LOAD_DIM as u32, + num_bits: NUM_BITS, + packed: true, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let code_len = LOAD_DIM / 8; + let binary_codes = (0..LOAD_ROWS * code_len) + .map(|_| rng.random_range(0..=u8::MAX)) + .collect::>(); + let seq_code_len = rabit_ex_code_bytes(LOAD_DIM, ex_bits).unwrap(); + let seq_codes = (0..LOAD_ROWS * seq_code_len) + .map(|_| rng.random_range(0..=u8::MAX)) + .collect::>(); + let blocked_code_len = blocked_ex_code_bytes(LOAD_DIM, ex_bits); + let mut blocked_codes = vec![0u8; LOAD_ROWS * blocked_code_len]; + for (seq_row, blocked_row) in seq_codes + .chunks_exact(seq_code_len) + .zip(blocked_codes.chunks_exact_mut(blocked_code_len)) + { + repack_sequential_row(seq_row, LOAD_DIM, ex_bits, blocked_row); + } + + let make_batch = |ex_column: &str, ex_values: Vec, ex_code_len: usize| { + arrow_array::RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..LOAD_ROWS as u64)) as ArrayRef, + ), + ( + RABIT_CODE_COLUMN, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(binary_codes.clone()), + code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ( + ex_column, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(ex_values), + ex_code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; LOAD_ROWS])) as ArrayRef, + ), + ]) + .unwrap() + }; + + let blocked_batch = make_batch( + RABIT_BLOCKED_EX_CODE_COLUMN, + blocked_codes, + blocked_code_len, + ); + c.bench_function( + format!("RQ storage load (blocked ex codes): num_bits={NUM_BITS}, DIM={LOAD_DIM}, rows={LOAD_ROWS}") + .as_str(), + |b| { + b.iter(|| { + black_box( + RabitQuantizationStorage::try_from_batch( + blocked_batch.clone(), + &metadata, + DistanceType::L2, + None, + ) + .unwrap(), + ) + }) + }, + ); + + let legacy_batch = make_batch(RABIT_EX_CODE_COLUMN, seq_codes, seq_code_len); + c.bench_function( + format!("RQ storage load (legacy ex codes): num_bits={NUM_BITS}, DIM={LOAD_DIM}, rows={LOAD_ROWS}") + .as_str(), + |b| { + b.iter(|| { + black_box( + RabitQuantizationStorage::try_from_batch( + legacy_batch.clone(), + &metadata, + DistanceType::L2, + None, + ) + .unwrap(), + ) + }) + }, + ); +} + criterion_group!( name=benches; config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = construct_dist_table, compute_distances, ex_dot_kernels); + targets = construct_dist_table, compute_distances, ex_dot_kernels, ex_code_storage_load); criterion_main!(benches); diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 178a6bb5435..9eb7fc76903 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -25,7 +25,7 @@ use crate::vector::bq::transform::{ SCALE_FACTORS_FIELD, }; use crate::vector::bq::{ - RQBuildParams, RQRotationType, rabit_binary_code_bytes, rabit_ex_bits, rabit_ex_code_bytes, + RQBuildParams, RQRotationType, rabit_binary_code_bytes, rabit_ex_bits, rotation::{apply_fast_rotation, fast_rotation_signs_len, random_fast_rotation_signs}, validate_rq_num_bits, }; @@ -78,21 +78,6 @@ fn pack_sign_bits(codes: &mut [u8], rotated: &[f32]) { } } -#[inline] -fn pack_ex_code_bits(codes: &mut [u8], ex_values: &[u8], ex_bits: u8) { - codes.fill(0); - let ex_bits = ex_bits as usize; - for (dim_idx, &value) in ex_values.iter().enumerate() { - let bit_offset = dim_idx * ex_bits; - for bit_idx in 0..ex_bits { - if (value >> bit_idx) & 1 != 0 { - let dst_bit = bit_offset + bit_idx; - codes[dst_bit / u8::BITS as usize] |= 1u8 << (dst_bit % u8::BITS as usize); - } - } - } -} - const EX_QUANTIZATION_EPSILON: f32 = 1.0e-5; const EX_TIGHT_START: [f32; 9] = [0.0, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81]; @@ -200,7 +185,7 @@ fn quantize_ex_code( *ex_code_value = ex_code; } - pack_ex_code_bits(ex_code_dst, ex_code_values_dst, ex_bits); + crate::vector::bq::ex_dot::pack_blocked_row(ex_code_values_dst, ex_bits, ex_code_dst); residual_dot_code } @@ -599,7 +584,11 @@ impl RabitQuantizer { .as_slice(); let code_dim = self.code_dim(); let code_bytes = rabit_binary_code_bytes(code_dim); - let ex_code_bytes = rabit_ex_code_bytes(code_dim, ex_bits)?; + let ex_code_bytes = if ex_bits == 0 { + 0 + } else { + crate::vector::bq::ex_dot::blocked_ex_code_bytes(code_dim, ex_bits) + }; let mut encoded_codes = vec![0u8; n * code_bytes]; let mut encoded_ex_codes = (ex_bits != 0).then(|| vec![0u8; n * ex_code_bytes]); @@ -901,7 +890,7 @@ mod tests { use lance_linalg::distance::DistanceType; use rstest::rstest; - use crate::vector::bq::storage::RABIT_EX_CODE_COLUMN; + use crate::vector::bq::storage::RABIT_BLOCKED_EX_CODE_COLUMN; #[rstest] #[case(8)] @@ -978,14 +967,14 @@ mod tests { assert!( !fields .iter() - .any(|field| field.name() == RABIT_EX_CODE_COLUMN) + .any(|field| field.name() == RABIT_BLOCKED_EX_CODE_COLUMN) ); let q = RabitQuantizer::new_with_rotation::(3, 128, RQRotationType::Fast); let fields = q.extra_fields(); for expected in [ ERROR_FACTORS_FIELD.name().as_str(), - RABIT_EX_CODE_COLUMN, + RABIT_BLOCKED_EX_CODE_COLUMN, EX_ADD_FACTORS_FIELD.name().as_str(), EX_SCALE_FACTORS_FIELD.name().as_str(), ] { @@ -1095,7 +1084,8 @@ mod tests { .unwrap() .as_fixed_size_list() .value_length(), - 32 + // dim=32 is padded to one 64-dim block at ex_bits=8. + 64 ); } diff --git a/rust/lance-index/src/vector/bq/ex_dot.rs b/rust/lance-index/src/vector/bq/ex_dot.rs index afe54e6ee37..1aeb83ba40c 100644 --- a/rust/lance-index/src/vector/bq/ex_dot.rs +++ b/rust/lance-index/src/vector/bq/ex_dot.rs @@ -11,60 +11,62 @@ //! query directly, following the kernel design of the RaBitQ reference library //! (, Apache-2.0). //! -//! Two code layouts are consumed: +//! Codes are stored in the *blocked* layout: dims are grouped into 64-dim +//! blocks (the last block zero-padded) and bit-interleaved within each block +//! so that the SIMD unpack emits codes in natural dim order: //! -//! - `ex_bits` ∈ {1, 2, 4, 8}: the sequential LSB-first layout written by the -//! index builder is already byte-aligned, so rows are used as stored. -//! - `ex_bits` ∈ {3, 5, 6, 7}: codes straddle byte boundaries in the sequential -//! layout, so rows are repacked once at load time ([`plane_pack_row`]) into -//! bit-planes that unpack with byte-wise shifts: +//! ```text +//! per 64-dim block (T = ex_bits - 1, the top bit; "run k" = dims 16k..16k+16): +//! 1 bit: [8B] bit i of the LE word = dim i +//! 2 bits: [16B] byte b = dims {b, b+16, b+32, b+48} at bit pairs 0/2/4/6 +//! 3 bits: [16B 2-bit plane as above][8B top-bit plane] +//! 4 bits: [32B] byte 8j+b = dim 16j+b (low nibble) | dim 16j+8+b (high nibble) +//! 5 bits: [32B 4-bit plane: byte b = dims b|b+16; byte 16+b = dims b+32|b+48] +//! [8B top-bit plane] +//! 6 bits: [48B] byte 16k+b = dim 16k+b (6 low bits) | bits 2k..2k+2 of +//! dim 48+b (2 high bits) +//! 7 bits: [48B as 6 bits][8B top-bit plane] +//! 8 bits: [64B] identity +//! top-bit plane: top bit of dim 16k+b at bit 8*(b%8) + 2k + b/8 of a LE u64 +//! ``` //! -//! ```text -//! per 64-dim group (T = ex_bits - 1, the top bit): -//! 3 bits: [16B 2-bit plane][8B top-bit plane] -//! 5 bits: [32B 4-bit plane][8B top-bit plane] -//! 6 bits: [48B: 3 blocks of "6 low bits | 2 stolen bits"] -//! 7 bits: [48B: same as 6 bits][8B top-bit plane] -//! ``` -//! -//! Kernels unpack each group into runs of 16 code bytes whose dimension order -//! differs from the natural order, so the query is permuted once per search -//! with [`build_ex_query_into`] and the kernels then read both sides -//! sequentially. The permuted query is zero-padded to a multiple of 64 so that -//! padded lanes contribute nothing. +//! Because unpack order is natural, the kernels read the rotated query +//! directly; it only needs zero-padding ([`pad_query_into`]) when the rotated +//! dim is not a multiple of 64. Legacy indexes store ex codes sequentially +//! (LSB-first bit stream) and are repacked once at load time +//! ([`repack_sequential_row`]); for `ex_bits` ∈ {1, 8} the two layouts agree +//! (modulo trailing padding, which the kernels tolerate) and rows are used as +//! stored. use std::sync::LazyLock; -/// Dimensions are processed in groups; the permuted query is padded to this -/// multiple so every kernel sees whole groups. -pub const EX_DOT_GROUP_DIMS: usize = 64; +/// Dims are packed in blocks of this size; the query is zero-padded to a +/// whole number of blocks when the rotated dim is not already a multiple. +pub const EX_DOT_BLOCK_DIMS: usize = 64; -/// `f32` length of the kernel-order query built by [`build_ex_query_into`]. -pub fn ex_query_len(dim: usize) -> usize { - dim.next_multiple_of(EX_DOT_GROUP_DIMS) +/// `f32` length of the query consumed by the kernels. +pub fn padded_query_len(dim: usize) -> usize { + dim.next_multiple_of(EX_DOT_BLOCK_DIMS) } -/// Whether the sequential ex-code layout must be repacked into bit-planes for -/// the dot kernels. For the remaining widths codes are byte-aligned already. -pub fn needs_plane_repack(ex_bits: u8) -> bool { - matches!(ex_bits, 3 | 5 | 6 | 7) +/// Whether the legacy sequential layout of a row already matches the blocked +/// layout (modulo trailing zero padding, which the kernels tolerate), so +/// legacy rows can be consumed without repacking. +pub fn sequential_matches_blocked(ex_bits: u8) -> bool { + matches!(ex_bits, 1 | 8) } -/// Bytes per row of the code layout consumed by the dot kernels. -pub fn ex_dot_code_bytes(dim: usize, ex_bits: u8) -> usize { +/// Bytes per row of the blocked ex-code layout. +pub fn blocked_ex_code_bytes(dim: usize, ex_bits: u8) -> usize { debug_assert!((1..=8).contains(&ex_bits)); - if needs_plane_repack(ex_bits) { - ex_query_len(dim) * ex_bits as usize / 8 - } else { - (dim * ex_bits as usize).div_ceil(u8::BITS as usize) - } + padded_query_len(dim) * ex_bits as usize / 8 } /// Dimensions per unpacking group for the given code width. fn group_dims(ex_bits: u8) -> usize { match ex_bits { 1 | 4 | 8 => 16, - _ => EX_DOT_GROUP_DIMS, + _ => EX_DOT_BLOCK_DIMS, } } @@ -89,52 +91,24 @@ pub fn packed_ex_code_value(row_codes: &[u8], dim_idx: usize, ex_bits: u8) -> u8 ((bits >> bit_shift) & mask) as u8 } -/// Kernel-order position of `dim` within its group (see [`build_ex_query_into`]). -fn kernel_position(dim: usize, ex_bits: u8) -> usize { - match ex_bits { - 1 | 8 => dim, - // 16-dim groups unpack the low nibbles (even dims) before the high - // nibbles (odd dims). - 4 => { - let group = dim / 16; - let r = dim % 16; - group * 16 + r / 2 + (r % 2) * 8 - } - // 64-dim groups unpack four 16-byte runs holding dims k, k+4, k+8, ... - 2 | 3 | 5 | 6 | 7 => { - let group = dim / 64; - let r = dim % 64; - group * 64 + (r % 4) * 16 + r / 4 - } - _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), - } -} - -/// Permute the rotated query into the order the dot kernels unpack codes in, -/// zero-padding to a multiple of [`EX_DOT_GROUP_DIMS`]. -pub fn build_ex_query_into(rotated_query: &[f32], ex_bits: u8, out: &mut [f32]) { - debug_assert_eq!(out.len(), ex_query_len(rotated_query.len())); - out.fill(0.0); - for (dim, &value) in rotated_query.iter().enumerate() { - out[kernel_position(dim, ex_bits)] = value; - } -} - -pub fn build_ex_query(rotated_query: &[f32], ex_bits: u8) -> Vec { - let mut out = vec![0.0; ex_query_len(rotated_query.len())]; - build_ex_query_into(rotated_query, ex_bits, &mut out); - out +/// Zero-pad the rotated query to a whole number of 64-dim blocks. Only needed +/// when `dim` is not a multiple of [`EX_DOT_BLOCK_DIMS`]; aligned queries are +/// passed to the kernels as-is. +pub fn pad_query_into(rotated_query: &[f32], out: &mut [f32]) { + debug_assert_eq!(out.len(), padded_query_len(rotated_query.len())); + out[..rotated_query.len()].copy_from_slice(rotated_query); + out[rotated_query.len()..].fill(0.0); } /// Pack the top bit of each of 64 codes into a `u64` so kernels can position -/// it with two shifts per 16-code run: the top bit of dim `4j + k` is stored -/// at bit `8 * (j % 8) + 2k + j / 8`. -fn pack_top_plane(group_values: &[u8; 64], top_bit: u8) -> u64 { +/// it with two shifts per 16-code run: the top bit of dim `16k + b` is stored +/// at bit `8 * (b % 8) + 2k + b / 8`. +fn pack_top_plane(block_values: &[u8; 64], top_bit: u8) -> u64 { let mut plane = 0u64; - for j in 0..16 { - for k in 0..4 { - let bit = (group_values[4 * j + k] >> top_bit) & 1; - plane |= (bit as u64) << (8 * (j % 8) + 2 * k + j / 8); + for k in 0..4 { + for b in 0..16 { + let bit = (block_values[16 * k + b] >> top_bit) & 1; + plane |= (bit as u64) << (8 * (b % 8) + 2 * k + b / 8); } } plane @@ -150,66 +124,94 @@ fn shift_plane(plane: u64, from_bit: usize, to_bit: usize) -> u64 { } } -/// Pack one group of 64 code values (natural dim order) into the bit-plane +/// Pack one block of 64 code values (natural dim order) into the blocked /// layout described in the module docs. -fn plane_pack_group(ex_bits: u8, group_values: &[u8; 64], out: &mut [u8]) { - let v = group_values; +fn pack_block(ex_bits: u8, block_values: &[u8; 64], out: &mut [u8]) { + let v = block_values; match ex_bits { - 3 => { + 1 => { + for (b, byte) in out[..8].iter_mut().enumerate() { + *byte = (0..8).fold(0, |acc, t| acc | ((v[8 * b + t] & 1) << t)); + } + } + 2 | 3 => { for b in 0..16 { - out[b] = (v[4 * b] & 0b11) - | ((v[4 * b + 1] & 0b11) << 2) - | ((v[4 * b + 2] & 0b11) << 4) - | ((v[4 * b + 3] & 0b11) << 6); + out[b] = (v[b] & 0b11) + | ((v[16 + b] & 0b11) << 2) + | ((v[32 + b] & 0b11) << 4) + | ((v[48 + b] & 0b11) << 6); + } + if ex_bits == 3 { + out[16..24].copy_from_slice(&pack_top_plane(v, 2).to_le_bytes()); + } + } + 4 => { + for unit in 0..4 { + for b in 0..8 { + out[8 * unit + b] = + (v[16 * unit + b] & 0x0f) | ((v[16 * unit + 8 + b] & 0x0f) << 4); + } } - out[16..24].copy_from_slice(&pack_top_plane(v, 2).to_le_bytes()); } 5 => { for b in 0..16 { - out[b] = (v[4 * b] & 0x0f) | ((v[4 * b + 1] & 0x0f) << 4); - out[16 + b] = (v[4 * b + 2] & 0x0f) | ((v[4 * b + 3] & 0x0f) << 4); + out[b] = (v[b] & 0x0f) | ((v[16 + b] & 0x0f) << 4); + out[16 + b] = (v[32 + b] & 0x0f) | ((v[48 + b] & 0x0f) << 4); } out[32..40].copy_from_slice(&pack_top_plane(v, 4).to_le_bytes()); } 6 | 7 => { - // Dims k, k+4, ... (k < 3) keep their 6 low bits in block k; the - // fourth dim of each quad is split into three 2-bit pieces stored - // in the blocks' top bits. + // Runs 0..3 keep their 6 low bits in place; the fourth run's dims + // are split into three 2-bit pieces stored in the runs' top bits. for k in 0..3 { for b in 0..16 { out[16 * k + b] = - (v[4 * b + k] & 0x3f) | (((v[4 * b + 3] >> (2 * k)) & 0b11) << 6); + (v[16 * k + b] & 0x3f) | (((v[48 + b] >> (2 * k)) & 0b11) << 6); } } if ex_bits == 7 { out[48..56].copy_from_slice(&pack_top_plane(v, 6).to_le_bytes()); } } - _ => unreachable!("plane packing is only used for ex_bits 3, 5, 6, 7"), + 8 => out[..64].copy_from_slice(v), + _ => unreachable!("invalid RabitQ ex_bits={ex_bits}"), + } +} + +/// Pack one row of unpacked code values (one `u8` per dim) into the blocked +/// layout; the writer path. `out` must have [`blocked_ex_code_bytes`] bytes. +pub fn pack_blocked_row(values: &[u8], ex_bits: u8, out: &mut [u8]) { + debug_assert_eq!(out.len(), blocked_ex_code_bytes(values.len(), ex_bits)); + let block_bytes = EX_DOT_BLOCK_DIMS * ex_bits as usize / 8; + let mut block_values = [0u8; 64]; + for (block, out) in out.chunks_exact_mut(block_bytes).enumerate() { + let base = block * EX_DOT_BLOCK_DIMS; + let count = EX_DOT_BLOCK_DIMS.min(values.len() - base); + block_values[..count].copy_from_slice(&values[base..base + count]); + block_values[count..].fill(0); + pack_block(ex_bits, &block_values, out); } } -/// Repack one sequentially bit-packed row into the kernel bit-plane layout. -/// `out` must have [`ex_dot_code_bytes`] bytes. -pub fn plane_pack_row(seq_row: &[u8], dim: usize, ex_bits: u8, out: &mut [u8]) { - debug_assert!(needs_plane_repack(ex_bits)); - debug_assert_eq!(out.len(), ex_dot_code_bytes(dim, ex_bits)); - let bytes_per_group = group_bytes(ex_bits); - let mut group_values = [0u8; 64]; - for (group, out) in out.chunks_exact_mut(bytes_per_group).enumerate() { - group_values.fill(0); - let base = group * 64; - let count = 64.min(dim.saturating_sub(base)); - for (i, value) in group_values[..count].iter_mut().enumerate() { +/// Repack one legacy sequentially bit-packed row into the blocked layout. +/// `out` must have [`blocked_ex_code_bytes`] bytes. +pub fn repack_sequential_row(seq_row: &[u8], dim: usize, ex_bits: u8, out: &mut [u8]) { + debug_assert_eq!(out.len(), blocked_ex_code_bytes(dim, ex_bits)); + let block_bytes = EX_DOT_BLOCK_DIMS * ex_bits as usize / 8; + let mut block_values = [0u8; 64]; + for (block, out) in out.chunks_exact_mut(block_bytes).enumerate() { + block_values.fill(0); + let base = block * EX_DOT_BLOCK_DIMS; + let count = EX_DOT_BLOCK_DIMS.min(dim.saturating_sub(base)); + for (i, value) in block_values[..count].iter_mut().enumerate() { *value = packed_ex_code_value(seq_row, base + i, ex_bits); } - plane_pack_group(ex_bits, &group_values, out); + pack_block(ex_bits, &block_values, out); } } -/// Unpack one code group into per-dim values in kernel order (the order -/// [`build_ex_query_into`] permutes the query into). Reference implementation -/// for the SIMD unpackers; also the scalar fallback. +/// Unpack one code group into per-dim values (natural dim order). Reference +/// implementation for the SIMD unpackers; also the scalar fallback. fn unpack_group(ex_bits: u8, group_codes: &[u8], out: &mut [u8; 64]) { debug_assert_eq!(group_codes.len(), group_bytes(ex_bits)); match ex_bits { @@ -276,12 +278,12 @@ fn unpack_group(ex_bits: u8, group_codes: &[u8], out: &mut [u8; 64]) { } } -/// `sum_d ex_query[d] * code[d]` for one row of kernel-layout codes. +/// `sum_d query[d] * code[d]` for one row of blocked-layout codes. /// -/// `ex_query` must be the kernel-order query from [`build_ex_query_into`]; -/// `codes` is the row slice (sequential layout for `ex_bits` ∈ {1, 2, 4, 8}, -/// bit-plane layout otherwise). Rows shorter than the padded query length are -/// treated as zero-padded. +/// The query must cover a whole number of 64-dim blocks (the rotated query +/// as-is for aligned dims, otherwise zero-padded via [`pad_query_into`]); +/// `codes` is the blocked row slice. Rows shorter than the padded query +/// length are treated as zero-padded. pub type ExDotFn = fn(&[f32], &[u8]) -> f32; /// Resolve the dot kernel for `ex_bits` once; the result can be cached by the @@ -330,7 +332,7 @@ fn scalar_kernel(ex_bits: u8) -> ExDotFn { fn ex_dot_scalar(ex_query: &[f32], codes: &[u8]) -> f32 { let group_dims = group_dims(EX_BITS); let bytes_per_group = group_bytes(EX_BITS); - debug_assert_eq!(ex_query.len() % EX_DOT_GROUP_DIMS, 0); + debug_assert_eq!(ex_query.len() % EX_DOT_BLOCK_DIMS, 0); debug_assert!(codes.len() * u8::BITS as usize <= ex_query.len() * EX_BITS as usize); let mut sum = 0.0f32; @@ -559,7 +561,7 @@ mod x86 { unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; - debug_assert_eq!(ex_query.len() % super::EX_DOT_GROUP_DIMS, 0); + debug_assert_eq!(ex_query.len() % super::EX_DOT_BLOCK_DIMS, 0); debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); let groups = ex_query.len() / GROUP_DIMS; @@ -613,7 +615,7 @@ mod x86 { unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; - debug_assert_eq!(ex_query.len() % super::EX_DOT_GROUP_DIMS, 0); + debug_assert_eq!(ex_query.len() % super::EX_DOT_BLOCK_DIMS, 0); debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); let groups = ex_query.len() / GROUP_DIMS; @@ -835,7 +837,7 @@ mod neon { unsafe fn $name(ex_query: &[f32], codes: &[u8]) -> f32 { const GROUP_DIMS: usize = if $runs == 1 { 16 } else { 64 }; const GROUP_BYTES: usize = GROUP_DIMS * $ex_bits / 8; - debug_assert_eq!(ex_query.len() % super::EX_DOT_GROUP_DIMS, 0); + debug_assert_eq!(ex_query.len() % super::EX_DOT_BLOCK_DIMS, 0); debug_assert!(codes.len() * 8 <= ex_query.len() * $ex_bits); let groups = ex_query.len() / GROUP_DIMS; @@ -916,14 +918,10 @@ mod tests { } fn kernel_codes(values: &[u8], dim: usize, ex_bits: u8) -> Vec { - let seq = pack_sequential(values, ex_bits); - if needs_plane_repack(ex_bits) { - let mut out = vec![0u8; ex_dot_code_bytes(dim, ex_bits)]; - plane_pack_row(&seq, dim, ex_bits, &mut out); - out - } else { - seq - } + debug_assert_eq!(values.len(), dim); + let mut out = vec![0u8; blocked_ex_code_bytes(dim, ex_bits)]; + pack_blocked_row(values, ex_bits, &mut out); + out } fn available_kernels(ex_bits: u8) -> Vec<(&'static str, ExDotFn)> { @@ -968,8 +966,8 @@ mod tests { .sum::(); let codes = kernel_codes(&values, dim, ex_bits); - let ex_query = build_ex_query(&query, ex_bits); - assert_eq!(ex_query.len() % EX_DOT_GROUP_DIMS, 0); + let mut ex_query = vec![0.0; padded_query_len(dim)]; + pad_query_into(&query, &mut ex_query); let tolerance = 1e-3 * expected.abs().max(1.0); for (name, kernel) in available_kernels(ex_bits) { @@ -984,24 +982,57 @@ mod tests { #[rstest] fn test_unpack_group_roundtrip(#[values(1, 2, 3, 4, 5, 6, 7, 8)] ex_bits: u8) { let mut rng = SmallRng::seed_from_u64(7 + ex_bits as u64); - let dims = group_dims(ex_bits); let max_code = ((1u16 << ex_bits) - 1) as u8; - let values = (0..dims) + let values = (0..EX_DOT_BLOCK_DIMS) .map(|_| rng.random_range(0..=max_code)) .collect::>(); - let codes = kernel_codes(&values, dims, ex_bits); + let codes = kernel_codes(&values, EX_DOT_BLOCK_DIMS, ex_bits); + // Unpacking each kernel group must reproduce the values in natural + // dim order. + let dims = group_dims(ex_bits); + let bytes = group_bytes(ex_bits); let mut unpacked = [0u8; 64]; - unpack_group(ex_bits, &codes, &mut unpacked); - for dim in 0..dims { + for group in 0..EX_DOT_BLOCK_DIMS / dims { + unpack_group( + ex_bits, + &codes[group * bytes..(group + 1) * bytes], + &mut unpacked, + ); assert_eq!( - unpacked[kernel_position(dim, ex_bits)], - values[dim], - "ex_bits={ex_bits} dim={dim}" + &unpacked[..dims], + &values[group * dims..(group + 1) * dims], + "ex_bits={ex_bits} group={group}" ); } } + /// The legacy sequential rows must repack into exactly what the writer + /// produces from the unpacked values. + #[rstest] + fn test_repack_sequential_matches_blocked( + #[values(1, 2, 3, 4, 5, 6, 7, 8)] ex_bits: u8, + #[values(7, 64, 100, 1536)] dim: usize, + ) { + let mut rng = SmallRng::seed_from_u64(11 + ex_bits as u64 * 100 + dim as u64); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let values = (0..dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + let seq = pack_sequential(&values, ex_bits); + + let mut repacked = vec![0u8; blocked_ex_code_bytes(dim, ex_bits)]; + repack_sequential_row(&seq, dim, ex_bits, &mut repacked); + assert_eq!(repacked, kernel_codes(&values, dim, ex_bits)); + + // For the widths where the sequential layout is already blocked + // (modulo trailing padding), the raw row must be a prefix. + if sequential_matches_blocked(ex_bits) { + assert_eq!(&repacked[..seq.len()], &seq); + assert!(repacked[seq.len()..].iter().all(|&byte| byte == 0)); + } + } + /// Dense dim sweep for the bit-plane widths: every tail shape within the /// 64-dim kernel group, plus multi-group sizes. #[rstest] @@ -1022,7 +1053,8 @@ mod tests { .sum::(); let codes = kernel_codes(&values, dim, ex_bits); - let ex_query = build_ex_query(&query, ex_bits); + let mut ex_query = vec![0.0; padded_query_len(dim)]; + pad_query_into(&query, &mut ex_query); let tolerance = 1e-3 * expected.abs().max(1.0); for (name, kernel) in available_kernels(ex_bits) { let actual = kernel(&ex_query, &codes) as f64; @@ -1035,13 +1067,12 @@ mod tests { } #[test] - fn test_build_ex_query_pads_with_zeros() { + fn test_pad_query_pads_with_zeros() { let query = vec![1.0f32; 100]; - for ex_bits in 1..=8u8 { - let ex_query = build_ex_query(&query, ex_bits); - assert_eq!(ex_query.len(), 128); - let sum = ex_query.iter().sum::(); - assert_eq!(sum, 100.0, "ex_bits={ex_bits}"); - } + let mut padded = vec![f32::NAN; padded_query_len(query.len())]; + pad_query_into(&query, &mut padded); + assert_eq!(padded.len(), 128); + assert_eq!(&padded[..100], &query[..]); + assert!(padded[100..].iter().all(|&value| value == 0.0)); } } diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index b5b38971fed..6e47fe42e25 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -42,8 +42,8 @@ use crate::frag_reuse::FragReuseIndex; use crate::pb; use crate::vector::ApproxMode; use crate::vector::bq::ex_dot::{ - ExDotFn, build_ex_query, build_ex_query_into, ex_dot_code_bytes, ex_dot_kernel, ex_query_len, - needs_plane_repack, plane_pack_row, + EX_DOT_BLOCK_DIMS, ExDotFn, blocked_ex_code_bytes, ex_dot_kernel, pad_query_into, + padded_query_len, repack_sequential_row, sequential_matches_blocked, }; use crate::vector::bq::rotation::{apply_fast_rotation, apply_fast_rotation_in_place}; use crate::vector::bq::transform::{ @@ -63,7 +63,14 @@ use crate::vector::storage::{ pub const RABIT_METADATA_KEY: &str = "lance:rabit"; pub const RABIT_CODE_COLUMN: &str = "_rabit_codes"; +/// Legacy ex-code column: sequential LSB-first bit stream per row. Read-only; +/// rows are repacked into the blocked layout at load time. pub const RABIT_EX_CODE_COLUMN: &str = "__ex_codes"; +/// Ex-code column in the blocked layout consumed by the ex-dot kernels (see +/// `ex_dot` module docs). Indexes written with this column cannot be read by +/// older versions, which fail with a missing-column error instead of +/// misinterpreting the bytes. +pub const RABIT_BLOCKED_EX_CODE_COLUMN: &str = "__blocked_ex_codes"; pub const SEGMENT_LENGTH: usize = 4; pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH; const RABIT_PRUNE_STATS_ENV: &str = "LANCE_RQ_PRUNE_STATS"; @@ -214,10 +221,10 @@ pub fn rabit_ex_code_field(rotated_dim: usize, num_bits: u8) -> Result 0 { - ex_query.resize(ex_query_len(code_dim), 0.0); - build_ex_query_into(&rotated_query, ex_bits, &mut ex_query); + if ex_bits > 0 && !code_dim.is_multiple_of(EX_DOT_BLOCK_DIMS) { + ex_query.resize(padded_query_len(code_dim), 0.0); + pad_query_into(&rotated_query, &mut ex_query); } let sum_q = rotated_query.iter().copied().sum(); @@ -476,14 +485,12 @@ pub struct RabitQuantizationStorage { add_factors: Float32Array, scale_factors: Float32Array, error_factors: Option, + // ex codes in the blocked kernel layout; always aliases the batch column + // (legacy sequential batches are normalized at load, replacing the + // sequential column with the repacked one, so rewrites emit the blocked + // format). ex_codes: Option, packed_ex_codes: Option, - // ex codes repacked into the bit-plane layout consumed by the ex-dot - // kernels; only present for the widths whose sequential layout is not - // byte-aligned (see `ex_dot::needs_plane_repack`). This keeps a second - // resident copy of the ex codes, mirroring `packed_ex_codes` for the - // FastScan widths. - plane_ex_codes: Option, ex_add_factors: Option, ex_scale_factors: Option, } @@ -497,11 +504,6 @@ impl DeepSizeOf for RabitQuantizationStorage { .as_ref() .map(|codes| (codes as &dyn Array).deep_size_of_children(context)) .unwrap_or_default() - + self - .plane_ex_codes - .as_ref() - .map(|codes| (codes as &dyn Array).deep_size_of_children(context)) - .unwrap_or_default() } } @@ -592,12 +594,14 @@ impl RabitQuantizationStorage { query_error, approx_mode, } = parts; - // The ex-dot kernels consume the bit-plane repack where one exists and - // the sequential (byte-aligned) layout otherwise. + let ex_code_len = self + .ex_codes + .as_ref() + .map(|codes| codes.value_length() as usize) + .unwrap_or_default(); let ex_codes = self - .plane_ex_codes + .ex_codes .as_ref() - .or(self.ex_codes.as_ref()) .map(|codes| codes.values().as_primitive::().values().as_ref()); let packed_ex_codes = self .packed_ex_codes @@ -613,6 +617,7 @@ impl RabitQuantizationStorage { sum_q, self.codes.values().as_primitive::().values(), ex_codes, + ex_code_len, self.add_factors.values(), self.scale_factors.values(), self.error_factors @@ -805,15 +810,27 @@ struct RabitDistCalculatorParts<'a> { approx_mode: ApproxMode, } +/// Pick the query slice the ex-dot kernels consume: the rotated query itself +/// when the dim is block-aligned, otherwise a zero-padded copy. +fn kernel_query<'a>(rotated_query: &'a [f32], padded: &'a [f32]) -> &'a [f32] { + if rotated_query.len().is_multiple_of(EX_DOT_BLOCK_DIMS) { + rotated_query + } else { + padded + } +} + pub struct RabitDistCalculator<'a> { dim: usize, num_bits: u8, query_estimator: RabitQueryEstimator, // n * d / 8 binary-code bytes codes: &'a [u8], - // per-row ex codes in the layout consumed by the ex-dot kernels - // (`ex_dot::ex_dot_code_bytes` bytes per row) + // per-row ex codes in the blocked kernel layout ex_codes: Option<&'a [u8]>, + // bytes per ex-code row; legacy rows for layout-compatible widths may be + // shorter than the blocked size, which the kernels treat as zero padding + ex_code_len: usize, // this is a flattened 2D array of size d/4 * 16, // we split the query codes into d/4 chunks, each chunk is with 4 elements, // then dist_table[i][j] is the distance between the i-th query code and the code j @@ -821,7 +838,7 @@ pub struct RabitDistCalculator<'a> { // only built for the ex widths supported by FastScan; the exact rerank // path uses `ex_query` + `ex_dot` instead ex_dist_table: Cow<'a, [f32]>, - // rotated query permuted into kernel order (see `ex_dot::build_ex_query_into`) + // the rotated query, zero-padded to a 64-dim multiple when needed ex_query: Cow<'a, [f32]>, ex_dot: Option, add_factors: &'a [f32], @@ -850,6 +867,7 @@ impl<'a> RabitDistCalculator<'a> { sum_q: f32, codes: &'a [u8], ex_codes: Option<&'a [u8]>, + ex_code_len: usize, add_factors: &'a [f32], scale_factors: &'a [f32], error_factors: Option<&'a [f32]>, @@ -867,6 +885,7 @@ impl<'a> RabitDistCalculator<'a> { query_estimator, codes, ex_codes, + ex_code_len, dist_table, ex_dist_table, ex_query, @@ -887,13 +906,13 @@ impl<'a> RabitDistCalculator<'a> { /// `sum_d query[d] * ex_code[d]` for the candidate's packed ex codes. #[inline] - fn ex_code_dot(&self, ex_codes: &[u8], id: usize, ex_code_len: usize) -> f32 { + fn ex_code_dot(&self, ex_codes: &[u8], id: usize) -> f32 { let ex_dot = self .ex_dot .expect("raw-query multi-bit RQ requires an ex-dot kernel"); ex_dot( self.ex_query.as_ref(), - &ex_codes[id * ex_code_len..(id + 1) * ex_code_len], + &ex_codes[id * self.ex_code_len..(id + 1) * self.ex_code_len], ) } @@ -1084,7 +1103,6 @@ impl<'a> RabitDistCalculator<'a> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = ex_dot_code_bytes(self.dim, ex_bits); let code_scale = (1u32 << ex_bits) as f32; let code_bias = -(code_scale - 0.5); @@ -1092,11 +1110,11 @@ impl<'a> RabitDistCalculator<'a> { self.packed_ex_codes .map(|packed_ex_codes| { let fastscan_len = simd_len; - let fastscan_code_len = ex_fastscan_code_len(self.dim, ex_bits) - .expect("RabitQ num_bits should be validated"); + let fastscan_code_len = self.ex_code_len; let (qmin, qmax, quantization_max) = quantize_ex_fastscan_dist_table_into( self.dim, ex_bits, + self.ex_code_len, &self.ex_dist_table, quantized_dists_table, ); @@ -1141,7 +1159,7 @@ impl<'a> RabitDistCalculator<'a> { .enumerate() .skip(fastscan_len) .for_each(|(id, dist)| { - let ex_dist = self.ex_code_dot(ex_codes, id, ex_code_len); + let ex_dist = self.ex_code_dot(ex_codes, id); let full_dot = code_scale * *dist + ex_dist + code_bias * self.sum_q; *dist = full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor; }); @@ -1167,12 +1185,11 @@ impl<'a> RabitDistCalculator<'a> { id: usize, binary_ip: f32, ex_bits: u8, - ex_code_len: usize, ex_codes: &[u8], ex_add_factors: &[f32], ex_scale_factors: &[f32], ) -> f32 { - let ex_dist = self.ex_code_dot(ex_codes, id, ex_code_len); + let ex_dist = self.ex_code_dot(ex_codes, id); let code_bias = -((1u32 << ex_bits) as f32 - 0.5); let full_dot = (1u32 << ex_bits) as f32 * binary_ip + ex_dist + code_bias * self.sum_q; full_dot * ex_scale_factors[id] + ex_add_factors[id] + self.query_factor @@ -1219,7 +1236,6 @@ impl<'a> RabitDistCalculator<'a> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = ex_dot_code_bytes(self.dim, ex_bits); let query_lower_bound = lower_bound.unwrap_or(f32::MIN); let query_upper_bound = upper_bound.unwrap_or(f32::MAX); let mut max_dist = res.peek().map(|node| node.dist); @@ -1251,7 +1267,6 @@ impl<'a> RabitDistCalculator<'a> { id, binary_ip, ex_bits, - ex_code_len, ex_codes, ex_add_factors, ex_scale_factors, @@ -1444,6 +1459,7 @@ fn quantize_dist_table_u16_into( fn quantize_ex_fastscan_dist_table_into( dim: usize, ex_bits: u8, + ex_code_len: usize, ex_dist_table: &[f32], quantized_dist_table: &mut Vec, ) -> (f32, f32, f32) { @@ -1451,8 +1467,8 @@ fn quantize_ex_fastscan_dist_table_into( let entries_per_dim = 1usize << ex_bits; debug_assert_eq!(ex_dist_table.len(), dim * entries_per_dim); - let num_split_tables = - ex_fastscan_code_len(dim, ex_bits).expect("RabitQ num_bits should be validated") * 2; + // One split table per code nibble of the row. + let num_split_tables = ex_code_len * 2; let quantization_max = (u16::MAX as usize / num_split_tables) .min(u8::MAX as usize) .max(1) as f32; @@ -1490,14 +1506,10 @@ fn supports_ex_fastscan(ex_bits: u8) -> bool { matches!(ex_bits, 2 | 4 | 8) } -#[inline] -fn ex_fastscan_code_len(dim: usize, ex_bits: u8) -> Option { - match ex_bits { - 2 | 4 | 8 => rabit_ex_code_bytes(dim, ex_bits).ok(), - _ => None, - } -} - +/// The FastScan LUT value for one nibble of a blocked-layout code byte: +/// `table_idx / 2` is the byte position within a row and `table_idx % 2` +/// selects its low/high nibble (see the `ex_dot` module docs for the +/// byte-to-dim mapping per width). Padding dims contribute zero. #[inline] fn ex_fastscan_dist_table_value( dim: usize, @@ -1506,21 +1518,39 @@ fn ex_fastscan_dist_table_value( table_idx: usize, code: usize, ) -> f32 { + let byte_idx = table_idx / 2; + let high_nibble = table_idx % 2 == 1; match ex_bits { 2 => { - let dim_idx = table_idx * 2; + // byte 16g+b = dims {64g+b, +16, +32, +48} at bit pairs; the low + // nibble covers the first two dims, the high nibble the last two. + let dim_idx = 64 * (byte_idx / 16) + byte_idx % 16 + 32 * usize::from(high_nibble); let low = code & 0b11; let high = (code >> 2) & 0b11; ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, low) - + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx + 1, high) + + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx + 16, high) + } + 4 => { + // byte 32g+8j+b = dim 64g+16j+b (low nibble) | dim +8 (high). + let in_block = byte_idx % 32; + let dim_idx = 64 * (byte_idx / 32) + + 16 * (in_block / 8) + + in_block % 8 + + 8 * usize::from(high_nibble); + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code) } - 4 => ex_dist_table_value(ex_dist_table, dim, ex_bits, table_idx, code), 8 => { - let dim_idx = table_idx / 2; - if table_idx.is_multiple_of(2) { - ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code) + // byte = dim identity; the high nibble carries code bits 4..8. + if high_nibble { + ex_dist_table_value( + ex_dist_table, + dim, + ex_bits, + byte_idx, + code << SEGMENT_LENGTH, + ) } else { - ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code << SEGMENT_LENGTH) + ex_dist_table_value(ex_dist_table, dim, ex_bits, byte_idx, code) } } _ => unreachable!("unsupported RabitQ ex_bits={ex_bits} for FastScan"), @@ -1553,31 +1583,86 @@ fn maybe_pack_ex_codes( } } -/// Repack sequential ex codes into the bit-plane layout the ex-dot kernels -/// consume, for the widths whose sequential layout is not byte-aligned. -fn maybe_plane_pack_ex_codes( - ex_codes: Option<&FixedSizeListArray>, +/// Bring legacy sequential ex codes into the blocked kernel layout: rows are +/// repacked, except for the widths whose layouts agree byte-for-byte (then +/// the column is used as stored). +fn blocked_ex_codes_from_sequential( + seq_codes: &FixedSizeListArray, dim: usize, ex_bits: u8, -) -> Result> { - let ex_codes = match ex_codes { - Some(ex_codes) if needs_plane_repack(ex_bits) => ex_codes, - _ => return Ok(None), - }; - let seq_code_len = ex_codes.value_length() as usize; - let seq_values = ex_codes.values().as_primitive::().values(); - let plane_code_len = ex_dot_code_bytes(dim, ex_bits); - let mut plane_values = vec![0u8; ex_codes.len() * plane_code_len]; - for (seq_row, plane_row) in seq_values +) -> Result { + if sequential_matches_blocked(ex_bits) + && seq_codes.value_length() as usize == blocked_ex_code_bytes(dim, ex_bits) + { + return Ok(seq_codes.clone()); + } + let seq_code_len = seq_codes.value_length() as usize; + let seq_values = seq_codes.values().as_primitive::().values(); + let blocked_code_len = blocked_ex_code_bytes(dim, ex_bits); + let mut blocked_values = vec![0u8; seq_codes.len() * blocked_code_len]; + for (seq_row, blocked_row) in seq_values .chunks_exact(seq_code_len) - .zip(plane_values.chunks_exact_mut(plane_code_len)) + .zip(blocked_values.chunks_exact_mut(blocked_code_len)) { - plane_pack_row(seq_row, dim, ex_bits, plane_row); + repack_sequential_row(seq_row, dim, ex_bits, blocked_row); } - Ok(Some(FixedSizeListArray::try_new_from_values( - UInt8Array::from(plane_values), - plane_code_len as i32, - )?)) + Ok(FixedSizeListArray::try_new_from_values( + UInt8Array::from(blocked_values), + blocked_code_len as i32, + )?) +} + +/// Load the ex-code column of an index batch into the blocked kernel layout, +/// accepting both the blocked format and the legacy sequential format. Legacy +/// batches are normalized in place (the sequential column is replaced by the +/// blocked one), so rewrites — remap, optimize merges — always emit the +/// blocked format and legacy indexes upgrade on their next rewrite. +pub(crate) fn load_blocked_ex_codes( + batch: RecordBatch, + rotated_dim: usize, + num_bits: u8, +) -> Result<(RecordBatch, FixedSizeListArray)> { + let ex_bits = rabit_ex_bits(num_bits)?; + if let Some(column) = batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN) { + let codes = column.as_fixed_size_list().clone(); + let expected_bytes = blocked_ex_code_bytes(rotated_dim, ex_bits); + if codes.value_length() as usize != expected_bytes { + return Err(Error::invalid_input(format!( + "RabitQ ex-code byte width mismatch: column {} has {} bytes, metadata rotated_dim={} ex_bits={} requires {} bytes", + RABIT_BLOCKED_EX_CODE_COLUMN, + codes.value_length(), + rotated_dim, + ex_bits, + expected_bytes + ))); + } + return Ok((batch, codes)); + } + let column = batch.column_by_name(RABIT_EX_CODE_COLUMN).ok_or_else(|| { + Error::invalid_input(format!( + "RabitQ num_bits={} requires {} column", + num_bits, RABIT_BLOCKED_EX_CODE_COLUMN + )) + })?; + let codes = column.as_fixed_size_list().clone(); + let expected_bytes = rabit_ex_code_bytes(rotated_dim, ex_bits)?; + if codes.value_length() as usize != expected_bytes { + return Err(Error::invalid_input(format!( + "RabitQ ex-code byte width mismatch: column {} has {} bytes, metadata rotated_dim={} ex_bits={} requires {} bytes", + RABIT_EX_CODE_COLUMN, + codes.value_length(), + rotated_dim, + ex_bits, + expected_bytes + ))); + } + let blocked = blocked_ex_codes_from_sequential(&codes, rotated_dim, ex_bits)?; + let ex_code_field = rabit_ex_code_field(rotated_dim, num_bits)? + .expect("multi-bit RabitQ always has an ex-code field"); + let batch = batch + .drop_column(RABIT_EX_CODE_COLUMN)? + .try_with_column(ex_code_field, Arc::new(blocked.clone()))?; + Ok((batch, blocked)) } impl DistCalculator for RabitDistCalculator<'_> { @@ -1610,12 +1695,10 @@ impl DistCalculator for RabitDistCalculator<'_> { let ex_scale_factors = self .ex_scale_factors .expect("raw-query multi-bit RQ requires ex scale factors"); - let ex_code_len = ex_dot_code_bytes(self.dim, ex_bits); self.raw_query_multi_bit_exact_distance( id, dist, ex_bits, - ex_code_len, ex_codes, ex_add_factors, ex_scale_factors, @@ -1900,11 +1983,6 @@ impl VectorStore for RabitQuantizationStorage { } else { Vec::new() }; - let ex_query = if ex_bits > 0 { - build_ex_query(&rotated_qr, ex_bits) - } else { - Vec::new() - }; let query_factor = match self.metadata.query_estimator { RabitQueryEstimator::ResidualQuery => self.residual_query_factor(dist_q_c), RabitQueryEstimator::RawQuery => self.raw_query_factor(dist_q_c, &rotated_qr, None), @@ -1915,7 +1993,16 @@ impl VectorStore for RabitQuantizationStorage { self.raw_query_error_for_gating(dist_q_c, &rotated_qr, None) } }; - let sum_q = rotated_qr.into_iter().sum(); + let sum_q = rotated_qr.iter().copied().sum(); + // The kernels read the rotated query directly; only unaligned dims + // need a zero-padded copy. + let ex_query = if code_dim.is_multiple_of(EX_DOT_BLOCK_DIMS) { + rotated_qr + } else { + let mut padded = vec![0.0; padded_query_len(code_dim)]; + pad_query_into(&rotated_qr, &mut padded); + padded + }; self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, @@ -1961,7 +2048,10 @@ impl VectorStore for RabitQuantizationStorage { dim: code_dim, dist_table: Cow::Borrowed(&raw_query.dist_table), ex_dist_table: Cow::Borrowed(&raw_query.ex_dist_table), - ex_query: Cow::Borrowed(&raw_query.ex_query), + ex_query: Cow::Borrowed(kernel_query( + &raw_query.rotated_query, + &raw_query.ex_query, + )), sum_q: raw_query.sum_q, query_factor, query_error, @@ -1978,10 +2068,12 @@ impl VectorStore for RabitQuantizationStorage { } else { 0 }; - let ex_query_table_len = if ex_bits == 0 { + // The kernels read the rotated query in place; a zero-padded copy is + // only needed when the rotated dim is not block-aligned. + let ex_query_table_len = if ex_bits == 0 || code_dim.is_multiple_of(EX_DOT_BLOCK_DIMS) { 0 } else { - ex_query_len(code_dim) + padded_query_len(code_dim) }; f32_scratch.resize( code_dim + dist_table_len + ex_dist_table_len + ex_query_table_len, @@ -2033,8 +2125,8 @@ impl VectorStore for RabitQuantizationStorage { }; build_dist_table_direct_into::(rotated_qr, dist_table); build_ex_dist_table_direct_into(rotated_qr, ex_bits, ex_dist_table); - if ex_bits > 0 { - build_ex_query_into(rotated_qr, ex_bits, ex_query); + if ex_query_table_len > 0 { + pad_query_into(rotated_qr, ex_query); } rotated_qr.iter().copied().sum() }; @@ -2045,9 +2137,10 @@ impl VectorStore for RabitQuantizationStorage { dim: code_dim, dist_table: Cow::Borrowed(&f32_scratch[code_dim..ex_dist_table_start]), ex_dist_table: Cow::Borrowed(&f32_scratch[ex_dist_table_start..ex_query_start]), - ex_query: Cow::Borrowed( + ex_query: Cow::Borrowed(kernel_query( + &f32_scratch[..code_dim], &f32_scratch[ex_query_start..ex_query_start + ex_query_table_len], - ), + )), sum_q, query_factor, query_error, @@ -2248,31 +2341,14 @@ impl QuantizerStorage for RabitQuantizationStorage { .column_by_name(ERROR_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); let ex_bits = rabit_ex_bits(metadata.num_bits)?; + let mut batch = batch; let mut ex_codes = None; let mut ex_add_factors = None; let mut ex_scale_factors = None; if ex_bits != 0 { - let codes = batch - .column_by_name(RABIT_EX_CODE_COLUMN) - .ok_or_else(|| { - Error::invalid_input(format!( - "RabitQ num_bits={} requires {} column", - metadata.num_bits, RABIT_EX_CODE_COLUMN - )) - })? - .as_fixed_size_list() - .clone(); - let expected_ex_code_bytes = rabit_ex_code_bytes(metadata.rotated_dim(), ex_bits)?; - if codes.value_length() as usize != expected_ex_code_bytes { - return Err(Error::invalid_input(format!( - "RabitQ ex-code byte width mismatch: column {} has {} bytes, metadata rotated_dim={} ex_bits={} requires {} bytes", - RABIT_EX_CODE_COLUMN, - codes.value_length(), - metadata.rotated_dim(), - ex_bits, - expected_ex_code_bytes - ))); - } + let (normalized_batch, codes) = + load_blocked_ex_codes(batch, metadata.rotated_dim(), metadata.num_bits)?; + batch = normalized_batch; ex_codes = Some(codes); ex_add_factors = Some( batch @@ -2302,16 +2378,19 @@ impl QuantizerStorage for RabitQuantizationStorage { if batch.column_by_name(EX_ADD_FACTORS_COLUMN).is_some() || batch.column_by_name(EX_SCALE_FACTORS_COLUMN).is_some() || batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() + || batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN).is_some() { return Err(Error::invalid_input( "RabitQ num_bits=1 raw-query indexes must not contain ex-code columns" .to_string(), )); } - } else if batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() { + } else if batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() + || batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN).is_some() + { return Err(Error::invalid_input(format!( - "RabitQ num_bits={} does not support {} column", - metadata.num_bits, RABIT_EX_CODE_COLUMN + "RabitQ num_bits={} does not support ex-code columns", + metadata.num_bits ))); } @@ -2327,8 +2406,6 @@ impl QuantizerStorage for RabitQuantizationStorage { let mut metadata = metadata.clone(); metadata.packed = true; let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); - let plane_ex_codes = - maybe_plane_pack_ex_codes(ex_codes.as_ref(), metadata.rotated_dim(), ex_bits)?; Ok(Self { metadata, @@ -2341,7 +2418,6 @@ impl QuantizerStorage for RabitQuantizationStorage { error_factors, ex_codes, packed_ex_codes, - plane_ex_codes, ex_add_factors, ex_scale_factors, }) @@ -2412,13 +2488,17 @@ impl QuantizerStorage for RabitQuantizationStorage { let error_factors = batch .column_by_name(ERROR_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); - let ex_codes = batch - .column_by_name(RABIT_EX_CODE_COLUMN) - .map(|codes| codes.as_fixed_size_list().clone()); let ex_bits = rabit_ex_bits(self.metadata.num_bits)?; + let (batch, ex_codes) = if ex_bits == 0 { + (batch, None) + } else { + // `self.batch` is already normalized at load, so this is a + // zero-copy column lookup. + let (batch, codes) = + load_blocked_ex_codes(batch, self.metadata.rotated_dim(), self.metadata.num_bits)?; + (batch, Some(codes)) + }; let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); - let plane_ex_codes = - maybe_plane_pack_ex_codes(ex_codes.as_ref(), self.metadata.rotated_dim(), ex_bits)?; let ex_add_factors = batch .column_by_name(EX_ADD_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); @@ -2436,7 +2516,6 @@ impl QuantizerStorage for RabitQuantizationStorage { error_factors, ex_codes, packed_ex_codes, - plane_ex_codes, ex_add_factors, ex_scale_factors, row_ids: new_row_ids, @@ -2757,7 +2836,7 @@ mod tests { assert!(rabit_ex_code_field(128, 1).unwrap().is_none()); let ex_field = rabit_ex_code_field(128, 9).unwrap().unwrap(); - assert_eq!(ex_field.name(), RABIT_EX_CODE_COLUMN); + assert_eq!(ex_field.name(), RABIT_BLOCKED_EX_CODE_COLUMN); let DataType::FixedSizeList(_, ex_code_bytes) = ex_field.data_type() else { panic!("ex-code field should be FixedSizeList"); }; @@ -2972,189 +3051,212 @@ mod tests { use rand::{Rng, SeedableRng}; // 72 exercises the kernels' padded-tail path; 1536 is a production - // embedding dim exercising the full-group path. + // embedding dim exercising the full-group path. Both the blocked + // format and the legacy sequential format must produce the same + // distances. for (code_dim, num_rows) in [(72usize, 33usize), (1536, 33)] { for num_bits in 2..=9u8 { - let ex_bits = num_bits - 1; - let mut rng = SmallRng::seed_from_u64(num_bits as u64); - - let sign_bits = (0..num_rows * code_dim) - .map(|_| rng.random_bool(0.5)) - .collect::>(); - let max_code = ((1u16 << ex_bits) - 1) as u8; - let ex_values = (0..num_rows * code_dim) - .map(|_| rng.random_range(0..=max_code)) - .collect::>(); - - let code_len = rabit_binary_code_bytes(code_dim); - let mut code_bytes = vec![0u8; num_rows * code_len]; - for (row, bits) in sign_bits.chunks_exact(code_dim).enumerate() { - for (dim, &bit) in bits.iter().enumerate() { - code_bytes[row * code_len + dim / 8] |= (bit as u8) << (dim % 8); - } - } - let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); - let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; - for (row, values) in ex_values.chunks_exact(code_dim).enumerate() { - for (dim, &value) in values.iter().enumerate() { - let bit_offset = dim * ex_bits as usize; - let bits = (value as u16) << (bit_offset % 8); - ex_code_bytes[row * ex_code_len + bit_offset / 8] |= bits as u8; - if bits >> 8 != 0 { - ex_code_bytes[row * ex_code_len + bit_offset / 8 + 1] |= - (bits >> 8) as u8; + for legacy_format in [false, true] { + let ex_bits = num_bits - 1; + let mut rng = SmallRng::seed_from_u64(num_bits as u64); + + let sign_bits = (0..num_rows * code_dim) + .map(|_| rng.random_bool(0.5)) + .collect::>(); + let max_code = ((1u16 << ex_bits) - 1) as u8; + let ex_values = (0..num_rows * code_dim) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + + let code_len = rabit_binary_code_bytes(code_dim); + let mut code_bytes = vec![0u8; num_rows * code_len]; + for (row, bits) in sign_bits.chunks_exact(code_dim).enumerate() { + for (dim, &bit) in bits.iter().enumerate() { + code_bytes[row * code_len + dim / 8] |= (bit as u8) << (dim % 8); } } - } - - let identity = Float32Array::from_iter_values((0..code_dim).flat_map(|row| { - (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 }) - })); - let rotate_mat = - FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); - let metadata = RabitQuantizationMetadata { - rotate_mat: Some(rotate_mat), - rotate_mat_position: None, - fast_rotation_signs: None, - rotation_type: RQRotationType::Matrix, - code_dim: code_dim as u32, - num_bits, - packed: false, - query_estimator: RabitQueryEstimator::RawQuery, - }; - let codes = FixedSizeListArray::try_new_from_values( - UInt8Array::from(code_bytes), - code_len as i32, - ) - .unwrap(); - let ex_codes = FixedSizeListArray::try_new_from_values( - UInt8Array::from(ex_code_bytes), - ex_code_len as i32, - ) - .unwrap(); - let ex_add_factors = (0..num_rows) - .map(|_| rng.random_range(-1.0f32..1.0)) - .collect::>(); - let ex_scale_factors = (0..num_rows) - .map(|_| rng.random_range(0.1f32..1.0)) - .collect::>(); - let batch = RecordBatch::try_from_iter(vec![ - ( - ROW_ID, - Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef, - ), - (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef), - ( - ADD_FACTORS_COLUMN, - Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, - ), - ( - SCALE_FACTORS_COLUMN, - Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, - ), - (RABIT_EX_CODE_COLUMN, Arc::new(ex_codes) as ArrayRef), - ( - EX_ADD_FACTORS_COLUMN, - Arc::new(Float32Array::from(ex_add_factors.clone())) as ArrayRef, - ), - ( - EX_SCALE_FACTORS_COLUMN, - Arc::new(Float32Array::from(ex_scale_factors.clone())) as ArrayRef, - ), - ]) - .unwrap(); - let storage = RabitQuantizationStorage::try_from_batch( - batch, - &metadata, - DistanceType::L2, - None, - ) - .unwrap(); - - let query = (0..code_dim) - .map(|_| rng.random_range(-1.0f32..1.0)) - .collect::>(); - let sum_q = query.iter().sum::(); - let calc = storage - .dist_calculator(Arc::new(Float32Array::from(query.clone())) as ArrayRef, 0.0); - - let code_scale = (1u32 << ex_bits) as f32; - let code_bias = -(code_scale - 0.5); - let expected = (0..num_rows) - .map(|row| { - let binary_ip = (0..code_dim) - .map(|dim| query[dim] * sign_bits[row * code_dim + dim] as u8 as f32) - .sum::(); - let ex_dist = (0..code_dim) - .map(|dim| query[dim] * ex_values[row * code_dim + dim] as f32) - .sum::(); - let full_dot = code_scale * binary_ip + ex_dist + code_bias * sum_q; - full_dot * ex_scale_factors[row] + ex_add_factors[row] - }) - .collect::>(); - - for (row, &want) in expected.iter().enumerate() { - let got = calc.distance(row as u32); - assert!( - (got - want).abs() <= 1e-3 * want.abs().max(1.0), - "num_bits={num_bits} row={row}: {got} != {want}" + let (ex_code_column, ex_code_len, ex_code_bytes) = if legacy_format { + let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_values.chunks_exact(code_dim).enumerate() { + for (dim, &value) in values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + ex_code_bytes[row * ex_code_len + bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + ex_code_bytes[row * ex_code_len + bit_offset / 8 + 1] |= + (bits >> 8) as u8; + } + } + } + (RABIT_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + } else { + let ex_code_len = blocked_ex_code_bytes(code_dim, ex_bits); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_code_bytes + .chunks_exact_mut(ex_code_len) + .zip(ex_values.chunks_exact(code_dim)) + { + crate::vector::bq::ex_dot::pack_blocked_row(values, ex_bits, row); + } + (RABIT_BLOCKED_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + }; + + let identity = Float32Array::from_iter_values((0..code_dim).flat_map(|row| { + (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 }) + })); + let rotate_mat = + FixedSizeListArray::try_new_from_values(identity, code_dim as i32).unwrap(); + let metadata = RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type: RQRotationType::Matrix, + code_dim: code_dim as u32, + num_bits, + packed: false, + query_estimator: RabitQueryEstimator::RawQuery, + }; + let codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from(code_bytes), + code_len as i32, + ) + .unwrap(); + let ex_codes = FixedSizeListArray::try_new_from_values( + UInt8Array::from(ex_code_bytes), + ex_code_len as i32, + ) + .unwrap(); + let ex_add_factors = (0..num_rows) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let ex_scale_factors = (0..num_rows) + .map(|_| rng.random_range(0.1f32..1.0)) + .collect::>(); + let batch = RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef, + ), + (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, + ), + (ex_code_column, Arc::new(ex_codes) as ArrayRef), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(ex_add_factors.clone())) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(ex_scale_factors.clone())) as ArrayRef, + ), + ]) + .unwrap(); + let storage = RabitQuantizationStorage::try_from_batch( + batch, + &metadata, + DistanceType::L2, + None, + ) + .unwrap(); + + let query = (0..code_dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(); + let sum_q = query.iter().sum::(); + let calc = storage.dist_calculator( + Arc::new(Float32Array::from(query.clone())) as ArrayRef, + 0.0, ); - } - let mut distances = Vec::new(); - let mut u16_scratch = Vec::new(); - let mut u8_scratch = Vec::new(); - let mut u32_scratch = Vec::new(); - calc.distance_all_with_scratch( - 0, - &mut distances, - &mut u16_scratch, - &mut u8_scratch, - &mut u32_scratch, - ); - assert_eq!(distances.len(), num_rows); - // The bulk path quantizes the binary LUT to u8, and that error is - // amplified by 2^ex_bits in the multi-bit estimate, so the value - // assertions need a quantization-aware bound. The FastScan ex - // widths additionally quantize the ex LUT and are covered by - // `test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes`. - if !matches!(ex_bits, 2 | 4 | 8) { - // Worst-case |error| of one u8-quantized binary LUT lookup is - // (table range) / 255 / 2, accumulated over one lookup per - // 8-dim pair of segments. - let num_tables = code_dim.div_ceil(4); - let mut table_min = f32::INFINITY; - let mut table_max = f32::NEG_INFINITY; - for segment in query.chunks(4) { - for subset in 0..16usize { - let value = segment - .iter() - .enumerate() - .filter(|(idx, _)| subset & (1 << idx) != 0) - .map(|(_, q)| *q) + let code_scale = (1u32 << ex_bits) as f32; + let code_bias = -(code_scale - 0.5); + let expected = (0..num_rows) + .map(|row| { + let binary_ip = (0..code_dim) + .map(|dim| { + query[dim] * sign_bits[row * code_dim + dim] as u8 as f32 + }) .sum::(); - table_min = table_min.min(value); - table_max = table_max.max(value); - } - } - let binary_bound = - code_scale * num_tables as f32 * (table_max - table_min) / 255.0 / 2.0 - * ex_scale_factors.iter().fold(0.0f32, |max, &s| max.max(s)); - for (row, (&got, &want)) in distances.iter().zip(expected.iter()).enumerate() { + let ex_dist = (0..code_dim) + .map(|dim| query[dim] * ex_values[row * code_dim + dim] as f32) + .sum::(); + let full_dot = code_scale * binary_ip + ex_dist + code_bias * sum_q; + full_dot * ex_scale_factors[row] + ex_add_factors[row] + }) + .collect::>(); + + for (row, &want) in expected.iter().enumerate() { + let got = calc.distance(row as u32); assert!( - (got - want).abs() <= binary_bound + 1e-3, - "num_bits={num_bits} row={row} (distance_all): {got} != {want} (bound {binary_bound})" + (got - want).abs() <= 1e-3 * want.abs().max(1.0), + "num_bits={num_bits} row={row}: {got} != {want}" ); } - // Rows past the SIMD batch use the exact binary path, so the - // final remainder row must match the per-candidate distance. - let remainder_row = num_rows - 1; - let got = distances[remainder_row]; - let want = calc.distance(remainder_row as u32); - assert!( - (got - want).abs() <= 1e-3 * want.abs().max(1.0), - "num_bits={num_bits} remainder row (distance_all): {got} != {want}" + + let mut distances = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + let mut u32_scratch = Vec::new(); + calc.distance_all_with_scratch( + 0, + &mut distances, + &mut u16_scratch, + &mut u8_scratch, + &mut u32_scratch, ); + assert_eq!(distances.len(), num_rows); + // The bulk path quantizes the binary LUT to u8, and that error is + // amplified by 2^ex_bits in the multi-bit estimate, so the value + // assertions need a quantization-aware bound. The FastScan ex + // widths additionally quantize the ex LUT and are covered by + // `test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes`. + if !matches!(ex_bits, 2 | 4 | 8) { + // Worst-case |error| of one u8-quantized binary LUT lookup is + // (table range) / 255 / 2, accumulated over one lookup per + // 8-dim pair of segments. + let num_tables = code_dim.div_ceil(4); + let mut table_min = f32::INFINITY; + let mut table_max = f32::NEG_INFINITY; + for segment in query.chunks(4) { + for subset in 0..16usize { + let value = segment + .iter() + .enumerate() + .filter(|(idx, _)| subset & (1 << idx) != 0) + .map(|(_, q)| *q) + .sum::(); + table_min = table_min.min(value); + table_max = table_max.max(value); + } + } + let binary_bound = + code_scale * num_tables as f32 * (table_max - table_min) / 255.0 / 2.0 + * ex_scale_factors.iter().fold(0.0f32, |max, &s| max.max(s)); + for (row, (&got, &want)) in + distances.iter().zip(expected.iter()).enumerate() + { + assert!( + (got - want).abs() <= binary_bound + 1e-3, + "num_bits={num_bits} row={row} (distance_all): {got} != {want} (bound {binary_bound})" + ); + } + // Rows past the SIMD batch use the exact binary path, so the + // final remainder row must match the per-candidate distance. + let remainder_row = num_rows - 1; + let got = distances[remainder_row]; + let want = calc.distance(remainder_row as u32); + assert!( + (got - want).abs() <= 1e-3 * want.abs().max(1.0), + "num_bits={num_bits} remainder row (distance_all): {got} != {want}" + ); + } } } } @@ -3323,10 +3425,13 @@ mod tests { assert_eq!(hacc_accum_len, num_rows); } - fn assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits: u8) { - let code_dim = 8usize; + fn assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits: u8, legacy_format: bool) { + // Not a multiple of 64, so the padded-tail LUT entries are exercised; + // a multiple of 8 as the binary stage requires. + let code_dim = 72usize; let num_rows = BATCH_SIZE + 1; let ex_bits = rabit_ex_bits(num_bits).unwrap(); + let max_code = ((1u16 << ex_bits) - 1) as u8; let identity = Float32Array::from_iter_values( (0..code_dim) .flat_map(|row| (0..code_dim).map(move |col| if row == col { 1.0 } else { 0.0 })), @@ -3343,16 +3448,42 @@ mod tests { packed: false, query_estimator: RabitQueryEstimator::RawQuery, }; + let code_len = rabit_binary_code_bytes(code_dim); let codes = FixedSizeListArray::try_new_from_values( - UInt8Array::from_iter_values((0..num_rows).map(|idx| (idx * 13) as u8)), - 1, + UInt8Array::from_iter_values((0..num_rows * code_len).map(|idx| (idx * 13) as u8)), + code_len as i32, ) .unwrap(); - let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let ex_values = (0..num_rows * code_dim) + .map(|idx| ((idx * 37) % (max_code as usize + 1)) as u8) + .collect::>(); + let (ex_code_column, ex_code_len, ex_code_bytes) = if legacy_format { + let ex_code_len = rabit_ex_code_bytes(code_dim, ex_bits).unwrap(); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_values.chunks_exact(code_dim).enumerate() { + for (dim, &value) in values.iter().enumerate() { + let bit_offset = dim * ex_bits as usize; + let bits = (value as u16) << (bit_offset % 8); + ex_code_bytes[row * ex_code_len + bit_offset / 8] |= bits as u8; + if bits >> 8 != 0 { + ex_code_bytes[row * ex_code_len + bit_offset / 8 + 1] |= (bits >> 8) as u8; + } + } + } + (RABIT_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + } else { + let ex_code_len = blocked_ex_code_bytes(code_dim, ex_bits); + let mut ex_code_bytes = vec![0u8; num_rows * ex_code_len]; + for (row, values) in ex_code_bytes + .chunks_exact_mut(ex_code_len) + .zip(ex_values.chunks_exact(code_dim)) + { + crate::vector::bq::ex_dot::pack_blocked_row(values, ex_bits, row); + } + (RABIT_BLOCKED_EX_CODE_COLUMN, ex_code_len, ex_code_bytes) + }; let ex_codes = FixedSizeListArray::try_new_from_values( - UInt8Array::from_iter_values( - (0..num_rows * ex_code_len).map(|idx| (idx * 37 % 251) as u8), - ), + UInt8Array::from(ex_code_bytes), ex_code_len as i32, ) .unwrap(); @@ -3370,7 +3501,7 @@ mod tests { SCALE_FACTORS_COLUMN, Arc::new(Float32Array::from(vec![1.0; num_rows])) as ArrayRef, ), - (RABIT_EX_CODE_COLUMN, Arc::new(ex_codes) as ArrayRef), + (ex_code_column, Arc::new(ex_codes) as ArrayRef), ( EX_ADD_FACTORS_COLUMN, Arc::new(Float32Array::from(vec![0.0; num_rows])) as ArrayRef, @@ -3386,7 +3517,12 @@ mod tests { .unwrap(); assert!(storage.packed_ex_codes.is_some()); - let query = Arc::new(Float32Array::from(vec![1.0; code_dim])) as ArrayRef; + // A per-dim varying query so that any dim-mapping error in the + // FastScan LUT shows up as a value mismatch. + let query_values = (0..code_dim) + .map(|dim| (dim % 11) as f32 * 0.3 - 1.5) + .collect::>(); + let query = Arc::new(Float32Array::from(query_values.clone())) as ArrayRef; let calc = storage.dist_calculator(query, 0.0); let mut distances = Vec::new(); let mut u16_scratch = Vec::new(); @@ -3402,15 +3538,45 @@ mod tests { assert_eq!(distances.len(), num_rows); assert_eq!(u16_scratch.len(), BATCH_SIZE); - assert_eq!( - u8_scratch.len(), - ex_fastscan_code_len(code_dim, ex_bits).unwrap() * 2 * SEGMENT_NUM_CODES + let loaded_ex_code_len = storage.ex_codes.as_ref().unwrap().value_length() as usize; + assert_eq!(u8_scratch.len(), loaded_ex_code_len * 2 * SEGMENT_NUM_CODES); + + // The fastscan estimate differs from the exact path only by the u8 + // quantization of the binary LUT (amplified by 2^ex_bits) and of the + // ex LUT, so bound the comparison by those quantization errors. + let mut table_min = f32::INFINITY; + let mut table_max = f32::NEG_INFINITY; + for segment in query_values.chunks(4) { + for subset in 0..SEGMENT_NUM_CODES { + let value = segment + .iter() + .enumerate() + .filter(|(idx, _)| subset & (1 << idx) != 0) + .map(|(_, q)| *q) + .sum::(); + table_min = table_min.min(value); + table_max = table_max.max(value); + } + } + let code_scale = (1u32 << ex_bits) as f32; + let binary_bound = + code_scale * code_dim.div_ceil(4) as f32 * (table_max - table_min) / 510.0; + let ex_dist_table = build_ex_dist_table_direct(&query_values, ex_bits); + let mut quantized_table = Vec::new(); + let (ex_qmin, ex_qmax, ex_qcap) = quantize_ex_fastscan_dist_table_into( + code_dim, + ex_bits, + loaded_ex_code_len, + &ex_dist_table, + &mut quantized_table, ); + let ex_bound = (loaded_ex_code_len * 2) as f32 * (ex_qmax - ex_qmin) / ex_qcap / 2.0; + let bound = (binary_bound + ex_bound) * 1.5 + 1e-3; for (id, distance) in distances.iter().take(BATCH_SIZE).enumerate() { let exact = calc.distance(id as u32); assert!( - (*distance - exact).abs() < 10.0, - "distance_all fastscan mismatch for id {id}: actual={distance}, exact={exact}" + (*distance - exact).abs() <= bound, + "distance_all fastscan mismatch for id {id} (num_bits={num_bits} legacy={legacy_format}): actual={distance}, exact={exact}, bound={bound}" ); } assert_eq!(distances[BATCH_SIZE], calc.distance(BATCH_SIZE as u32)); @@ -3418,8 +3584,10 @@ mod tests { #[test] fn test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes() { - for num_bits in [3, 9] { - assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits); + for num_bits in [3, 5, 9] { + for legacy_format in [false, true] { + assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits, legacy_format); + } } } @@ -3501,7 +3669,6 @@ mod tests { id, binary_ip, ex_bits, - ex_code_len, ex_codes, ex_add_factors, ex_scale_factors, @@ -3719,7 +3886,8 @@ mod tests { ) .unwrap_err(); assert!( - err.to_string().contains("requires __ex_codes column"), + err.to_string() + .contains("requires __blocked_ex_codes column"), "{}", err ); @@ -3763,9 +3931,11 @@ mod tests { .unwrap(); assert!(storage.metadata().packed); + // Legacy batches are normalized to the blocked column at load. let stored_batch = storage.to_batches().unwrap().next().unwrap(); + assert!(stored_batch.column_by_name(RABIT_EX_CODE_COLUMN).is_none()); assert_eq!( - stored_batch[RABIT_EX_CODE_COLUMN] + stored_batch[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .value_length(), 64 @@ -3869,9 +4039,11 @@ mod tests { ); assert_eq!(remapped_row_ids, expected_row_ids.values()); - let ex_code_len = rabit_ex_code_bytes(code_dim, rabit_ex_bits(num_bits).unwrap()).unwrap(); + // Legacy batches are normalized to the blocked format at load, so the + // remapped batch carries the blocked column. + let ex_code_len = blocked_ex_code_bytes(code_dim, rabit_ex_bits(num_bits).unwrap()); assert_eq!( - remapped_batch[RABIT_EX_CODE_COLUMN] + remapped_batch[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .value_length(), ex_code_len as i32 @@ -3904,10 +4076,10 @@ mod tests { None, ) .unwrap(); - assert_eq!(remapped.plane_ex_codes, reloaded.plane_ex_codes); + assert_eq!(remapped.ex_codes, reloaded.ex_codes); assert_eq!( - remapped.plane_ex_codes.is_some(), - needs_plane_repack(rabit_ex_bits(num_bits).unwrap()) + remapped.ex_codes.as_ref().unwrap().value_length() as usize, + blocked_ex_code_bytes(code_dim, rabit_ex_bits(num_bits).unwrap()) ); } } diff --git a/rust/lance-index/src/vector/bq/transform.rs b/rust/lance-index/src/vector/bq/transform.rs index c2fc0608102..c87695e14cd 100644 --- a/rust/lance-index/src/vector/bq/transform.rs +++ b/rust/lance-index/src/vector/bq/transform.rs @@ -17,7 +17,9 @@ use tracing::instrument; use crate::vector::bq::builder::RabitQuantizer; use crate::vector::bq::rabit_ex_bits; -use crate::vector::bq::storage::{RABIT_CODE_COLUMN, RABIT_EX_CODE_COLUMN, RabitQueryEstimator}; +use crate::vector::bq::storage::{ + RABIT_BLOCKED_EX_CODE_COLUMN, RABIT_CODE_COLUMN, RabitQueryEstimator, +}; use crate::vector::quantizer::Quantization; use crate::vector::transform::Transformer; use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN}; @@ -281,7 +283,7 @@ impl Transformer for RQTransformer { #[instrument(name = "RQTransformer::transform", level = "debug", skip_all)] fn transform(&self, batch: &RecordBatch) -> Result { let has_split_codes = self.rq.num_bits() == 1 - || (batch.column_by_name(RABIT_EX_CODE_COLUMN).is_some() + || (batch.column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN).is_some() && batch.column_by_name(EX_ADD_FACTORS_COLUMN).is_some() && batch.column_by_name(EX_SCALE_FACTORS_COLUMN).is_some()); if batch.column_by_name(RABIT_CODE_COLUMN).is_some() && has_split_codes { @@ -494,7 +496,8 @@ mod tests { use crate::vector::bq::RQRotationType; use crate::vector::bq::builder::RabitQuantizer; - use crate::vector::bq::storage::RABIT_EX_CODE_COLUMN; + use crate::vector::bq::ex_dot::blocked_ex_code_bytes; + use crate::vector::bq::storage::RABIT_BLOCKED_EX_CODE_COLUMN; use crate::vector::transform::Transformer; use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN}; @@ -535,15 +538,19 @@ mod tests { .unwrap(); let transformed = transformer.transform(&batch).unwrap(); - assert!(transformed.column_by_name(RABIT_EX_CODE_COLUMN).is_some()); + assert!( + transformed + .column_by_name(RABIT_BLOCKED_EX_CODE_COLUMN) + .is_some() + ); assert_eq!( - transformed[RABIT_EX_CODE_COLUMN] + transformed[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .value_length(), - 3 + blocked_ex_code_bytes(8, 3) as i32 ); assert!( - transformed[RABIT_EX_CODE_COLUMN] + transformed[RABIT_BLOCKED_EX_CODE_COLUMN] .as_fixed_size_list() .values() .as_primitive::() diff --git a/rust/lance-index/src/vector/distributed/index_merger.rs b/rust/lance-index/src/vector/distributed/index_merger.rs index 5f59985673e..70371ad4794 100755 --- a/rust/lance-index/src/vector/distributed/index_merger.rs +++ b/rust/lance-index/src/vector/distributed/index_merger.rs @@ -1440,6 +1440,25 @@ pub async fn merge_partial_vector_auxiliary_files( ))); } + // Shards written by older lance versions carry sequential ex + // codes; normalize every batch to the blocked layout before + // concatenation so mixed-version shards merge correctly + // (concat_batches combines columns by position and would + // otherwise mix the two layouts silently). + let batches = match rq_meta.as_ref() { + Some(meta) if meta.num_bits > 1 => batches + .into_iter() + .map(|batch| { + crate::vector::bq::storage::load_blocked_ex_codes( + batch, + meta.rotated_dim(), + meta.num_bits, + ) + .map(|(batch, _)| batch) + }) + .collect::>>()?, + _ => batches, + }; let schema = batches[0].schema(); let partition_batch = concat_batches(&schema, batches.iter())?; if let Some(w) = v2w_opt.as_mut() { @@ -1527,7 +1546,7 @@ mod tests { use prost::Message; use crate::vector::bq::RQRotationType; - use crate::vector::bq::storage::{RABIT_EX_CODE_COLUMN, RabitQueryEstimator}; + use crate::vector::bq::storage::{RABIT_BLOCKED_EX_CODE_COLUMN, RabitQueryEstimator}; use crate::vector::bq::transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}; lance_testing::define_stage_event_progress!( RecordingProgress, @@ -2529,11 +2548,14 @@ mod tests { let batch = batch.unwrap(); if !checked_split_columns { let schema = batch.schema(); - let ex_code_field = schema.field_with_name(RABIT_EX_CODE_COLUMN).unwrap(); + let ex_code_field = schema + .field_with_name(RABIT_BLOCKED_EX_CODE_COLUMN) + .unwrap(); let DataType::FixedSizeList(_, ex_code_bytes) = ex_code_field.data_type() else { panic!("RQ ex-code field should be FixedSizeList"); }; - assert_eq!(*ex_code_bytes, 6); + // code_dim=16 padded to one 64-dim block at ex_bits=3. + assert_eq!(*ex_code_bytes, 24); assert!(schema.field_with_name(ERROR_FACTORS_FIELD.name()).is_ok()); assert!(schema.field_with_name(EX_ADD_FACTORS_COLUMN).is_ok()); assert!(schema.field_with_name(EX_SCALE_FACTORS_COLUMN).is_ok()); diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index 2b7382e1c65..9a048b5ebbf 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -252,8 +252,9 @@ pub struct RabitRawQueryContext { /// Quantized-table input for the FastScan ex path; empty for ex widths /// without FastScan support. pub ex_dist_table: Vec, - /// Rotated query permuted into ex-dot kernel order (see - /// `lance_index::vector::bq::ex_dot::build_ex_query_into`). + /// The rotated query zero-padded to a 64-dim multiple for the ex-dot + /// kernels; empty when `code_dim` is already aligned (the kernels then + /// read `rotated_query` directly). pub ex_query: Vec, pub sum_q: f32, } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 4c806355e2b..885c9f6a5e5 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -38,9 +38,9 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::{LocalMetricsCollector, MetricsCollector, NoOpMetricsCollector}; use lance_index::vector::VectorIndexCacheEntry; use lance_index::vector::bq::builder::RabitQuantizer; -use lance_index::vector::bq::ex_dot::ex_query_len; +use lance_index::vector::bq::ex_dot::{blocked_ex_code_bytes, padded_query_len}; +use lance_index::vector::bq::rabit_ex_bits; use lance_index::vector::bq::storage::{RabitQueryEstimator, SEGMENT_NUM_CODES}; -use lance_index::vector::bq::{rabit_ex_bits, rabit_ex_code_bytes}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::graph::OrderedNode; use lance_index::vector::hnsw::HNSW; @@ -154,16 +154,21 @@ fn rotated_partition_centroid_slice( } /// `f32` scratch needed for the ex-bit query state: the quantized-table input -/// for FastScan-supported widths plus the kernel-order query for the exact -/// rerank path. +/// for FastScan-supported widths, plus a zero-padded query copy when the +/// rotated dim is not a multiple of the 64-dim kernel block. fn rabit_ex_scratch_len(dim: usize, num_bits: u8) -> usize { + let padded_query = if dim.is_multiple_of(64) { + 0 + } else { + padded_query_len(dim) + }; rabit_ex_bits(num_bits) .map(|ex_bits| match ex_bits { 0 => 0, - 2 | 4 | 8 => dim * (1usize << usize::from(ex_bits)) + ex_query_len(dim), - _ => ex_query_len(dim), + 2 | 4 | 8 => dim * (1usize << usize::from(ex_bits)) + padded_query, + _ => padded_query, }) - .unwrap_or(dim * 256 + ex_query_len(dim)) + .unwrap_or(dim * 256 + padded_query) } fn rabit_u8_scratch_len(dim: usize, num_bits: u8) -> usize { @@ -171,7 +176,7 @@ fn rabit_u8_scratch_len(dim: usize, num_bits: u8) -> usize { let ex_dist_table_len = rabit_ex_bits(num_bits) .ok() .and_then(|ex_bits| match ex_bits { - 2 | 4 | 8 => rabit_ex_code_bytes(dim, ex_bits).ok(), + 2 | 4 | 8 => Some(blocked_ex_code_bytes(dim, ex_bits)), _ => None, }) .map(|ex_code_len| ex_code_len * 2 * SEGMENT_NUM_CODES) @@ -1910,9 +1915,8 @@ mod tests { use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::bq::{ RQBuildParams, RQRotationType, - ex_dot::ex_query_len, - rabit_ex_code_bytes, - storage::{RABIT_EX_CODE_COLUMN, RabitQuantizationMetadata, RabitQueryEstimator}, + ex_dot::{blocked_ex_code_bytes, padded_query_len}, + storage::{RABIT_BLOCKED_EX_CODE_COLUMN, RabitQuantizationMetadata, RabitQueryEstimator}, transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}, }; use lance_index::vector::storage::VectorStore; @@ -1988,14 +1992,17 @@ mod tests { #[test] fn test_rabit_ex_scratch_len_uses_num_bits() { + // 960 is block-aligned, so no padded query copy is needed. let dim = 960; - let ex_query = ex_query_len(dim); - assert_eq!(super::rabit_ex_scratch_len(dim, 1), 0); - assert_eq!(super::rabit_ex_scratch_len(dim, 3), dim * 4 + ex_query); - assert_eq!(super::rabit_ex_scratch_len(dim, 5), dim * 16 + ex_query); - assert_eq!(super::rabit_ex_scratch_len(dim, 7), ex_query); - assert_eq!(super::rabit_ex_scratch_len(dim, 9), dim * 256 + ex_query); + assert_eq!(super::rabit_ex_scratch_len(dim, 3), dim * 4); + assert_eq!(super::rabit_ex_scratch_len(dim, 5), dim * 16); + assert_eq!(super::rabit_ex_scratch_len(dim, 7), 0); + assert_eq!(super::rabit_ex_scratch_len(dim, 9), dim * 256); + + // Unaligned dims add one padded query copy. + let dim = 968; + assert_eq!(super::rabit_ex_scratch_len(dim, 7), padded_query_len(dim)); } #[test] @@ -2017,10 +2024,7 @@ mod tests { let capacity = super::rabit_query_scratch_capacity(dim, max_partition_len, 5); assert_eq!(capacity.distances, max_partition_len); - assert_eq!( - capacity.query_f32, - dim + dim * 4 + dim * 16 + ex_query_len(dim) - ); + assert_eq!(capacity.query_f32, dim + dim * 4 + dim * 16); assert_eq!(capacity.u16, max_partition_len); assert_eq!(capacity.u8, dim * 16); assert_eq!(capacity.u32, 0); @@ -4446,12 +4450,12 @@ mod tests { let reader = open_rq_aux_reader(&dataset, scheduler, &index_uuid).await; let schema = reader.schema(); - let ex_field = schema.field(RABIT_EX_CODE_COLUMN).unwrap(); + let ex_field = schema.field(RABIT_BLOCKED_EX_CODE_COLUMN).unwrap(); let DataType::FixedSizeList(_, ex_code_bytes) = ex_field.data_type() else { panic!("RQ ex-code field should be FixedSizeList"); }; let expected_ex_code_bytes = - rabit_ex_code_bytes(rq_meta.rotated_dim(), num_bits - 1).unwrap() as i32; + blocked_ex_code_bytes(rq_meta.rotated_dim(), num_bits - 1) as i32; assert_eq!(ex_code_bytes, expected_ex_code_bytes); assert!(schema.field(EX_ADD_FACTORS_COLUMN).is_some()); assert!(schema.field(EX_SCALE_FACTORS_COLUMN).is_some()); From da7f2ddd347fb3f21f0c0be1f17169f2ce025dd4 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 11 Jun 2026 17:26:13 +0800 Subject: [PATCH 3/3] perf(vector)!: skip RaBitQ ex-FastScan artifacts for gated indexes The FastScan ex bulk path is only reachable when lower-bound gating is disabled (legacy indexes without error factors); gated indexes rerank per candidate with the ex-dot kernels. Build its artifacts accordingly: - Compute the u8 ex LUT directly from the rotated query (the per-dim table is the pure multiplication q[d] * code), removing the dim * 2^ex_bits f32 table from the query context, the calculator, and the scratch layout entirely. This also speeds the bypass path itself up by 3-6%. - Skip the FastScan transpose (packed_ex_codes) when error factors are present, saving one resident copy of the ex codes and the per-load transpose for every fresh index. Bulk scoring on gated indexes falls through to the exact per-row kernels. The LUT bulk path stays for legacy indexes: at dim=1536 it scores 2.5x (ex_bits=4) to 7x (ex_bits=2) faster per row than the kernel loop. Co-Authored-By: Claude Fable 5 --- rust/lance-index/benches/rq.rs | 159 +++++++++++++++- rust/lance-index/src/vector/bq/storage.rs | 210 +++++++++------------- rust/lance-index/src/vector/storage.rs | 3 - rust/lance/src/index/vector/ivf/v2.rs | 35 ++-- 4 files changed, 256 insertions(+), 151 deletions(-) diff --git a/rust/lance-index/benches/rq.rs b/rust/lance-index/benches/rq.rs index 088927a54da..e29ce9c4695 100644 --- a/rust/lance-index/benches/rq.rs +++ b/rust/lance-index/benches/rq.rs @@ -355,9 +355,166 @@ fn ex_code_storage_load(c: &mut Criterion) { ); } +/// Bulk-scoring cost of the ex stage: the quantized ex-FastScan LUT path +/// (inside `distance_all`) vs the exact per-row ex-dot kernel. The +/// binary-only run isolates the shared binary stage so the ex cost is the +/// difference from the full run. +fn ex_bulk_paths(c: &mut Criterion) { + use arrow_array::{ArrayRef, FixedSizeListArray, Float32Array, UInt8Array, UInt64Array}; + use lance_arrow::FixedSizeListArrayExt; + use lance_index::vector::ApproxMode; + use lance_index::vector::bq::ex_dot::pad_query_into; + use lance_index::vector::bq::transform::{EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN}; + use lance_index::vector::storage::DistanceCalculatorOptions; + use std::sync::Arc; + + const BULK_DIM: usize = 1536; + const BULK_ROWS: usize = 16384; + + let mut rng = SmallRng::seed_from_u64(13); + for num_bits in [3u8, 5, 9] { + let ex_bits = num_bits - 1; + let max_code = ((1u16 << ex_bits) - 1) as u8; + + let rq = RabitQuantizer::new_with_rotation::( + num_bits, + BULK_DIM as i32, + RQRotationType::Fast, + ); + let metadata = rq.metadata(None); + + let code_len = BULK_DIM / 8; + let binary_codes = (0..BULK_ROWS * code_len) + .map(|_| rng.random_range(0..=u8::MAX)) + .collect::>(); + let ex_code_len = blocked_ex_code_bytes(BULK_DIM, ex_bits); + let mut ex_codes = vec![0u8; BULK_ROWS * ex_code_len]; + let values = (0..BULK_DIM) + .map(|_| rng.random_range(0..=max_code)) + .collect::>(); + for row in ex_codes.chunks_exact_mut(ex_code_len) { + pack_blocked_row(&values, ex_bits, row); + } + + // No error factors: `distance_all` takes the FastScan ex bulk branch. + let batch = arrow_array::RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..BULK_ROWS as u64)) as ArrayRef, + ), + ( + RABIT_CODE_COLUMN, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(binary_codes), + code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; BULK_ROWS])) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; BULK_ROWS])) as ArrayRef, + ), + ( + RABIT_BLOCKED_EX_CODE_COLUMN, + Arc::new( + FixedSizeListArray::try_new_from_values( + UInt8Array::from(ex_codes.clone()), + ex_code_len as i32, + ) + .unwrap(), + ) as ArrayRef, + ), + ( + EX_ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![0.0f32; BULK_ROWS])) as ArrayRef, + ), + ( + EX_SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from(vec![1.0f32; BULK_ROWS])) as ArrayRef, + ), + ]) + .unwrap(); + let storage = + RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) + .unwrap(); + + let query: ArrayRef = Arc::new(Float32Array::from( + (0..BULK_DIM) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect::>(), + )); + + for (label, approx_mode) in [ + ("full distance_all (binary + ex LUT)", ApproxMode::Normal), + ("binary-only distance_all (fast mode)", ApproxMode::Fast), + ] { + let mut f32_scratch = Vec::new(); + let calc = storage.dist_calculator_with_scratch( + query.clone(), + 0.0, + None, + &mut f32_scratch, + DistanceCalculatorOptions { approx_mode }, + ); + let mut dists = Vec::new(); + let mut u16_scratch = Vec::new(); + let mut u8_scratch = Vec::new(); + let mut u32_scratch = Vec::new(); + c.bench_function( + format!("RQ bulk {label}: num_bits={num_bits}, DIM={BULK_DIM}, rows={BULK_ROWS}") + .as_str(), + |b| { + b.iter(|| { + calc.distance_all_with_scratch( + 0, + &mut dists, + &mut u16_scratch, + &mut u8_scratch, + &mut u32_scratch, + ); + black_box(dists.len()) + }) + }, + ); + } + + let kernel = ex_dot_kernel(ex_bits); + let mut ex_query = vec![0.0f32; BULK_DIM]; + pad_query_into( + query + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &mut ex_query, + ); + c.bench_function( + format!( + "RQ bulk ex kernel loop: num_bits={num_bits}, DIM={BULK_DIM}, rows={BULK_ROWS}" + ) + .as_str(), + |b| { + b.iter(|| { + let mut sum = 0.0f32; + for row in ex_codes.chunks_exact(ex_code_len) { + sum += kernel(&ex_query, row); + } + black_box(sum) + }) + }, + ); + } +} + criterion_group!( name=benches; config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = construct_dist_table, compute_distances, ex_dot_kernels, ex_code_storage_load); + targets = construct_dist_table, compute_distances, ex_dot_kernels, ex_code_storage_load, ex_bulk_paths); criterion_main!(benches); diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index 6e47fe42e25..36e56986921 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -360,14 +360,6 @@ impl RabitQuantizationMetadata { let code_dim = self.code_dim(); let ex_bits = rabit_ex_bits(self.num_bits)?; let dist_table_len = code_dim * 4; - // The quantized ex dist table is only consumed by the FastScan bulk - // path; the exact rerank path multiplies the query against the packed - // codes directly (see `ex_dot`). - let ex_dist_table_len = if supports_ex_fastscan(ex_bits) { - code_dim * (1usize << ex_bits) - } else { - 0 - }; let mut rotated_query = vec![0.0; code_dim]; self.rotate_vector_with_residual_into(query, None, &mut rotated_query); @@ -375,9 +367,6 @@ impl RabitQuantizationMetadata { let mut dist_table = vec![0.0; dist_table_len]; build_dist_table_direct_into::(&rotated_query, &mut dist_table); - let mut ex_dist_table = vec![0.0; ex_dist_table_len]; - build_ex_dist_table_direct_into(&rotated_query, ex_bits, &mut ex_dist_table); - // The kernels consume the rotated query directly; a zero-padded copy // is only needed when the rotated dim is not block-aligned. let mut ex_query = Vec::new(); @@ -392,7 +381,6 @@ impl RabitQuantizationMetadata { ex_bits, rotated_query, dist_table, - ex_dist_table, ex_query, sum_q, }) @@ -587,7 +575,6 @@ impl RabitQuantizationStorage { let RabitDistCalculatorParts { dim, dist_table, - ex_dist_table, ex_query, sum_q, query_factor, @@ -612,7 +599,6 @@ impl RabitQuantizationStorage { self.metadata.num_bits, self.metadata.query_estimator, dist_table, - ex_dist_table, ex_query, sum_q, self.codes.values().as_primitive::().values(), @@ -802,7 +788,6 @@ fn copy_subtract_f32(lhs: &[f32], rhs: &[f32], output: &mut [f32]) { struct RabitDistCalculatorParts<'a> { dim: usize, dist_table: Cow<'a, [f32]>, - ex_dist_table: Cow<'a, [f32]>, ex_query: Cow<'a, [f32]>, sum_q: f32, query_factor: f32, @@ -835,10 +820,8 @@ pub struct RabitDistCalculator<'a> { // we split the query codes into d/4 chunks, each chunk is with 4 elements, // then dist_table[i][j] is the distance between the i-th query code and the code j dist_table: Cow<'a, [f32]>, - // only built for the ex widths supported by FastScan; the exact rerank - // path uses `ex_query` + `ex_dot` instead - ex_dist_table: Cow<'a, [f32]>, - // the rotated query, zero-padded to a 64-dim multiple when needed + // the rotated query, zero-padded to a 64-dim multiple when needed; also + // the source for the FastScan ex LUT on the legacy bypass path ex_query: Cow<'a, [f32]>, ex_dot: Option, add_factors: &'a [f32], @@ -862,7 +845,6 @@ impl<'a> RabitDistCalculator<'a> { num_bits: u8, query_estimator: RabitQueryEstimator, dist_table: Cow<'a, [f32]>, - ex_dist_table: Cow<'a, [f32]>, ex_query: Cow<'a, [f32]>, sum_q: f32, codes: &'a [u8], @@ -887,7 +869,6 @@ impl<'a> RabitDistCalculator<'a> { ex_codes, ex_code_len, dist_table, - ex_dist_table, ex_query, ex_dot, add_factors, @@ -1112,10 +1093,9 @@ impl<'a> RabitDistCalculator<'a> { let fastscan_len = simd_len; let fastscan_code_len = self.ex_code_len; let (qmin, qmax, quantization_max) = quantize_ex_fastscan_dist_table_into( - self.dim, ex_bits, self.ex_code_len, - &self.ex_dist_table, + self.ex_query.as_ref(), quantized_dists_table, ); quantized_dists.clear(); @@ -1329,35 +1309,6 @@ where dist_table } -fn build_ex_dist_table_direct(rotated_query: &[f32], ex_bits: u8) -> Vec { - if ex_bits == 0 { - return Vec::new(); - } - let entries_per_dim = 1usize << ex_bits; - let mut dist_table = vec![0.0; rotated_query.len() * entries_per_dim]; - build_ex_dist_table_direct_into(rotated_query, ex_bits, &mut dist_table); - dist_table -} - -fn build_ex_dist_table_direct_into(rotated_query: &[f32], ex_bits: u8, dist_table: &mut [f32]) { - // The table may legitimately be empty for multi-bit widths without - // FastScan support; the exact path uses the ex-dot kernels instead. - if ex_bits == 0 || dist_table.is_empty() { - debug_assert!(dist_table.is_empty()); - return; - } - let entries_per_dim = 1usize << ex_bits; - debug_assert_eq!(dist_table.len(), rotated_query.len() * entries_per_dim); - for (query_value, table) in rotated_query - .iter() - .zip(dist_table.chunks_exact_mut(entries_per_dim)) - { - for (code, value) in table.iter_mut().enumerate() { - *value = *query_value * code as f32; - } - } -} - fn build_dist_table_direct_into(qc: &[T::Native], dist_table: &mut [f32]) where T::Native: AsPrimitive, @@ -1456,17 +1407,18 @@ fn quantize_dist_table_u16_into( (qmin, qmax) } +/// Build the u8 FastScan LUT for the ex codes directly from the rotated +/// query (`ex_query`, natural dim order, padding dims zero): the underlying +/// per-dim table is the pure multiplication `q[d] * code`, so no intermediate +/// `dim * 2^ex_bits` table is materialized. fn quantize_ex_fastscan_dist_table_into( - dim: usize, ex_bits: u8, ex_code_len: usize, - ex_dist_table: &[f32], + ex_query: &[f32], quantized_dist_table: &mut Vec, ) -> (f32, f32, f32) { debug_assert!(supports_ex_fastscan(ex_bits)); - let entries_per_dim = 1usize << ex_bits; - debug_assert_eq!(ex_dist_table.len(), dim * entries_per_dim); // One split table per code nibble of the row. let num_split_tables = ex_code_len * 2; let quantization_max = (u16::MAX as usize / num_split_tables) @@ -1477,7 +1429,7 @@ fn quantize_ex_fastscan_dist_table_into( let mut qmax = f32::NEG_INFINITY; for table_idx in 0..num_split_tables { for code in 0..SEGMENT_NUM_CODES { - let value = ex_fastscan_dist_table_value(dim, ex_bits, ex_dist_table, table_idx, code); + let value = ex_fastscan_dist_table_value(ex_query, ex_bits, table_idx, code); qmin = qmin.min(value); qmax = qmax.max(value); } @@ -1493,7 +1445,7 @@ fn quantize_ex_fastscan_dist_table_into( let factor = quantization_max / (qmax - qmin); for table_idx in 0..num_split_tables { for code in 0..SEGMENT_NUM_CODES { - let value = ex_fastscan_dist_table_value(dim, ex_bits, ex_dist_table, table_idx, code); + let value = ex_fastscan_dist_table_value(ex_query, ex_bits, table_idx, code); quantized_dist_table.push(((value - qmin) * factor).round() as u8); } } @@ -1509,15 +1461,16 @@ fn supports_ex_fastscan(ex_bits: u8) -> bool { /// The FastScan LUT value for one nibble of a blocked-layout code byte: /// `table_idx / 2` is the byte position within a row and `table_idx % 2` /// selects its low/high nibble (see the `ex_dot` module docs for the -/// byte-to-dim mapping per width). Padding dims contribute zero. +/// byte-to-dim mapping per width). Dims beyond the query length (block +/// padding) contribute zero. #[inline] fn ex_fastscan_dist_table_value( - dim: usize, + ex_query: &[f32], ex_bits: u8, - ex_dist_table: &[f32], table_idx: usize, code: usize, ) -> f32 { + let query = |dim_idx: usize| ex_query.get(dim_idx).copied().unwrap_or(0.0); let byte_idx = table_idx / 2; let high_nibble = table_idx % 2 == 1; match ex_bits { @@ -1525,10 +1478,9 @@ fn ex_fastscan_dist_table_value( // byte 16g+b = dims {64g+b, +16, +32, +48} at bit pairs; the low // nibble covers the first two dims, the high nibble the last two. let dim_idx = 64 * (byte_idx / 16) + byte_idx % 16 + 32 * usize::from(high_nibble); - let low = code & 0b11; - let high = (code >> 2) & 0b11; - ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, low) - + ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx + 16, high) + let low = (code & 0b11) as f32; + let high = ((code >> 2) & 0b11) as f32; + query(dim_idx) * low + query(dim_idx + 16) * high } 4 => { // byte 32g+8j+b = dim 64g+16j+b (low nibble) | dim +8 (high). @@ -1537,46 +1489,34 @@ fn ex_fastscan_dist_table_value( + 16 * (in_block / 8) + in_block % 8 + 8 * usize::from(high_nibble); - ex_dist_table_value(ex_dist_table, dim, ex_bits, dim_idx, code) + query(dim_idx) * code as f32 } 8 => { // byte = dim identity; the high nibble carries code bits 4..8. - if high_nibble { - ex_dist_table_value( - ex_dist_table, - dim, - ex_bits, - byte_idx, - code << SEGMENT_LENGTH, - ) + let code = if high_nibble { + code << SEGMENT_LENGTH } else { - ex_dist_table_value(ex_dist_table, dim, ex_bits, byte_idx, code) - } + code + }; + query(byte_idx) * code as f32 } _ => unreachable!("unsupported RabitQ ex_bits={ex_bits} for FastScan"), } } -#[inline] -fn ex_dist_table_value( - ex_dist_table: &[f32], - dim: usize, - ex_bits: u8, - dim_idx: usize, - code: usize, -) -> f32 { - if dim_idx >= dim { - return 0.0; - } - let entries_per_dim = 1usize << ex_bits; - ex_dist_table[dim_idx * entries_per_dim + code] -} - +/// Transpose ex codes for the FastScan bulk path. That path is only reachable +/// when lower-bound gating is disabled, i.e. for legacy indexes without error +/// factors; gated indexes rerank per candidate with the ex-dot kernels and +/// never touch this copy, so skip the transpose (and its resident memory). fn maybe_pack_ex_codes( ex_codes: Option<&FixedSizeListArray>, ex_bits: u8, + error_factors: Option<&Float32Array>, ) -> Option { let ex_codes = ex_codes?; + if error_factors.is_some() { + return None; + } match ex_bits { 2 | 4 | 8 => Some(pack_codes(ex_codes)), _ => None, @@ -1977,12 +1917,6 @@ impl VectorStore for RabitQuantizationStorage { let code_dim = self.code_dim(); let rotated_qr = self.rotate_query_vector(code_dim, &qr); let dist_table = build_dist_table_direct::(&rotated_qr); - let ex_bits = self.metadata.num_bits - 1; - let ex_dist_table = if supports_ex_fastscan(ex_bits) { - build_ex_dist_table_direct(&rotated_qr, ex_bits) - } else { - Vec::new() - }; let query_factor = match self.metadata.query_estimator { RabitQueryEstimator::ResidualQuery => self.residual_query_factor(dist_q_c), RabitQueryEstimator::RawQuery => self.raw_query_factor(dist_q_c, &rotated_qr, None), @@ -2007,7 +1941,6 @@ impl VectorStore for RabitQuantizationStorage { self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, dist_table: Cow::Owned(dist_table), - ex_dist_table: Cow::Owned(ex_dist_table), ex_query: Cow::Owned(ex_query), sum_q, query_factor, @@ -2047,7 +1980,6 @@ impl VectorStore for RabitQuantizationStorage { return self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, dist_table: Cow::Borrowed(&raw_query.dist_table), - ex_dist_table: Cow::Borrowed(&raw_query.ex_dist_table), ex_query: Cow::Borrowed(kernel_query( &raw_query.rotated_query, &raw_query.ex_query, @@ -2061,13 +1993,6 @@ impl VectorStore for RabitQuantizationStorage { let dist_table_len = code_dim * 4; let ex_bits = self.metadata.num_bits - 1; - // Only the FastScan bulk path consumes the quantized ex dist table; - // the exact rerank path uses the kernel-order query instead. - let ex_dist_table_len = if supports_ex_fastscan(ex_bits) { - code_dim * (1usize << ex_bits) - } else { - 0 - }; // The kernels read the rotated query in place; a zero-padded copy is // only needed when the rotated dim is not block-aligned. let ex_query_table_len = if ex_bits == 0 || code_dim.is_multiple_of(EX_DOT_BLOCK_DIMS) { @@ -2075,17 +2000,13 @@ impl VectorStore for RabitQuantizationStorage { } else { padded_query_len(code_dim) }; - f32_scratch.resize( - code_dim + dist_table_len + ex_dist_table_len + ex_query_table_len, - 0.0, - ); + f32_scratch.resize(code_dim + dist_table_len + ex_query_table_len, 0.0); let query_factor; let query_error; let sum_q = { let (rotated_qr, remaining) = f32_scratch.split_at_mut(code_dim); - let (dist_table, remaining) = remaining.split_at_mut(dist_table_len); - let (ex_dist_table, ex_query) = remaining.split_at_mut(ex_dist_table_len); + let (dist_table, ex_query) = remaining.split_at_mut(dist_table_len); match residual { Some(QueryResidual::Centroid(residual_centroid)) => { self.rotate_query_vector_into( @@ -2124,19 +2045,16 @@ impl VectorStore for RabitQuantizationStorage { } }; build_dist_table_direct_into::(rotated_qr, dist_table); - build_ex_dist_table_direct_into(rotated_qr, ex_bits, ex_dist_table); if ex_query_table_len > 0 { pad_query_into(rotated_qr, ex_query); } rotated_qr.iter().copied().sum() }; - let ex_dist_table_start = code_dim + dist_table_len; - let ex_query_start = ex_dist_table_start + ex_dist_table_len; + let ex_query_start = code_dim + dist_table_len; self.distance_calculator_from_parts(RabitDistCalculatorParts { dim: code_dim, - dist_table: Cow::Borrowed(&f32_scratch[code_dim..ex_dist_table_start]), - ex_dist_table: Cow::Borrowed(&f32_scratch[ex_dist_table_start..ex_query_start]), + dist_table: Cow::Borrowed(&f32_scratch[code_dim..ex_query_start]), ex_query: Cow::Borrowed(kernel_query( &f32_scratch[..code_dim], &f32_scratch[ex_query_start..ex_query_start + ex_query_table_len], @@ -2405,7 +2323,8 @@ impl QuantizerStorage for RabitQuantizationStorage { let mut metadata = metadata.clone(); metadata.packed = true; - let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); + let packed_ex_codes = + maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits, error_factors.as_ref()); Ok(Self { metadata, @@ -2498,7 +2417,8 @@ impl QuantizerStorage for RabitQuantizationStorage { load_blocked_ex_codes(batch, self.metadata.rotated_dim(), self.metadata.num_bits)?; (batch, Some(codes)) }; - let packed_ex_codes = maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits); + let packed_ex_codes = + maybe_pack_ex_codes(ex_codes.as_ref(), ex_bits, error_factors.as_ref()); let ex_add_factors = batch .column_by_name(EX_ADD_FACTORS_COLUMN) .map(|factors| factors.as_primitive::().clone()); @@ -3425,7 +3345,11 @@ mod tests { assert_eq!(hacc_accum_len, num_rows); } - fn assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits: u8, legacy_format: bool) { + fn assert_raw_query_multi_bit_distance_all_uses_fastscan( + num_bits: u8, + legacy_format: bool, + with_error_factors: bool, + ) { // Not a multiple of 64, so the padded-tail LUT entries are exercised; // a multiple of 8 as the binary stage requires. let code_dim = 72usize; @@ -3512,10 +3436,23 @@ mod tests { ), ]) .unwrap(); + let batch = if with_error_factors { + batch + .try_with_column( + crate::vector::bq::transform::ERROR_FACTORS_FIELD.clone(), + Arc::new(Float32Array::from(vec![1000.0; num_rows])) as ArrayRef, + ) + .unwrap() + } else { + batch + }; let storage = RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None) .unwrap(); - assert!(storage.packed_ex_codes.is_some()); + // The FastScan transpose only exists for indexes that can reach the + // bulk bypass path (no error factors); gated indexes fall through to + // the exact per-row kernels in `distance_all`. + assert_eq!(storage.packed_ex_codes.is_some(), !with_error_factors); // A per-dim varying query so that any dim-mapping error in the // FastScan LUT shows up as a value mismatch. @@ -3539,7 +3476,13 @@ mod tests { assert_eq!(distances.len(), num_rows); assert_eq!(u16_scratch.len(), BATCH_SIZE); let loaded_ex_code_len = storage.ex_codes.as_ref().unwrap().value_length() as usize; - assert_eq!(u8_scratch.len(), loaded_ex_code_len * 2 * SEGMENT_NUM_CODES); + if with_error_factors { + // The gated path never builds the ex LUT; the scratch holds the + // binary LUT only. + assert_eq!(u8_scratch.len(), code_dim * 4); + } else { + assert_eq!(u8_scratch.len(), loaded_ex_code_len * 2 * SEGMENT_NUM_CODES); + } // The fastscan estimate differs from the exact path only by the u8 // quantization of the binary LUT (amplified by 2^ex_bits) and of the @@ -3561,16 +3504,22 @@ mod tests { let code_scale = (1u32 << ex_bits) as f32; let binary_bound = code_scale * code_dim.div_ceil(4) as f32 * (table_max - table_min) / 510.0; - let ex_dist_table = build_ex_dist_table_direct(&query_values, ex_bits); + let mut padded_query = vec![0.0f32; crate::vector::bq::ex_dot::padded_query_len(code_dim)]; + crate::vector::bq::ex_dot::pad_query_into(&query_values, &mut padded_query); let mut quantized_table = Vec::new(); let (ex_qmin, ex_qmax, ex_qcap) = quantize_ex_fastscan_dist_table_into( - code_dim, ex_bits, loaded_ex_code_len, - &ex_dist_table, + &padded_query, &mut quantized_table, ); - let ex_bound = (loaded_ex_code_len * 2) as f32 * (ex_qmax - ex_qmin) / ex_qcap / 2.0; + // Without the FastScan transpose the ex stage is exact, so only the + // binary LUT quantization remains. + let ex_bound = if with_error_factors { + 0.0 + } else { + (loaded_ex_code_len * 2) as f32 * (ex_qmax - ex_qmin) / ex_qcap / 2.0 + }; let bound = (binary_bound + ex_bound) * 1.5 + 1e-3; for (id, distance) in distances.iter().take(BATCH_SIZE).enumerate() { let exact = calc.distance(id as u32); @@ -3586,8 +3535,15 @@ mod tests { fn test_raw_query_multi_bit_distance_all_uses_fastscan_for_split_ex_codes() { for num_bits in [3, 5, 9] { for legacy_format in [false, true] { - assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits, legacy_format); + assert_raw_query_multi_bit_distance_all_uses_fastscan( + num_bits, + legacy_format, + false, + ); } + // Gated indexes (with error factors) skip the FastScan artifacts + // and score the bulk path with the exact kernels. + assert_raw_query_multi_bit_distance_all_uses_fastscan(num_bits, false, true); } } diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index 9a048b5ebbf..4cb61dc9e49 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -249,9 +249,6 @@ pub struct RabitRawQueryContext { pub ex_bits: u8, pub rotated_query: Vec, pub dist_table: Vec, - /// Quantized-table input for the FastScan ex path; empty for ex widths - /// without FastScan support. - pub ex_dist_table: Vec, /// The rotated query zero-padded to a 64-dim multiple for the ex-dot /// kernels; empty when `code_dim` is already aligned (the kernels then /// read `rotated_query` directly). diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 885c9f6a5e5..d169cf08190 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -153,22 +153,18 @@ fn rotated_partition_centroid_slice( cache.rotated_centroids.get(start..end) } -/// `f32` scratch needed for the ex-bit query state: the quantized-table input -/// for FastScan-supported widths, plus a zero-padded query copy when the -/// rotated dim is not a multiple of the 64-dim kernel block. +/// `f32` scratch needed for the ex-bit query state: a zero-padded query copy +/// when the rotated dim is not a multiple of the 64-dim kernel block (the +/// FastScan ex LUT is built directly from the query, with no f32 table). fn rabit_ex_scratch_len(dim: usize, num_bits: u8) -> usize { - let padded_query = if dim.is_multiple_of(64) { + let multi_bit = rabit_ex_bits(num_bits) + .map(|ex_bits| ex_bits > 0) + .unwrap_or(true); + if !multi_bit || dim.is_multiple_of(64) { 0 } else { padded_query_len(dim) - }; - rabit_ex_bits(num_bits) - .map(|ex_bits| match ex_bits { - 0 => 0, - 2 | 4 | 8 => dim * (1usize << usize::from(ex_bits)) + padded_query, - _ => padded_query, - }) - .unwrap_or(dim * 256 + padded_query) + } } fn rabit_u8_scratch_len(dim: usize, num_bits: u8) -> usize { @@ -1992,16 +1988,15 @@ mod tests { #[test] fn test_rabit_ex_scratch_len_uses_num_bits() { - // 960 is block-aligned, so no padded query copy is needed. + // Block-aligned dims read the rotated query in place. let dim = 960; - assert_eq!(super::rabit_ex_scratch_len(dim, 1), 0); - assert_eq!(super::rabit_ex_scratch_len(dim, 3), dim * 4); - assert_eq!(super::rabit_ex_scratch_len(dim, 5), dim * 16); - assert_eq!(super::rabit_ex_scratch_len(dim, 7), 0); - assert_eq!(super::rabit_ex_scratch_len(dim, 9), dim * 256); + for num_bits in [1, 3, 5, 7, 9] { + assert_eq!(super::rabit_ex_scratch_len(dim, num_bits), 0); + } - // Unaligned dims add one padded query copy. + // Unaligned multi-bit queries add one padded query copy. let dim = 968; + assert_eq!(super::rabit_ex_scratch_len(dim, 1), 0); assert_eq!(super::rabit_ex_scratch_len(dim, 7), padded_query_len(dim)); } @@ -2024,7 +2019,7 @@ mod tests { let capacity = super::rabit_query_scratch_capacity(dim, max_partition_len, 5); assert_eq!(capacity.distances, max_partition_len); - assert_eq!(capacity.query_f32, dim + dim * 4 + dim * 16); + assert_eq!(capacity.query_f32, dim + dim * 4); assert_eq!(capacity.u16, max_partition_len); assert_eq!(capacity.u8, dim * 16); assert_eq!(capacity.u32, 0);