feat(temporal): Dense backend implementation (ADR-096 §5 A/B gate, #513)
Closes the Dense placeholder from earlier commits. Now both backends implement forward(); only SparseGqa supports streaming step()/KvCache, which is the structural gap dense MHA can't bridge by design. Dense path: - src/dense.rs new — DenseHead wraps upstream dense_attention. Stores causal flag and (cloned) config. forward() is a one-line delegation; no GQA dispatch (dense_attention upstream requires q_heads == kv_heads). - AetherTemporalHead::Dense changed from a unit variant to Dense(DenseHead). Construction succeeds for any valid TemporalHeadConfig where backend is Dense. - AetherTemporalHead.step() returns BackendDoesNotSupportStreaming for Dense — there is no dense-MHA-with-KV-cache equivalent and offering one would silently swallow the ADR-096 §3.2 structural argument. - AetherTemporalHead.make_cache() likewise — there's no cache to size for a dense kernel. Errors: - New TemporalError::BackendDoesNotSupportStreaming variant covers the Dense-step / Dense-make_cache cases. Specific so callers can fall back to forward() instead of giving up entirely. - TemporalError::DenseBackendNotImplemented retained for v0.1 back-compat (no consumers depend on it post-this-commit, but removing a public variant is a hard break). Future work can deprecate it once downstream callers move off. Tests (19/19 passing): - dense_backend_returns_typed_error → renamed and rewritten as dense_backend_forward_runs_with_matching_shape: constructs a Dense head, runs forward over (32, 4, 4, 16) Q/K/V, asserts output shape. - New dense_backend_step_returns_streaming_error: constructs Dense, attempts make_cache, expects BackendDoesNotSupportStreaming. - All 8 weight blob, 2 blob e2e, 3 streaming, 5 other smoke tests unchanged and still passing. This commit completes the ADR-096 §5 A/B gate: callers can now run the same Q/K/V through both backends and compare outputs / latency. The §5 four-gate validation (contrastive loss within 1%, rank-1 within 1pp, Spearman ≥0.95, latency ≥5×) becomes a runnable proposition, not a future task — though the actual gate run requires trained AETHER weights, which is its own track. Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
2aee4d21cf
commit
4ea8457017
|
|
@ -0,0 +1,44 @@
|
|||
use ruvllm_sparse_attention::{dense_attention, Tensor3};
|
||||
|
||||
use crate::{TemporalError, TemporalHeadConfig};
|
||||
|
||||
/// Dense MHA backend (ADR-096 §5 A/B baseline).
|
||||
///
|
||||
/// Wraps upstream `dense_attention` — the naive O(N²) reference kernel.
|
||||
/// Same approximation surface as classical scaled-dot-product attention,
|
||||
/// no log-stride / landmarks / windowing. Exists primarily as the
|
||||
/// reference path for the §5 validation gate (rank correlation,
|
||||
/// contrastive-loss parity, latency baseline).
|
||||
///
|
||||
/// Has no streaming counterpart: dense MHA structurally cannot do
|
||||
/// O(log T) decode — every new token requires recomputing the full
|
||||
/// attention matrix. Callers that want streaming must use SparseGqa.
|
||||
pub struct DenseHead {
|
||||
causal: bool,
|
||||
cfg: TemporalHeadConfig,
|
||||
}
|
||||
|
||||
impl DenseHead {
|
||||
pub fn new(cfg: &TemporalHeadConfig) -> Result<Self, TemporalError> {
|
||||
cfg.validate()?;
|
||||
Ok(Self {
|
||||
causal: cfg.causal,
|
||||
cfg: cfg.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cfg(&self) -> &TemporalHeadConfig {
|
||||
&self.cfg
|
||||
}
|
||||
|
||||
/// Naive O(N²) prefill. Q/K/V must share the same head count
|
||||
/// (no GQA) — `dense_attention` upstream enforces it.
|
||||
pub fn forward(
|
||||
&self,
|
||||
q: &Tensor3,
|
||||
k: &Tensor3,
|
||||
v: &Tensor3,
|
||||
) -> Result<Tensor3, TemporalError> {
|
||||
Ok(dense_attention(q, k, v, self.causal)?)
|
||||
}
|
||||
}
|
||||
|
|
@ -5,9 +5,18 @@ pub enum TemporalError {
|
|||
#[error("temporal head config invalid: {0}")]
|
||||
InvalidConfig(&'static str),
|
||||
|
||||
/// Retained for back-compat with v0.1 callers; superseded by the
|
||||
/// per-operation errors below now that Dense is implemented.
|
||||
#[error("dense MHA backend not implemented yet (ADR-096 §4.4 follow-up)")]
|
||||
DenseBackendNotImplemented,
|
||||
|
||||
/// Dense MHA has no notion of an accumulated KV cache — every
|
||||
/// new frame requires recomputing the full N² attention matrix
|
||||
/// (the structural gap ADR-096 §3.2 flagged). Callers that want
|
||||
/// streaming decode must use the SparseGqa backend.
|
||||
#[error("dense backend does not support streaming step(); use SparseGqa for online decode")]
|
||||
BackendDoesNotSupportStreaming,
|
||||
|
||||
#[error("sparse attention kernel error: {0}")]
|
||||
Kernel(String),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,11 +10,13 @@
|
|||
// (ADR-096 §8.5) is finalized.
|
||||
|
||||
pub mod config;
|
||||
pub mod dense;
|
||||
pub mod error;
|
||||
pub mod sparse;
|
||||
pub mod weights;
|
||||
|
||||
pub use config::{TemporalBackendKind, TemporalHeadConfig};
|
||||
pub use dense::DenseHead;
|
||||
pub use error::TemporalError;
|
||||
pub use sparse::SparseGqaHead;
|
||||
pub use weights::{
|
||||
|
|
@ -28,12 +30,13 @@ pub use ruvllm_sparse_attention::{KvCache, Tensor3};
|
|||
|
||||
/// Thin facade so callers can pick a backend by name.
|
||||
///
|
||||
/// Today only `SparseGqa` is implemented; `Dense` is reserved per
|
||||
/// ADR-096 §4.4 and returns `TemporalError::DenseBackendNotImplemented`
|
||||
/// until the back-compat path lands.
|
||||
/// Both backends implement `forward()` for prefill. Only `SparseGqa`
|
||||
/// implements `step()` (streaming O(log T) decode against KvCache);
|
||||
/// dense MHA structurally lacks a streaming counterpart and returns
|
||||
/// `TemporalError::BackendDoesNotSupportStreaming` on `step()`.
|
||||
pub enum AetherTemporalHead {
|
||||
SparseGqa(SparseGqaHead),
|
||||
Dense, // placeholder; ADR-096 §4.4 selection rule
|
||||
Dense(DenseHead),
|
||||
}
|
||||
|
||||
impl AetherTemporalHead {
|
||||
|
|
@ -42,7 +45,7 @@ impl AetherTemporalHead {
|
|||
TemporalBackendKind::SparseGqa => {
|
||||
Ok(AetherTemporalHead::SparseGqa(SparseGqaHead::new(cfg)?))
|
||||
}
|
||||
TemporalBackendKind::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
TemporalBackendKind::Dense => Ok(AetherTemporalHead::Dense(DenseHead::new(cfg)?)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -59,7 +62,7 @@ impl AetherTemporalHead {
|
|||
) -> Result<Tensor3, TemporalError> {
|
||||
match self {
|
||||
AetherTemporalHead::SparseGqa(h) => h.forward(q, k, v),
|
||||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
AetherTemporalHead::Dense(h) => h.forward(q, k, v),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -69,6 +72,9 @@ impl AetherTemporalHead {
|
|||
///
|
||||
/// Returns the attention output for the single new token. Caller
|
||||
/// is responsible for downstream pooling / classifier head.
|
||||
///
|
||||
/// Dense backend returns `BackendDoesNotSupportStreaming` — no
|
||||
/// dense-MHA-with-KV-cache equivalent exists, by design.
|
||||
pub fn step(
|
||||
&self,
|
||||
q_new: &Tensor3,
|
||||
|
|
@ -78,17 +84,22 @@ impl AetherTemporalHead {
|
|||
) -> Result<Tensor3, TemporalError> {
|
||||
match self {
|
||||
AetherTemporalHead::SparseGqa(h) => h.step(q_new, k_new, v_new, cache),
|
||||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
AetherTemporalHead::Dense(_) => {
|
||||
Err(TemporalError::BackendDoesNotSupportStreaming)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a `KvCache` sized correctly for this head. Convenience
|
||||
/// wrapper so AETHER's `pose_tracker.rs` doesn't need to import
|
||||
/// the upstream crate.
|
||||
///
|
||||
/// Dense backend returns `BackendDoesNotSupportStreaming` — there
|
||||
/// is no cache to size for a dense kernel.
|
||||
pub fn make_cache(&self, capacity: usize) -> Result<KvCache, TemporalError> {
|
||||
match self {
|
||||
AetherTemporalHead::SparseGqa(h) => Ok(h.make_cache(capacity)),
|
||||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
AetherTemporalHead::Dense(_) => Err(TemporalError::BackendDoesNotSupportStreaming),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,18 +63,38 @@ fn sparse_mha_path_runs_when_qkv_heads_match() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn dense_backend_returns_typed_error() {
|
||||
fn dense_backend_forward_runs_with_matching_shape() {
|
||||
// Dense_attention upstream requires q_heads == kv_heads (no GQA).
|
||||
// Use MHA shape; n_classes/n_layers don't matter for forward-only.
|
||||
let cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::Dense,
|
||||
q_heads: 4,
|
||||
kv_heads: 1,
|
||||
head_dim: 32,
|
||||
window: 32,
|
||||
block_size: 16,
|
||||
kv_heads: 4,
|
||||
head_dim: 16,
|
||||
window: 8,
|
||||
block_size: 4,
|
||||
causal: true,
|
||||
};
|
||||
let err = AetherTemporalHead::new(&cfg).err().expect("dense rejected");
|
||||
matches!(err, TemporalError::DenseBackendNotImplemented);
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct dense");
|
||||
let (q, k, v) = make_qkv(32, 4, 4, 16);
|
||||
let out = head.forward(&q, &k, &v).expect("dense forward");
|
||||
assert_eq!(out.shape(), (32, 4, 16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dense_backend_step_returns_streaming_error() {
|
||||
let cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::Dense,
|
||||
q_heads: 4,
|
||||
kv_heads: 4,
|
||||
head_dim: 16,
|
||||
window: 8,
|
||||
block_size: 4,
|
||||
causal: true,
|
||||
};
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct dense");
|
||||
let cache_err = head.make_cache(32).err().expect("no cache for dense");
|
||||
matches!(cache_err, TemporalError::BackendDoesNotSupportStreaming);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
Loading…
Reference in New Issue