Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ generation → `RankQuant` rerank) and the full mode comparison, see
For runtimes that own their own parallelism — an embedded vector DB driving a
bounded thread pool, or a binding releasing the GIL — ordvec exposes a
**no-rayon** serial two-stage path so the *caller* schedules the work, with an
**allocation-free rerank step** (`_into`, on the AVX-512/AVX2 path) for the
steady-state hot loop:
**allocation-free rerank step** (`_into`) for the steady-state hot loop:

```rust
use ordvec::{RankQuant, SignBitmap, SubsetScratch};
Expand All @@ -235,12 +234,11 @@ Contract: candidates are **CSR** (`offsets.len() == nq + 1`; row `qi` is
underfull rows — size both buffers to `nq * k.min(index.len())`. Scores, row ids,
and the deterministic tie policy (`score desc, global row-id asc`) match the
single-query `search_asymmetric_subset`. **Only the `_into` rerank step is
allocation-free** — on the **AVX-512 / AVX2** SIMD path, and only on repeated
calls of the *same* batch shape — reusing the warmed `SubsetScratch` and your
output buffers (no per-row alloc, no whole-buffer preclear). The scalar fallback
(no AVX2, e.g. aarch64) allocates a per-query scoring LUT. Stage 1
(`top_m_candidates_batched_serial_csr`) also allocates a fresh `CandidateBatch`
each call. Neither primitive enters rayon —
allocation-free** — SIMD or scalar — on repeated calls of the *same* batch
shape, reusing the warmed `SubsetScratch` and your output buffers (no per-row
alloc, no scalar-LUT alloc, no whole-buffer preclear). Stage 1
(`top_m_candidates_batched_serial_csr`) still allocates a fresh
`CandidateBatch` each call. Neither primitive enters rayon —
partition the query batch and call `_into` once per worker range from your own
pool. A focused decomposition benchmark lives in
[`examples/two_stage_bench.rs`](examples/two_stage_bench.rs).
Expand Down
90 changes: 79 additions & 11 deletions src/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
use rayon::prelude::*;

use crate::quant_kernels::{
scan_b1_to_topk, scan_b2_to_topk, scan_b4_to_topk, scan_b8_asym, scan_b8_to_topk,
scan_via_lut_scalar,
scan_b1_to_topk, scan_b2_to_topk, scan_b4_to_topk, scan_b8_asym, scan_b8_asym_with_lut,
scan_b8_to_topk, scan_via_lut_scalar, scan_via_lut_scalar_with_lut,
};
#[cfg(target_arch = "x86_64")]
use crate::quant_kernels::{
Expand All @@ -48,6 +48,7 @@ use crate::{validate_candidate_ids, OrdvecError, SearchResults};
pub struct SubsetScratch {
q_unit: Vec<f32>,
sub_packed: Vec<u8>,
scalar_lut: Vec<f32>,
top: TopK,
local_indices: Vec<i64>,
final_order: Vec<(f32, i64, i64, usize)>,
Expand All @@ -58,6 +59,7 @@ impl Default for SubsetScratch {
Self {
q_unit: Vec::new(),
sub_packed: Vec::new(),
scalar_lut: Vec::new(),
top: TopK::new(0),
local_indices: Vec::new(),
final_order: Vec::new(),
Expand Down Expand Up @@ -318,10 +320,10 @@ fn select_simd_tier(dim: usize, bits: u8) -> SimdTier {
///
/// Returns `true` when the asymmetric subset rerank takes a SIMD kernel (vs the
/// scalar LUT fallback) for a **constructor-valid** `(dim, bits)` on this CPU.
/// The scalar fallback allocates a per-query LUT, so the allocation-free
/// steady-state guarantee of
/// [`RankQuant::search_asymmetric_subset_batched_serial_into`] holds exactly
/// when this is `true`.
/// The allocation-free tests use this to force coverage of both dispatch
/// families; the steady-state allocation-free guarantee of
/// [`RankQuant::search_asymmetric_subset_batched_serial_into`] applies after
/// the caller-provided [`SubsetScratch`] is warmed.
///
/// Returns `false` for any `(dim, bits)` that [`RankQuant::new`] would reject,
/// so it answers "the rerank will take a SIMD kernel" rather than acting as a
Expand Down Expand Up @@ -1160,13 +1162,14 @@ impl RankQuant {
// The tie keys on `scratch.top` still map local scratch positions →
// global row IDs exactly as for b ∈ {1,2,4}.
if bits == 8 {
scan_b8_asym(
scan_b8_asym_with_lut(
&scratch.sub_packed,
m,
dim,
&scratch.q_unit,
inv_norm,
&mut scratch.top,
&mut scratch.scalar_lut,
);
} else {
#[cfg(target_arch = "x86_64")]
Expand Down Expand Up @@ -1216,7 +1219,7 @@ impl RankQuant {
&mut scratch.top,
);
}
_ => scan_via_lut_scalar(
_ => scan_via_lut_scalar_with_lut(
&scratch.sub_packed,
m,
dim,
Expand All @@ -1225,11 +1228,12 @@ impl RankQuant {
&scratch.q_unit,
inv_norm,
&mut scratch.top,
&mut scratch.scalar_lut,
),
}
}
#[cfg(not(target_arch = "x86_64"))]
scan_via_lut_scalar(
scan_via_lut_scalar_with_lut(
&scratch.sub_packed,
m,
dim,
Expand All @@ -1238,6 +1242,7 @@ impl RankQuant {
&scratch.q_unit,
inv_norm,
&mut scratch.top,
&mut scratch.scalar_lut,
);
}

Expand All @@ -1259,8 +1264,8 @@ impl RankQuant {
}

/// Serial (NO rayon) batched subset rerank into caller-owned buffers.
/// Allocation-free after `scratch` warmup **on the SIMD rerank path
/// (AVX-512 / AVX2)**; the scalar fallback allocates a per-query scoring LUT.
/// Allocation-free after `scratch` warmup; both SIMD and scalar rerank
/// paths reuse caller-owned scratch buffers, including the scalar LUT.
/// The integration contract for runtimes that own their own parallelism
/// (call this from a bounded pool, with the GIL released, one row range per
/// worker is the caller's choice).
Expand Down Expand Up @@ -1761,3 +1766,66 @@ pub fn search_asymmetric_byte_lut(index: &RankQuant, queries: &[f32], k: usize)
k,
}
}

#[cfg(test)]
mod tests {
use super::*;

fn corpus(rows: usize, dim: usize) -> Vec<f32> {
let mut out = Vec::with_capacity(rows * dim);
for row in 0..rows {
for col in 0..dim {
out.push((((row + 3) * (col + 5)) % 23) as f32 - 11.0);
}
}
out
}

#[test]
fn scalar_lut_scratch_reuses_capacity_after_warmup() {
let dim = 64usize;
let rows = 16usize;
let mut index = RankQuant::new(dim, 1);
let corpus = corpus(rows, dim);
index.add(&corpus);

let nq = 2usize;
let queries = corpus[..nq * dim].to_vec();
let candidates: Vec<u32> = (0..rows as u32).chain(0..rows as u32).collect();
let candidate_offsets = vec![0usize, rows, rows * 2];
let k = 4usize;
let mut scores = vec![0.0f32; nq * k];
let mut indices = vec![0i64; nq * k];
let mut scratch = SubsetScratch::new();

index.search_asymmetric_subset_batched_serial_into(
&queries,
&candidate_offsets,
&candidates,
k,
&mut scratch,
&mut scores,
&mut indices,
);
let scalar_lut_capacity = scratch.scalar_lut.capacity();
assert!(
scalar_lut_capacity >= dim * 2,
"b=1 scalar LUT should reserve one row per coordinate and bucket"
);

index.search_asymmetric_subset_batched_serial_into(
&queries,
&candidate_offsets,
&candidates,
k,
&mut scratch,
&mut scores,
&mut indices,
);
assert_eq!(
scratch.scalar_lut.capacity(),
scalar_lut_capacity,
"scalar LUT scratch must reuse capacity after warmup"
);
}
}
94 changes: 71 additions & 23 deletions src/quant_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,48 @@ pub(crate) fn scan_via_lut_scalar(
scale: f32,
top: &mut TopK,
) {
let mut lut = vec![0.0f32; dim * n_buckets];
for d in 0..dim {
for b in 0..n_buckets {
lut[d * n_buckets + b] = q_unit[d] * bucket_centre(b as u8, bits);
let mut lut = Vec::new();
scan_via_lut_scalar_with_lut(
packed, n, dim, bits, n_buckets, q_unit, scale, top, &mut lut,
);
}

pub(crate) fn build_asym_lut_into(
lut: &mut Vec<f32>,
dim: usize,
bits: u8,
n_buckets: usize,
q_unit: &[f32],
) {
assert_eq!(q_unit.len(), dim);
lut.resize(dim * n_buckets, 0.0);
for (&qd, row) in q_unit.iter().zip(lut.chunks_exact_mut(n_buckets)) {
for (b, slot) in row.iter_mut().enumerate() {
*slot = qd * bucket_centre(b as u8, bits);
}
}
}
Comment thread
Fieldnote-Echo marked this conversation as resolved.

/// Same scalar LUT scan as [`scan_via_lut_scalar`], but the caller supplies the
/// LUT buffer so hot paths can reuse capacity after warmup.
#[allow(clippy::too_many_arguments)] // kernel arity is intrinsic to the packed-scan signature
pub(crate) fn scan_via_lut_scalar_with_lut(
packed: &[u8],
n: usize,
dim: usize,
bits: u8,
n_buckets: usize,
q_unit: &[f32],
scale: f32,
top: &mut TopK,
lut: &mut Vec<f32>,
) {
build_asym_lut_into(lut, dim, bits, n_buckets, q_unit);
match bits {
1 => scan_b1_to_topk(packed, n, dim, &lut, scale, top),
2 => scan_b2_to_topk(packed, n, dim, &lut, scale, top),
4 => scan_b4_to_topk(packed, n, dim, &lut, scale, top),
8 => scan_b8_to_topk(packed, n, dim, &lut, scale, top),
1 => scan_b1_to_topk(packed, n, dim, lut, scale, top),
2 => scan_b2_to_topk(packed, n, dim, lut, scale, top),
4 => scan_b4_to_topk(packed, n, dim, lut, scale, top),
8 => scan_b8_to_topk(packed, n, dim, lut, scale, top),
_ => unreachable!("bits validated in new()"),
}
}
Expand Down Expand Up @@ -135,17 +166,14 @@ pub(crate) fn scan_b4_to_topk(
///
/// `bucket_centre(code, 8) = code - 127.5`, so each row is the query
/// coordinate scaled across the 256 centred bucket values.
pub(crate) fn build_b8_asym_lut(q_unit: &[f32]) -> Vec<f32> {
pub(crate) fn build_b8_asym_lut_into(lut: &mut Vec<f32>, q_unit: &[f32]) {
let dim = q_unit.len();
let mut lut = vec![0.0f32; dim * 256];
for d in 0..dim {
let qd = q_unit[d];
let row = &mut lut[d * 256..(d + 1) * 256];
lut.resize(dim * 256, 0.0);
for (&qd, row) in q_unit.iter().zip(lut.chunks_exact_mut(256)) {
for (code, slot) in row.iter_mut().enumerate() {
*slot = qd * bucket_centre(code as u8, 8);
}
}
lut
}
Comment thread
Fieldnote-Echo marked this conversation as resolved.

/// 8-bit scan. 1 code per byte; n_buckets = 256. The degenerate
Expand Down Expand Up @@ -555,7 +583,7 @@ pub(crate) unsafe fn scan_b4_asym_avx512(
/// Single entry point for the `b=8` asymmetric scan.
///
/// Builds the shared `dim * 256` per-coordinate LUT once
/// ([`build_b8_asym_lut`]), then dispatches to the AVX-512 gather kernel
/// ([`build_b8_asym_lut_into`]), then dispatches to the AVX-512 gather kernel
/// ([`scan_b8_asym_avx512_gather`]) when `avx512f` + `avx512bw` are detected at
/// runtime and `dim % 16 == 0`, falling back to the portable scalar reference
/// ([`scan_b8_to_topk`]) on every other target / CPU / dim. Centralising
Expand All @@ -569,7 +597,21 @@ pub(crate) fn scan_b8_asym(
scale: f32,
top: &mut TopK,
) {
let lut = build_b8_asym_lut(q_unit);
let mut lut = Vec::new();
scan_b8_asym_with_lut(packed, n, dim, q_unit, scale, top, &mut lut);
}

pub(crate) fn scan_b8_asym_with_lut(
packed: &[u8],
n: usize,
dim: usize,
q_unit: &[f32],
scale: f32,
top: &mut TopK,
lut: &mut Vec<f32>,
) {
assert_eq!(q_unit.len(), dim);
build_b8_asym_lut_into(lut, q_unit);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f")
Expand All @@ -583,12 +625,12 @@ pub(crate) fn scan_b8_asym(
// above). The explicit block is required by
// `#![deny(unsafe_op_in_unsafe_fn)]`.
unsafe {
scan_b8_asym_avx512_gather(packed, n, dim, &lut, scale, top);
scan_b8_asym_avx512_gather(packed, n, dim, lut, scale, top);
}
return;
}
}
scan_b8_to_topk(packed, n, dim, &lut, scale, top);
scan_b8_to_topk(packed, n, dim, lut, scale, top);
}

