test(temporal): dense-vs-sparse numerical A/B baseline (ADR-096 §5, #513)
Establishes the kernel-level output-divergence envelope between the two backends — what §5's downstream-metric gate (contrastive loss, rank-1, Spearman) would calibrate against. Two regimes: 1. Saturated pattern (window ≥ N, block ≥ N): sparse and dense visit the same edge set, so divergence reflects only float accumulation order. **Asserted < 1e-4** at N=32, heads=4, dim=16. Tight bound. 2. Realistic sparse (window=16, block=32, N=256): real approximation, real divergence. **Measured max_abs_err = 5.22e-3, mean = 1.79e-3** on the deterministic test inputs. Sanity-checked finite + < 1.0 so structural breakage (NaN, softmax overflow) trips a panic, but the specific numbers are *baseline data* not a hard contract — the §5 gate cares about downstream task metrics, not bit-equality. Why this is in the test suite rather than a benchmark: - It runs in <0.2s, no need to gate behind --release. - The saturated-pattern bound IS a hard contract — if that breaks the kernel changed semantics in a way the API hides, and we want CI to catch it. - Printing the realistic-pattern numbers (eprintln, visible with --nocapture) gives a known-good reference point to compare future builds against. Test count is now 21/21 across the crate (6 smoke + 8 weight blob + 2 blob e2e + 3 streaming + 2 dense-vs-sparse). Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
4ea8457017
commit
2b903752c4
|
|
@ -0,0 +1,184 @@
|
|||
//! Numerical A/B test for ADR-096 §5: do Dense and SparseGqa produce
|
||||
//! comparable outputs on the same input?
|
||||
//!
|
||||
//! Background. Sparse attention is *structurally* an approximation —
|
||||
//! it skips edges that the local window + log-stride + landmark
|
||||
//! pattern decided wouldn't matter. The §5 validation gate cares
|
||||
//! about whether that approximation degrades downstream metrics
|
||||
//! (contrastive loss, rank-1 accuracy, Spearman correlation), not
|
||||
//! whether outputs are bit-equal. This file establishes the *direct*
|
||||
//! output-level error envelope so the gate can be calibrated against
|
||||
//! it.
|
||||
//!
|
||||
//! Two regimes:
|
||||
//!
|
||||
//! 1. **Sparse pattern is dense.** When window ≥ N AND block_size ≥ N
|
||||
//! AND every position is global, sparse and dense visit the same
|
||||
//! edge set. Output divergence then reflects only floating-point
|
||||
//! accumulation order, which is a tight bound (~1e-5 for f32 sums
|
||||
//! of ~100 terms at 0.1 magnitude).
|
||||
//!
|
||||
//! 2. **Sparse pattern is sparse.** Default config drops most edges
|
||||
//! at long N. Output divergence here is the *real* approximation
|
||||
//! error — and the §5 gate's tolerances apply downstream of it.
|
||||
|
||||
use ruvllm_sparse_attention::Tensor3;
|
||||
use wifi_densepose_temporal::{
|
||||
AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig,
|
||||
};
|
||||
|
||||
fn make_qkv(seq: usize, heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) {
|
||||
let mut q = Tensor3::zeros(seq, heads, dim);
|
||||
let mut k = Tensor3::zeros(seq, heads, dim);
|
||||
let mut v = Tensor3::zeros(seq, heads, dim);
|
||||
for s in 0..seq {
|
||||
for h in 0..heads {
|
||||
for d in 0..dim {
|
||||
let qv = ((s * 31 + h * 7 + d) as f32).sin() * 0.1;
|
||||
let kv = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1;
|
||||
q.set(s, h, d, qv);
|
||||
k.set(s, h, d, kv);
|
||||
v.set(s, h, d, kv * 0.5);
|
||||
}
|
||||
}
|
||||
}
|
||||
(q, k, v)
|
||||
}
|
||||
|
||||
fn max_abs_err(a: &Tensor3, b: &Tensor3) -> f32 {
|
||||
let (s, h, d) = a.shape();
|
||||
assert_eq!((s, h, d), b.shape(), "shape mismatch");
|
||||
let mut max_err = 0.0f32;
|
||||
for ti in 0..s {
|
||||
for hi in 0..h {
|
||||
for di in 0..d {
|
||||
let e = (a.get(ti, hi, di) - b.get(ti, hi, di)).abs();
|
||||
if e > max_err {
|
||||
max_err = e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
max_err
|
||||
}
|
||||
|
||||
fn mean_abs_err(a: &Tensor3, b: &Tensor3) -> f32 {
|
||||
let (s, h, d) = a.shape();
|
||||
let mut sum = 0.0f32;
|
||||
let mut n = 0usize;
|
||||
for ti in 0..s {
|
||||
for hi in 0..h {
|
||||
for di in 0..d {
|
||||
sum += (a.get(ti, hi, di) - b.get(ti, hi, di)).abs();
|
||||
n += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
sum / n.max(1) as f32
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dense_and_sparse_agree_when_sparse_pattern_is_dense() {
|
||||
// Saturate the sparse pattern: window ≥ N means the local-window
|
||||
// primitive includes every causal predecessor, so the attention
|
||||
// edge set is identical to dense MHA's. The remaining gap is
|
||||
// floating-point accumulation order (sparse goes
|
||||
// window-then-stride-then-landmark, dense goes naive 0..i).
|
||||
let seq = 32;
|
||||
let heads = 4;
|
||||
let dim = 16;
|
||||
let (q, k, v) = make_qkv(seq, heads, dim);
|
||||
|
||||
let dense_cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::Dense,
|
||||
q_heads: heads,
|
||||
kv_heads: heads,
|
||||
head_dim: dim,
|
||||
window: seq, // saturate
|
||||
block_size: seq,
|
||||
causal: true,
|
||||
};
|
||||
let sparse_cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::SparseGqa,
|
||||
..dense_cfg.clone()
|
||||
};
|
||||
|
||||
let dense = AetherTemporalHead::new(&dense_cfg).expect("dense");
|
||||
let sparse = AetherTemporalHead::new(&sparse_cfg).expect("sparse");
|
||||
|
||||
let d = dense.forward(&q, &k, &v).expect("dense forward");
|
||||
let s = sparse.forward(&q, &k, &v).expect("sparse forward");
|
||||
|
||||
let max_err = max_abs_err(&d, &s);
|
||||
let mean_err = mean_abs_err(&d, &s);
|
||||
|
||||
// 1e-4 covers a generous f32-summation-order envelope at 0.1
|
||||
// input magnitude. If this ever blows up, either the saturation
|
||||
// assumption is wrong (window/block_size no longer covers
|
||||
// everything) or the kernel changed semantics.
|
||||
assert!(
|
||||
max_err < 1.0e-4,
|
||||
"saturated-pattern max_abs_err exceeds 1e-4: max={max_err} mean={mean_err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dense_and_sparse_diverge_predictably_at_long_n() {
|
||||
// The interesting case: real sparse pattern (window << N), real
|
||||
// approximation. We don't assert a specific error bound here —
|
||||
// that's what ADR-096 §5's validation gate calibrates. We only
|
||||
// check the numbers come out finite and plausible (per-position
|
||||
// outputs stay within a few × the input magnitude after
|
||||
// attention-weighted averaging — softmax can't blow them up).
|
||||
let seq = 256;
|
||||
let heads = 4;
|
||||
let dim = 16;
|
||||
let (q, k, v) = make_qkv(seq, heads, dim);
|
||||
|
||||
let dense_cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::Dense,
|
||||
q_heads: heads,
|
||||
kv_heads: heads,
|
||||
head_dim: dim,
|
||||
window: seq, // dense — placeholder; ignored by Dense backend
|
||||
block_size: seq,
|
||||
causal: true,
|
||||
};
|
||||
let sparse_cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::SparseGqa,
|
||||
q_heads: heads,
|
||||
kv_heads: heads,
|
||||
head_dim: dim,
|
||||
window: 16, // realistic sparse window
|
||||
block_size: 32,
|
||||
causal: true,
|
||||
};
|
||||
|
||||
let dense = AetherTemporalHead::new(&dense_cfg).expect("dense");
|
||||
let sparse = AetherTemporalHead::new(&sparse_cfg).expect("sparse");
|
||||
|
||||
let d = dense.forward(&q, &k, &v).expect("dense forward");
|
||||
let s = sparse.forward(&q, &k, &v).expect("sparse forward");
|
||||
|
||||
let max_err = max_abs_err(&d, &s);
|
||||
let mean_err = mean_abs_err(&d, &s);
|
||||
|
||||
// Sanity bounds. Inputs are scaled to 0.1, attention is a softmax
|
||||
// average so outputs stay in roughly [-0.1, 0.1]. If max_err > 1.0
|
||||
// something is structurally broken (NaN, underflow, etc).
|
||||
assert!(
|
||||
max_err.is_finite() && mean_err.is_finite(),
|
||||
"non-finite error: max={max_err} mean={mean_err}"
|
||||
);
|
||||
assert!(
|
||||
max_err < 1.0,
|
||||
"implausibly large divergence: max={max_err} mean={mean_err}"
|
||||
);
|
||||
|
||||
// Print the numbers so they're visible when running `cargo test --
|
||||
// --nocapture`. These are what ADR-096 §5's gate would calibrate
|
||||
// against on real AETHER inputs.
|
||||
eprintln!(
|
||||
"dense_vs_sparse @ N={seq}, window=16, block=32: max_abs_err={max_err:.6e}, mean_abs_err={mean_err:.6e}"
|
||||
);
|
||||
}
|
||||
Loading…
Reference in New Issue