feat(temporal): streaming step() + KvCache (ADR-096 §3.2, #513)

The structural advantage that's the entire point of ADR-096: O(log T)
per new token via decode_step against an accumulated KvCache, vs
O(N²) recompute for dense MHA. This commit lands the API and proves
the numerical equivalence at the last position.

API:
- AetherTemporalHead::step(q_new, k_new, v_new, &mut cache)
  Single-token decode. Appends (k_new, v_new) to cache, runs
  decode_step(q_new) against the now-updated cache, returns the new
  position's output.
- AetherTemporalHead::make_cache(capacity)
  Convenience constructor — caller doesn't need to import
  ruvllm_sparse_attention to size a cache. Per ADR-096 §8.5 the
  natural lifetime is per-PoseTrack (re-ID) or per-session (online
  classification); when the track drops, drop the cache.
- KvCache re-exported at the crate root.

Contract:
- q_new/k_new/v_new must each have seq == 1. Multi-token q is the
  prefill path (forward), not decode_step.
- Cache lifetime is the caller's. The crate enforces shape via
  make_cache so callers can't mismatch kv_heads / head_dim / block_size.
- KvCache fill is the caller's problem. Upstream H2O heavy-hitter
  eviction is opt-in; this crate's wrapper doesn't pre-pick a policy.

Tests (18/18 total now passing):
- streaming_step_matches_forward_at_last_position — central claim:
  16-token sequence, append k/v one at a time via step(), compare
  the streamed last-token output to forward(full Q,K,V)[N-1].
  max_abs_err < 1e-3 (currently passes well under that bound for
  the 0.1-magnitude activations the test uses).
- step_rejects_multi_token_q — contract enforcement.
- make_cache_returns_kvcache_with_correct_shape — wiring smoke,
  confirms (capacity, kv_heads, dim, block_size) ordering is correct
  through the make_cache wrapper.

Test config uses MHA shape (q_heads == kv_heads) because the upstream
decode_step is wired to the MHA branch; the GQA decode path is on
upstream's roadmap and lands in a separate ADR-096 follow-up when it
does.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-05-08 11:57:31 -04:00
parent 3a5fe5e0de
commit 49e57efcec
3 changed files with 216 additions and 4 deletions

View File

@ -22,9 +22,9 @@ pub use weights::{
WEIGHT_BLOB_VERSION,
};
// Re-export the upstream Tensor3 so callers don't need a direct
// `ruvllm_sparse_attention` dep.
pub use ruvllm_sparse_attention::Tensor3;
// Re-export the upstream Tensor3 + KvCache so callers don't need a
// direct `ruvllm_sparse_attention` dep.
pub use ruvllm_sparse_attention::{KvCache, Tensor3};
/// Thin facade so callers can pick a backend by name.
///
@ -62,4 +62,33 @@ impl AetherTemporalHead {
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
}
}
/// Streaming decode (ADR-096 §3.2). Caller owns the `cache`; the
/// natural lifetime is per-tracked-person (one cache per
/// `PoseTrack`, dropped when the track evicts).
///
/// Returns the attention output for the single new token. Caller
/// is responsible for downstream pooling / classifier head.
pub fn step(
&self,
q_new: &Tensor3,
k_new: &Tensor3,
v_new: &Tensor3,
cache: &mut KvCache,
) -> Result<Tensor3, TemporalError> {
match self {
AetherTemporalHead::SparseGqa(h) => h.step(q_new, k_new, v_new, cache),
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
}
}
/// Allocate a `KvCache` sized correctly for this head. Convenience
/// wrapper so AETHER's `pose_tracker.rs` doesn't need to import
/// the upstream crate.
pub fn make_cache(&self, capacity: usize) -> Result<KvCache, TemporalError> {
match self {
AetherTemporalHead::SparseGqa(h) => Ok(h.make_cache(capacity)),
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
}
}
}

View File

@ -1,5 +1,5 @@
use ruvllm_sparse_attention::{
AttentionBackend, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
AttentionBackend, KvCache, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
};
use crate::{TemporalError, TemporalHeadConfig};
@ -57,6 +57,50 @@ impl SparseGqaHead {
Ok(self.attn.forward_gqa(q, k, v)?)
}
}
/// Streaming decode for re-ID and online classification (ADR-096 §3.2).
///
/// Given one new token's q/k/v, append (k, v) to `cache` and return
/// the attention output for that one position against the full
/// accumulated history. Cost is O(log T) per step against a cache
/// of capacity T — the structural advantage over dense MHA's O(N²)
/// recompute that ADR-096 specifically calls out as the
/// dense-MHA-cannot-follow path.
///
/// Cache lifetime is owned by the caller. Per ADR-096 §8.5 the
/// natural place is one cache per `PoseTrack` (re-ID) or one cache
/// per active session (online classification). When the track is
/// dropped, drop the cache.
pub fn step(
&self,
q_new: &Tensor3,
k_new: &Tensor3,
v_new: &Tensor3,
cache: &mut KvCache,
) -> Result<Tensor3, TemporalError> {
if q_new.seq != 1 || k_new.seq != 1 || v_new.seq != 1 {
return Err(TemporalError::InvalidConfig(
"step() requires single-token q/k/v (seq == 1 each)",
));
}
// Append must succeed before decode_step sees the cache; if
// the cache fills, the caller is responsible for eviction or
// resetting per ADR-096 §3.2 (H2O heavy-hitter eviction is
// available upstream but kept opt-in).
cache.try_append(k_new, v_new)?;
Ok(self.attn.decode_step(q_new, cache)?)
}
/// Construct a KvCache sized for this head's shape. Convenience
/// so callers don't need to import the upstream crate directly.
pub fn make_cache(&self, capacity: usize) -> KvCache {
KvCache::new(
capacity,
self.cfg.kv_heads,
self.cfg.head_dim,
self.cfg.block_size,
)
}
}
/// Always treat token 0 as a global anchor — AETHER's contrastive

