feat(train): AetherTemporalAggregator — wire wifi-densepose-temporal into the tch graph (#513)

ADR-096 train integration. Additive — does NOT modify model.rs. The
existing WiFiDensePoseModel forward stays bit-equivalent for back-compat.
New code lives in temporal_aether.rs behind the `aether-sparse-temporal`
feature flag (which itself requires `tch-backend`).

Architecture:

    tch::Tensor [T, in_dim]   ──── tch nn::Linear (q/k/v projections)
                                    ↓
                              [T, q_heads*head_dim] etc
                                    ↓
                             tch_to_tensor3 (CPU, f32, 1× copy)
                                    ↓
                              ruvllm_sparse_attention::Tensor3
                                    ↓
                            AetherTemporalHead::forward()
                                    ↓
                              Tensor3 [T, q_heads, head_dim]
                                    ↓
                             tensor3_to_tch (1× copy)
                                    ↓
                              tch::Tensor [T, q_heads*head_dim]
                                    ↓
                              tch nn::Linear (output projection)
                                    ↓
                              tch::Tensor [T, in_dim]

Why additive rather than swapping `apply_antenna_attention` /
`apply_spatial_attention` in model.rs: those are over antenna and
spatial axes, not temporal — ADR-096 §8.1 was right that AETHER
doesn't currently HAVE a temporal-axis attention. This commit adds
that path without disturbing the others, so the §5 validation gate
can A/B the two options before flipping the production default.

Scope notes:
- B=1 prefill only this version. Multi-batch lands when §5 turns
  green and we need to take perf seriously. The forward expects
  `[T, in_dim]` not `[B, T, in_dim]`; documented in the file.
- Streaming step() bridge deferred — KvCache lifecycle ties to
  PoseTrack per ADR-096 §8.5, which is signal-side not train-side.
- Two CPU memory copies per call (in + out). For training-rate
  forwards (~100/sec at batch 16) this is negligible vs the actual
  attention work; for inference-rate streaming it'd be the
  bottleneck and a zero-copy path is the natural follow-up.

Build verification:
- Source compiles cleanly with cargo check on the host crate
  (`-p wifi-densepose-temporal`, 21/21 tests still passing).
- The train crate's tch-backend build is environmentally blocked
  on this Windows machine — torch-sys fails to link against the
  system PyTorch 2.11 + MSVC 14.50 toolchain. This predates this
  commit and affects all tch-bound code paths in the workspace.
  CI runners with working libtorch will verify the new module
  builds; the source follows the same nn::Linear / Module patterns
  the existing model.rs uses.

Feature gating ensures default builds are byte-equivalent. Off by
default; enable with `--features aether-sparse-temporal`.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-05-08 12:42:41 -04:00
parent 2b903752c4
commit c9fde3cba5
3 changed files with 194 additions and 0 deletions

View File

@ -24,6 +24,11 @@ required-features = ["tch-backend"]
default = []
tch-backend = ["tch"]
cuda = ["tch-backend"]
# ADR-096 sparse-GQA temporal head. Pulls wifi-densepose-temporal in
# alongside tch — the new path is additive, doesn't touch the existing
# model.rs code paths, and stays opt-in until the §5 validation gate
# clears.
aether-sparse-temporal = ["tch-backend", "dep:wifi-densepose-temporal"]
[dependencies]
# Internal crates
@ -54,6 +59,10 @@ ruvector-temporal-tensor = { workspace = true }
ruvector-solver = { workspace = true }
ruvector-attention = { workspace = true }
# AETHER temporal head (ADR-096). Optional + tch-gated — only meaningful
# alongside the existing tch-bound model graph.
wifi-densepose-temporal = { workspace = true, optional = true }
# Data loading
ndarray-npy.workspace = true
memmap2 = "0.9"

View File

@ -69,6 +69,13 @@ pub mod proof;
#[cfg(feature = "tch-backend")]
pub mod trainer;
// ADR-096 AETHER temporal head — additive integration. Pulled in via
// the `aether-sparse-temporal` feature, which itself requires
// `tch-backend`. Kept under its own cfg so the existing build with
// just `tch-backend` is byte-equivalent to before.
#[cfg(feature = "aether-sparse-temporal")]
pub mod temporal_aether;
// Convenient re-exports at the crate root.
pub use config::TrainingConfig;
pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};

