Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 117 additions & 8 deletions src/sign_bitmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
//! scalar path. See [`crate::avx512vpop_supported`].

use rayon::prelude::*;
use std::collections::BinaryHeap;

use crate::OrdvecError;

Expand Down Expand Up @@ -220,13 +221,112 @@ impl SignBitmap {
/// SIMD dispatch paths — same audit discipline as
/// [`crate::Bitmap::top_m_candidates`].
#[must_use = "this scans the corpus to generate candidates; dropping the result discards that work"]
/// Streamed exact top-m selection shared by [`Self::top_m_candidates`]
/// and [`Self::top_m_candidates_batched_serial_csr`]: the corpus is
/// scanned once per call in L2-sized doc blocks, each hot block is
/// scored against every query (in small query tiles), and per-query
/// bounded min-m collectors keyed by `(hamming, doc_id)` select exactly
/// the lexicographic top-m — bit-identical to a full sort, independent
/// of processing order. Serial by contract: no rayon.
fn top_m_candidates_streamed(&self, queries: &[f32], m_eff: usize) -> Vec<Vec<u32>> {
const TILE_QUERIES: usize = 32;
const BLOCK_BYTES: usize = 256 * 1024;

let dim = self.dim;
debug_assert!(
queries.len().is_multiple_of(dim),
"queries buffer must be a whole number of rows"
);
let nq = queries.len() / dim;
let qpv = self.qwords_per_vec;
Comment thread
qodo-code-review[bot] marked this conversation as resolved.
let n = self.n_vectors;
debug_assert!(m_eff >= 1 && m_eff <= n);

// Build bitmaps in place: the entry points already validated the
// whole query buffer, and build_query_bitmap would allocate a fresh
// Vec (and re-validate) per query on this hot path.
let mut q_bitmaps = vec![0u64; nq * qpv];
for qi in 0..nq {
let q = &queries[qi * dim..(qi + 1) * dim];
let bm = &mut q_bitmaps[qi * qpv..(qi + 1) * qpv];
for (j, &value) in q.iter().enumerate() {
if value > 0.0 {
bm[j / 64] |= 1u64 << (j % 64);
}
}
}

let block_docs = (BLOCK_BYTES / (qpv * 8)).max(64).min(n);
let tile = TILE_QUERIES.min(nq);
let mut block_scores = vec![0u32; tile * block_docs];
// Max-heap keeps the current worst kept key at the top, so the
// retained set is always the m lexicographically smallest
// (hamming, doc_id) keys seen so far.
// Selection state is O(nq * m_eff) on top of the CSR output — an
// explicit checked bound (32-bit/wasm32 targets can overflow the
// multiplication) with a clear message, per the crate's
// checked-allocation discipline. Exact per-heap reservation of
// m_eff + 1 is deliberate: gradual growth would double-allocate to
// the next power of two (~2x m_eff peak per query); callers with
// extreme nq * m_eff should tile the query batch (as OrdinalDB's
// chunk scheduler does).
let selection_cells = nq.checked_mul(m_eff).unwrap_or_else(|| {
panic!("selection state nq ({nq}) * m ({m_eff}) overflows usize; tile the query batch")
});
let _ = selection_cells;
let mut heaps: Vec<BinaryHeap<(u32, u32)>> = (0..nq)
.map(|_| BinaryHeap::with_capacity(m_eff + 1))
.collect();
Comment thread
project-navi-bot marked this conversation as resolved.

let mut block_start = 0usize;
while block_start < n {
let bn = block_docs.min(n - block_start);
let block = &self.bitmaps[block_start * qpv..(block_start + bn) * qpv];
let mut tile_start = 0usize;
while tile_start < nq {
let tq = tile.min(nq - tile_start);
let qb_tile = &q_bitmaps[tile_start * qpv..(tile_start + tq) * qpv];
let scores = &mut block_scores[..tq * bn];
sign_scan_collect_batched(block, bn, qpv, qb_tile, tq, scores);
for ti in 0..tq {
let heap = &mut heaps[tile_start + ti];
let row = &scores[ti * bn..(ti + 1) * bn];
for (d, &hamming) in row.iter().enumerate() {
let key = (hamming, (block_start + d) as u32);
if heap.len() < m_eff {
heap.push(key);
} else if key < *heap.peek().expect("non-empty full collector") {
heap.pop();
heap.push(key);
}
}
}
tile_start += tq;
}
block_start += bn;
}

heaps
.into_iter()
.map(|heap| {
let mut kept = heap.into_vec();
kept.sort_unstable();
kept.into_iter().map(|(_, doc)| doc).collect()
})
.collect()
}

pub fn top_m_candidates(&self, q: &[f32], m: usize) -> Vec<u32> {
assert_eq!(q.len(), self.dim);
crate::util::assert_all_finite(q);
let m_eff = m.min(self.n_vectors);
if m_eff == 0 {
return Vec::new();
}
// Single-query stays on the dense partition path: with one query
// there is no scan to share, and select_nth_unstable_by (O(n)
// average) measurably beats an O(n log m) bounded heap for m in the
// hundreds at small/medium n (audit: +50-90% regression otherwise).
let qb = self.build_query_bitmap(q);
let mut scores = vec![0u32; self.n_vectors]; // Hamming distance per doc
sign_scan_collect(
Expand Down Expand Up @@ -313,10 +413,12 @@ impl SignBitmap {
/// pool. (The existing [`Self::top_m_candidates_batched`] remains the
/// internally-parallel standalone convenience.)
///
/// Track-1 implementation is intentionally naive — it loops the single-query
/// [`Self::top_m_candidates`] (which materialises a per-query `n` Hamming
/// row). A future release may replace the internals with streaming top-m
/// behind this frozen signature; the CSR output contract will not change.
/// The internals stream the corpus **once per call** in L2-sized doc
/// blocks, scoring every query of the call against each hot block and
/// selecting per-query top-m with bounded `(hamming, doc_id)` collectors
/// — per-query corpus traffic drops by the call's query count relative
/// to the historical per-query rescan. The CSR output contract is
/// unchanged and bit-identical to the previous implementation.
///
/// # Example
/// ```no_run
Expand Down Expand Up @@ -344,10 +446,17 @@ impl SignBitmap {
let m_eff = m.min(self.n_vectors);
let mut offsets = Vec::with_capacity(nq + 1);
offsets.push(0usize);
let mut candidates = Vec::with_capacity(nq.saturating_mul(m_eff));
for qi in 0..nq {
let q = &queries[qi * dim..(qi + 1) * dim];
let row = self.top_m_candidates(q, m);
let mut candidates = Vec::with_capacity(nq.checked_mul(m_eff).unwrap_or_else(|| {
panic!("CSR output nq ({nq}) * m ({m_eff}) overflows usize; tile the query batch")
}));
if nq == 0 || m_eff == 0 {
offsets.extend(std::iter::repeat_n(0usize, nq));
return CandidateBatch {
candidates,
offsets,
};
}
for row in self.top_m_candidates_streamed(queries, m_eff) {
candidates.extend_from_slice(&row);
offsets.push(candidates.len());
}
Expand Down
175 changes: 175 additions & 0 deletions tests/tiled_candgen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
//! Contract-pinning tests for sign candidate generation, written ahead of the
//! tiled internals swap of `top_m_candidates` /
//! `top_m_candidates_batched_serial_csr`. The oracle is independent of the
//! implementation under test: `score_all` (dense agreement counts) plus a
//! full lexicographic sort by `(hamming asc, doc_id asc)`. These tests pin
//! today's behavior exactly — including tie handling at the m-th position —
//! and must pass bit-identically before and after the swap.

use ordvec::SignBitmap;

/// Deterministic xorshift so corpora are reproducible without a rand dep.
struct XorShift(u64);

impl XorShift {
fn next_f32(&mut self) -> f32 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 7;
self.0 ^= self.0 << 17;
// Map to [-1, 1) with plenty of sign variety.
((self.0 >> 40) as f32 / 8_388_608.0) - 1.0
}
}

fn random_corpus(dim: usize, n: usize, seed: u64) -> Vec<f32> {
let mut rng = XorShift(seed | 1);
(0..n * dim).map(|_| rng.next_f32()).collect()
}

/// Tie-heavy corpus: every coordinate is +/-1 drawn from a tiny pattern set,
/// so hamming distances collide massively and the (hamming, doc_id)
/// tie-break does real work at the selection boundary.
fn tie_heavy_corpus(dim: usize, n: usize) -> Vec<f32> {
(0..n)
.flat_map(|doc| {
let pattern = doc % 4;
(0..dim).map(move |c| if (c + pattern) % 3 == 0 { -1.0 } else { 1.0 })
})
.collect()
}

fn oracle_top_m(sign: &SignBitmap, q: &[f32], m: usize) -> Vec<u32> {
let dim_u32 = u32::try_from(q.len()).unwrap();
// score_all returns agreement (dim - hamming), higher is better.
let agreements = sign.score_all(q);
let mut ids: Vec<u32> = (0..agreements.len() as u32).collect();
ids.sort_by_key(|&i| (dim_u32 - agreements[i as usize], i));
ids.truncate(m.min(agreements.len()));
ids
}

fn assert_contract(dim: usize, vectors: &[f32], queries: &[f32], m: usize, label: &str) {
let mut sign = SignBitmap::new(dim);
sign.add(vectors);
let nq = queries.len() / dim;

// Single-query path.
for qi in 0..nq {
let q = &queries[qi * dim..(qi + 1) * dim];
let got = sign.top_m_candidates(q, m);
let want = oracle_top_m(&sign, q, m);
assert_eq!(
got, want,
"{label}: single-query mismatch at query {qi}, m={m}"
);
}

// Batched serial CSR path: row qi must equal the single-query result.
let cb = sign.top_m_candidates_batched_serial_csr(queries, m);
assert_eq!(cb.offsets.len(), nq + 1, "{label}: CSR offsets length");
for qi in 0..nq {
let row = &cb.candidates[cb.offsets[qi]..cb.offsets[qi + 1]];
let want = oracle_top_m(&sign, &queries[qi * dim..(qi + 1) * dim], m);
assert_eq!(
row,
&want[..],
"{label}: CSR row mismatch at query {qi}, m={m}"
);
}
}

/// Random corpus large enough to span many doc blocks under any plausible
/// tile size, at a SIMD-friendly dim.
#[test]
fn random_corpus_matches_oracle_across_block_boundaries() {
// dim=512 -> 8 qwords/vec -> 4096-doc blocks; n=10240 spans three
// blocks including a final partial one (audit: the previous dim=128
// shape fit in a single block, so the loop never crossed a boundary).
let dim = 512;
let n = 10_240;
let vectors = random_corpus(dim, n, 0xC0FFEE);
let queries = random_corpus(dim, 33, 0xBEEF);
for m in [1, 7, 256, 500] {
assert_contract(dim, &vectors, &queries, m, "random");
}
}

/// Massive hamming ties: selection at the boundary is decided purely by
/// doc_id ascending. This is the case a streaming collector most easily gets
/// subtly wrong.
#[test]
fn tie_heavy_corpus_selects_lowest_doc_ids_at_boundary() {
let dim = 64;
let n = 4_096;
let vectors = tie_heavy_corpus(dim, n);
let queries = random_corpus(dim, 9, 0xABCD);
for m in [1, 3, 100, 1_000] {
assert_contract(dim, &vectors, &queries, m, "tie-heavy");
}
}

/// Exact duplicate documents: every duplicate group is one giant tie run,
/// longer than m, exercising equal-hamming runs that exceed the collector.
#[test]
fn duplicate_documents_tie_runs_longer_than_m() {
let dim = 64;
let base = random_corpus(dim, 8, 0x1234);
// 8 distinct vectors, each repeated 512 times => tie runs of 512.
let mut vectors = Vec::with_capacity(8 * 512 * dim);
for rep in 0..512 {
let _ = rep;
vectors.extend_from_slice(&base);
}
let queries = random_corpus(dim, 5, 0x9999);
for m in [10, 100, 513] {
assert_contract(dim, &vectors, &queries, m, "duplicates");
}
}

/// Edge geometry: m >= n, m == n, single doc, single query, nq == 0.
#[test]
fn edge_geometries_match_oracle() {
let dim = 64;
let vectors = random_corpus(dim, 17, 0x42);
let queries = random_corpus(dim, 3, 0x43);
for m in [17, 25, 1] {
assert_contract(dim, &vectors, &queries, m, "edge");
}

let single_doc = random_corpus(dim, 1, 0x77);
assert_contract(dim, &single_doc, &queries, 4, "single-doc");

// Empty query batch: CSR must be a single zero offset and no candidates.
let mut sign = SignBitmap::new(dim);
sign.add(&vectors);
let cb = sign.top_m_candidates_batched_serial_csr(&[], 8);
assert_eq!(cb.offsets, vec![0]);
assert!(cb.candidates.is_empty());
}

/// Large-dim smoke at the shape the arXiv corpus uses (1024 dims), enough
/// rows to cross several L2-sized doc blocks.
#[test]
fn dim_1024_shape_matches_oracle() {
let dim = 1024;
let n = 6_000;
let vectors = random_corpus(dim, n, 0xA5A5);
let queries = random_corpus(dim, 8, 0x5A5A);
for m in [256, 320] {
assert_contract(dim, &vectors, &queries, m, "dim1024");
}
}

/// AVX-512 tail residue (dim=768 -> qpv=12, rem=4) composed with
/// multi-block crossing and a final partial block — the kernel-shape case
/// the audit flagged as untested in the permanent suite.
#[test]
fn dim_768_tail_residue_crosses_blocks() {
let dim = 768;
let n = 3_200; // block_docs = 262144/96 = 2730 -> 2 blocks, partial tail
let vectors = random_corpus(dim, n, 0x7E57);
let queries = random_corpus(dim, 7, 0x7E58);
for m in [64, 320] {
assert_contract(dim, &vectors, &queries, m, "dim768-tail");
}
}
Loading