wifi-densepose/v2/crates/wifi-densepose-ruvector/src/rotation.rs

374 lines
15 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! RaBitQ **Pass 2** — deterministic randomized orthogonal rotation.
//!
//! Implements the "Pass 2" deferred in [`crate::sketch`]'s Pass-1 doc and in
//! [ADR-156 §8](../../../../../docs/adr/ADR-156-ruvector-fusion-beyond-sota.md)
//! (Multi-bit / Extended RaBitQ). The published *RaBitQ* algorithm
//! (Gao & Long, SIGMOD 2024) wraps the 1-bit sign-quantization of Pass 1 with
//! a **randomized orthogonal rotation** `R` applied to every embedding *before*
//! sign-quantization. The rotation decorrelates coordinates so the per-bit sign
//! carries more independent information, which gives both the paper's
//! theoretical error bound and better top-K recall on anisotropic / correlated
//! embedding distributions (exactly the case ADR-084's "Open questions" flagged
//! for skewed spectrogram embeddings).
//!
//! # Why a Fast Hadamard Transform, not a dense d×d matrix
//!
//! A full dense orthogonal matrix `R ∈ ^{d×d}` is **O(d²) memory and O(d²)
//! time per vector**. ADR-084's wire format already provisions for embeddings
//! up to `u16::MAX = 65,535` dimensions; a dense rotation there is ~4.3 G
//! floats (17 GiB) — completely infeasible on the cluster-Pi / edge targets
//! this sketch is built for.
//!
//! Instead we use the **randomized Hadamard transform** (the "HD" construction,
//! a.k.a. a structured JohnsonLindenstrauss / fast-JL rotation):
//!
//! ```text
//! R · x = H · D · x
//! ```
//!
//! where `D` is a diagonal matrix of random ±1 sign flips and `H` is the
//! (normalized) WalshHadamard matrix applied via the **Fast Hadamard
//! Transform (FHT)**. The FHT is `O(d log d)` time and `O(1)` extra memory
//! (in-place butterfly); `D` is `O(d)` memory (one sign per dimension, packed).
//! `H` and `D` are each orthogonal, so `R = H·D` is orthogonal and therefore
//! **norm-preserving** — a hard requirement for a rotation that must not distort
//! relative distances. This is the same fast-orthogonal trick used by Fast-JL,
//! Structured Orthogonal Random Features, and the RaBitQ reference rotation.
//!
//! # Determinism (index-time == query-time)
//!
//! The rotation **must** be identical when the bank is built and when it is
//! queried, or the two sign-quantizations live in different rotated frames and
//! hamming distance becomes meaningless. We therefore derive the ±1 sign flips
//! deterministically from a stored `u64` seed via a SplitMix64 PRNG — **never**
//! an unseeded / OS RNG. Two [`Rotation`]s built from the same `(seed, dim)`
//! produce bit-identical output for the same input (pinned by
//! `rotation_is_deterministic_for_seed`).
//!
//! # Power-of-two padding
//!
//! The FHT is defined on lengths that are powers of two. For a `d` that is not
//! a power of two we pad the (sign-flipped) input with zeros up to the next
//! power of two `m = next_pow2(d)`, run the length-`m` FHT, and then **read back
//! the first `d` coordinates**. Zero-padding + orthogonal `H` keeps the
//! transform norm-preserving on the padded vector; we sign-quantize the first
//! `d` rotated coordinates so the sketch dimension is unchanged from Pass 1
//! (API-compatible: same `embedding_dim`, same packed-byte length, same
//! `SketchBank` schema).
/// A deterministic randomized orthogonal rotation (FHT-based) applied to an
/// embedding before sign-quantization — RaBitQ Pass 2.
///
/// Construct once per `(seed, dim)` and reuse for **every** embedding that goes
/// into the same [`crate::SketchBank`] (and for every query against it). The
/// seed is stored so the rotation is reproducible across processes and runs.
///
/// # Invariants
///
/// - `dim` is the source-embedding dimension (the sketch keeps this dimension).
/// - `padded` is `next_pow2(dim)` — the FHT working length.
/// - `signs` has exactly `padded` entries (`+1.0` / `-1.0`), derived from
/// `seed` via SplitMix64. Padding positions get signs too; they only ever
/// multiply zeros, so their value is irrelevant to the result but they keep
/// the construction uniform.
#[derive(Debug, Clone)]
pub struct Rotation {
/// Source-embedding dimension; the rotated sketch keeps this dimension.
dim: usize,
/// FHT working length = `next_pow2(dim)`.
padded: usize,
/// Random ±1 sign flips (the diagonal `D`), length `padded`.
signs: Vec<f32>,
/// The seed the sign flips were derived from (stored for reproducibility).
seed: u64,
}
impl Rotation {
/// Build a rotation for `dim`-dimensional embeddings from a fixed `seed`.
///
/// The same `(seed, dim)` always yields a bit-identical rotation, so an
/// index built with `Rotation::new(seed, d)` and a query rotated with a
/// freshly-constructed `Rotation::new(seed, d)` agree exactly.
///
/// `dim == 0` yields an identity (empty) rotation — `apply` returns an
/// empty vector — which keeps the constructor total (no panic on a
/// degenerate dimension).
pub fn new(seed: u64, dim: usize) -> Self {
let padded = next_pow2(dim);
let mut signs = Vec::with_capacity(padded);
// SplitMix64: a tiny, well-distributed, fully deterministic PRNG. We
// only need a reproducible stream of bits to pick ±1 per dimension;
// SplitMix64 is the standard seeding generator and is more than
// adequate (and far better-mixed than the LCG used for bench fixtures).
let mut state = seed;
for _ in 0..padded {
state = split_mix64(&mut state);
// Use the top bit of the mixed word to choose the sign.
signs.push(if state >> 63 == 1 { 1.0 } else { -1.0 });
}
Self {
dim,
padded,
signs,
seed,
}
}
/// The seed this rotation was derived from (for serialization / audit).
#[inline]
pub fn seed(&self) -> u64 {
self.seed
}
/// Source-embedding dimension this rotation expects.
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
/// FHT working length (`next_pow2(dim)`).
#[inline]
pub fn padded_dim(&self) -> usize {
self.padded
}
/// Apply the rotation `R = H·D` to `embedding`, returning the first `dim`
/// rotated coordinates.
///
/// If `embedding.len() != dim` the input is treated charitably: it is
/// truncated or zero-extended to `dim` before rotation. This mirrors
/// Pass 1's saturating tolerance and keeps the call total.
///
/// The returned vector has length `self.dim`. Its L2 norm equals the L2
/// norm of the (dim-truncated / zero-extended) input up to floating-point
/// rounding — see [`Rotation::apply`] tests and
/// `rotation_preserves_norm`.
pub fn apply(&self, embedding: &[f32]) -> Vec<f32> {
if self.dim == 0 {
return Vec::new();
}
let mut buf = self.apply_padded(embedding);
// Read back the first `dim` rotated coordinates as the sketch input.
buf.truncate(self.dim);
buf
}
/// Apply the rotation `R = H·D` and return **all `padded_dim` rotated
/// coordinates** (not truncated to `dim`).
///
/// This is the frame the RaBitQ estimator ([`crate::estimator`]) works in:
/// the 1-bit code `x̄ ∈ {±1/√D}^D` is unit over the **padded** length `D`,
/// and the query dot product `⟨x̄, q'⟩` must be taken over that same `D`. For
/// a power-of-two `dim`, `padded_dim == dim` and this equals
/// [`Rotation::apply`]; for a non-power-of-two `dim` the tail coordinates
/// (the zero-padded energy redistributed by the FHT) are retained here but
/// dropped by `apply`.
///
/// `dim == 0` yields an empty vector. Ragged input is handled charitably
/// (truncate / zero-extend to `dim`), as in [`Rotation::apply`].
pub fn apply_padded(&self, embedding: &[f32]) -> Vec<f32> {
if self.dim == 0 {
return Vec::new();
}
// Build the padded, sign-flipped working buffer: buf = D · x, then 0-pad.
let mut buf = vec![0.0f32; self.padded];
let n = embedding.len().min(self.dim);
for i in 0..n {
buf[i] = embedding[i] * self.signs[i];
}
// (positions n..dim and dim..padded stay zero — zero-extend + pad)
// In-place normalized Fast Hadamard Transform.
fht_normalized(&mut buf);
buf
}
}
/// Smallest power of two `>= n` (with `next_pow2(0) == 1`, `next_pow2(1) == 1`).
///
/// Pulled out (and `pub(crate)`) so the sketch layer and tests can reason about
/// the FHT working length without duplicating the rule.
#[inline]
pub(crate) fn next_pow2(n: usize) -> usize {
if n <= 1 {
return 1;
}
// `n` here is small relative to usize::MAX in every realistic embedding
// (<= 65_535), so `next_power_of_two` cannot overflow.
n.next_power_of_two()
}
/// SplitMix64 step: advance `state` and return a well-mixed 64-bit word.
///
/// Reference algorithm (public domain, by Sebastiano Vigna). Deterministic and
/// dependency-free — exactly what we need for a reproducible sign stream.
#[inline]
fn split_mix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
/// In-place **normalized** Fast Hadamard Transform on a power-of-two slice.
///
/// Computes `y = (1/√m) · H_m · x` in place, where `H_m` is the `m × m`
/// WalshHadamard matrix and `m = buf.len()` is a power of two. The `1/√m`
/// normalization makes `H` orthogonal (`HᵀH = I`), so the transform preserves
/// the L2 norm. Runs in `O(m log m)` with `O(1)` extra memory (the standard
/// iterative butterfly).
///
/// # Panics
///
/// Debug-asserts that `buf.len()` is a power of two. Callers in this module
/// always pass `next_pow2(dim)`, so this never fires in practice; it documents
/// the precondition.
fn fht_normalized(buf: &mut [f32]) {
let m = buf.len();
debug_assert!(m.is_power_of_two(), "FHT length must be a power of two");
if m <= 1 {
return;
}
// Unnormalized in-place WalshHadamard butterfly.
let mut h = 1usize;
while h < m {
let mut i = 0usize;
while i < m {
for j in i..i + h {
let x = buf[j];
let y = buf[j + h];
buf[j] = x + y;
buf[j + h] = x - y;
}
i += h * 2;
}
h *= 2;
}
// Normalize by 1/√m so H is orthogonal (norm-preserving).
let inv_sqrt_m = 1.0f32 / (m as f32).sqrt();
for v in buf.iter_mut() {
*v *= inv_sqrt_m;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn l2(v: &[f32]) -> f32 {
v.iter().map(|&x| x * x).sum::<f32>().sqrt()
}
#[test]
fn next_pow2_rounds_up() {
assert_eq!(next_pow2(0), 1);
assert_eq!(next_pow2(1), 1);
assert_eq!(next_pow2(2), 2);
assert_eq!(next_pow2(3), 4);
assert_eq!(next_pow2(128), 128);
assert_eq!(next_pow2(129), 256);
assert_eq!(next_pow2(200), 256);
assert_eq!(next_pow2(65_535), 65_536);
}
#[test]
fn fht_is_norm_preserving_on_power_of_two() {
// Pure FHT (no sign flips) must preserve L2 norm to fp tolerance.
let mut v: Vec<f32> = (0..8).map(|i| (i as f32 - 3.5) * 0.7).collect();
let before = l2(&v);
fht_normalized(&mut v);
let after = l2(&v);
assert!(
(before - after).abs() < 1e-5,
"FHT changed norm: {before} -> {after}"
);
}
#[test]
fn fht_self_inverse_normalized() {
// Normalized H is symmetric and orthogonal, so H·H·x == x.
let original: Vec<f32> = vec![1.0, -2.0, 3.0, 0.5];
let mut v = original.clone();
fht_normalized(&mut v);
fht_normalized(&mut v);
for (a, b) in original.iter().zip(v.iter()) {
assert!((a - b).abs() < 1e-5, "H·H·x != x: {a} vs {b}");
}
}
#[test]
fn rotation_is_deterministic_for_seed() {
// Two rotations from the same (seed, dim) must produce identical
// output for the same input — the index-time == query-time contract.
let r1 = Rotation::new(0xDEAD_BEEF_CAFE_1234, 130);
let r2 = Rotation::new(0xDEAD_BEEF_CAFE_1234, 130);
let x: Vec<f32> = (0..130).map(|i| (i as f32 * 0.31).sin()).collect();
let a = r1.apply(&x);
let b = r2.apply(&x);
assert_eq!(a.len(), 130);
assert_eq!(a, b, "same seed must give identical rotation");
// A different seed must (almost surely) differ.
let r3 = Rotation::new(0x0000_0000_0000_0001, 130);
let c = r3.apply(&x);
assert_ne!(a, c, "different seed must give different rotation");
}
#[test]
fn rotation_preserves_norm() {
// R = H·D is orthogonal; on a power-of-two dim the first `dim`
// coordinates ARE the whole transform, so norm is preserved exactly
// (to fp tolerance). We test a power-of-two dim for the exact claim.
let r = Rotation::new(42, 128);
let x: Vec<f32> = (0..128).map(|i| ((i * 7 % 13) as f32 - 6.0) * 0.5).collect();
let y = r.apply(&x);
let before = l2(&x);
let after = l2(&y);
assert!(
(before - after).abs() < 1e-3 * before.max(1.0),
"rotation changed norm: {before} -> {after}"
);
}
#[test]
fn rotation_non_power_of_two_preserves_norm_via_padding() {
// For a non-power-of-two dim, reading back the first `dim` coords of a
// padded FHT only preserves norm if the padded tail carries ~no energy.
// We assert the rotated norm does not EXCEED the input norm (the padded
// transform is non-expansive on the truncated read-back) and stays
// within a loose band — enough to confirm padding is sane, not a hard
// exact-norm claim.
let r = Rotation::new(7, 130); // pads 130 -> 256
assert_eq!(r.padded_dim(), 256);
let x: Vec<f32> = (0..130).map(|i| (i as f32 * 0.13).cos()).collect();
let y = r.apply(&x);
assert_eq!(y.len(), 130);
let before = l2(&x);
let after = l2(&y);
// Truncated read-back is non-expansive: ||y|| <= ||Hx|| == ||x||.
assert!(
after <= before + 1e-4,
"truncated rotation expanded norm: {before} -> {after}"
);
}
#[test]
fn rotation_dim_zero_is_empty() {
let r = Rotation::new(1, 0);
assert!(r.apply(&[]).is_empty());
assert!(r.apply(&[1.0, 2.0]).is_empty());
}
#[test]
fn rotation_handles_ragged_input() {
// Charitable length handling: short input zero-extends, long truncates.
let r = Rotation::new(99, 64);
let short = r.apply(&[1.0, 2.0, 3.0]); // zero-extended to 64
assert_eq!(short.len(), 64);
let long: Vec<f32> = (0..200).map(|i| i as f32).collect();
let truncated = r.apply(&long); // truncated to 64
assert_eq!(truncated.len(), 64);
}
}