// -------------------------------------------------------------------
Expand Down Expand Up @@ -652,7 +694,7 @@ pub(crate) unsafe fn scan_b8_asym_avx512_gather(
// Hard backstop (see `scan_b2_asym_avx2`): mis-dispatch must fail
// loudly in release, not silently drop the trailing chunk.
assert_eq!(dim % 16, 0, "b=8 AVX-512 gather path needs dim % 16 == 0");
debug_assert_eq!(lut.len(), dim * 256, "b=8 LUT must be dim * 256 entries");
assert_eq!(lut.len(), dim * 256, "b=8 LUT must be dim * 256 entries");
let bytes_per_vec = dim; // one byte per coordinate
let lut_ptr = lut.as_ptr();

Expand Down Expand Up @@ -724,7 +766,7 @@ pub(crate) unsafe fn scan_b8_asym_avx512_gather(

#[cfg(all(test, target_arch = "x86_64"))]
mod b8_gather_tests {
use super::{build_b8_asym_lut, scan_b8_asym_avx512_gather, scan_b8_to_topk};
use super::{build_b8_asym_lut_into, scan_b8_asym_avx512_gather, scan_b8_to_topk};
use crate::util::TopK;
use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
Expand All @@ -739,6 +781,12 @@ mod b8_gather_tests {
(scores, idxs)
}

fn b8_lut(q_unit: &[f32]) -> Vec<f32> {
let mut lut = Vec::new();
build_b8_asym_lut_into(&mut lut, q_unit);
lut
}

/// The AVX-512 `vgatherdps` b=8 kernel must match the scalar LUT
/// reference within the crate's 1e-4 cross-backend score tolerance,
/// across the headline embedding dims (all `% 16 == 0`, so the gather
Expand All @@ -764,7 +812,7 @@ mod b8_gather_tests {
let q_unit: Vec<f32> = q.iter().map(|x| x / qn).collect();
let scale = 1.0f32 / 137.0; // arbitrary inv_norm-like scale

let lut = build_b8_asym_lut(&q_unit);
let lut = b8_lut(&q_unit);

let mut top_scalar = TopK::new(k);
scan_b8_to_topk(&packed, n, dim, &lut, scale, &mut top_scalar);
Expand Down Expand Up @@ -816,7 +864,7 @@ mod b8_gather_tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x00B8_FACE);
let packed: Vec<u8> = (0..n * dim).map(|_| rng.random::<u8>()).collect();
let q_unit: Vec<f32> = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect();
let lut = build_b8_asym_lut(&q_unit);
let lut = b8_lut(&q_unit);

let mut top = TopK::new(k);
// SAFETY: avx512f+avx512bw confirmed; dim % 16 == 0; shapes match.
Expand Down Expand Up @@ -886,7 +934,7 @@ mod b8_gather_tests {
// b=4 corpus: two codes per byte → dim/2 bytes per doc.
let packed4: Vec<u8> = (0..n * dim / 2).map(|_| rng.random::<u8>()).collect();

let lut8 = build_b8_asym_lut(&q_unit);
let lut8 = b8_lut(&q_unit);

let bench = |label: &str, mut f: Box<dyn FnMut()>| {
f(); // warmup
Expand Down
Loading
Loading