wifi-densepose/v2/crates/wifi-densepose-temporal/tests/streaming.rs

140 lines
5.0 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! ADR-096 §3.2 streaming-decode test: token-by-token `step()` against
//! a `KvCache` should match a single-shot `forward()` over the same
//! Q/K/V at the final position. This is the structural advantage
//! dense MHA can't follow — proving it stays correct under streaming
//! is what the §5 validation gate would care about most.
use wifi_densepose_temporal::{
AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig, Tensor3,
};
fn make_qkv(seq: usize, q_heads: usize, kv_heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) {
let mut q = Tensor3::zeros(seq, q_heads, dim);
let mut k = Tensor3::zeros(seq, kv_heads, dim);
let mut v = Tensor3::zeros(seq, kv_heads, dim);
for s in 0..seq {
for h in 0..q_heads {
for d in 0..dim {
let val = ((s * 31 + h * 7 + d) as f32).sin() * 0.1;
q.set(s, h, d, val);
}
}
for h in 0..kv_heads {
for d in 0..dim {
let val = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1;
k.set(s, h, d, val);
v.set(s, h, d, val * 0.5);
}
}
}
(q, k, v)
}
fn slice_token(t: &Tensor3, idx: usize) -> Tensor3 {
let (_, heads, dim) = t.shape();
let mut out = Tensor3::zeros(1, heads, dim);
for h in 0..heads {
for d in 0..dim {
out.set(0, h, d, t.get(idx, h, d));
}
}
out
}
fn config_mha_small() -> TemporalHeadConfig {
// Equal q/k heads forces the `forward` MHA branch — `decode_step`
// upstream is wired to this branch, not the GQA branch (which has
// its own decode path coming in upstream's roadmap).
TemporalHeadConfig {
backend: TemporalBackendKind::SparseGqa,
q_heads: 2,
kv_heads: 2,
head_dim: 16,
window: 8,
block_size: 4,
causal: true,
}
}
#[test]
fn streaming_step_matches_forward_at_last_position() {
let cfg = config_mha_small();
let head = AetherTemporalHead::new(&cfg).expect("construct");
let seq = 16usize;
let (q, k, v) = make_qkv(seq, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
// Reference: single-shot forward over the full sequence.
let reference = head.forward(&q, &k, &v).expect("forward");
// Streaming: append k/v one token at a time, decode the new q.
let mut cache = head.make_cache(seq).expect("cache");
let mut last_out: Option<Tensor3> = None;
for t in 0..seq {
let qt = slice_token(&q, t);
let kt = slice_token(&k, t);
let vt = slice_token(&v, t);
last_out = Some(head.step(&qt, &kt, &vt, &mut cache).expect("step"));
}
let streamed = last_out.expect("at least one step");
// Compare the streamed last-token output to the reference's
// last-token output. Tolerance is generous because numerical
// accumulation differs between the two paths even at exact
// mathematical equivalence.
let (s_seq, s_heads, s_dim) = streamed.shape();
assert_eq!((s_seq, s_heads, s_dim), (1, cfg.q_heads, cfg.head_dim));
let mut max_abs_err: f32 = 0.0;
for h in 0..cfg.q_heads {
for d in 0..cfg.head_dim {
let a = streamed.get(0, h, d);
let b = reference.get(seq - 1, h, d);
let err = (a - b).abs();
if err > max_abs_err {
max_abs_err = err;
}
}
}
// 1e-3 absolute is a comfortable bound for activations of this
// magnitude (~0.1 input scale). Tighten if the kernel ever
// promises closer match.
assert!(
max_abs_err < 1.0e-3,
"streaming/forward divergence at last token exceeds 1e-3: max_abs_err = {max_abs_err}"
);
}
#[test]
fn step_rejects_multi_token_q() {
let cfg = config_mha_small();
let head = AetherTemporalHead::new(&cfg).expect("construct");
let mut cache = head.make_cache(8).expect("cache");
// Build a 2-token Q/K/V — `step` must reject (its contract is
// single-token decode).
let (q, k, v) = make_qkv(2, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
let err = head.step(&q, &k, &v, &mut cache).err().expect("rejected");
let s = format!("{err}");
assert!(
s.contains("single-token") || s.to_lowercase().contains("seq"),
"expected single-token rejection, got: {s}"
);
}
#[test]
fn make_cache_returns_kvcache_with_correct_shape() {
// Smoke test that the convenience wrapper plumbs the right dims
// into KvCache::new — the upstream constructor takes
// (capacity, kv_heads, dim, block_size) and we want to make sure
// we're not transposing any of those.
let cfg = config_mha_small();
let head = AetherTemporalHead::new(&cfg).expect("construct");
let mut cache = head.make_cache(32).expect("cache");
// Append one token shaped for kv_heads × head_dim — should not error.
let (_, k, v) = make_qkv(1, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
let kt = slice_token(&k, 0);
let vt = slice_token(&v, 0);
cache.try_append(&kt, &vt).expect("append shape ok");
}