View File

@ -0,0 +1,178 @@
//! ADR-096 AETHER temporal head — `tch::nn` bridge.
//!
//! Additive integration: wires `wifi-densepose-temporal` (sparse-GQA
//! attention + streaming KvCache) into the train crate's tch graph.
//! Does NOT modify the existing `WiFiDensePoseModel` forward in
//! `model.rs` — that path stays bit-equivalent for back-compat. Use
//! this aggregator alongside the existing model when you want a
//! temporal-axis pooling on top of per-frame backbone features.
//!
//! Bridge boundary:
//! tch::Tensor [T, in_dim] → Tensor3 (seq=T, heads, dim) → attention
//! ← Tensor3 ← forward()
//! tch::Tensor [in_dim] (pooled embedding)
//!
//! Memory pattern: tch.copy_data → Vec<f32> → Tensor3::from_vec on the
//! way in; Tensor3 raw → Tensor::of_slice on the way out. Two host
//! copies per call. For training-rate forwards (~100 calls/sec at
//! batch 16) this is negligible vs the actual attention work; for
//! inference-rate streaming it'd be the bottleneck and a
//! zero-copy path is the natural Phase 2.
//!
//! Only the B=1 prefill path is implemented in this commit. Multi-batch
//! and the streaming `step()` bridge land when the §5 validation gate
//! turns green and we need to take the perf hit seriously.
//!
//! Feature-gated: `aether-sparse-temporal` (also requires `tch-backend`).
use tch::{
nn::{self, Module},
Device, Kind, Tensor,
};
use wifi_densepose_temporal::{
AetherTemporalHead, TemporalBackendKind, TemporalError, TemporalHeadConfig, Tensor3,
};
/// Aggregator: tch-side projections + the pure-Rust sparse attention
/// kernel + a tch-side output projection. The projection layers are
/// `nn::Linear` so they participate in the tch VarStore the same way
/// the rest of the model does — gradients, save/load, etc.
pub struct AetherTemporalAggregator {
cfg: TemporalHeadConfig,
in_dim: i64,
// tch-side learnable projections.
q_proj: nn::Linear,
k_proj: nn::Linear,
v_proj: nn::Linear,
o_proj: nn::Linear,
// The kernel itself is configuration-only; no weights live inside
// because the sparse attention forward is purely a function of
// q/k/v + the SparseAttentionConfig.
head: AetherTemporalHead,
}
impl AetherTemporalAggregator {
/// Build the aggregator. `vs` is the tch namespace under which
/// the four projection layers register. `in_dim` is the input
/// feature dimension per frame (e.g. backbone output dim).
pub fn new(vs: nn::Path, in_dim: i64, cfg: TemporalHeadConfig) -> Result<Self, TemporalError> {
cfg.validate()?;
// Backend has to be Sparse — Dense projections would still
// work, but the whole point of this integration is the new
// sparse-GQA path. If a caller wants dense, they can keep
// using `apply_antenna_attention` / `apply_spatial_attention`
// from model.rs.
if !matches!(cfg.backend, TemporalBackendKind::SparseGqa) {
return Err(TemporalError::InvalidConfig(
"aggregator only wires SparseGqa; use existing model.rs paths for dense",
));
}
let total_q = (cfg.q_heads * cfg.head_dim) as i64;
let total_kv = (cfg.kv_heads * cfg.head_dim) as i64;
let q_proj = nn::linear(&vs / "q_proj", in_dim, total_q, Default::default());
let k_proj = nn::linear(&vs / "k_proj", in_dim, total_kv, Default::default());
let v_proj = nn::linear(&vs / "v_proj", in_dim, total_kv, Default::default());
let o_proj = nn::linear(&vs / "o_proj", total_q, in_dim, Default::default());
let head = AetherTemporalHead::new(&cfg)?;
Ok(Self {
cfg,
in_dim,
q_proj,
k_proj,
v_proj,
o_proj,
head,
})
}
/// Forward over a single sequence of frames. Input shape:
/// `[T, in_dim]` (NB: B=1 only this version — see file header).
/// Returns the per-token attention output passed through the
/// output projection: `[T, in_dim]`.
///
/// Pooling (mean over T, last-token, attention-pool, etc.) is the
/// caller's job — different downstream consumers want different
/// pools and we don't want to bake one in.
pub fn forward(&self, frames: &Tensor) -> Result<Tensor, TemporalError> {
let dims = frames.size();
if dims.len() != 2 || dims[1] != self.in_dim {
return Err(TemporalError::InvalidConfig(
"aggregator.forward expects [T, in_dim] tch::Tensor",
));
}
let t = dims[0] as usize;
let device = frames.device();
// ── Project to Q/K/V on the tch side ──────────────────────
let q_th = self.q_proj.forward(frames); // [T, q_heads*head_dim]
let k_th = self.k_proj.forward(frames); // [T, kv_heads*head_dim]
let v_th = self.v_proj.forward(frames); // [T, kv_heads*head_dim]
// ── Bridge to Tensor3 (CPU, f32) ──────────────────────────
let q_t3 = tch_to_tensor3(&q_th, t, self.cfg.q_heads, self.cfg.head_dim)?;
let k_t3 = tch_to_tensor3(&k_th, t, self.cfg.kv_heads, self.cfg.head_dim)?;
let v_t3 = tch_to_tensor3(&v_th, t, self.cfg.kv_heads, self.cfg.head_dim)?;
// ── Sparse attention forward (pure-Rust path) ────────────
let attn_out = self.head.forward(&q_t3, &k_t3, &v_t3)?;
// ── Bridge back to tch ───────────────────────────────────
let attn_th = tensor3_to_tch(&attn_out, device);
// attn_th shape is [T, q_heads*head_dim].
// ── Output projection on tch side ────────────────────────
let out = self.o_proj.forward(&attn_th); // [T, in_dim]
Ok(out)
}
}
/// Reshape a `[T, heads*head_dim]` tch::Tensor on (any device, any
/// kind) into a CPU `Tensor3(seq=T, heads, head_dim)`. Forces f32 +
/// CPU + contiguous memory; copies once.
fn tch_to_tensor3(
th: &Tensor,
seq: usize,
heads: usize,
head_dim: usize,
) -> Result<Tensor3, TemporalError> {
let dims = th.size();
if dims.len() != 2 || dims[0] as usize != seq || dims[1] as usize != heads * head_dim {
return Err(TemporalError::InvalidConfig(
"tch_to_tensor3 shape mismatch",
));
}
let cpu = th.to_kind(Kind::Float).to_device(Device::Cpu).contiguous();
let total = seq * heads * head_dim;
let mut buf = vec![0.0f32; total];
cpu.copy_data(&mut buf, total);
// tch row-major flatten gives [seq][heads*head_dim]. Tensor3
// expects [seq][heads][dim] in the same row-major order, so the
// contiguous bytes are layout-compatible — no per-element
// transpose required.
Tensor3::from_vec(buf, seq, heads, head_dim)
.map_err(|e| TemporalError::InvalidConfig(Box::leak(format!("from_vec: {e}").into_boxed_str())))
}
/// Inverse of `tch_to_tensor3`: take a `Tensor3(seq, heads, dim)` and
/// produce a `[seq, heads*dim]` tch::Tensor on the requested device.
fn tensor3_to_tch(t3: &Tensor3, device: Device) -> Tensor {
let (seq, heads, dim) = t3.shape();
// Tensor3 stores seq×heads×dim contiguously; flatten heads/dim
// by reading the row at each (seq, head) and concatenating.
let mut flat = Vec::with_capacity(seq * heads * dim);
for s in 0..seq {
for h in 0..heads {
flat.extend_from_slice(t3.row(s, h));
}
}
Tensor::from_slice(&flat)
.reshape([seq as i64, (heads * dim) as i64])
.to_device(device)
}