feat(temporal): streaming step() + KvCache (ADR-096 §3.2, #513)
The structural advantage that's the entire point of ADR-096: O(log T) per new token via decode_step against an accumulated KvCache, vs O(N²) recompute for dense MHA. This commit lands the API and proves the numerical equivalence at the last position. API: - AetherTemporalHead::step(q_new, k_new, v_new, &mut cache) Single-token decode. Appends (k_new, v_new) to cache, runs decode_step(q_new) against the now-updated cache, returns the new position's output. - AetherTemporalHead::make_cache(capacity) Convenience constructor — caller doesn't need to import ruvllm_sparse_attention to size a cache. Per ADR-096 §8.5 the natural lifetime is per-PoseTrack (re-ID) or per-session (online classification); when the track drops, drop the cache. - KvCache re-exported at the crate root. Contract: - q_new/k_new/v_new must each have seq == 1. Multi-token q is the prefill path (forward), not decode_step. - Cache lifetime is the caller's. The crate enforces shape via make_cache so callers can't mismatch kv_heads / head_dim / block_size. - KvCache fill is the caller's problem. Upstream H2O heavy-hitter eviction is opt-in; this crate's wrapper doesn't pre-pick a policy. Tests (18/18 total now passing): - streaming_step_matches_forward_at_last_position — central claim: 16-token sequence, append k/v one at a time via step(), compare the streamed last-token output to forward(full Q,K,V)[N-1]. max_abs_err < 1e-3 (currently passes well under that bound for the 0.1-magnitude activations the test uses). - step_rejects_multi_token_q — contract enforcement. - make_cache_returns_kvcache_with_correct_shape — wiring smoke, confirms (capacity, kv_heads, dim, block_size) ordering is correct through the make_cache wrapper. Test config uses MHA shape (q_heads == kv_heads) because the upstream decode_step is wired to the MHA branch; the GQA decode path is on upstream's roadmap and lands in a separate ADR-096 follow-up when it does. Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
3a5fe5e0de
commit
49e57efcec
|
|
@ -22,9 +22,9 @@ pub use weights::{
|
|||
WEIGHT_BLOB_VERSION,
|
||||
};
|
||||
|
||||
// Re-export the upstream Tensor3 so callers don't need a direct
|
||||
// `ruvllm_sparse_attention` dep.
|
||||
pub use ruvllm_sparse_attention::Tensor3;
|
||||
// Re-export the upstream Tensor3 + KvCache so callers don't need a
|
||||
// direct `ruvllm_sparse_attention` dep.
|
||||
pub use ruvllm_sparse_attention::{KvCache, Tensor3};
|
||||
|
||||
/// Thin facade so callers can pick a backend by name.
|
||||
///
|
||||
|
|
@ -62,4 +62,33 @@ impl AetherTemporalHead {
|
|||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming decode (ADR-096 §3.2). Caller owns the `cache`; the
|
||||
/// natural lifetime is per-tracked-person (one cache per
|
||||
/// `PoseTrack`, dropped when the track evicts).
|
||||
///
|
||||
/// Returns the attention output for the single new token. Caller
|
||||
/// is responsible for downstream pooling / classifier head.
|
||||
pub fn step(
|
||||
&self,
|
||||
q_new: &Tensor3,
|
||||
k_new: &Tensor3,
|
||||
v_new: &Tensor3,
|
||||
cache: &mut KvCache,
|
||||
) -> Result<Tensor3, TemporalError> {
|
||||
match self {
|
||||
AetherTemporalHead::SparseGqa(h) => h.step(q_new, k_new, v_new, cache),
|
||||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a `KvCache` sized correctly for this head. Convenience
|
||||
/// wrapper so AETHER's `pose_tracker.rs` doesn't need to import
|
||||
/// the upstream crate.
|
||||
pub fn make_cache(&self, capacity: usize) -> Result<KvCache, TemporalError> {
|
||||
match self {
|
||||
AetherTemporalHead::SparseGqa(h) => Ok(h.make_cache(capacity)),
|
||||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use ruvllm_sparse_attention::{
|
||||
AttentionBackend, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
|
||||
AttentionBackend, KvCache, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
|
||||
};
|
||||
|
||||
use crate::{TemporalError, TemporalHeadConfig};
|
||||
|
|
@ -57,6 +57,50 @@ impl SparseGqaHead {
|
|||
Ok(self.attn.forward_gqa(q, k, v)?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming decode for re-ID and online classification (ADR-096 §3.2).
|
||||
///
|
||||
/// Given one new token's q/k/v, append (k, v) to `cache` and return
|
||||
/// the attention output for that one position against the full
|
||||
/// accumulated history. Cost is O(log T) per step against a cache
|
||||
/// of capacity T — the structural advantage over dense MHA's O(N²)
|
||||
/// recompute that ADR-096 specifically calls out as the
|
||||
/// dense-MHA-cannot-follow path.
|
||||
///
|
||||
/// Cache lifetime is owned by the caller. Per ADR-096 §8.5 the
|
||||
/// natural place is one cache per `PoseTrack` (re-ID) or one cache
|
||||
/// per active session (online classification). When the track is
|
||||
/// dropped, drop the cache.
|
||||
pub fn step(
|
||||
&self,
|
||||
q_new: &Tensor3,
|
||||
k_new: &Tensor3,
|
||||
v_new: &Tensor3,
|
||||
cache: &mut KvCache,
|
||||
) -> Result<Tensor3, TemporalError> {
|
||||
if q_new.seq != 1 || k_new.seq != 1 || v_new.seq != 1 {
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"step() requires single-token q/k/v (seq == 1 each)",
|
||||
));
|
||||
}
|
||||
// Append must succeed before decode_step sees the cache; if
|
||||
// the cache fills, the caller is responsible for eviction or
|
||||
// resetting per ADR-096 §3.2 (H2O heavy-hitter eviction is
|
||||
// available upstream but kept opt-in).
|
||||
cache.try_append(k_new, v_new)?;
|
||||
Ok(self.attn.decode_step(q_new, cache)?)
|
||||
}
|
||||
|
||||
/// Construct a KvCache sized for this head's shape. Convenience
|
||||
/// so callers don't need to import the upstream crate directly.
|
||||
pub fn make_cache(&self, capacity: usize) -> KvCache {
|
||||
KvCache::new(
|
||||
capacity,
|
||||
self.cfg.kv_heads,
|
||||
self.cfg.head_dim,
|
||||
self.cfg.block_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Always treat token 0 as a global anchor — AETHER's contrastive
|
||||
|
|
|
|||
|
|
@ -0,0 +1,139 @@
|
|||
//! ADR-096 §3.2 streaming-decode test: token-by-token `step()` against
|
||||
//! a `KvCache` should match a single-shot `forward()` over the same
|
||||
//! Q/K/V at the final position. This is the structural advantage
|
||||
//! dense MHA can't follow — proving it stays correct under streaming
|
||||
//! is what the §5 validation gate would care about most.
|
||||
|
||||
use wifi_densepose_temporal::{
|
||||
AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig, Tensor3,
|
||||
};
|
||||
|
||||
fn make_qkv(seq: usize, q_heads: usize, kv_heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) {
|
||||
let mut q = Tensor3::zeros(seq, q_heads, dim);
|
||||
let mut k = Tensor3::zeros(seq, kv_heads, dim);
|
||||
let mut v = Tensor3::zeros(seq, kv_heads, dim);
|
||||
for s in 0..seq {
|
||||
for h in 0..q_heads {
|
||||
for d in 0..dim {
|
||||
let val = ((s * 31 + h * 7 + d) as f32).sin() * 0.1;
|
||||
q.set(s, h, d, val);
|
||||
}
|
||||
}
|
||||
for h in 0..kv_heads {
|
||||
for d in 0..dim {
|
||||
let val = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1;
|
||||
k.set(s, h, d, val);
|
||||
v.set(s, h, d, val * 0.5);
|
||||
}
|
||||
}
|
||||
}
|
||||
(q, k, v)
|
||||
}
|
||||
|
||||
fn slice_token(t: &Tensor3, idx: usize) -> Tensor3 {
|
||||
let (_, heads, dim) = t.shape();
|
||||
let mut out = Tensor3::zeros(1, heads, dim);
|
||||
for h in 0..heads {
|
||||
for d in 0..dim {
|
||||
out.set(0, h, d, t.get(idx, h, d));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn config_mha_small() -> TemporalHeadConfig {
|
||||
// Equal q/k heads forces the `forward` MHA branch — `decode_step`
|
||||
// upstream is wired to this branch, not the GQA branch (which has
|
||||
// its own decode path coming in upstream's roadmap).
|
||||
TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::SparseGqa,
|
||||
q_heads: 2,
|
||||
kv_heads: 2,
|
||||
head_dim: 16,
|
||||
window: 8,
|
||||
block_size: 4,
|
||||
causal: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_step_matches_forward_at_last_position() {
|
||||
let cfg = config_mha_small();
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||
|
||||
let seq = 16usize;
|
||||
let (q, k, v) = make_qkv(seq, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
|
||||
|
||||
// Reference: single-shot forward over the full sequence.
|
||||
let reference = head.forward(&q, &k, &v).expect("forward");
|
||||
|
||||
// Streaming: append k/v one token at a time, decode the new q.
|
||||
let mut cache = head.make_cache(seq).expect("cache");
|
||||
let mut last_out: Option<Tensor3> = None;
|
||||
for t in 0..seq {
|
||||
let qt = slice_token(&q, t);
|
||||
let kt = slice_token(&k, t);
|
||||
let vt = slice_token(&v, t);
|
||||
last_out = Some(head.step(&qt, &kt, &vt, &mut cache).expect("step"));
|
||||
}
|
||||
let streamed = last_out.expect("at least one step");
|
||||
|
||||
// Compare the streamed last-token output to the reference's
|
||||
// last-token output. Tolerance is generous because numerical
|
||||
// accumulation differs between the two paths even at exact
|
||||
// mathematical equivalence.
|
||||
let (s_seq, s_heads, s_dim) = streamed.shape();
|
||||
assert_eq!((s_seq, s_heads, s_dim), (1, cfg.q_heads, cfg.head_dim));
|
||||
let mut max_abs_err: f32 = 0.0;
|
||||
for h in 0..cfg.q_heads {
|
||||
for d in 0..cfg.head_dim {
|
||||
let a = streamed.get(0, h, d);
|
||||
let b = reference.get(seq - 1, h, d);
|
||||
let err = (a - b).abs();
|
||||
if err > max_abs_err {
|
||||
max_abs_err = err;
|
||||
}
|
||||
}
|
||||
}
|
||||
// 1e-3 absolute is a comfortable bound for activations of this
|
||||
// magnitude (~0.1 input scale). Tighten if the kernel ever
|
||||
// promises closer match.
|
||||
assert!(
|
||||
max_abs_err < 1.0e-3,
|
||||
"streaming/forward divergence at last token exceeds 1e-3: max_abs_err = {max_abs_err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn step_rejects_multi_token_q() {
|
||||
let cfg = config_mha_small();
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||
let mut cache = head.make_cache(8).expect("cache");
|
||||
|
||||
// Build a 2-token Q/K/V — `step` must reject (its contract is
|
||||
// single-token decode).
|
||||
let (q, k, v) = make_qkv(2, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
|
||||
let err = head.step(&q, &k, &v, &mut cache).err().expect("rejected");
|
||||
let s = format!("{err}");
|
||||
assert!(
|
||||
s.contains("single-token") || s.to_lowercase().contains("seq"),
|
||||
"expected single-token rejection, got: {s}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn make_cache_returns_kvcache_with_correct_shape() {
|
||||
// Smoke test that the convenience wrapper plumbs the right dims
|
||||
// into KvCache::new — the upstream constructor takes
|
||||
// (capacity, kv_heads, dim, block_size) and we want to make sure
|
||||
// we're not transposing any of those.
|
||||
let cfg = config_mha_small();
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||
let mut cache = head.make_cache(32).expect("cache");
|
||||
|
||||
// Append one token shaped for kv_heads × head_dim — should not error.
|
||||
let (_, k, v) = make_qkv(1, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
|
||||
let kt = slice_token(&k, 0);
|
||||
let vt = slice_token(&v, 0);
|
||||
cache.try_append(&kt, &vt).expect("append shape ok");
|
||||
}
|
||||
Loading…
Reference in New Issue