diff --git a/v2/crates/wifi-densepose-temporal/tests/dense_vs_sparse.rs b/v2/crates/wifi-densepose-temporal/tests/dense_vs_sparse.rs new file mode 100644 index 00000000..222b6b73 --- /dev/null +++ b/v2/crates/wifi-densepose-temporal/tests/dense_vs_sparse.rs @@ -0,0 +1,184 @@ +//! Numerical A/B test for ADR-096 §5: do Dense and SparseGqa produce +//! comparable outputs on the same input? +//! +//! Background. Sparse attention is *structurally* an approximation — +//! it skips edges that the local window + log-stride + landmark +//! pattern decided wouldn't matter. The §5 validation gate cares +//! about whether that approximation degrades downstream metrics +//! (contrastive loss, rank-1 accuracy, Spearman correlation), not +//! whether outputs are bit-equal. This file establishes the *direct* +//! output-level error envelope so the gate can be calibrated against +//! it. +//! +//! Two regimes: +//! +//! 1. **Sparse pattern is dense.** When window ≥ N AND block_size ≥ N +//! AND every position is global, sparse and dense visit the same +//! edge set. Output divergence then reflects only floating-point +//! accumulation order, which is a tight bound (~1e-5 for f32 sums +//! of ~100 terms at 0.1 magnitude). +//! +//! 2. **Sparse pattern is sparse.** Default config drops most edges +//! at long N. Output divergence here is the *real* approximation +//! error — and the §5 gate's tolerances apply downstream of it. + +use ruvllm_sparse_attention::Tensor3; +use wifi_densepose_temporal::{ + AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig, +}; + +fn make_qkv(seq: usize, heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) { + let mut q = Tensor3::zeros(seq, heads, dim); + let mut k = Tensor3::zeros(seq, heads, dim); + let mut v = Tensor3::zeros(seq, heads, dim); + for s in 0..seq { + for h in 0..heads { + for d in 0..dim { + let qv = ((s * 31 + h * 7 + d) as f32).sin() * 0.1; + let kv = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1; + q.set(s, h, d, qv); + k.set(s, h, d, kv); + v.set(s, h, d, kv * 0.5); + } + } + } + (q, k, v) +} + +fn max_abs_err(a: &Tensor3, b: &Tensor3) -> f32 { + let (s, h, d) = a.shape(); + assert_eq!((s, h, d), b.shape(), "shape mismatch"); + let mut max_err = 0.0f32; + for ti in 0..s { + for hi in 0..h { + for di in 0..d { + let e = (a.get(ti, hi, di) - b.get(ti, hi, di)).abs(); + if e > max_err { + max_err = e; + } + } + } + } + max_err +} + +fn mean_abs_err(a: &Tensor3, b: &Tensor3) -> f32 { + let (s, h, d) = a.shape(); + let mut sum = 0.0f32; + let mut n = 0usize; + for ti in 0..s { + for hi in 0..h { + for di in 0..d { + sum += (a.get(ti, hi, di) - b.get(ti, hi, di)).abs(); + n += 1; + } + } + } + sum / n.max(1) as f32 +} + +#[test] +fn dense_and_sparse_agree_when_sparse_pattern_is_dense() { + // Saturate the sparse pattern: window ≥ N means the local-window + // primitive includes every causal predecessor, so the attention + // edge set is identical to dense MHA's. The remaining gap is + // floating-point accumulation order (sparse goes + // window-then-stride-then-landmark, dense goes naive 0..i). + let seq = 32; + let heads = 4; + let dim = 16; + let (q, k, v) = make_qkv(seq, heads, dim); + + let dense_cfg = TemporalHeadConfig { + backend: TemporalBackendKind::Dense, + q_heads: heads, + kv_heads: heads, + head_dim: dim, + window: seq, // saturate + block_size: seq, + causal: true, + }; + let sparse_cfg = TemporalHeadConfig { + backend: TemporalBackendKind::SparseGqa, + ..dense_cfg.clone() + }; + + let dense = AetherTemporalHead::new(&dense_cfg).expect("dense"); + let sparse = AetherTemporalHead::new(&sparse_cfg).expect("sparse"); + + let d = dense.forward(&q, &k, &v).expect("dense forward"); + let s = sparse.forward(&q, &k, &v).expect("sparse forward"); + + let max_err = max_abs_err(&d, &s); + let mean_err = mean_abs_err(&d, &s); + + // 1e-4 covers a generous f32-summation-order envelope at 0.1 + // input magnitude. If this ever blows up, either the saturation + // assumption is wrong (window/block_size no longer covers + // everything) or the kernel changed semantics. + assert!( + max_err < 1.0e-4, + "saturated-pattern max_abs_err exceeds 1e-4: max={max_err} mean={mean_err}" + ); +} + +#[test] +fn dense_and_sparse_diverge_predictably_at_long_n() { + // The interesting case: real sparse pattern (window << N), real + // approximation. We don't assert a specific error bound here — + // that's what ADR-096 §5's validation gate calibrates. We only + // check the numbers come out finite and plausible (per-position + // outputs stay within a few × the input magnitude after + // attention-weighted averaging — softmax can't blow them up). + let seq = 256; + let heads = 4; + let dim = 16; + let (q, k, v) = make_qkv(seq, heads, dim); + + let dense_cfg = TemporalHeadConfig { + backend: TemporalBackendKind::Dense, + q_heads: heads, + kv_heads: heads, + head_dim: dim, + window: seq, // dense — placeholder; ignored by Dense backend + block_size: seq, + causal: true, + }; + let sparse_cfg = TemporalHeadConfig { + backend: TemporalBackendKind::SparseGqa, + q_heads: heads, + kv_heads: heads, + head_dim: dim, + window: 16, // realistic sparse window + block_size: 32, + causal: true, + }; + + let dense = AetherTemporalHead::new(&dense_cfg).expect("dense"); + let sparse = AetherTemporalHead::new(&sparse_cfg).expect("sparse"); + + let d = dense.forward(&q, &k, &v).expect("dense forward"); + let s = sparse.forward(&q, &k, &v).expect("sparse forward"); + + let max_err = max_abs_err(&d, &s); + let mean_err = mean_abs_err(&d, &s); + + // Sanity bounds. Inputs are scaled to 0.1, attention is a softmax + // average so outputs stay in roughly [-0.1, 0.1]. If max_err > 1.0 + // something is structurally broken (NaN, underflow, etc). + assert!( + max_err.is_finite() && mean_err.is_finite(), + "non-finite error: max={max_err} mean={mean_err}" + ); + assert!( + max_err < 1.0, + "implausibly large divergence: max={max_err} mean={mean_err}" + ); + + // Print the numbers so they're visible when running `cargo test -- + // --nocapture`. These are what ADR-096 §5's gate would calibrate + // against on real AETHER inputs. + eprintln!( + "dense_vs_sparse @ N={seq}, window=16, block=32: max_abs_err={max_err:.6e}, mean_abs_err={mean_err:.6e}" + ); +}