From 49e57efcece644aff1ec8aa80e30718dda2a487e Mon Sep 17 00:00:00 2001 From: ruv Date: Fri, 8 May 2026 11:57:31 -0400 Subject: [PATCH] =?UTF-8?q?feat(temporal):=20streaming=20step()=20+=20KvCa?= =?UTF-8?q?che=20(ADR-096=20=C2=A73.2,=20#513)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The structural advantage that's the entire point of ADR-096: O(log T) per new token via decode_step against an accumulated KvCache, vs O(N²) recompute for dense MHA. This commit lands the API and proves the numerical equivalence at the last position. API: - AetherTemporalHead::step(q_new, k_new, v_new, &mut cache) Single-token decode. Appends (k_new, v_new) to cache, runs decode_step(q_new) against the now-updated cache, returns the new position's output. - AetherTemporalHead::make_cache(capacity) Convenience constructor — caller doesn't need to import ruvllm_sparse_attention to size a cache. Per ADR-096 §8.5 the natural lifetime is per-PoseTrack (re-ID) or per-session (online classification); when the track drops, drop the cache. - KvCache re-exported at the crate root. Contract: - q_new/k_new/v_new must each have seq == 1. Multi-token q is the prefill path (forward), not decode_step. - Cache lifetime is the caller's. The crate enforces shape via make_cache so callers can't mismatch kv_heads / head_dim / block_size. - KvCache fill is the caller's problem. Upstream H2O heavy-hitter eviction is opt-in; this crate's wrapper doesn't pre-pick a policy. Tests (18/18 total now passing): - streaming_step_matches_forward_at_last_position — central claim: 16-token sequence, append k/v one at a time via step(), compare the streamed last-token output to forward(full Q,K,V)[N-1]. max_abs_err < 1e-3 (currently passes well under that bound for the 0.1-magnitude activations the test uses). - step_rejects_multi_token_q — contract enforcement. - make_cache_returns_kvcache_with_correct_shape — wiring smoke, confirms (capacity, kv_heads, dim, block_size) ordering is correct through the make_cache wrapper. Test config uses MHA shape (q_heads == kv_heads) because the upstream decode_step is wired to the MHA branch; the GQA decode path is on upstream's roadmap and lands in a separate ADR-096 follow-up when it does. Co-Authored-By: claude-flow --- v2/crates/wifi-densepose-temporal/src/lib.rs | 35 ++++- .../wifi-densepose-temporal/src/sparse.rs | 46 +++++- .../tests/streaming.rs | 139 ++++++++++++++++++ 3 files changed, 216 insertions(+), 4 deletions(-) create mode 100644 v2/crates/wifi-densepose-temporal/tests/streaming.rs diff --git a/v2/crates/wifi-densepose-temporal/src/lib.rs b/v2/crates/wifi-densepose-temporal/src/lib.rs index 744b6e87..64f06129 100644 --- a/v2/crates/wifi-densepose-temporal/src/lib.rs +++ b/v2/crates/wifi-densepose-temporal/src/lib.rs @@ -22,9 +22,9 @@ pub use weights::{ WEIGHT_BLOB_VERSION, }; -// Re-export the upstream Tensor3 so callers don't need a direct -// `ruvllm_sparse_attention` dep. -pub use ruvllm_sparse_attention::Tensor3; +// Re-export the upstream Tensor3 + KvCache so callers don't need a +// direct `ruvllm_sparse_attention` dep. +pub use ruvllm_sparse_attention::{KvCache, Tensor3}; /// Thin facade so callers can pick a backend by name. /// @@ -62,4 +62,33 @@ impl AetherTemporalHead { AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented), } } + + /// Streaming decode (ADR-096 §3.2). Caller owns the `cache`; the + /// natural lifetime is per-tracked-person (one cache per + /// `PoseTrack`, dropped when the track evicts). + /// + /// Returns the attention output for the single new token. Caller + /// is responsible for downstream pooling / classifier head. + pub fn step( + &self, + q_new: &Tensor3, + k_new: &Tensor3, + v_new: &Tensor3, + cache: &mut KvCache, + ) -> Result { + match self { + AetherTemporalHead::SparseGqa(h) => h.step(q_new, k_new, v_new, cache), + AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented), + } + } + + /// Allocate a `KvCache` sized correctly for this head. Convenience + /// wrapper so AETHER's `pose_tracker.rs` doesn't need to import + /// the upstream crate. + pub fn make_cache(&self, capacity: usize) -> Result { + match self { + AetherTemporalHead::SparseGqa(h) => Ok(h.make_cache(capacity)), + AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented), + } + } } diff --git a/v2/crates/wifi-densepose-temporal/src/sparse.rs b/v2/crates/wifi-densepose-temporal/src/sparse.rs index 17436ec6..edab0022 100644 --- a/v2/crates/wifi-densepose-temporal/src/sparse.rs +++ b/v2/crates/wifi-densepose-temporal/src/sparse.rs @@ -1,5 +1,5 @@ use ruvllm_sparse_attention::{ - AttentionBackend, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3, + AttentionBackend, KvCache, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3, }; use crate::{TemporalError, TemporalHeadConfig}; @@ -57,6 +57,50 @@ impl SparseGqaHead { Ok(self.attn.forward_gqa(q, k, v)?) } } + + /// Streaming decode for re-ID and online classification (ADR-096 §3.2). + /// + /// Given one new token's q/k/v, append (k, v) to `cache` and return + /// the attention output for that one position against the full + /// accumulated history. Cost is O(log T) per step against a cache + /// of capacity T — the structural advantage over dense MHA's O(N²) + /// recompute that ADR-096 specifically calls out as the + /// dense-MHA-cannot-follow path. + /// + /// Cache lifetime is owned by the caller. Per ADR-096 §8.5 the + /// natural place is one cache per `PoseTrack` (re-ID) or one cache + /// per active session (online classification). When the track is + /// dropped, drop the cache. + pub fn step( + &self, + q_new: &Tensor3, + k_new: &Tensor3, + v_new: &Tensor3, + cache: &mut KvCache, + ) -> Result { + if q_new.seq != 1 || k_new.seq != 1 || v_new.seq != 1 { + return Err(TemporalError::InvalidConfig( + "step() requires single-token q/k/v (seq == 1 each)", + )); + } + // Append must succeed before decode_step sees the cache; if + // the cache fills, the caller is responsible for eviction or + // resetting per ADR-096 §3.2 (H2O heavy-hitter eviction is + // available upstream but kept opt-in). + cache.try_append(k_new, v_new)?; + Ok(self.attn.decode_step(q_new, cache)?) + } + + /// Construct a KvCache sized for this head's shape. Convenience + /// so callers don't need to import the upstream crate directly. + pub fn make_cache(&self, capacity: usize) -> KvCache { + KvCache::new( + capacity, + self.cfg.kv_heads, + self.cfg.head_dim, + self.cfg.block_size, + ) + } } /// Always treat token 0 as a global anchor — AETHER's contrastive diff --git a/v2/crates/wifi-densepose-temporal/tests/streaming.rs b/v2/crates/wifi-densepose-temporal/tests/streaming.rs new file mode 100644 index 00000000..9f3368aa --- /dev/null +++ b/v2/crates/wifi-densepose-temporal/tests/streaming.rs @@ -0,0 +1,139 @@ +//! ADR-096 §3.2 streaming-decode test: token-by-token `step()` against +//! a `KvCache` should match a single-shot `forward()` over the same +//! Q/K/V at the final position. This is the structural advantage +//! dense MHA can't follow — proving it stays correct under streaming +//! is what the §5 validation gate would care about most. + +use wifi_densepose_temporal::{ + AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig, Tensor3, +}; + +fn make_qkv(seq: usize, q_heads: usize, kv_heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) { + let mut q = Tensor3::zeros(seq, q_heads, dim); + let mut k = Tensor3::zeros(seq, kv_heads, dim); + let mut v = Tensor3::zeros(seq, kv_heads, dim); + for s in 0..seq { + for h in 0..q_heads { + for d in 0..dim { + let val = ((s * 31 + h * 7 + d) as f32).sin() * 0.1; + q.set(s, h, d, val); + } + } + for h in 0..kv_heads { + for d in 0..dim { + let val = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1; + k.set(s, h, d, val); + v.set(s, h, d, val * 0.5); + } + } + } + (q, k, v) +} + +fn slice_token(t: &Tensor3, idx: usize) -> Tensor3 { + let (_, heads, dim) = t.shape(); + let mut out = Tensor3::zeros(1, heads, dim); + for h in 0..heads { + for d in 0..dim { + out.set(0, h, d, t.get(idx, h, d)); + } + } + out +} + +fn config_mha_small() -> TemporalHeadConfig { + // Equal q/k heads forces the `forward` MHA branch — `decode_step` + // upstream is wired to this branch, not the GQA branch (which has + // its own decode path coming in upstream's roadmap). + TemporalHeadConfig { + backend: TemporalBackendKind::SparseGqa, + q_heads: 2, + kv_heads: 2, + head_dim: 16, + window: 8, + block_size: 4, + causal: true, + } +} + +#[test] +fn streaming_step_matches_forward_at_last_position() { + let cfg = config_mha_small(); + let head = AetherTemporalHead::new(&cfg).expect("construct"); + + let seq = 16usize; + let (q, k, v) = make_qkv(seq, cfg.q_heads, cfg.kv_heads, cfg.head_dim); + + // Reference: single-shot forward over the full sequence. + let reference = head.forward(&q, &k, &v).expect("forward"); + + // Streaming: append k/v one token at a time, decode the new q. + let mut cache = head.make_cache(seq).expect("cache"); + let mut last_out: Option = None; + for t in 0..seq { + let qt = slice_token(&q, t); + let kt = slice_token(&k, t); + let vt = slice_token(&v, t); + last_out = Some(head.step(&qt, &kt, &vt, &mut cache).expect("step")); + } + let streamed = last_out.expect("at least one step"); + + // Compare the streamed last-token output to the reference's + // last-token output. Tolerance is generous because numerical + // accumulation differs between the two paths even at exact + // mathematical equivalence. + let (s_seq, s_heads, s_dim) = streamed.shape(); + assert_eq!((s_seq, s_heads, s_dim), (1, cfg.q_heads, cfg.head_dim)); + let mut max_abs_err: f32 = 0.0; + for h in 0..cfg.q_heads { + for d in 0..cfg.head_dim { + let a = streamed.get(0, h, d); + let b = reference.get(seq - 1, h, d); + let err = (a - b).abs(); + if err > max_abs_err { + max_abs_err = err; + } + } + } + // 1e-3 absolute is a comfortable bound for activations of this + // magnitude (~0.1 input scale). Tighten if the kernel ever + // promises closer match. + assert!( + max_abs_err < 1.0e-3, + "streaming/forward divergence at last token exceeds 1e-3: max_abs_err = {max_abs_err}" + ); +} + +#[test] +fn step_rejects_multi_token_q() { + let cfg = config_mha_small(); + let head = AetherTemporalHead::new(&cfg).expect("construct"); + let mut cache = head.make_cache(8).expect("cache"); + + // Build a 2-token Q/K/V — `step` must reject (its contract is + // single-token decode). + let (q, k, v) = make_qkv(2, cfg.q_heads, cfg.kv_heads, cfg.head_dim); + let err = head.step(&q, &k, &v, &mut cache).err().expect("rejected"); + let s = format!("{err}"); + assert!( + s.contains("single-token") || s.to_lowercase().contains("seq"), + "expected single-token rejection, got: {s}" + ); +} + +#[test] +fn make_cache_returns_kvcache_with_correct_shape() { + // Smoke test that the convenience wrapper plumbs the right dims + // into KvCache::new — the upstream constructor takes + // (capacity, kv_heads, dim, block_size) and we want to make sure + // we're not transposing any of those. + let cfg = config_mha_small(); + let head = AetherTemporalHead::new(&cfg).expect("construct"); + let mut cache = head.make_cache(32).expect("cache"); + + // Append one token shaped for kv_heads × head_dim — should not error. + let (_, k, v) = make_qkv(1, cfg.q_heads, cfg.kv_heads, cfg.head_dim); + let kt = slice_token(&k, 0); + let vt = slice_token(&v, 0); + cache.try_append(&kt, &vt).expect("append shape ok"); +}