feat(temporal): streaming step() + KvCache (ADR-096 §3.2, #513)
The structural advantage that's the entire point of ADR-096: O(log T) per new token via decode_step against an accumulated KvCache, vs O(N²) recompute for dense MHA. This commit lands the API and proves the numerical equivalence at the last position. API: - AetherTemporalHead::step(q_new, k_new, v_new, &mut cache) Single-token decode. Appends (k_new, v_new) to cache, runs decode_step(q_new) against the now-updated cache, returns the new position's output. - AetherTemporalHead::make_cache(capacity) Convenience constructor — caller doesn't need to import ruvllm_sparse_attention to size a cache. Per ADR-096 §8.5 the natural lifetime is per-PoseTrack (re-ID) or per-session (online classification); when the track drops, drop the cache. - KvCache re-exported at the crate root. Contract: - q_new/k_new/v_new must each have seq == 1. Multi-token q is the prefill path (forward), not decode_step. - Cache lifetime is the caller's. The crate enforces shape via make_cache so callers can't mismatch kv_heads / head_dim / block_size. - KvCache fill is the caller's problem. Upstream H2O heavy-hitter eviction is opt-in; this crate's wrapper doesn't pre-pick a policy. Tests (18/18 total now passing): - streaming_step_matches_forward_at_last_position — central claim: 16-token sequence, append k/v one at a time via step(), compare the streamed last-token output to forward(full Q,K,V)[N-1]. max_abs_err < 1e-3 (currently passes well under that bound for the 0.1-magnitude activations the test uses). - step_rejects_multi_token_q — contract enforcement. - make_cache_returns_kvcache_with_correct_shape — wiring smoke, confirms (capacity, kv_heads, dim, block_size) ordering is correct through the make_cache wrapper. Test config uses MHA shape (q_heads == kv_heads) because the upstream decode_step is wired to the MHA branch; the GQA decode path is on upstream's roadmap and lands in a separate ADR-096 follow-up when it does. Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
3a5fe5e0de
commit
49e57efcec
|
|
@ -22,9 +22,9 @@ pub use weights::{
|
||||||
WEIGHT_BLOB_VERSION,
|
WEIGHT_BLOB_VERSION,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Re-export the upstream Tensor3 so callers don't need a direct
|
// Re-export the upstream Tensor3 + KvCache so callers don't need a
|
||||||
// `ruvllm_sparse_attention` dep.
|
// direct `ruvllm_sparse_attention` dep.
|
||||||
pub use ruvllm_sparse_attention::Tensor3;
|
pub use ruvllm_sparse_attention::{KvCache, Tensor3};
|
||||||
|
|
||||||
/// Thin facade so callers can pick a backend by name.
|
/// Thin facade so callers can pick a backend by name.
|
||||||
///
|
///
|
||||||
|
|
@ -62,4 +62,33 @@ impl AetherTemporalHead {
|
||||||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Streaming decode (ADR-096 §3.2). Caller owns the `cache`; the
|
||||||
|
/// natural lifetime is per-tracked-person (one cache per
|
||||||
|
/// `PoseTrack`, dropped when the track evicts).
|
||||||
|
///
|
||||||
|
/// Returns the attention output for the single new token. Caller
|
||||||
|
/// is responsible for downstream pooling / classifier head.
|
||||||
|
pub fn step(
|
||||||
|
&self,
|
||||||
|
q_new: &Tensor3,
|
||||||
|
k_new: &Tensor3,
|
||||||
|
v_new: &Tensor3,
|
||||||
|
cache: &mut KvCache,
|
||||||
|
) -> Result<Tensor3, TemporalError> {
|
||||||
|
match self {
|
||||||
|
AetherTemporalHead::SparseGqa(h) => h.step(q_new, k_new, v_new, cache),
|
||||||
|
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allocate a `KvCache` sized correctly for this head. Convenience
|
||||||
|
/// wrapper so AETHER's `pose_tracker.rs` doesn't need to import
|
||||||
|
/// the upstream crate.
|
||||||
|
pub fn make_cache(&self, capacity: usize) -> Result<KvCache, TemporalError> {
|
||||||
|
match self {
|
||||||
|
AetherTemporalHead::SparseGqa(h) => Ok(h.make_cache(capacity)),
|
||||||
|
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use ruvllm_sparse_attention::{
|
use ruvllm_sparse_attention::{
|
||||||
AttentionBackend, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
|
AttentionBackend, KvCache, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{TemporalError, TemporalHeadConfig};
|
use crate::{TemporalError, TemporalHeadConfig};
|
||||||
|
|
@ -57,6 +57,50 @@ impl SparseGqaHead {
|
||||||
Ok(self.attn.forward_gqa(q, k, v)?)
|
Ok(self.attn.forward_gqa(q, k, v)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Streaming decode for re-ID and online classification (ADR-096 §3.2).
|
||||||
|
///
|
||||||
|
/// Given one new token's q/k/v, append (k, v) to `cache` and return
|
||||||
|
/// the attention output for that one position against the full
|
||||||
|
/// accumulated history. Cost is O(log T) per step against a cache
|
||||||
|
/// of capacity T — the structural advantage over dense MHA's O(N²)
|
||||||
|
/// recompute that ADR-096 specifically calls out as the
|
||||||
|
/// dense-MHA-cannot-follow path.
|
||||||
|
///
|
||||||
|
/// Cache lifetime is owned by the caller. Per ADR-096 §8.5 the
|
||||||
|
/// natural place is one cache per `PoseTrack` (re-ID) or one cache
|
||||||
|
/// per active session (online classification). When the track is
|
||||||
|
/// dropped, drop the cache.
|
||||||
|
pub fn step(
|
||||||
|
&self,
|
||||||
|
q_new: &Tensor3,
|
||||||
|
k_new: &Tensor3,
|
||||||
|
v_new: &Tensor3,
|
||||||
|
cache: &mut KvCache,
|
||||||
|
) -> Result<Tensor3, TemporalError> {
|
||||||
|
if q_new.seq != 1 || k_new.seq != 1 || v_new.seq != 1 {
|
||||||
|
return Err(TemporalError::InvalidConfig(
|
||||||
|
"step() requires single-token q/k/v (seq == 1 each)",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
// Append must succeed before decode_step sees the cache; if
|
||||||
|
// the cache fills, the caller is responsible for eviction or
|
||||||
|
// resetting per ADR-096 §3.2 (H2O heavy-hitter eviction is
|
||||||
|
// available upstream but kept opt-in).
|
||||||
|
cache.try_append(k_new, v_new)?;
|
||||||
|
Ok(self.attn.decode_step(q_new, cache)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a KvCache sized for this head's shape. Convenience
|
||||||
|
/// so callers don't need to import the upstream crate directly.
|
||||||
|
pub fn make_cache(&self, capacity: usize) -> KvCache {
|
||||||
|
KvCache::new(
|
||||||
|
capacity,
|
||||||
|
self.cfg.kv_heads,
|
||||||
|
self.cfg.head_dim,
|
||||||
|
self.cfg.block_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Always treat token 0 as a global anchor — AETHER's contrastive
|
/// Always treat token 0 as a global anchor — AETHER's contrastive
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
//! ADR-096 §3.2 streaming-decode test: token-by-token `step()` against
|
||||||
|
//! a `KvCache` should match a single-shot `forward()` over the same
|
||||||
|
//! Q/K/V at the final position. This is the structural advantage
|
||||||
|
//! dense MHA can't follow — proving it stays correct under streaming
|
||||||
|
//! is what the §5 validation gate would care about most.
|
||||||
|
|
||||||
|
use wifi_densepose_temporal::{
|
||||||
|
AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig, Tensor3,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn make_qkv(seq: usize, q_heads: usize, kv_heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) {
|
||||||
|
let mut q = Tensor3::zeros(seq, q_heads, dim);
|
||||||
|
let mut k = Tensor3::zeros(seq, kv_heads, dim);
|
||||||
|
let mut v = Tensor3::zeros(seq, kv_heads, dim);
|
||||||
|
for s in 0..seq {
|
||||||
|
for h in 0..q_heads {
|
||||||
|
for d in 0..dim {
|
||||||
|
let val = ((s * 31 + h * 7 + d) as f32).sin() * 0.1;
|
||||||
|
q.set(s, h, d, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for h in 0..kv_heads {
|
||||||
|
for d in 0..dim {
|
||||||
|
let val = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1;
|
||||||
|
k.set(s, h, d, val);
|
||||||
|
v.set(s, h, d, val * 0.5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(q, k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn slice_token(t: &Tensor3, idx: usize) -> Tensor3 {
|
||||||
|
let (_, heads, dim) = t.shape();
|
||||||
|
let mut out = Tensor3::zeros(1, heads, dim);
|
||||||
|
for h in 0..heads {
|
||||||
|
for d in 0..dim {
|
||||||
|
out.set(0, h, d, t.get(idx, h, d));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_mha_small() -> TemporalHeadConfig {
|
||||||
|
// Equal q/k heads forces the `forward` MHA branch — `decode_step`
|
||||||
|
// upstream is wired to this branch, not the GQA branch (which has
|
||||||
|
// its own decode path coming in upstream's roadmap).
|
||||||
|
TemporalHeadConfig {
|
||||||
|
backend: TemporalBackendKind::SparseGqa,
|
||||||
|
q_heads: 2,
|
||||||
|
kv_heads: 2,
|
||||||
|
head_dim: 16,
|
||||||
|
window: 8,
|
||||||
|
block_size: 4,
|
||||||
|
causal: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn streaming_step_matches_forward_at_last_position() {
|
||||||
|
let cfg = config_mha_small();
|
||||||
|
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||||
|
|
||||||
|
let seq = 16usize;
|
||||||
|
let (q, k, v) = make_qkv(seq, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
|
||||||
|
|
||||||
|
// Reference: single-shot forward over the full sequence.
|
||||||
|
let reference = head.forward(&q, &k, &v).expect("forward");
|
||||||
|
|
||||||
|
// Streaming: append k/v one token at a time, decode the new q.
|
||||||
|
let mut cache = head.make_cache(seq).expect("cache");
|
||||||
|
let mut last_out: Option<Tensor3> = None;
|
||||||
|
for t in 0..seq {
|
||||||
|
let qt = slice_token(&q, t);
|
||||||
|
let kt = slice_token(&k, t);
|
||||||
|
let vt = slice_token(&v, t);
|
||||||
|
last_out = Some(head.step(&qt, &kt, &vt, &mut cache).expect("step"));
|
||||||
|
}
|
||||||
|
let streamed = last_out.expect("at least one step");
|
||||||
|
|
||||||
|
// Compare the streamed last-token output to the reference's
|
||||||
|
// last-token output. Tolerance is generous because numerical
|
||||||
|
// accumulation differs between the two paths even at exact
|
||||||
|
// mathematical equivalence.
|
||||||
|
let (s_seq, s_heads, s_dim) = streamed.shape();
|
||||||
|
assert_eq!((s_seq, s_heads, s_dim), (1, cfg.q_heads, cfg.head_dim));
|
||||||
|
let mut max_abs_err: f32 = 0.0;
|
||||||
|
for h in 0..cfg.q_heads {
|
||||||
|
for d in 0..cfg.head_dim {
|
||||||
|
let a = streamed.get(0, h, d);
|
||||||
|
let b = reference.get(seq - 1, h, d);
|
||||||
|
let err = (a - b).abs();
|
||||||
|
if err > max_abs_err {
|
||||||
|
max_abs_err = err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 1e-3 absolute is a comfortable bound for activations of this
|
||||||
|
// magnitude (~0.1 input scale). Tighten if the kernel ever
|
||||||
|
// promises closer match.
|
||||||
|
assert!(
|
||||||
|
max_abs_err < 1.0e-3,
|
||||||
|
"streaming/forward divergence at last token exceeds 1e-3: max_abs_err = {max_abs_err}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn step_rejects_multi_token_q() {
|
||||||
|
let cfg = config_mha_small();
|
||||||
|
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||||
|
let mut cache = head.make_cache(8).expect("cache");
|
||||||
|
|
||||||
|
// Build a 2-token Q/K/V — `step` must reject (its contract is
|
||||||
|
// single-token decode).
|
||||||
|
let (q, k, v) = make_qkv(2, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
|
||||||
|
let err = head.step(&q, &k, &v, &mut cache).err().expect("rejected");
|
||||||
|
let s = format!("{err}");
|
||||||
|
assert!(
|
||||||
|
s.contains("single-token") || s.to_lowercase().contains("seq"),
|
||||||
|
"expected single-token rejection, got: {s}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn make_cache_returns_kvcache_with_correct_shape() {
|
||||||
|
// Smoke test that the convenience wrapper plumbs the right dims
|
||||||
|
// into KvCache::new — the upstream constructor takes
|
||||||
|
// (capacity, kv_heads, dim, block_size) and we want to make sure
|
||||||
|
// we're not transposing any of those.
|
||||||
|
let cfg = config_mha_small();
|
||||||
|
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||||
|
let mut cache = head.make_cache(32).expect("cache");
|
||||||
|
|
||||||
|
// Append one token shaped for kv_heads × head_dim — should not error.
|
||||||
|
let (_, k, v) = make_qkv(1, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
|
||||||
|
let kt = slice_token(&k, 0);
|
||||||
|
let vt = slice_token(&v, 0);
|
||||||
|
cache.try_append(&kt, &vt).expect("append shape ok");
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue