From 7d26b15eefbf4edca37526fe1247859c3dd1239c Mon Sep 17 00:00:00 2001 From: ruv Date: Mon, 11 May 2026 13:05:27 -0400 Subject: [PATCH] =?UTF-8?q?feat(train):=20MERIDIAN-MAE=20=E2=80=94=20csi?= =?UTF-8?q?=5Fmae::model=20+=20pretrain=20loop=20+=20pretrain-mae=20bin=20?= =?UTF-8?q?(iter=202b,=20#68)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Real CSI masked-autoencoder behind feature `tch-backend` (ADR-027 §2.0): - CsiMae: dual-stream per-token amp+phase embed → fuse → residual-MLP encoder over the visible tokens → flatten-to-latent bottleneck → learned per-position query + broadcast latent → residual-MLP decoder → dec_amp_head / dec_ph_head → index_select the masked positions. (MLP-based v0; self-attention transformer blocks are iter 3.) - CsiMae::reconstruction_loss(pred_amp, pred_phase, tgt_amp, tgt_phase, phase_w) = MSE(amp) + phase_w * MSE(phase). - MaeBatch::from_windows — partition computed once from window 0 and reused across the batch (the bottleneck fixes n_tokens), ndarray → tch conversion. - pretrain_step(model, opt, batch) -> f64 — one Adam step, returns the loss. - src/bin/pretrain_mae.rs — synthetic-data pre-train driver (required-features = ["tch-backend"]); clap args for epochs/batch/samples/lr/mask-ratio/save. - #[cfg(feature="tch-backend")] smoke test: loss halves when overfitting one batch over 60 steps; also asserts model.n_visible/n_masked match mask_csi_window's clamping. v0 limits (documented in the module): fixed n_tokens; batch-shared masking; MSE on unwrapped phase (vs a circular loss). The dev box has no LibTorch, so the tch path is CI-verified (`--features tch-backend`), not locally. The default `cargo test -p wifi-densepose-train --no-default-features` stays green (121 lib tests) — the model module and the bin are both feature-gated. Co-Authored-By: claude-flow --- v2/crates/wifi-densepose-train/Cargo.toml | 5 + .../src/bin/pretrain_mae.rs | 108 +++++ v2/crates/wifi-densepose-train/src/csi_mae.rs | 390 ++++++++++++++++-- 3 files changed, 475 insertions(+), 28 deletions(-) create mode 100644 v2/crates/wifi-densepose-train/src/bin/pretrain_mae.rs diff --git a/v2/crates/wifi-densepose-train/Cargo.toml b/v2/crates/wifi-densepose-train/Cargo.toml index ac0fa37d..323f33fa 100644 --- a/v2/crates/wifi-densepose-train/Cargo.toml +++ b/v2/crates/wifi-densepose-train/Cargo.toml @@ -20,6 +20,11 @@ name = "verify-training" path = "src/bin/verify_training.rs" required-features = ["tch-backend"] +[[bin]] +name = "pretrain-mae" +path = "src/bin/pretrain_mae.rs" +required-features = ["tch-backend"] + [features] default = [] tch-backend = ["tch"] diff --git a/v2/crates/wifi-densepose-train/src/bin/pretrain_mae.rs b/v2/crates/wifi-densepose-train/src/bin/pretrain_mae.rs new file mode 100644 index 00000000..8e3b370c --- /dev/null +++ b/v2/crates/wifi-densepose-train/src/bin/pretrain_mae.rs @@ -0,0 +1,108 @@ +//! `pretrain-mae` — drive the MERIDIAN CSI masked-autoencoder pre-train on a +//! deterministic `SyntheticCsiDataset` (ADR-027 §2.0, prototype iteration 2). +//! +//! This is the *prototype* driver — it exercises the full pre-train loop +//! (mask → encode visible → reconstruct masked amplitude+phase → optimiser +//! step) end-to-end on synthetic CSI. Real cross-domain pre-training (iter 3+) +//! ingests heterogeneous capture — MM-Fi / Wi-Pose / `data/recordings/` / +//! multi-band virtual sub-carriers — and runs on GPU (`scripts/gcloud-train.sh` +//! / the cognitum project). +//! +//! ```text +//! cargo run -p wifi-densepose-train --features tch-backend --bin pretrain-mae -- --epochs 5 +//! ``` +//! +//! Only compiled with `--features tch-backend` (see Cargo.toml `required-features`). + +use clap::Parser; +use tch::nn::OptimizerConfig; +use tch::{nn, Device}; + +use wifi_densepose_train::csi_mae::model::{pretrain_step, CsiMae, MaeBatch}; +use wifi_densepose_train::csi_mae::{MaeConfig, MaskStrategy, TokenLayout}; +use wifi_densepose_train::dataset::{CsiDataset, SyntheticConfig, SyntheticCsiDataset}; + +/// MERIDIAN CSI masked-autoencoder pre-train (prototype, synthetic data). +#[derive(Parser, Debug)] +#[command(name = "pretrain-mae", version, about)] +struct Cli { + /// Number of epochs over the synthetic dataset. + #[arg(long, default_value_t = 5)] + epochs: usize, + /// Mini-batch size (windows per optimiser step). + #[arg(long, default_value_t = 8)] + batch: usize, + /// Number of synthetic samples to generate. + #[arg(long, default_value_t = 256)] + samples: usize, + /// Adam learning rate. + #[arg(long, default_value_t = 1e-3)] + lr: f64, + /// Fraction of tokens masked per window. + #[arg(long, default_value_t = 0.75)] + mask_ratio: f64, + /// Optional path to save the pre-trained variable store (`.ot`). + #[arg(long)] + save: Option, +} + +fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + let _ = tracing_subscriber::fmt::try_init(); + + let ds = SyntheticCsiDataset::new(cli.samples, SyntheticConfig::default()); + if ds.len() < cli.batch { + anyhow::bail!("need at least --batch ({}) samples, have {}", cli.batch, ds.len()); + } + let s0 = ds.get(0)?; + let layout = TokenLayout::from_window(s0.amplitude.view()); + let n_tokens = layout.n_tokens as i64; + + let mut cfg = MaeConfig::default(); + cfg.token_dim = layout.token_dim; + cfg.mask_ratio = cli.mask_ratio; + cfg.validate().map_err(anyhow::Error::msg)?; + + let device = Device::cuda_if_available(); + let vs = nn::VarStore::new(device); + let model = CsiMae::new(&vs.root(), &cfg, n_tokens); + let mut opt = nn::Adam::default().build(&vs, cli.lr)?; + + println!( + "pretrain-mae: device={device:?} n_tokens={n_tokens} token_dim={} V={} M={} samples={} batch={} epochs={} lr={} mask_ratio={}", + cfg.token_dim, model.n_visible, model.n_masked, cli.samples, cli.batch, cli.epochs, cli.lr, cli.mask_ratio + ); + + let mut step: u64 = 0; + for epoch in 0..cli.epochs { + let mut epoch_loss = 0.0_f64; + let mut nb = 0_usize; + let mut i = 0_usize; + while i + cli.batch <= ds.len() { + let mut windows = Vec::with_capacity(cli.batch); + for j in i..i + cli.batch { + let s = ds.get(j)?; + windows.push((s.amplitude, s.phase)); + } + let seed = step.wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xC511_0027; + let batch = MaeBatch::from_windows(&windows, &cfg, seed, MaskStrategy::InfoGuided, device) + .map_err(anyhow::Error::msg)?; + let loss = pretrain_step(&model, &mut opt, &batch); + if !loss.is_finite() { + anyhow::bail!("non-finite loss at epoch {epoch} step {step}"); + } + epoch_loss += loss; + nb += 1; + step += 1; + i += cli.batch; + } + let avg = if nb > 0 { epoch_loss / nb as f64 } else { f64::NAN }; + println!("epoch {epoch}: avg reconstruction loss = {avg:.6} ({nb} batches)"); + } + + if let Some(path) = cli.save { + vs.save(&path)?; + println!("saved pre-trained variable store → {path}"); + } + Ok(()) +} diff --git a/v2/crates/wifi-densepose-train/src/csi_mae.rs b/v2/crates/wifi-densepose-train/src/csi_mae.rs index 254ca88d..a2e1b934 100644 --- a/v2/crates/wifi-densepose-train/src/csi_mae.rs +++ b/v2/crates/wifi-densepose-train/src/csi_mae.rs @@ -494,43 +494,377 @@ pub fn reassemble_tokens( /// CSI masked-autoencoder networks (LibTorch / `tch`). /// /// **Compiled only with `--features tch-backend`.** Not exercised by the default -/// `cargo test --workspace --no-default-features` CI job, so compile-checking it -/// requires a LibTorch toolchain. +/// `cargo test --workspace --no-default-features` CI job — compile-/run-checking +/// this submodule requires a LibTorch toolchain (`LIBTORCH` was unset on the dev +/// box that wrote it, so it is CI-verified only; if a `tch` API call below has +/// drifted, it's a localised fix). /// -/// Iteration-1 status: this is a v0 *skeleton* — the public interface -/// ([`model::CsiMae`], `forward`, `reconstruction_loss`, `model::pretrain_step`) -/// is fixed, but the encoder/decoder are placeholder MLPs and per-batch masking -/// is assumed shared. Transformer blocks, per-sample masking, and the -/// `pretrain-mae` binary land in iteration 2. +/// # v0 design (iteration 2) +/// +/// A deliberately small **dual-stream** MAE, MLP-based (no self-attention yet — +/// transformer blocks are iteration 3): +/// +/// ```text +/// visible amplitude [B, V, sub] ─► amp_embed ─┐ +/// ├─ cat ─► tok_fuse ─► relu ─► enc_blocks(residual MLP) ─► [B, V, enc] +/// visible phase [B, V, sub] ─► ph_embed ─┘ │ +/// reshape [B, V·enc] │ +/// to_latent│ +/// ▼ +/// latent [B, enc] +/// from_latent│ +/// ▼ +/// learned per-position query pos_query [N, dec] + ─► relu ─► dec_blocks(residual MLP) ─► [B, N, dec] +/// (broadcast latent over N positions) │ +/// ┌──────────────────────┤ +/// dec_amp_head dec_ph_head +/// [B, N, sub] [B, N, sub] +/// index_select(masked positions) ─► (pred_amp, pred_ph) [B, M, sub] +/// ``` +/// +/// Limitations to lift later: (1) a *fixed* `n_tokens` (the bottleneck flattens +/// all visible token embeddings, so V — hence N and `mask_ratio` — is baked in +/// at `new()` time); (2) **batch-shared masking** (`MaeBatch::from_samples` masks +/// every sample in a batch with the same seed, so `masked_pos` is shared) — +/// per-sample masking via gather/scatter is iteration 3; (3) MSE on unwrapped +/// phase rather than a circular loss. #[cfg(feature = "tch-backend")] pub mod model { - use super::MaeConfig; + use super::{mask_csi_window, MaeConfig, MaskStrategy}; + use ndarray::{Array2, Array4, Axis}; + use tch::{nn, nn::Module, Device, Kind, Reduction, Tensor}; - /// Placeholder for the CSI masked autoencoder network. - /// - /// Iteration 2 fills this in (dual-stream encoder over visible tokens → - /// shared latent; narrow decoder over all positions with learned mask - /// tokens; reconstruct amplitude + phase). For now it exists so downstream - /// code can name the type and the build wiring is in place. + /// A residual MLP block: `LayerNorm(x + relu(Linear(x)))`. #[derive(Debug)] - pub struct CsiMae { - _cfg: MaeConfig, + struct ResidualMlp { + lin: nn::Linear, + ln: nn::LayerNorm, } - - impl CsiMae { - /// Build a `CsiMae` under the given variable-store path with `cfg`. - /// - /// Iteration-1 stub — see the module docs. - pub fn new(_vs: &tch::nn::Path, cfg: &MaeConfig) -> Self { - Self { _cfg: cfg.clone() } + impl ResidualMlp { + fn new(p: &nn::Path, dim: i64) -> Self { + Self { + lin: nn::linear(p / "lin", dim, dim, Default::default()), + ln: nn::layer_norm(p / "ln", vec![dim], Default::default()), + } + } + fn forward(&self, x: &Tensor) -> Tensor { + self.ln.forward(&(x + self.lin.forward(x).relu())) } } - // NOTE (iteration 2): add - // impl CsiMae { fn forward(&self, vis_amp, vis_phase, vis_idx, n_tokens) -> (pred_amp, pred_phase) } - // fn reconstruction_loss(pred_amp, pred_phase, tgt_amp, tgt_phase, phase_w) -> Tensor - // fn pretrain_step(model, opt, batch) -> f64 - // plus a `bin/pretrain_mae.rs` driving SyntheticCsiDataset / MmFiDataset. + /// The CSI masked autoencoder. See the module docs for the v0 design. + #[derive(Debug)] + pub struct CsiMae { + /// Hyper-parameters this model was built with. + pub cfg: MaeConfig, + /// Number of tokens per window (`T·tx·rx`) — fixed at construction. + pub n_tokens: i64, + /// Number of masked (target) tokens per window. + pub n_masked: i64, + /// Number of visible (encoder-input) tokens per window. + pub n_visible: i64, + device: Device, + amp_embed: nn::Linear, + ph_embed: nn::Linear, + tok_fuse: nn::Linear, + enc_blocks: Vec, + to_latent: nn::Linear, + from_latent: nn::Linear, + /// Learned per-position query, shape `[n_tokens, decoder_dim]`. + pos_query: Tensor, + dec_blocks: Vec, + dec_amp_head: nn::Linear, + dec_ph_head: nn::Linear, + } + + impl CsiMae { + /// Build a `CsiMae` under `vs` for windows of exactly `n_tokens` tokens. + /// + /// `n_tokens` is fixed because the bottleneck flattens all visible token + /// embeddings; it must equal `T·tx·rx` of the windows fed at train/eval + /// time (e.g. `TokenLayout::from_window(sample.amplitude.view()).n_tokens`). + pub fn new(vs: &nn::Path, cfg: &MaeConfig, n_tokens: i64) -> Self { + assert!(n_tokens >= 2, "n_tokens must be >= 2"); + let td = cfg.token_dim as i64; + let enc = cfg.encoder_dim as i64; + let dec = cfg.decoder_dim as i64; + // Mirror mask_csi_window's clamping so the shapes line up exactly. + let mut n_mask = (cfg.mask_ratio * n_tokens as f64).round() as i64; + if n_mask < 1 { + n_mask = 1; + } + if n_mask >= n_tokens { + n_mask = n_tokens - 1; + } + let n_vis = n_tokens - n_mask; + + let enc_blocks = (0..cfg.encoder_depth) + .map(|i| ResidualMlp::new(&(vs / "enc" / i), enc)) + .collect(); + let dec_blocks = (0..cfg.decoder_depth) + .map(|i| ResidualMlp::new(&(vs / "dec" / i), dec)) + .collect(); + let pos_query = vs.var( + "pos_query", + &[n_tokens, dec], + nn::Init::Randn { mean: 0.0, stdev: 0.02 }, + ); + + Self { + cfg: cfg.clone(), + n_tokens, + n_masked: n_mask, + n_visible: n_vis, + device: vs.device(), + amp_embed: nn::linear(vs / "amp_embed", td, enc, Default::default()), + ph_embed: nn::linear(vs / "ph_embed", td, enc, Default::default()), + tok_fuse: nn::linear(vs / "tok_fuse", 2 * enc, enc, Default::default()), + enc_blocks, + to_latent: nn::linear(vs / "to_latent", n_vis * enc, enc, Default::default()), + from_latent: nn::linear(vs / "from_latent", enc, dec, Default::default()), + pos_query, + dec_blocks, + dec_amp_head: nn::linear(vs / "dec_amp_head", dec, td, Default::default()), + dec_ph_head: nn::linear(vs / "dec_ph_head", dec, td, Default::default()), + } + } + + /// Reconstruct the masked amplitude & phase tokens. + /// + /// * `vis_amp`, `vis_phase` — `[B, n_visible, token_dim]`. + /// * `masked_pos` — the `n_masked` masked token indices (shared across + /// the batch in this v0; see the module docs). + /// * returns `(pred_amp, pred_phase)`, each `[B, n_masked, token_dim]`. + pub fn forward( + &self, + vis_amp: &Tensor, + vis_phase: &Tensor, + masked_pos: &[i64], + train: bool, + ) -> (Tensor, Tensor) { + let _ = train; // dropout/layernorm-train hooks would go here in iter 3 + let enc = self.cfg.encoder_dim as i64; + let b = vis_amp.size()[0]; + + // Per-token dual-stream embed → fuse. + let a = self.amp_embed.forward(vis_amp); // [B, V, enc] + let p = self.ph_embed.forward(vis_phase); // [B, V, enc] + let mut t = self.tok_fuse.forward(&Tensor::cat(&[&a, &p], -1)).relu(); // [B, V, enc] + for blk in &self.enc_blocks { + t = blk.forward(&t); + } + + // Bottleneck: flatten visible token embeddings → latent [B, enc]. + let flat = t.reshape([b, self.n_visible * enc]); + let latent = self.to_latent.forward(&flat).relu(); // [B, enc] + + // Decoder: learned per-position query + broadcast latent context. + let ctx = self.from_latent.forward(&latent).unsqueeze(1); // [B, 1, dec] + let mut d = (self.pos_query.unsqueeze(0) + ctx).relu(); // [B, N, dec] + for blk in &self.dec_blocks { + d = blk.forward(&d); + } + + let all_amp = self.dec_amp_head.forward(&d); // [B, N, td] + let all_ph = self.dec_ph_head.forward(&d); // [B, N, td] + let idx = Tensor::from_slice(masked_pos).to_device(self.device); // [M] i64 + (all_amp.index_select(1, &idx), all_ph.index_select(1, &idx)) + } + + /// Dual-stream reconstruction loss: `MSE(pred_amp, tgt_amp) + w·MSE(pred_phase, tgt_phase)`. + pub fn reconstruction_loss( + pred_amp: &Tensor, + pred_phase: &Tensor, + tgt_amp: &Tensor, + tgt_phase: &Tensor, + phase_w: f64, + ) -> Tensor { + let amp_l = pred_amp.mse_loss(tgt_amp, Reduction::Mean); + let ph_l = pred_phase.mse_loss(tgt_phase, Reduction::Mean); + amp_l + ph_l * phase_w + } + } + + /// One batch of masked CSI windows ready for [`pretrain_step`]. + /// + /// All windows in the batch are masked with the *same* seed (v0 + /// simplification), so `masked_pos` / `n_visible` / `n_masked` are shared. + #[derive(Debug)] + pub struct MaeBatch { + /// Visible amplitude tokens, `[B, n_visible, token_dim]`. + pub vis_amp: Tensor, + /// Visible phase tokens, `[B, n_visible, token_dim]`. + pub vis_phase: Tensor, + /// Target (masked) amplitude tokens, `[B, n_masked, token_dim]`. + pub tgt_amp: Tensor, + /// Target (masked) phase tokens, `[B, n_masked, token_dim]`. + pub tgt_phase: Tensor, + /// Masked token indices (length `n_masked`), shared across the batch. + pub masked_pos: Vec, + /// `T·tx·rx` of every window in the batch. + pub n_tokens: i64, + } + + impl MaeBatch { + /// Build a batch from `(amplitude, phase)` windows (each `[T,tx,rx,sub]`). + /// + /// The visible/masked token partition is computed once from the **first** + /// window (via [`mask_csi_window`] with `strategy`/`seed`) and reused for + /// every window in the batch, so `masked_pos` is shared — the + /// fixed-`n_tokens` model requires it. Every window must have the same + /// `[T,tx,rx,sub]` shape. Returns `Err` on a shape mismatch / empty batch. + pub fn from_windows( + windows: &[(Array4, Array4)], + cfg: &MaeConfig, + seed: u64, + strategy: MaskStrategy, + device: Device, + ) -> Result { + if windows.is_empty() { + return Err("MaeBatch::from_windows: empty batch".into()); + } + let td = cfg.token_dim; + + // Partition from window 0; reuse it for the rest of the batch. + let m0 = mask_csi_window(windows[0].0.view(), windows[0].1.view(), cfg.mask_ratio, strategy, seed) + .map_err(|e| format!("MaeBatch window 0: {e}"))?; + if m0.layout.token_dim != td { + return Err(format!("MaeBatch window 0: token_dim {} != cfg.token_dim {td}", m0.layout.token_dim)); + } + let n_tokens = m0.layout.n_tokens as i64; + let visible_idx = m0.visible_idx.clone(); + let masked_idx = m0.masked_idx.clone(); + let masked_pos: Vec = masked_idx.iter().map(|&x| x as i64).collect(); + + let gather = |grid: &Array2, idx: &[usize]| -> Array2 { + let mut out = Array2::::zeros((idx.len(), td)); + for (r, &i) in idx.iter().enumerate() { + out.row_mut(r).assign(&grid.row(i)); + } + out + }; + + let mut vis_amp_rows: Vec> = Vec::with_capacity(windows.len()); + let mut vis_ph_rows: Vec> = Vec::with_capacity(windows.len()); + let mut tgt_amp_rows: Vec> = Vec::with_capacity(windows.len()); + let mut tgt_ph_rows: Vec> = Vec::with_capacity(windows.len()); + + for (i, (amp, ph)) in windows.iter().enumerate() { + let layout = super::TokenLayout::from_window(amp.view()); + if layout.token_dim != td || layout.n_tokens as i64 != n_tokens { + return Err(format!( + "MaeBatch window {i}: shape {:?} incompatible with batch (n_tokens={n_tokens}, token_dim={td})", + amp.shape() + )); + } + if amp.shape() != ph.shape() { + return Err(format!("MaeBatch window {i}: amplitude/phase shape mismatch")); + } + let amp_flat = super::TokenLayout::flatten(amp.view()); + let ph_flat = super::TokenLayout::flatten(ph.view()); + vis_amp_rows.push(gather(&_flat, &visible_idx)); + vis_ph_rows.push(gather(&ph_flat, &visible_idx)); + tgt_amp_rows.push(gather(&_flat, &masked_idx)); + tgt_ph_rows.push(gather(&ph_flat, &masked_idx)); + } + + let stack3 = |rows: &[Array2]| -> Tensor { + let views: Vec<_> = rows.iter().map(|r| r.view()).collect(); + let a3 = ndarray::stack(Axis(0), &views).expect("uniform [k, td] rows stack"); + let (b, k, d) = a3.dim(); + let std = a3.as_standard_layout(); + Tensor::from_slice(std.as_slice().expect("contiguous")) + .reshape([b as i64, k as i64, d as i64]) + .to_device(device) + }; + + Ok(MaeBatch { + vis_amp: stack3(&vis_amp_rows), + vis_phase: stack3(&vis_ph_rows), + tgt_amp: stack3(&tgt_amp_rows), + tgt_phase: stack3(&tgt_ph_rows), + masked_pos, + n_tokens, + }) + } + } + + /// Run one optimiser step on `batch`. Returns the (scalar) reconstruction loss. + pub fn pretrain_step(model: &CsiMae, opt: &mut nn::Optimizer, batch: &MaeBatch) -> f64 { + let (pred_amp, pred_ph) = model.forward(&batch.vis_amp, &batch.vis_phase, &batch.masked_pos, true); + let loss = CsiMae::reconstruction_loss( + &pred_amp, + &pred_ph, + &batch.tgt_amp, + &batch.tgt_phase, + model.cfg.phase_loss_weight, + ); + opt.backward_step(&loss); + f64::try_from(&loss).unwrap_or(f64::NAN) + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::csi_mae::{MaeConfig, MaskStrategy, TokenLayout}; + use tch::nn::OptimizerConfig; + + /// Deterministic synthetic CSI window `[T, tx, rx, sub]` with structure. + fn synth(seed: u64, frames: usize, tx: usize, rx: usize, sub: usize) -> (Array4, Array4) { + let mut s = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xDEAD_BEEF; + let mut next = || { + s = s.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = s; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + ((z ^ (z >> 31)) as f64 / u64::MAX as f64) as f32 + }; + let amp = Array4::from_shape_fn((frames, tx, rx, sub), |(f, _, _, c)| { + 0.5 + 0.4 * ((f as f32 * 0.3 + c as f32 * 0.1).sin()) + 0.05 * next() + }); + let ph = Array4::from_shape_fn((frames, tx, rx, sub), |(f, _, _, c)| { + 0.3 * ((f as f32 * 0.2 - c as f32 * 0.05).cos()) + 0.05 * next() + }); + (amp, ph) + } + + #[test] + fn loss_decreases_when_overfitting_one_batch() { + tch::manual_seed(7); + let (frames, tx, rx, sub) = (6usize, 1usize, 1usize, 8usize); + let n_tokens = (frames * tx * rx) as i64; + let windows: Vec<_> = (0..3).map(|i| synth(i, frames, tx, rx, sub)).collect(); + + let mut cfg = MaeConfig::default(); + cfg.token_dim = sub; + cfg.encoder_dim = 32; + cfg.decoder_dim = 16; + cfg.encoder_depth = 1; + cfg.decoder_depth = 1; + cfg.mask_ratio = 0.5; + cfg.validate().unwrap(); + + // sanity: the model's derived n_visible matches mask_csi_window's. + let m0 = mask_csi_window(windows[0].0.view(), windows[0].1.view(), cfg.mask_ratio, MaskStrategy::Random, 1).unwrap(); + assert_eq!(TokenLayout::from_window(windows[0].0.view()).n_tokens as i64, n_tokens); + + let vs = nn::VarStore::new(Device::Cpu); + let model = CsiMae::new(&vs.root(), &cfg, n_tokens); + assert_eq!(model.n_visible, m0.visible_idx.len() as i64); + assert_eq!(model.n_masked, m0.masked_idx.len() as i64); + + let mut opt = nn::Adam::default().build(&vs, 1e-2).unwrap(); + let batch = MaeBatch::from_windows(&windows, &cfg, 1, MaskStrategy::Random, Device::Cpu).unwrap(); + + let l0 = pretrain_step(&model, &mut opt, &batch); + let mut last = l0; + for _ in 0..60 { + last = pretrain_step(&model, &mut opt, &batch); + } + assert!(l0.is_finite() && last.is_finite(), "loss must be finite (l0={l0}, last={last})"); + assert!(last < 0.5 * l0, "overfitting one batch should cut loss in half: l0={l0}, last={last}"); + } + } } // ---------------------------------------------------------------------------