View File

@ -0,0 +1,139 @@
//! ADR-096 §3.2 streaming-decode test: token-by-token `step()` against
//! a `KvCache` should match a single-shot `forward()` over the same
//! Q/K/V at the final position. This is the structural advantage
//! dense MHA can't follow — proving it stays correct under streaming
//! is what the §5 validation gate would care about most.
use wifi_densepose_temporal::{
AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig, Tensor3,
};
fn make_qkv(seq: usize, q_heads: usize, kv_heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) {
let mut q = Tensor3::zeros(seq, q_heads, dim);
let mut k = Tensor3::zeros(seq, kv_heads, dim);
let mut v = Tensor3::zeros(seq, kv_heads, dim);
for s in 0..seq {
for h in 0..q_heads {
for d in 0..dim {
let val = ((s * 31 + h * 7 + d) as f32).sin() * 0.1;
q.set(s, h, d, val);
}
}
for h in 0..kv_heads {
for d in 0..dim {
let val = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1;
k.set(s, h, d, val);
v.set(s, h, d, val * 0.5);
}
}
}
(q, k, v)
}
fn slice_token(t: &Tensor3, idx: usize) -> Tensor3 {
let (_, heads, dim) = t.shape();
let mut out = Tensor3::zeros(1, heads, dim);
for h in 0..heads {
for d in 0..dim {
out.set(0, h, d, t.get(idx, h, d));
}
}
out
}
fn config_mha_small() -> TemporalHeadConfig {
// Equal q/k heads forces the `forward` MHA branch — `decode_step`
// upstream is wired to this branch, not the GQA branch (which has
// its own decode path coming in upstream's roadmap).
TemporalHeadConfig {
backend: TemporalBackendKind::SparseGqa,
q_heads: 2,
kv_heads: 2,
head_dim: 16,
window: 8,
block_size: 4,
causal: true,
}
}
#[test]
fn streaming_step_matches_forward_at_last_position() {
let cfg = config_mha_small();
let head = AetherTemporalHead::new(&cfg).expect("construct");
let seq = 16usize;
let (q, k, v) = make_qkv(seq, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
// Reference: single-shot forward over the full sequence.
let reference = head.forward(&q, &k, &v).expect("forward");
// Streaming: append k/v one token at a time, decode the new q.
let mut cache = head.make_cache(seq).expect("cache");
let mut last_out: Option<Tensor3> = None;
for t in 0..seq {
let qt = slice_token(&q, t);
let kt = slice_token(&k, t);
let vt = slice_token(&v, t);
last_out = Some(head.step(&qt, &kt, &vt, &mut cache).expect("step"));
}
let streamed = last_out.expect("at least one step");
// Compare the streamed last-token output to the reference's
// last-token output. Tolerance is generous because numerical
// accumulation differs between the two paths even at exact
// mathematical equivalence.
let (s_seq, s_heads, s_dim) = streamed.shape();
assert_eq!((s_seq, s_heads, s_dim), (1, cfg.q_heads, cfg.head_dim));
let mut max_abs_err: f32 = 0.0;
for h in 0..cfg.q_heads {
for d in 0..cfg.head_dim {
let a = streamed.get(0, h, d);
let b = reference.get(seq - 1, h, d);
let err = (a - b).abs();
if err > max_abs_err {
max_abs_err = err;
}
}
}
// 1e-3 absolute is a comfortable bound for activations of this
// magnitude (~0.1 input scale). Tighten if the kernel ever
// promises closer match.
assert!(
max_abs_err < 1.0e-3,
"streaming/forward divergence at last token exceeds 1e-3: max_abs_err = {max_abs_err}"
);
}
#[test]
fn step_rejects_multi_token_q() {
let cfg = config_mha_small();
let head = AetherTemporalHead::new(&cfg).expect("construct");
let mut cache = head.make_cache(8).expect("cache");
// Build a 2-token Q/K/V — `step` must reject (its contract is
// single-token decode).
let (q, k, v) = make_qkv(2, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
let err = head.step(&q, &k, &v, &mut cache).err().expect("rejected");
let s = format!("{err}");
assert!(
s.contains("single-token") || s.to_lowercase().contains("seq"),
"expected single-token rejection, got: {s}"
);
}
#[test]
fn make_cache_returns_kvcache_with_correct_shape() {
// Smoke test that the convenience wrapper plumbs the right dims
// into KvCache::new — the upstream constructor takes
// (capacity, kv_heads, dim, block_size) and we want to make sure
// we're not transposing any of those.
let cfg = config_mha_small();
let head = AetherTemporalHead::new(&cfg).expect("construct");
let mut cache = head.make_cache(32).expect("cache");
// Append one token shaped for kv_heads × head_dim — should not error.
let (_, k, v) = make_qkv(1, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
let kt = slice_token(&k, 0);
let vt = slice_token(&v, 0);
cache.try_append(&kt, &vt).expect("append shape ok");
}