diff --git a/v2/crates/wifi-densepose-ruvector/src/sketch.rs b/v2/crates/wifi-densepose-ruvector/src/sketch.rs index 10aead72..9045d2c4 100644 --- a/v2/crates/wifi-densepose-ruvector/src/sketch.rs +++ b/v2/crates/wifi-densepose-ruvector/src/sketch.rs @@ -41,6 +41,8 @@ //! embeddings is `Sketch::from_embedding`. use ruvector_core::quantization::{BinaryQuantized, QuantizedVector}; +use std::cmp::Reverse; +use std::collections::BinaryHeap; /// Errors raised by the sketch API. #[derive(Debug, thiserror::Error)] @@ -295,17 +297,47 @@ impl SketchBank { }); } } - // O(n log k) using a partial sort; for small k (typical k = 8 to 64) - // and bank sizes up to a few thousand sketches, the simple sort-all - // approach is faster in practice (cache-friendly) and easier to audit. - // Switch to a max-heap if profiling shows this becomes a hot spot. - let mut scored: Vec<(u32, u32)> = self - .entries - .iter() - .map(|(id, sk)| (*id, sk.distance_unchecked(query))) + // Pass-1.5 optimisation: O(n log k) partial sort via a fixed-size + // max-heap of `Reverse((distance, id))`. The heap's `peek()` + // returns the *largest* of the current best-k. Each candidate is + // compared against the heap top in O(1); only better candidates + // trigger an O(log k) push/pop. Avoids touching the long tail of + // large-distance entries that the truncate would have discarded. + // + // Fast path: when n ≤ k there is nothing to discard, so a plain + // collect + sort is faster than building a heap. + let n = self.entries.len(); + if n <= k { + let mut scored: Vec<(u32, u32)> = self + .entries + .iter() + .map(|(id, sk)| (*id, sk.distance_unchecked(query))) + .collect(); + scored.sort_by_key(|&(_, d)| d); + return Ok(scored); + } + + let mut heap: BinaryHeap> = BinaryHeap::with_capacity(k + 1); + for (id, sk) in &self.entries { + let d = sk.distance_unchecked(query); + if heap.len() < k { + heap.push(Reverse((d, *id))); + } else { + // Safe: heap has exactly k > 0 elements, just checked. + let worst = heap.peek().expect("heap len == k > 0").0 .0; + if d < worst { + heap.pop(); + heap.push(Reverse((d, *id))); + } + } + } + // Drain heap into a Vec — already in (Reverse) descending order; + // sort to expose ascending-by-distance per the public contract. + let mut scored: Vec<(u32, u32)> = heap + .into_iter() + .map(|Reverse((d, id))| (id, d)) .collect(); scored.sort_by_key(|&(_, d)| d); - scored.truncate(k); Ok(scored) }