diff --git a/src/sign_bitmap.rs b/src/sign_bitmap.rs index 66f971a..8d2e764 100644 --- a/src/sign_bitmap.rs +++ b/src/sign_bitmap.rs @@ -39,6 +39,7 @@ //! scalar path. See [`crate::avx512vpop_supported`]. use rayon::prelude::*; +use std::collections::BinaryHeap; use crate::OrdvecError; @@ -220,6 +221,101 @@ impl SignBitmap { /// SIMD dispatch paths — same audit discipline as /// [`crate::Bitmap::top_m_candidates`]. #[must_use = "this scans the corpus to generate candidates; dropping the result discards that work"] + /// Streamed exact top-m selection shared by [`Self::top_m_candidates`] + /// and [`Self::top_m_candidates_batched_serial_csr`]: the corpus is + /// scanned once per call in L2-sized doc blocks, each hot block is + /// scored against every query (in small query tiles), and per-query + /// bounded min-m collectors keyed by `(hamming, doc_id)` select exactly + /// the lexicographic top-m — bit-identical to a full sort, independent + /// of processing order. Serial by contract: no rayon. + fn top_m_candidates_streamed(&self, queries: &[f32], m_eff: usize) -> Vec> { + const TILE_QUERIES: usize = 32; + const BLOCK_BYTES: usize = 256 * 1024; + + let dim = self.dim; + debug_assert!( + queries.len().is_multiple_of(dim), + "queries buffer must be a whole number of rows" + ); + let nq = queries.len() / dim; + let qpv = self.qwords_per_vec; + let n = self.n_vectors; + debug_assert!(m_eff >= 1 && m_eff <= n); + + // Build bitmaps in place: the entry points already validated the + // whole query buffer, and build_query_bitmap would allocate a fresh + // Vec (and re-validate) per query on this hot path. + let mut q_bitmaps = vec![0u64; nq * qpv]; + for qi in 0..nq { + let q = &queries[qi * dim..(qi + 1) * dim]; + let bm = &mut q_bitmaps[qi * qpv..(qi + 1) * qpv]; + for (j, &value) in q.iter().enumerate() { + if value > 0.0 { + bm[j / 64] |= 1u64 << (j % 64); + } + } + } + + let block_docs = (BLOCK_BYTES / (qpv * 8)).max(64).min(n); + let tile = TILE_QUERIES.min(nq); + let mut block_scores = vec![0u32; tile * block_docs]; + // Max-heap keeps the current worst kept key at the top, so the + // retained set is always the m lexicographically smallest + // (hamming, doc_id) keys seen so far. + // Selection state is O(nq * m_eff) on top of the CSR output — an + // explicit checked bound (32-bit/wasm32 targets can overflow the + // multiplication) with a clear message, per the crate's + // checked-allocation discipline. Exact per-heap reservation of + // m_eff + 1 is deliberate: gradual growth would double-allocate to + // the next power of two (~2x m_eff peak per query); callers with + // extreme nq * m_eff should tile the query batch (as OrdinalDB's + // chunk scheduler does). + let selection_cells = nq.checked_mul(m_eff).unwrap_or_else(|| { + panic!("selection state nq ({nq}) * m ({m_eff}) overflows usize; tile the query batch") + }); + let _ = selection_cells; + let mut heaps: Vec> = (0..nq) + .map(|_| BinaryHeap::with_capacity(m_eff + 1)) + .collect(); + + let mut block_start = 0usize; + while block_start < n { + let bn = block_docs.min(n - block_start); + let block = &self.bitmaps[block_start * qpv..(block_start + bn) * qpv]; + let mut tile_start = 0usize; + while tile_start < nq { + let tq = tile.min(nq - tile_start); + let qb_tile = &q_bitmaps[tile_start * qpv..(tile_start + tq) * qpv]; + let scores = &mut block_scores[..tq * bn]; + sign_scan_collect_batched(block, bn, qpv, qb_tile, tq, scores); + for ti in 0..tq { + let heap = &mut heaps[tile_start + ti]; + let row = &scores[ti * bn..(ti + 1) * bn]; + for (d, &hamming) in row.iter().enumerate() { + let key = (hamming, (block_start + d) as u32); + if heap.len() < m_eff { + heap.push(key); + } else if key < *heap.peek().expect("non-empty full collector") { + heap.pop(); + heap.push(key); + } + } + } + tile_start += tq; + } + block_start += bn; + } + + heaps + .into_iter() + .map(|heap| { + let mut kept = heap.into_vec(); + kept.sort_unstable(); + kept.into_iter().map(|(_, doc)| doc).collect() + }) + .collect() + } + pub fn top_m_candidates(&self, q: &[f32], m: usize) -> Vec { assert_eq!(q.len(), self.dim); crate::util::assert_all_finite(q); @@ -227,6 +323,10 @@ impl SignBitmap { if m_eff == 0 { return Vec::new(); } + // Single-query stays on the dense partition path: with one query + // there is no scan to share, and select_nth_unstable_by (O(n) + // average) measurably beats an O(n log m) bounded heap for m in the + // hundreds at small/medium n (audit: +50-90% regression otherwise). let qb = self.build_query_bitmap(q); let mut scores = vec![0u32; self.n_vectors]; // Hamming distance per doc sign_scan_collect( @@ -313,10 +413,12 @@ impl SignBitmap { /// pool. (The existing [`Self::top_m_candidates_batched`] remains the /// internally-parallel standalone convenience.) /// - /// Track-1 implementation is intentionally naive — it loops the single-query - /// [`Self::top_m_candidates`] (which materialises a per-query `n` Hamming - /// row). A future release may replace the internals with streaming top-m - /// behind this frozen signature; the CSR output contract will not change. + /// The internals stream the corpus **once per call** in L2-sized doc + /// blocks, scoring every query of the call against each hot block and + /// selecting per-query top-m with bounded `(hamming, doc_id)` collectors + /// — per-query corpus traffic drops by the call's query count relative + /// to the historical per-query rescan. The CSR output contract is + /// unchanged and bit-identical to the previous implementation. /// /// # Example /// ```no_run @@ -344,10 +446,17 @@ impl SignBitmap { let m_eff = m.min(self.n_vectors); let mut offsets = Vec::with_capacity(nq + 1); offsets.push(0usize); - let mut candidates = Vec::with_capacity(nq.saturating_mul(m_eff)); - for qi in 0..nq { - let q = &queries[qi * dim..(qi + 1) * dim]; - let row = self.top_m_candidates(q, m); + let mut candidates = Vec::with_capacity(nq.checked_mul(m_eff).unwrap_or_else(|| { + panic!("CSR output nq ({nq}) * m ({m_eff}) overflows usize; tile the query batch") + })); + if nq == 0 || m_eff == 0 { + offsets.extend(std::iter::repeat_n(0usize, nq)); + return CandidateBatch { + candidates, + offsets, + }; + } + for row in self.top_m_candidates_streamed(queries, m_eff) { candidates.extend_from_slice(&row); offsets.push(candidates.len()); } diff --git a/tests/tiled_candgen.rs b/tests/tiled_candgen.rs new file mode 100644 index 0000000..33ac414 --- /dev/null +++ b/tests/tiled_candgen.rs @@ -0,0 +1,175 @@ +//! Contract-pinning tests for sign candidate generation, written ahead of the +//! tiled internals swap of `top_m_candidates` / +//! `top_m_candidates_batched_serial_csr`. The oracle is independent of the +//! implementation under test: `score_all` (dense agreement counts) plus a +//! full lexicographic sort by `(hamming asc, doc_id asc)`. These tests pin +//! today's behavior exactly — including tie handling at the m-th position — +//! and must pass bit-identically before and after the swap. + +use ordvec::SignBitmap; + +/// Deterministic xorshift so corpora are reproducible without a rand dep. +struct XorShift(u64); + +impl XorShift { + fn next_f32(&mut self) -> f32 { + self.0 ^= self.0 << 13; + self.0 ^= self.0 >> 7; + self.0 ^= self.0 << 17; + // Map to [-1, 1) with plenty of sign variety. + ((self.0 >> 40) as f32 / 8_388_608.0) - 1.0 + } +} + +fn random_corpus(dim: usize, n: usize, seed: u64) -> Vec { + let mut rng = XorShift(seed | 1); + (0..n * dim).map(|_| rng.next_f32()).collect() +} + +/// Tie-heavy corpus: every coordinate is +/-1 drawn from a tiny pattern set, +/// so hamming distances collide massively and the (hamming, doc_id) +/// tie-break does real work at the selection boundary. +fn tie_heavy_corpus(dim: usize, n: usize) -> Vec { + (0..n) + .flat_map(|doc| { + let pattern = doc % 4; + (0..dim).map(move |c| if (c + pattern) % 3 == 0 { -1.0 } else { 1.0 }) + }) + .collect() +} + +fn oracle_top_m(sign: &SignBitmap, q: &[f32], m: usize) -> Vec { + let dim_u32 = u32::try_from(q.len()).unwrap(); + // score_all returns agreement (dim - hamming), higher is better. + let agreements = sign.score_all(q); + let mut ids: Vec = (0..agreements.len() as u32).collect(); + ids.sort_by_key(|&i| (dim_u32 - agreements[i as usize], i)); + ids.truncate(m.min(agreements.len())); + ids +} + +fn assert_contract(dim: usize, vectors: &[f32], queries: &[f32], m: usize, label: &str) { + let mut sign = SignBitmap::new(dim); + sign.add(vectors); + let nq = queries.len() / dim; + + // Single-query path. + for qi in 0..nq { + let q = &queries[qi * dim..(qi + 1) * dim]; + let got = sign.top_m_candidates(q, m); + let want = oracle_top_m(&sign, q, m); + assert_eq!( + got, want, + "{label}: single-query mismatch at query {qi}, m={m}" + ); + } + + // Batched serial CSR path: row qi must equal the single-query result. + let cb = sign.top_m_candidates_batched_serial_csr(queries, m); + assert_eq!(cb.offsets.len(), nq + 1, "{label}: CSR offsets length"); + for qi in 0..nq { + let row = &cb.candidates[cb.offsets[qi]..cb.offsets[qi + 1]]; + let want = oracle_top_m(&sign, &queries[qi * dim..(qi + 1) * dim], m); + assert_eq!( + row, + &want[..], + "{label}: CSR row mismatch at query {qi}, m={m}" + ); + } +} + +/// Random corpus large enough to span many doc blocks under any plausible +/// tile size, at a SIMD-friendly dim. +#[test] +fn random_corpus_matches_oracle_across_block_boundaries() { + // dim=512 -> 8 qwords/vec -> 4096-doc blocks; n=10240 spans three + // blocks including a final partial one (audit: the previous dim=128 + // shape fit in a single block, so the loop never crossed a boundary). + let dim = 512; + let n = 10_240; + let vectors = random_corpus(dim, n, 0xC0FFEE); + let queries = random_corpus(dim, 33, 0xBEEF); + for m in [1, 7, 256, 500] { + assert_contract(dim, &vectors, &queries, m, "random"); + } +} + +/// Massive hamming ties: selection at the boundary is decided purely by +/// doc_id ascending. This is the case a streaming collector most easily gets +/// subtly wrong. +#[test] +fn tie_heavy_corpus_selects_lowest_doc_ids_at_boundary() { + let dim = 64; + let n = 4_096; + let vectors = tie_heavy_corpus(dim, n); + let queries = random_corpus(dim, 9, 0xABCD); + for m in [1, 3, 100, 1_000] { + assert_contract(dim, &vectors, &queries, m, "tie-heavy"); + } +} + +/// Exact duplicate documents: every duplicate group is one giant tie run, +/// longer than m, exercising equal-hamming runs that exceed the collector. +#[test] +fn duplicate_documents_tie_runs_longer_than_m() { + let dim = 64; + let base = random_corpus(dim, 8, 0x1234); + // 8 distinct vectors, each repeated 512 times => tie runs of 512. + let mut vectors = Vec::with_capacity(8 * 512 * dim); + for rep in 0..512 { + let _ = rep; + vectors.extend_from_slice(&base); + } + let queries = random_corpus(dim, 5, 0x9999); + for m in [10, 100, 513] { + assert_contract(dim, &vectors, &queries, m, "duplicates"); + } +} + +/// Edge geometry: m >= n, m == n, single doc, single query, nq == 0. +#[test] +fn edge_geometries_match_oracle() { + let dim = 64; + let vectors = random_corpus(dim, 17, 0x42); + let queries = random_corpus(dim, 3, 0x43); + for m in [17, 25, 1] { + assert_contract(dim, &vectors, &queries, m, "edge"); + } + + let single_doc = random_corpus(dim, 1, 0x77); + assert_contract(dim, &single_doc, &queries, 4, "single-doc"); + + // Empty query batch: CSR must be a single zero offset and no candidates. + let mut sign = SignBitmap::new(dim); + sign.add(&vectors); + let cb = sign.top_m_candidates_batched_serial_csr(&[], 8); + assert_eq!(cb.offsets, vec![0]); + assert!(cb.candidates.is_empty()); +} + +/// Large-dim smoke at the shape the arXiv corpus uses (1024 dims), enough +/// rows to cross several L2-sized doc blocks. +#[test] +fn dim_1024_shape_matches_oracle() { + let dim = 1024; + let n = 6_000; + let vectors = random_corpus(dim, n, 0xA5A5); + let queries = random_corpus(dim, 8, 0x5A5A); + for m in [256, 320] { + assert_contract(dim, &vectors, &queries, m, "dim1024"); + } +} + +/// AVX-512 tail residue (dim=768 -> qpv=12, rem=4) composed with +/// multi-block crossing and a final partial block — the kernel-shape case +/// the audit flagged as untested in the permanent suite. +#[test] +fn dim_768_tail_residue_crosses_blocks() { + let dim = 768; + let n = 3_200; // block_docs = 262144/96 = 2730 -> 2 blocks, partial tail + let vectors = random_corpus(dim, n, 0x7E57); + let queries = random_corpus(dim, 7, 0x7E58); + for m in [64, 320] { + assert_contract(dim, &vectors, &queries, m, "dim768-tail"); + } +}