From c9fde3cba54f77012cae70ca2d01972cd14b66ef Mon Sep 17 00:00:00 2001 From: ruv Date: Fri, 8 May 2026 12:42:41 -0400 Subject: [PATCH] =?UTF-8?q?feat(train):=20AetherTemporalAggregator=20?= =?UTF-8?q?=E2=80=94=20wire=20wifi-densepose-temporal=20into=20the=20tch?= =?UTF-8?q?=20graph=20(#513)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ADR-096 train integration. Additive — does NOT modify model.rs. The existing WiFiDensePoseModel forward stays bit-equivalent for back-compat. New code lives in temporal_aether.rs behind the `aether-sparse-temporal` feature flag (which itself requires `tch-backend`). Architecture: tch::Tensor [T, in_dim] ──── tch nn::Linear (q/k/v projections) ↓ [T, q_heads*head_dim] etc ↓ tch_to_tensor3 (CPU, f32, 1× copy) ↓ ruvllm_sparse_attention::Tensor3 ↓ AetherTemporalHead::forward() ↓ Tensor3 [T, q_heads, head_dim] ↓ tensor3_to_tch (1× copy) ↓ tch::Tensor [T, q_heads*head_dim] ↓ tch nn::Linear (output projection) ↓ tch::Tensor [T, in_dim] Why additive rather than swapping `apply_antenna_attention` / `apply_spatial_attention` in model.rs: those are over antenna and spatial axes, not temporal — ADR-096 §8.1 was right that AETHER doesn't currently HAVE a temporal-axis attention. This commit adds that path without disturbing the others, so the §5 validation gate can A/B the two options before flipping the production default. Scope notes: - B=1 prefill only this version. Multi-batch lands when §5 turns green and we need to take perf seriously. The forward expects `[T, in_dim]` not `[B, T, in_dim]`; documented in the file. - Streaming step() bridge deferred — KvCache lifecycle ties to PoseTrack per ADR-096 §8.5, which is signal-side not train-side. - Two CPU memory copies per call (in + out). For training-rate forwards (~100/sec at batch 16) this is negligible vs the actual attention work; for inference-rate streaming it'd be the bottleneck and a zero-copy path is the natural follow-up. Build verification: - Source compiles cleanly with cargo check on the host crate (`-p wifi-densepose-temporal`, 21/21 tests still passing). - The train crate's tch-backend build is environmentally blocked on this Windows machine — torch-sys fails to link against the system PyTorch 2.11 + MSVC 14.50 toolchain. This predates this commit and affects all tch-bound code paths in the workspace. CI runners with working libtorch will verify the new module builds; the source follows the same nn::Linear / Module patterns the existing model.rs uses. Feature gating ensures default builds are byte-equivalent. Off by default; enable with `--features aether-sparse-temporal`. Co-Authored-By: claude-flow --- v2/crates/wifi-densepose-train/Cargo.toml | 9 + v2/crates/wifi-densepose-train/src/lib.rs | 7 + .../src/temporal_aether.rs | 178 ++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 v2/crates/wifi-densepose-train/src/temporal_aether.rs diff --git a/v2/crates/wifi-densepose-train/Cargo.toml b/v2/crates/wifi-densepose-train/Cargo.toml index ac0fa37d..491cc732 100644 --- a/v2/crates/wifi-densepose-train/Cargo.toml +++ b/v2/crates/wifi-densepose-train/Cargo.toml @@ -24,6 +24,11 @@ required-features = ["tch-backend"] default = [] tch-backend = ["tch"] cuda = ["tch-backend"] +# ADR-096 sparse-GQA temporal head. Pulls wifi-densepose-temporal in +# alongside tch — the new path is additive, doesn't touch the existing +# model.rs code paths, and stays opt-in until the §5 validation gate +# clears. +aether-sparse-temporal = ["tch-backend", "dep:wifi-densepose-temporal"] [dependencies] # Internal crates @@ -54,6 +59,10 @@ ruvector-temporal-tensor = { workspace = true } ruvector-solver = { workspace = true } ruvector-attention = { workspace = true } +# AETHER temporal head (ADR-096). Optional + tch-gated — only meaningful +# alongside the existing tch-bound model graph. +wifi-densepose-temporal = { workspace = true, optional = true } + # Data loading ndarray-npy.workspace = true memmap2 = "0.9" diff --git a/v2/crates/wifi-densepose-train/src/lib.rs b/v2/crates/wifi-densepose-train/src/lib.rs index 8831c549..162bc487 100644 --- a/v2/crates/wifi-densepose-train/src/lib.rs +++ b/v2/crates/wifi-densepose-train/src/lib.rs @@ -69,6 +69,13 @@ pub mod proof; #[cfg(feature = "tch-backend")] pub mod trainer; +// ADR-096 AETHER temporal head — additive integration. Pulled in via +// the `aether-sparse-temporal` feature, which itself requires +// `tch-backend`. Kept under its own cfg so the existing build with +// just `tch-backend` is byte-equivalent to before. +#[cfg(feature = "aether-sparse-temporal")] +pub mod temporal_aether; + // Convenient re-exports at the crate root. pub use config::TrainingConfig; pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig}; diff --git a/v2/crates/wifi-densepose-train/src/temporal_aether.rs b/v2/crates/wifi-densepose-train/src/temporal_aether.rs new file mode 100644 index 00000000..67d4a69e --- /dev/null +++ b/v2/crates/wifi-densepose-train/src/temporal_aether.rs @@ -0,0 +1,178 @@ +//! ADR-096 AETHER temporal head — `tch::nn` bridge. +//! +//! Additive integration: wires `wifi-densepose-temporal` (sparse-GQA +//! attention + streaming KvCache) into the train crate's tch graph. +//! Does NOT modify the existing `WiFiDensePoseModel` forward in +//! `model.rs` — that path stays bit-equivalent for back-compat. Use +//! this aggregator alongside the existing model when you want a +//! temporal-axis pooling on top of per-frame backbone features. +//! +//! Bridge boundary: +//! tch::Tensor [T, in_dim] → Tensor3 (seq=T, heads, dim) → attention +//! ← Tensor3 ← forward() +//! tch::Tensor [in_dim] (pooled embedding) +//! +//! Memory pattern: tch.copy_data → Vec → Tensor3::from_vec on the +//! way in; Tensor3 raw → Tensor::of_slice on the way out. Two host +//! copies per call. For training-rate forwards (~100 calls/sec at +//! batch 16) this is negligible vs the actual attention work; for +//! inference-rate streaming it'd be the bottleneck and a +//! zero-copy path is the natural Phase 2. +//! +//! Only the B=1 prefill path is implemented in this commit. Multi-batch +//! and the streaming `step()` bridge land when the §5 validation gate +//! turns green and we need to take the perf hit seriously. +//! +//! Feature-gated: `aether-sparse-temporal` (also requires `tch-backend`). + +use tch::{ + nn::{self, Module}, + Device, Kind, Tensor, +}; + +use wifi_densepose_temporal::{ + AetherTemporalHead, TemporalBackendKind, TemporalError, TemporalHeadConfig, Tensor3, +}; + +/// Aggregator: tch-side projections + the pure-Rust sparse attention +/// kernel + a tch-side output projection. The projection layers are +/// `nn::Linear` so they participate in the tch VarStore the same way +/// the rest of the model does — gradients, save/load, etc. +pub struct AetherTemporalAggregator { + cfg: TemporalHeadConfig, + in_dim: i64, + + // tch-side learnable projections. + q_proj: nn::Linear, + k_proj: nn::Linear, + v_proj: nn::Linear, + o_proj: nn::Linear, + + // The kernel itself is configuration-only; no weights live inside + // because the sparse attention forward is purely a function of + // q/k/v + the SparseAttentionConfig. + head: AetherTemporalHead, +} + +impl AetherTemporalAggregator { + /// Build the aggregator. `vs` is the tch namespace under which + /// the four projection layers register. `in_dim` is the input + /// feature dimension per frame (e.g. backbone output dim). + pub fn new(vs: nn::Path, in_dim: i64, cfg: TemporalHeadConfig) -> Result { + cfg.validate()?; + // Backend has to be Sparse — Dense projections would still + // work, but the whole point of this integration is the new + // sparse-GQA path. If a caller wants dense, they can keep + // using `apply_antenna_attention` / `apply_spatial_attention` + // from model.rs. + if !matches!(cfg.backend, TemporalBackendKind::SparseGqa) { + return Err(TemporalError::InvalidConfig( + "aggregator only wires SparseGqa; use existing model.rs paths for dense", + )); + } + + let total_q = (cfg.q_heads * cfg.head_dim) as i64; + let total_kv = (cfg.kv_heads * cfg.head_dim) as i64; + + let q_proj = nn::linear(&vs / "q_proj", in_dim, total_q, Default::default()); + let k_proj = nn::linear(&vs / "k_proj", in_dim, total_kv, Default::default()); + let v_proj = nn::linear(&vs / "v_proj", in_dim, total_kv, Default::default()); + let o_proj = nn::linear(&vs / "o_proj", total_q, in_dim, Default::default()); + + let head = AetherTemporalHead::new(&cfg)?; + + Ok(Self { + cfg, + in_dim, + q_proj, + k_proj, + v_proj, + o_proj, + head, + }) + } + + /// Forward over a single sequence of frames. Input shape: + /// `[T, in_dim]` (NB: B=1 only this version — see file header). + /// Returns the per-token attention output passed through the + /// output projection: `[T, in_dim]`. + /// + /// Pooling (mean over T, last-token, attention-pool, etc.) is the + /// caller's job — different downstream consumers want different + /// pools and we don't want to bake one in. + pub fn forward(&self, frames: &Tensor) -> Result { + let dims = frames.size(); + if dims.len() != 2 || dims[1] != self.in_dim { + return Err(TemporalError::InvalidConfig( + "aggregator.forward expects [T, in_dim] tch::Tensor", + )); + } + let t = dims[0] as usize; + let device = frames.device(); + + // ── Project to Q/K/V on the tch side ────────────────────── + let q_th = self.q_proj.forward(frames); // [T, q_heads*head_dim] + let k_th = self.k_proj.forward(frames); // [T, kv_heads*head_dim] + let v_th = self.v_proj.forward(frames); // [T, kv_heads*head_dim] + + // ── Bridge to Tensor3 (CPU, f32) ────────────────────────── + let q_t3 = tch_to_tensor3(&q_th, t, self.cfg.q_heads, self.cfg.head_dim)?; + let k_t3 = tch_to_tensor3(&k_th, t, self.cfg.kv_heads, self.cfg.head_dim)?; + let v_t3 = tch_to_tensor3(&v_th, t, self.cfg.kv_heads, self.cfg.head_dim)?; + + // ── Sparse attention forward (pure-Rust path) ──────────── + let attn_out = self.head.forward(&q_t3, &k_t3, &v_t3)?; + + // ── Bridge back to tch ─────────────────────────────────── + let attn_th = tensor3_to_tch(&attn_out, device); + // attn_th shape is [T, q_heads*head_dim]. + + // ── Output projection on tch side ──────────────────────── + let out = self.o_proj.forward(&attn_th); // [T, in_dim] + Ok(out) + } +} + +/// Reshape a `[T, heads*head_dim]` tch::Tensor on (any device, any +/// kind) into a CPU `Tensor3(seq=T, heads, head_dim)`. Forces f32 + +/// CPU + contiguous memory; copies once. +fn tch_to_tensor3( + th: &Tensor, + seq: usize, + heads: usize, + head_dim: usize, +) -> Result { + let dims = th.size(); + if dims.len() != 2 || dims[0] as usize != seq || dims[1] as usize != heads * head_dim { + return Err(TemporalError::InvalidConfig( + "tch_to_tensor3 shape mismatch", + )); + } + let cpu = th.to_kind(Kind::Float).to_device(Device::Cpu).contiguous(); + let total = seq * heads * head_dim; + let mut buf = vec![0.0f32; total]; + cpu.copy_data(&mut buf, total); + // tch row-major flatten gives [seq][heads*head_dim]. Tensor3 + // expects [seq][heads][dim] in the same row-major order, so the + // contiguous bytes are layout-compatible — no per-element + // transpose required. + Tensor3::from_vec(buf, seq, heads, head_dim) + .map_err(|e| TemporalError::InvalidConfig(Box::leak(format!("from_vec: {e}").into_boxed_str()))) +} + +/// Inverse of `tch_to_tensor3`: take a `Tensor3(seq, heads, dim)` and +/// produce a `[seq, heads*dim]` tch::Tensor on the requested device. +fn tensor3_to_tch(t3: &Tensor3, device: Device) -> Tensor { + let (seq, heads, dim) = t3.shape(); + // Tensor3 stores seq×heads×dim contiguously; flatten heads/dim + // by reading the row at each (seq, head) and concatenating. + let mut flat = Vec::with_capacity(seq * heads * dim); + for s in 0..seq { + for h in 0..heads { + flat.extend_from_slice(t3.row(s, h)); + } + } + Tensor::from_slice(&flat) + .reshape([seq as i64, (heads * dim) as i64]) + .to_device(device) +}