feat(temporal): scaffold wifi-densepose-temporal crate (ADR-096 Phase 1-3, #513)
Implements Phases 1-3 of the ADR-096 roadmap: Phase 1: workspace integration - Add `ruvllm_sparse_attention` as a path-vendored workspace dep against `vendor/ruvector/crates/ruvllm_sparse_attention`, default-features=false, features=["fp16"]. Mirrors the no_std posture ADR-095 will need on the firmware side so both consumers share a single feature set. - Register `wifi-densepose-temporal` as workspace member. Phase 2: AETHER temporal head - `AetherTemporalHead` facade dispatches to a `SparseGqa` backend wrapping `SubquadraticSparseAttention`. Selection rule from ADR-096 §4.4 enforced at forward(): MHA branch when q_heads == kv_heads, GQA branch otherwise. - `Dense` backend reserved (returns typed `DenseBackendNotImplemented`) so config-time validation fails loudly instead of at forward(). - `TemporalHeadConfig::default_aether()` matches the AETHER training default per ADR-096 §3.1 (window=32, block=16, q=4, kv=1 → MQA). - Token 0 always wired as a global anchor — preserves AETHER's contrastive "session-start reference" role per ADR-024. Phase 3: smoke tests (5/5 passing) - forward at AETHER default config, both MHA and GQA dispatch paths, rejected dense backend, rejected non-divisible GQA ratio, and the long-window roadmap target (N=1000, the 10s @ 100Hz case from ADR-096 §3.1 — proves the kernel runs at lengths where dense MHA costs 10⁶ edge ops vs sparse 10⁴). Streaming `step()` deferred — KvCache lifecycle ties to PoseTrack per ADR-096 §8.5 and lands when the firmware-side ABI does (Phase 4+). Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
684ef4f1a5
commit
bfb3fdee13
|
|
@ -231,6 +231,18 @@ dependencies = [
|
|||
"wait-timeout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-compression"
|
||||
version = "0.4.42"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e79b3f8a79cccc2898f31920fc69f304859b3bd567490f75ebf51ae1c792a9ac"
|
||||
dependencies = [
|
||||
"compression-codecs",
|
||||
"compression-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
|
|
@ -318,7 +330,7 @@ dependencies = [
|
|||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower 0.5.3",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
|
|
@ -871,6 +883,23 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "compression-codecs"
|
||||
version = "0.4.38"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce2548391e9c1929c21bf6aa2680af86fe4c1b33e6cea9ac1cfeec0bd11218cf"
|
||||
dependencies = [
|
||||
"compression-core",
|
||||
"flate2",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "compression-core"
|
||||
version = "0.4.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc14f565cf027a105f7a44ccf9e5b424348421a1d8952a8fc9d499d313107789"
|
||||
|
||||
[[package]]
|
||||
name = "concurrent-queue"
|
||||
version = "2.5.0"
|
||||
|
|
@ -2371,6 +2400,16 @@ version = "0.16.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
|
||||
|
||||
[[package]]
|
||||
name = "hdrhistogram"
|
||||
version = "7.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heapless"
|
||||
version = "0.6.1"
|
||||
|
|
@ -3892,13 +3931,35 @@ name = "nvsim"
|
|||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"approx 0.5.1",
|
||||
"criterion",
|
||||
"js-sys",
|
||||
"rand 0.8.5",
|
||||
"rand_chacha 0.3.1",
|
||||
"serde",
|
||||
"serde-wasm-bindgen",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvsim-server"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"clap",
|
||||
"futures-util",
|
||||
"nvsim",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tower 0.4.13",
|
||||
"tower-http 0.5.2",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -4487,6 +4548,26 @@ dependencies = [
|
|||
"siphasher 1.0.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.17"
|
||||
|
|
@ -5278,7 +5359,7 @@ dependencies = [
|
|||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tower",
|
||||
"tower 0.5.3",
|
||||
"tower-http 0.6.8",
|
||||
"tower-service",
|
||||
"url",
|
||||
|
|
@ -5311,7 +5392,7 @@ dependencies = [
|
|||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower 0.5.3",
|
||||
"tower-http 0.6.8",
|
||||
"tower-service",
|
||||
"url",
|
||||
|
|
@ -5798,6 +5879,14 @@ version = "2.0.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "178f93f84a4a72c582026a45d9b8710acf188df4a22a25434c5dbba1df6c4cac"
|
||||
|
||||
[[package]]
|
||||
name = "ruvllm_sparse_attention"
|
||||
version = "0.1.1"
|
||||
dependencies = [
|
||||
"half",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.23"
|
||||
|
|
@ -7379,6 +7468,27 @@ dependencies = [
|
|||
"zip 0.6.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"hdrhistogram",
|
||||
"indexmap 1.9.3",
|
||||
"pin-project",
|
||||
"pin-project-lite",
|
||||
"rand 0.8.5",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.3"
|
||||
|
|
@ -7401,8 +7511,10 @@ version = "0.5.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
|
||||
dependencies = [
|
||||
"async-compression",
|
||||
"bitflags 2.11.0",
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
|
|
@ -7433,7 +7545,7 @@ dependencies = [
|
|||
"http-body 1.0.1",
|
||||
"iri-string",
|
||||
"pin-project-lite",
|
||||
"tower",
|
||||
"tower 0.5.3",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
|
@ -8385,6 +8497,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tower-http 0.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -8452,6 +8565,15 @@ dependencies = [
|
|||
"wifi-densepose-ruvector",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-temporal"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"approx 0.5.1",
|
||||
"ruvllm_sparse_attention",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-train"
|
||||
version = "0.3.0"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ members = [
|
|||
"crates/wifi-densepose-wifiscan",
|
||||
"crates/wifi-densepose-vitals",
|
||||
"crates/wifi-densepose-ruvector",
|
||||
"crates/wifi-densepose-temporal",
|
||||
"crates/wifi-densepose-desktop",
|
||||
"crates/wifi-densepose-pointcloud",
|
||||
"crates/wifi-densepose-geo",
|
||||
|
|
@ -131,6 +132,11 @@ ruvector-attention = "2.0.4"
|
|||
ruvector-crv = "0.1.1"
|
||||
ruvector-gnn = { version = "2.0.5", default-features = false }
|
||||
|
||||
# ruvllm sparse attention (path-vendored per ADR-095/096)
|
||||
# Default-features=false keeps the kernel no_std-clean so the same workspace
|
||||
# version is consumable by the upcoming ESP-IDF Rust component (ADR-095).
|
||||
ruvllm_sparse_attention = { path = "../vendor/ruvector/crates/ruvllm_sparse_attention", default-features = false, features = ["fp16"] }
|
||||
|
||||
|
||||
# Internal crates
|
||||
wifi-densepose-core = { version = "0.3.0", path = "crates/wifi-densepose-core" }
|
||||
|
|
@ -143,6 +149,7 @@ wifi-densepose-hardware = { version = "0.3.0", path = "crates/wifi-densepose-har
|
|||
wifi-densepose-wasm = { version = "0.3.0", path = "crates/wifi-densepose-wasm" }
|
||||
wifi-densepose-mat = { version = "0.3.0", path = "crates/wifi-densepose-mat" }
|
||||
wifi-densepose-ruvector = { version = "0.3.0", path = "crates/wifi-densepose-ruvector" }
|
||||
wifi-densepose-temporal = { version = "0.1.0", path = "crates/wifi-densepose-temporal" }
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
[package]
|
||||
name = "wifi-densepose-temporal"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
description = "AETHER temporal head for WiFi-DensePose — sparse-GQA attention over CSI feature windows (ADR-096)"
|
||||
repository = "https://github.com/ruvnet/RuView"
|
||||
|
||||
[dependencies]
|
||||
ruvllm_sparse_attention = { workspace = true }
|
||||
thiserror = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# Enable FP16 KV cache path (mirrors the firmware-side ADR-095 build).
|
||||
fp16 = []
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
use crate::TemporalError;
|
||||
|
||||
/// Backend choice per ADR-096 §4.4.
|
||||
///
|
||||
/// * `Dense` — back-compat path against `ruvector-attention`. Reserved;
|
||||
/// not yet implemented in this crate (returns a typed error so callers
|
||||
/// can fail loudly during config validation rather than at forward()).
|
||||
/// * `SparseGqa` — `ruvllm_sparse_attention` `forward_gqa` for prefill,
|
||||
/// `decode_step` against `KvCache` for streaming inference.
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum TemporalBackendKind {
|
||||
Dense,
|
||||
SparseGqa,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TemporalHeadConfig {
|
||||
pub backend: TemporalBackendKind,
|
||||
|
||||
/// Number of query heads. For pure MHA, equals `kv_heads`.
|
||||
pub q_heads: usize,
|
||||
/// Number of key/value heads. Must divide `q_heads`. GQA group size
|
||||
/// is `q_heads / kv_heads`.
|
||||
pub kv_heads: usize,
|
||||
/// Per-head feature dimension.
|
||||
pub head_dim: usize,
|
||||
|
||||
/// Local attention window radius (sparse pattern primitive #1, ADR-096 §3).
|
||||
pub window: usize,
|
||||
/// Landmark block size (sparse pattern primitive #3).
|
||||
pub block_size: usize,
|
||||
/// Whether the attention is causal. AETHER temporal aggregation is
|
||||
/// causal (cannot peek at future CSI frames during streaming re-ID).
|
||||
pub causal: bool,
|
||||
}
|
||||
|
||||
impl TemporalHeadConfig {
|
||||
/// Default config sized for the AETHER training default
|
||||
/// (`window_frames = 100`) but with the sparse machinery wired up
|
||||
/// so the long-window roadmap (10 s / 1000 frames) only requires
|
||||
/// changing `window` at the call site, not re-architecting.
|
||||
pub fn default_aether() -> Self {
|
||||
Self {
|
||||
backend: TemporalBackendKind::SparseGqa,
|
||||
q_heads: 4,
|
||||
kv_heads: 1, // MQA — collapses to one shared K/V across query heads
|
||||
head_dim: 32,
|
||||
window: 32,
|
||||
block_size: 16,
|
||||
causal: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<(), TemporalError> {
|
||||
if self.q_heads == 0 || self.kv_heads == 0 || self.head_dim == 0 {
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"q_heads, kv_heads, head_dim must all be > 0",
|
||||
));
|
||||
}
|
||||
if self.q_heads % self.kv_heads != 0 {
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"q_heads must be divisible by kv_heads (GQA constraint)",
|
||||
));
|
||||
}
|
||||
if self.block_size == 0 {
|
||||
return Err(TemporalError::InvalidConfig("block_size must be > 0"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TemporalError {
|
||||
#[error("temporal head config invalid: {0}")]
|
||||
InvalidConfig(&'static str),
|
||||
|
||||
#[error("dense MHA backend not implemented yet (ADR-096 §4.4 follow-up)")]
|
||||
DenseBackendNotImplemented,
|
||||
|
||||
#[error("sparse attention kernel error: {0}")]
|
||||
Kernel(String),
|
||||
}
|
||||
|
||||
impl From<ruvllm_sparse_attention::AttentionError> for TemporalError {
|
||||
fn from(e: ruvllm_sparse_attention::AttentionError) -> Self {
|
||||
TemporalError::Kernel(format!("{e}"))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
// AETHER temporal head over CSI feature windows (ADR-096).
|
||||
//
|
||||
// Wraps `ruvllm_sparse_attention::SubquadraticSparseAttention` so AETHER
|
||||
// callers in `wifi-densepose-train` and `wifi-densepose-signal` can swap
|
||||
// dense MHA for sparse-GQA without touching the contrastive recipe.
|
||||
//
|
||||
// Status: scaffolding for ADR-096 §4.3. Sparse backend is functional;
|
||||
// the dense back-compat backend is a follow-up (Phase 2 of the roadmap
|
||||
// in #513). Streaming `step()` lands once the per-track KvCache lifecycle
|
||||
// (ADR-096 §8.5) is finalized.
|
||||
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod sparse;
|
||||
|
||||
pub use config::{TemporalBackendKind, TemporalHeadConfig};
|
||||
pub use error::TemporalError;
|
||||
pub use sparse::SparseGqaHead;
|
||||
|
||||
// Re-export the upstream Tensor3 so callers don't need a direct
|
||||
// `ruvllm_sparse_attention` dep.
|
||||
pub use ruvllm_sparse_attention::Tensor3;
|
||||
|
||||
/// Thin facade so callers can pick a backend by name.
|
||||
///
|
||||
/// Today only `SparseGqa` is implemented; `Dense` is reserved per
|
||||
/// ADR-096 §4.4 and returns `TemporalError::DenseBackendNotImplemented`
|
||||
/// until the back-compat path lands.
|
||||
pub enum AetherTemporalHead {
|
||||
SparseGqa(SparseGqaHead),
|
||||
Dense, // placeholder; ADR-096 §4.4 selection rule
|
||||
}
|
||||
|
||||
impl AetherTemporalHead {
|
||||
pub fn new(cfg: &TemporalHeadConfig) -> Result<Self, TemporalError> {
|
||||
match cfg.backend {
|
||||
TemporalBackendKind::SparseGqa => {
|
||||
Ok(AetherTemporalHead::SparseGqa(SparseGqaHead::new(cfg)?))
|
||||
}
|
||||
TemporalBackendKind::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
}
|
||||
}
|
||||
|
||||
/// Window-level prefill. Returns the per-token attention output as
|
||||
/// a Tensor3 of shape (window, q_heads, head_dim). Pooling to a
|
||||
/// single embedding is the caller's responsibility — different
|
||||
/// AETHER consumers use different pool ops (mean for re-ID,
|
||||
/// last-token for streaming).
|
||||
pub fn forward(
|
||||
&self,
|
||||
q: &Tensor3,
|
||||
k: &Tensor3,
|
||||
v: &Tensor3,
|
||||
) -> Result<Tensor3, TemporalError> {
|
||||
match self {
|
||||
AetherTemporalHead::SparseGqa(h) => h.forward(q, k, v),
|
||||
AetherTemporalHead::Dense => Err(TemporalError::DenseBackendNotImplemented),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
use ruvllm_sparse_attention::{
|
||||
AttentionBackend, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
|
||||
};
|
||||
|
||||
use crate::{TemporalError, TemporalHeadConfig};
|
||||
|
||||
/// AETHER temporal head implemented with `ruvllm_sparse_attention`.
|
||||
///
|
||||
/// The selection rule from ADR-096 §4.4 is enforced at `forward()`
|
||||
/// time: when `q_heads == kv_heads` we use `forward()` (plain MHA
|
||||
/// over the sparse pattern); when they differ we use `forward_gqa()`.
|
||||
/// The streaming `step()` path is staged behind a follow-up — KvCache
|
||||
/// lifecycle ties to `PoseTrack` per ADR-096 §8.5 and lives on the
|
||||
/// caller, not here.
|
||||
pub struct SparseGqaHead {
|
||||
cfg: TemporalHeadConfig,
|
||||
attn: SubquadraticSparseAttention,
|
||||
}
|
||||
|
||||
impl SparseGqaHead {
|
||||
pub fn new(cfg: &TemporalHeadConfig) -> Result<Self, TemporalError> {
|
||||
cfg.validate()?;
|
||||
|
||||
let attn_cfg = SparseAttentionConfig {
|
||||
window: cfg.window,
|
||||
block_size: cfg.block_size,
|
||||
global_tokens: alloc_first_token(),
|
||||
causal: cfg.causal,
|
||||
use_log_stride: true,
|
||||
use_landmarks: true,
|
||||
sort_candidates: false,
|
||||
};
|
||||
|
||||
let attn = SubquadraticSparseAttention::new(attn_cfg)?;
|
||||
Ok(Self {
|
||||
cfg: cfg.clone(),
|
||||
attn,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cfg(&self) -> &TemporalHeadConfig {
|
||||
&self.cfg
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
q: &Tensor3,
|
||||
k: &Tensor3,
|
||||
v: &Tensor3,
|
||||
) -> Result<Tensor3, TemporalError> {
|
||||
// ADR-096 §4.4: dispatch by GQA shape.
|
||||
if self.cfg.q_heads == self.cfg.kv_heads {
|
||||
// Pure MHA — sparse `forward` is the right path.
|
||||
Ok(self.attn.forward(q, k, v)?)
|
||||
} else {
|
||||
// GQA / MQA — kv_heads < q_heads, group share factor = q/kv.
|
||||
Ok(self.attn.forward_gqa(q, k, v)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Always treat token 0 as a global anchor — AETHER's contrastive
|
||||
/// recipe (ADR-024) gives the first token a special role as the
|
||||
/// "session start" reference embedding, and global tokens in the
|
||||
/// sparse pattern preserve full visibility for that one position.
|
||||
fn alloc_first_token() -> Vec<usize> {
|
||||
vec![0]
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
//! Smoke tests for the AETHER sparse-GQA temporal head (ADR-096 §5 gate is
|
||||
//! a separate accuracy benchmark; this file just proves the wiring works).
|
||||
|
||||
use wifi_densepose_temporal::{
|
||||
AetherTemporalHead, TemporalBackendKind, TemporalHeadConfig, TemporalError, Tensor3,
|
||||
};
|
||||
|
||||
fn make_qkv(seq: usize, q_heads: usize, kv_heads: usize, dim: usize) -> (Tensor3, Tensor3, Tensor3) {
|
||||
// Deterministic synthetic CSI-like activations so the test is
|
||||
// reproducible across machines without bringing in `rand`.
|
||||
let mut q = Tensor3::zeros(seq, q_heads, dim);
|
||||
for s in 0..seq {
|
||||
for h in 0..q_heads {
|
||||
for d in 0..dim {
|
||||
let v = ((s * 31 + h * 7 + d) as f32).sin() * 0.1;
|
||||
q.set(s, h, d, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut k = Tensor3::zeros(seq, kv_heads, dim);
|
||||
let mut v = Tensor3::zeros(seq, kv_heads, dim);
|
||||
for s in 0..seq {
|
||||
for h in 0..kv_heads {
|
||||
for d in 0..dim {
|
||||
let kv = (((s * 17 + h * 3 + d) as f32).cos()) * 0.1;
|
||||
k.set(s, h, d, kv);
|
||||
v.set(s, h, d, kv * 0.5);
|
||||
}
|
||||
}
|
||||
}
|
||||
(q, k, v)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_gqa_forward_runs_at_aether_default() {
|
||||
let cfg = TemporalHeadConfig::default_aether();
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||
|
||||
let (q, k, vt) = make_qkv(64, cfg.q_heads, cfg.kv_heads, cfg.head_dim);
|
||||
let out = head.forward(&q, &k, &vt).expect("forward");
|
||||
let (oseq, oh, od) = out.shape();
|
||||
assert_eq!(oseq, 64);
|
||||
assert_eq!(oh, cfg.q_heads);
|
||||
assert_eq!(od, cfg.head_dim);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_mha_path_runs_when_qkv_heads_match() {
|
||||
// q_heads == kv_heads forces the `forward` (non-GQA) branch.
|
||||
let cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::SparseGqa,
|
||||
q_heads: 2,
|
||||
kv_heads: 2,
|
||||
head_dim: 16,
|
||||
window: 8,
|
||||
block_size: 4,
|
||||
causal: true,
|
||||
};
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||
let (q, k, vt) = make_qkv(32, 2, 2, 16);
|
||||
let out = head.forward(&q, &k, &vt).expect("forward");
|
||||
assert_eq!(out.shape(), (32, 2, 16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dense_backend_returns_typed_error() {
|
||||
let cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::Dense,
|
||||
q_heads: 4,
|
||||
kv_heads: 1,
|
||||
head_dim: 32,
|
||||
window: 32,
|
||||
block_size: 16,
|
||||
causal: true,
|
||||
};
|
||||
let err = AetherTemporalHead::new(&cfg).err().expect("dense rejected");
|
||||
matches!(err, TemporalError::DenseBackendNotImplemented);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_gqa_ratio_rejected_at_construction() {
|
||||
let cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::SparseGqa,
|
||||
q_heads: 5,
|
||||
kv_heads: 2, // 5 % 2 != 0
|
||||
head_dim: 16,
|
||||
window: 8,
|
||||
block_size: 4,
|
||||
causal: true,
|
||||
};
|
||||
let err = AetherTemporalHead::new(&cfg).err().expect("rejected");
|
||||
matches!(err, TemporalError::InvalidConfig(_));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn long_window_at_aether_roadmap_target() {
|
||||
// ADR-096 §3.1 roadmap target: 10 s @ 100 Hz = 1000 frames. Verify
|
||||
// the kernel actually runs at this length so the long-window claim
|
||||
// is more than aspirational.
|
||||
let cfg = TemporalHeadConfig {
|
||||
backend: TemporalBackendKind::SparseGqa,
|
||||
q_heads: 4,
|
||||
kv_heads: 1,
|
||||
head_dim: 16,
|
||||
window: 64,
|
||||
block_size: 32,
|
||||
causal: true,
|
||||
};
|
||||
let head = AetherTemporalHead::new(&cfg).expect("construct");
|
||||
let (q, k, vt) = make_qkv(1000, 4, 1, 16);
|
||||
let out = head.forward(&q, &k, &vt).expect("forward at N=1000");
|
||||
assert_eq!(out.shape(), (1000, 4, 16));
|
||||
}
|
||||
Loading…
Reference in New Issue