From 1f014a5d3f3415e2abbd6cd97d3e3009a010ad7b Mon Sep 17 00:00:00 2001 From: Nelson Spence Date: Fri, 3 Jul 2026 14:39:41 -0500 Subject: [PATCH] perf: LUT + parallel constant-composition check on RankQuant load MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit load_rankquant's forged-buffer defense histogrammed every packed code serially — 1.29 billion shift/mask ops at 1.26M x 1024, ~1s of the 1.27s verified open. A 4KB per-byte bucket-count LUT replaces the per-code inner loop and rows validate in parallel; find_first keeps the lowest-offending-row error contract, with a scalar recheck producing the identical message. The security property is unchanged: every row still proves uniform composition before the index is usable. --- src/rank_io.rs | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/rank_io.rs b/src/rank_io.rs index e05505c..d0b2957 100644 --- a/src/rank_io.rs +++ b/src/rank_io.rs @@ -765,8 +765,39 @@ fn load_rankquant_from_stream( let expected_per_bucket = dim / n_buckets; let mask = (1u8 << bits) - 1; let bits_u = bits as usize; - for (row_idx, row) in packed.chunks_exact(bytes_per_row).enumerate() { - let mut hist = [0usize; 16]; // n_buckets <= 2^4 = 16 + // Per-byte bucket-count LUT: byte value -> how many of its packed codes + // land in each bucket. Replaces the per-code shift/mask loop (dim ops + // per row) with bytes_per_row table lookups, and rows check in parallel + // (they are independent). `find_first` preserves the serial contract of + // reporting the lowest offending row. + let mut lut = [[0u8; 16]; 256]; + for (byte, counts) in lut.iter_mut().enumerate() { + for slot in 0..codes_per_byte { + let shift = (codes_per_byte - 1 - slot) * bits_u; + counts[((byte as u8 >> shift) & mask) as usize] += 1; + } + } + let row_is_valid = |row: &[u8]| { + let mut hist = [0u16; 16]; + for &byte in row { + let counts = &lut[byte as usize]; + for bucket in 0..n_buckets { + hist[bucket] += u16::from(counts[bucket]); + } + } + hist[..n_buckets] + .iter() + .all(|&count| count as usize == expected_per_bucket) + }; + use rayon::prelude::*; + let first_bad = (0..n_vectors).into_par_iter().find_first(|&row_idx| { + !row_is_valid(&packed[row_idx * bytes_per_row..(row_idx + 1) * bytes_per_row]) + }); + if let Some(row_idx) = first_bad { + // Rerun the scalar histogram on the offending row for the exact + // bucket/count in the error message. + let row = &packed[row_idx * bytes_per_row..(row_idx + 1) * bytes_per_row]; + let mut hist = [0usize; 16]; for &byte in row { for slot in 0..codes_per_byte { let shift = (codes_per_byte - 1 - slot) * bits_u; @@ -781,6 +812,7 @@ fn load_rankquant_from_stream( ))); } } + unreachable!("row {row_idx} failed the LUT check but passed the scalar recheck"); } Ok((bits, dim, n_vectors, packed)) }