feat(train): MERIDIAN-MAE — csi_mae::model + pretrain loop + pretrain-mae bin (iter 2b, #68)

Real CSI masked-autoencoder behind feature `tch-backend` (ADR-027 §2.0):

  - CsiMae: dual-stream per-token amp+phase embed → fuse → residual-MLP encoder
    over the visible tokens → flatten-to-latent bottleneck → learned per-position
    query + broadcast latent → residual-MLP decoder → dec_amp_head / dec_ph_head
    → index_select the masked positions. (MLP-based v0; self-attention transformer
    blocks are iter 3.)
  - CsiMae::reconstruction_loss(pred_amp, pred_phase, tgt_amp, tgt_phase, phase_w)
    = MSE(amp) + phase_w * MSE(phase).
  - MaeBatch::from_windows — partition computed once from window 0 and reused
    across the batch (the bottleneck fixes n_tokens), ndarray → tch conversion.
  - pretrain_step(model, opt, batch) -> f64 — one Adam step, returns the loss.
  - src/bin/pretrain_mae.rs — synthetic-data pre-train driver (required-features
    = ["tch-backend"]); clap args for epochs/batch/samples/lr/mask-ratio/save.
  - #[cfg(feature="tch-backend")] smoke test: loss halves when overfitting one
    batch over 60 steps; also asserts model.n_visible/n_masked match
    mask_csi_window's clamping.

v0 limits (documented in the module): fixed n_tokens; batch-shared masking;
MSE on unwrapped phase (vs a circular loss). The dev box has no LibTorch, so the
tch path is CI-verified (`--features tch-backend`), not locally. The default
`cargo test -p wifi-densepose-train --no-default-features` stays green (121 lib
tests) — the model module and the bin are both feature-gated.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-05-11 13:05:27 -04:00
parent 48c7d03250
commit 7d26b15eef
3 changed files with 475 additions and 28 deletions

View File

@ -20,6 +20,11 @@ name = "verify-training"
path = "src/bin/verify_training.rs"
required-features = ["tch-backend"]
[[bin]]
name = "pretrain-mae"
path = "src/bin/pretrain_mae.rs"
required-features = ["tch-backend"]
[features]
default = []
tch-backend = ["tch"]

View File

@ -0,0 +1,108 @@
//! `pretrain-mae` — drive the MERIDIAN CSI masked-autoencoder pre-train on a
//! deterministic `SyntheticCsiDataset` (ADR-027 §2.0, prototype iteration 2).
//!
//! This is the *prototype* driver — it exercises the full pre-train loop
//! (mask → encode visible → reconstruct masked amplitude+phase → optimiser
//! step) end-to-end on synthetic CSI. Real cross-domain pre-training (iter 3+)
//! ingests heterogeneous capture — MM-Fi / Wi-Pose / `data/recordings/` /
//! multi-band virtual sub-carriers — and runs on GPU (`scripts/gcloud-train.sh`
//! / the cognitum project).
//!
//! ```text
//! cargo run -p wifi-densepose-train --features tch-backend --bin pretrain-mae -- --epochs 5
//! ```
//!
//! Only compiled with `--features tch-backend` (see Cargo.toml `required-features`).
use clap::Parser;
use tch::nn::OptimizerConfig;
use tch::{nn, Device};
use wifi_densepose_train::csi_mae::model::{pretrain_step, CsiMae, MaeBatch};
use wifi_densepose_train::csi_mae::{MaeConfig, MaskStrategy, TokenLayout};
use wifi_densepose_train::dataset::{CsiDataset, SyntheticConfig, SyntheticCsiDataset};
/// MERIDIAN CSI masked-autoencoder pre-train (prototype, synthetic data).
#[derive(Parser, Debug)]
#[command(name = "pretrain-mae", version, about)]
struct Cli {
/// Number of epochs over the synthetic dataset.
#[arg(long, default_value_t = 5)]
epochs: usize,
/// Mini-batch size (windows per optimiser step).
#[arg(long, default_value_t = 8)]
batch: usize,
/// Number of synthetic samples to generate.
#[arg(long, default_value_t = 256)]
samples: usize,
/// Adam learning rate.
#[arg(long, default_value_t = 1e-3)]
lr: f64,
/// Fraction of tokens masked per window.
#[arg(long, default_value_t = 0.75)]
mask_ratio: f64,
/// Optional path to save the pre-trained variable store (`.ot`).
#[arg(long)]
save: Option<String>,
}
fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
let _ = tracing_subscriber::fmt::try_init();
let ds = SyntheticCsiDataset::new(cli.samples, SyntheticConfig::default());
if ds.len() < cli.batch {
anyhow::bail!("need at least --batch ({}) samples, have {}", cli.batch, ds.len());
}
let s0 = ds.get(0)?;
let layout = TokenLayout::from_window(s0.amplitude.view());
let n_tokens = layout.n_tokens as i64;
let mut cfg = MaeConfig::default();
cfg.token_dim = layout.token_dim;
cfg.mask_ratio = cli.mask_ratio;
cfg.validate().map_err(anyhow::Error::msg)?;
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let model = CsiMae::new(&vs.root(), &cfg, n_tokens);
let mut opt = nn::Adam::default().build(&vs, cli.lr)?;
println!(
"pretrain-mae: device={device:?} n_tokens={n_tokens} token_dim={} V={} M={} samples={} batch={} epochs={} lr={} mask_ratio={}",
cfg.token_dim, model.n_visible, model.n_masked, cli.samples, cli.batch, cli.epochs, cli.lr, cli.mask_ratio
);
let mut step: u64 = 0;
for epoch in 0..cli.epochs {
let mut epoch_loss = 0.0_f64;
let mut nb = 0_usize;
let mut i = 0_usize;
while i + cli.batch <= ds.len() {
let mut windows = Vec::with_capacity(cli.batch);
for j in i..i + cli.batch {
let s = ds.get(j)?;
windows.push((s.amplitude, s.phase));
}
let seed = step.wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xC511_0027;
let batch = MaeBatch::from_windows(&windows, &cfg, seed, MaskStrategy::InfoGuided, device)
.map_err(anyhow::Error::msg)?;
let loss = pretrain_step(&model, &mut opt, &batch);
if !loss.is_finite() {
anyhow::bail!("non-finite loss at epoch {epoch} step {step}");
}
epoch_loss += loss;
nb += 1;
step += 1;
i += cli.batch;
}
let avg = if nb > 0 { epoch_loss / nb as f64 } else { f64::NAN };
println!("epoch {epoch}: avg reconstruction loss = {avg:.6} ({nb} batches)");
}
if let Some(path) = cli.save {
vs.save(&path)?;
println!("saved pre-trained variable store → {path}");
}
Ok(())
}

View File

@ -494,43 +494,377 @@ pub fn reassemble_tokens(
/// CSI masked-autoencoder networks (LibTorch / `tch`).
///
/// **Compiled only with `--features tch-backend`.** Not exercised by the default
/// `cargo test --workspace --no-default-features` CI job, so compile-checking it
/// requires a LibTorch toolchain.
/// `cargo test --workspace --no-default-features` CI job — compile-/run-checking
/// this submodule requires a LibTorch toolchain (`LIBTORCH` was unset on the dev
/// box that wrote it, so it is CI-verified only; if a `tch` API call below has
/// drifted, it's a localised fix).
///
/// Iteration-1 status: this is a v0 *skeleton* — the public interface
/// ([`model::CsiMae`], `forward`, `reconstruction_loss`, `model::pretrain_step`)
/// is fixed, but the encoder/decoder are placeholder MLPs and per-batch masking
/// is assumed shared. Transformer blocks, per-sample masking, and the
/// `pretrain-mae` binary land in iteration 2.
/// # v0 design (iteration 2)
///
/// A deliberately small **dual-stream** MAE, MLP-based (no self-attention yet —
/// transformer blocks are iteration 3):
///
/// ```text
/// visible amplitude [B, V, sub] ─► amp_embed ─┐
/// ├─ cat ─► tok_fuse ─► relu ─► enc_blocks(residual MLP) ─► [B, V, enc]
/// visible phase [B, V, sub] ─► ph_embed ─┘ │
/// reshape [B, V·enc] │
/// to_latent│
/// ▼
/// latent [B, enc]
/// from_latent│
/// ▼
/// learned per-position query pos_query [N, dec] + ─► relu ─► dec_blocks(residual MLP) ─► [B, N, dec]
/// (broadcast latent over N positions) │
/// ┌──────────────────────┤
/// dec_amp_head dec_ph_head
/// [B, N, sub] [B, N, sub]
/// index_select(masked positions) ─► (pred_amp, pred_ph) [B, M, sub]
/// ```
///
/// Limitations to lift later: (1) a *fixed* `n_tokens` (the bottleneck flattens
/// all visible token embeddings, so V — hence N and `mask_ratio` — is baked in
/// at `new()` time); (2) **batch-shared masking** (`MaeBatch::from_samples` masks
/// every sample in a batch with the same seed, so `masked_pos` is shared) —
/// per-sample masking via gather/scatter is iteration 3; (3) MSE on unwrapped
/// phase rather than a circular loss.
#[cfg(feature = "tch-backend")]
pub mod model {
use super::MaeConfig;
use super::{mask_csi_window, MaeConfig, MaskStrategy};
use ndarray::{Array2, Array4, Axis};
use tch::{nn, nn::Module, Device, Kind, Reduction, Tensor};
/// Placeholder for the CSI masked autoencoder network.
///
/// Iteration 2 fills this in (dual-stream encoder over visible tokens →
/// shared latent; narrow decoder over all positions with learned mask
/// tokens; reconstruct amplitude + phase). For now it exists so downstream
/// code can name the type and the build wiring is in place.
/// A residual MLP block: `LayerNorm(x + relu(Linear(x)))`.
#[derive(Debug)]
pub struct CsiMae {
_cfg: MaeConfig,
struct ResidualMlp {
lin: nn::Linear,
ln: nn::LayerNorm,
}
impl CsiMae {
/// Build a `CsiMae` under the given variable-store path with `cfg`.
///
/// Iteration-1 stub — see the module docs.
pub fn new(_vs: &tch::nn::Path, cfg: &MaeConfig) -> Self {
Self { _cfg: cfg.clone() }
impl ResidualMlp {
fn new(p: &nn::Path, dim: i64) -> Self {
Self {
lin: nn::linear(p / "lin", dim, dim, Default::default()),
ln: nn::layer_norm(p / "ln", vec![dim], Default::default()),
}
}
fn forward(&self, x: &Tensor) -> Tensor {
self.ln.forward(&(x + self.lin.forward(x).relu()))
}
}
// NOTE (iteration 2): add
// impl CsiMae { fn forward(&self, vis_amp, vis_phase, vis_idx, n_tokens) -> (pred_amp, pred_phase) }
// fn reconstruction_loss(pred_amp, pred_phase, tgt_amp, tgt_phase, phase_w) -> Tensor
// fn pretrain_step(model, opt, batch) -> f64
// plus a `bin/pretrain_mae.rs` driving SyntheticCsiDataset / MmFiDataset.
/// The CSI masked autoencoder. See the module docs for the v0 design.
#[derive(Debug)]
pub struct CsiMae {
/// Hyper-parameters this model was built with.
pub cfg: MaeConfig,
/// Number of tokens per window (`T·tx·rx`) — fixed at construction.
pub n_tokens: i64,
/// Number of masked (target) tokens per window.
pub n_masked: i64,
/// Number of visible (encoder-input) tokens per window.
pub n_visible: i64,
device: Device,
amp_embed: nn::Linear,
ph_embed: nn::Linear,
tok_fuse: nn::Linear,
enc_blocks: Vec<ResidualMlp>,
to_latent: nn::Linear,
from_latent: nn::Linear,
/// Learned per-position query, shape `[n_tokens, decoder_dim]`.
pos_query: Tensor,
dec_blocks: Vec<ResidualMlp>,
dec_amp_head: nn::Linear,
dec_ph_head: nn::Linear,
}
impl CsiMae {
/// Build a `CsiMae` under `vs` for windows of exactly `n_tokens` tokens.
///
/// `n_tokens` is fixed because the bottleneck flattens all visible token
/// embeddings; it must equal `T·tx·rx` of the windows fed at train/eval
/// time (e.g. `TokenLayout::from_window(sample.amplitude.view()).n_tokens`).
pub fn new(vs: &nn::Path, cfg: &MaeConfig, n_tokens: i64) -> Self {
assert!(n_tokens >= 2, "n_tokens must be >= 2");
let td = cfg.token_dim as i64;
let enc = cfg.encoder_dim as i64;
let dec = cfg.decoder_dim as i64;
// Mirror mask_csi_window's clamping so the shapes line up exactly.
let mut n_mask = (cfg.mask_ratio * n_tokens as f64).round() as i64;
if n_mask < 1 {
n_mask = 1;
}
if n_mask >= n_tokens {
n_mask = n_tokens - 1;
}
let n_vis = n_tokens - n_mask;
let enc_blocks = (0..cfg.encoder_depth)
.map(|i| ResidualMlp::new(&(vs / "enc" / i), enc))
.collect();
let dec_blocks = (0..cfg.decoder_depth)
.map(|i| ResidualMlp::new(&(vs / "dec" / i), dec))
.collect();
let pos_query = vs.var(
"pos_query",
&[n_tokens, dec],
nn::Init::Randn { mean: 0.0, stdev: 0.02 },
);
Self {
cfg: cfg.clone(),
n_tokens,
n_masked: n_mask,
n_visible: n_vis,
device: vs.device(),
amp_embed: nn::linear(vs / "amp_embed", td, enc, Default::default()),
ph_embed: nn::linear(vs / "ph_embed", td, enc, Default::default()),
tok_fuse: nn::linear(vs / "tok_fuse", 2 * enc, enc, Default::default()),
enc_blocks,
to_latent: nn::linear(vs / "to_latent", n_vis * enc, enc, Default::default()),
from_latent: nn::linear(vs / "from_latent", enc, dec, Default::default()),
pos_query,
dec_blocks,
dec_amp_head: nn::linear(vs / "dec_amp_head", dec, td, Default::default()),
dec_ph_head: nn::linear(vs / "dec_ph_head", dec, td, Default::default()),
}
}
/// Reconstruct the masked amplitude & phase tokens.
///
/// * `vis_amp`, `vis_phase` — `[B, n_visible, token_dim]`.
/// * `masked_pos` — the `n_masked` masked token indices (shared across
/// the batch in this v0; see the module docs).
/// * returns `(pred_amp, pred_phase)`, each `[B, n_masked, token_dim]`.
pub fn forward(
&self,
vis_amp: &Tensor,
vis_phase: &Tensor,
masked_pos: &[i64],
train: bool,
) -> (Tensor, Tensor) {
let _ = train; // dropout/layernorm-train hooks would go here in iter 3
let enc = self.cfg.encoder_dim as i64;
let b = vis_amp.size()[0];
// Per-token dual-stream embed → fuse.
let a = self.amp_embed.forward(vis_amp); // [B, V, enc]
let p = self.ph_embed.forward(vis_phase); // [B, V, enc]
let mut t = self.tok_fuse.forward(&Tensor::cat(&[&a, &p], -1)).relu(); // [B, V, enc]
for blk in &self.enc_blocks {
t = blk.forward(&t);
}
// Bottleneck: flatten visible token embeddings → latent [B, enc].
let flat = t.reshape([b, self.n_visible * enc]);
let latent = self.to_latent.forward(&flat).relu(); // [B, enc]
// Decoder: learned per-position query + broadcast latent context.
let ctx = self.from_latent.forward(&latent).unsqueeze(1); // [B, 1, dec]
let mut d = (self.pos_query.unsqueeze(0) + ctx).relu(); // [B, N, dec]
for blk in &self.dec_blocks {
d = blk.forward(&d);
}
let all_amp = self.dec_amp_head.forward(&d); // [B, N, td]
let all_ph = self.dec_ph_head.forward(&d); // [B, N, td]
let idx = Tensor::from_slice(masked_pos).to_device(self.device); // [M] i64
(all_amp.index_select(1, &idx), all_ph.index_select(1, &idx))
}
/// Dual-stream reconstruction loss: `MSE(pred_amp, tgt_amp) + w·MSE(pred_phase, tgt_phase)`.
pub fn reconstruction_loss(
pred_amp: &Tensor,
pred_phase: &Tensor,
tgt_amp: &Tensor,
tgt_phase: &Tensor,
phase_w: f64,
) -> Tensor {
let amp_l = pred_amp.mse_loss(tgt_amp, Reduction::Mean);
let ph_l = pred_phase.mse_loss(tgt_phase, Reduction::Mean);
amp_l + ph_l * phase_w
}
}
/// One batch of masked CSI windows ready for [`pretrain_step`].
///
/// All windows in the batch are masked with the *same* seed (v0
/// simplification), so `masked_pos` / `n_visible` / `n_masked` are shared.
#[derive(Debug)]
pub struct MaeBatch {
/// Visible amplitude tokens, `[B, n_visible, token_dim]`.
pub vis_amp: Tensor,
/// Visible phase tokens, `[B, n_visible, token_dim]`.
pub vis_phase: Tensor,
/// Target (masked) amplitude tokens, `[B, n_masked, token_dim]`.
pub tgt_amp: Tensor,
/// Target (masked) phase tokens, `[B, n_masked, token_dim]`.
pub tgt_phase: Tensor,
/// Masked token indices (length `n_masked`), shared across the batch.
pub masked_pos: Vec<i64>,
/// `T·tx·rx` of every window in the batch.
pub n_tokens: i64,
}
impl MaeBatch {
/// Build a batch from `(amplitude, phase)` windows (each `[T,tx,rx,sub]`).
///
/// The visible/masked token partition is computed once from the **first**
/// window (via [`mask_csi_window`] with `strategy`/`seed`) and reused for
/// every window in the batch, so `masked_pos` is shared — the
/// fixed-`n_tokens` model requires it. Every window must have the same
/// `[T,tx,rx,sub]` shape. Returns `Err` on a shape mismatch / empty batch.
pub fn from_windows(
windows: &[(Array4<f32>, Array4<f32>)],
cfg: &MaeConfig,
seed: u64,
strategy: MaskStrategy,
device: Device,
) -> Result<MaeBatch, String> {
if windows.is_empty() {
return Err("MaeBatch::from_windows: empty batch".into());
}
let td = cfg.token_dim;
// Partition from window 0; reuse it for the rest of the batch.
let m0 = mask_csi_window(windows[0].0.view(), windows[0].1.view(), cfg.mask_ratio, strategy, seed)
.map_err(|e| format!("MaeBatch window 0: {e}"))?;
if m0.layout.token_dim != td {
return Err(format!("MaeBatch window 0: token_dim {} != cfg.token_dim {td}", m0.layout.token_dim));
}
let n_tokens = m0.layout.n_tokens as i64;
let visible_idx = m0.visible_idx.clone();
let masked_idx = m0.masked_idx.clone();
let masked_pos: Vec<i64> = masked_idx.iter().map(|&x| x as i64).collect();
let gather = |grid: &Array2<f32>, idx: &[usize]| -> Array2<f32> {
let mut out = Array2::<f32>::zeros((idx.len(), td));
for (r, &i) in idx.iter().enumerate() {
out.row_mut(r).assign(&grid.row(i));
}
out
};
let mut vis_amp_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
let mut vis_ph_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
let mut tgt_amp_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
let mut tgt_ph_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
for (i, (amp, ph)) in windows.iter().enumerate() {
let layout = super::TokenLayout::from_window(amp.view());
if layout.token_dim != td || layout.n_tokens as i64 != n_tokens {
return Err(format!(
"MaeBatch window {i}: shape {:?} incompatible with batch (n_tokens={n_tokens}, token_dim={td})",
amp.shape()
));
}
if amp.shape() != ph.shape() {
return Err(format!("MaeBatch window {i}: amplitude/phase shape mismatch"));
}
let amp_flat = super::TokenLayout::flatten(amp.view());
let ph_flat = super::TokenLayout::flatten(ph.view());
vis_amp_rows.push(gather(&amp_flat, &visible_idx));
vis_ph_rows.push(gather(&ph_flat, &visible_idx));
tgt_amp_rows.push(gather(&amp_flat, &masked_idx));
tgt_ph_rows.push(gather(&ph_flat, &masked_idx));
}
let stack3 = |rows: &[Array2<f32>]| -> Tensor {
let views: Vec<_> = rows.iter().map(|r| r.view()).collect();
let a3 = ndarray::stack(Axis(0), &views).expect("uniform [k, td] rows stack");
let (b, k, d) = a3.dim();
let std = a3.as_standard_layout();
Tensor::from_slice(std.as_slice().expect("contiguous"))
.reshape([b as i64, k as i64, d as i64])
.to_device(device)
};
Ok(MaeBatch {
vis_amp: stack3(&vis_amp_rows),
vis_phase: stack3(&vis_ph_rows),
tgt_amp: stack3(&tgt_amp_rows),
tgt_phase: stack3(&tgt_ph_rows),
masked_pos,
n_tokens,
})
}
}
/// Run one optimiser step on `batch`. Returns the (scalar) reconstruction loss.
pub fn pretrain_step(model: &CsiMae, opt: &mut nn::Optimizer, batch: &MaeBatch) -> f64 {
let (pred_amp, pred_ph) = model.forward(&batch.vis_amp, &batch.vis_phase, &batch.masked_pos, true);
let loss = CsiMae::reconstruction_loss(
&pred_amp,
&pred_ph,
&batch.tgt_amp,
&batch.tgt_phase,
model.cfg.phase_loss_weight,
);
opt.backward_step(&loss);
f64::try_from(&loss).unwrap_or(f64::NAN)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csi_mae::{MaeConfig, MaskStrategy, TokenLayout};
use tch::nn::OptimizerConfig;
/// Deterministic synthetic CSI window `[T, tx, rx, sub]` with structure.
fn synth(seed: u64, frames: usize, tx: usize, rx: usize, sub: usize) -> (Array4<f32>, Array4<f32>) {
let mut s = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xDEAD_BEEF;
let mut next = || {
s = s.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = s;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
((z ^ (z >> 31)) as f64 / u64::MAX as f64) as f32
};
let amp = Array4::from_shape_fn((frames, tx, rx, sub), |(f, _, _, c)| {
0.5 + 0.4 * ((f as f32 * 0.3 + c as f32 * 0.1).sin()) + 0.05 * next()
});
let ph = Array4::from_shape_fn((frames, tx, rx, sub), |(f, _, _, c)| {
0.3 * ((f as f32 * 0.2 - c as f32 * 0.05).cos()) + 0.05 * next()
});
(amp, ph)
}
#[test]
fn loss_decreases_when_overfitting_one_batch() {
tch::manual_seed(7);
let (frames, tx, rx, sub) = (6usize, 1usize, 1usize, 8usize);
let n_tokens = (frames * tx * rx) as i64;
let windows: Vec<_> = (0..3).map(|i| synth(i, frames, tx, rx, sub)).collect();
let mut cfg = MaeConfig::default();
cfg.token_dim = sub;
cfg.encoder_dim = 32;
cfg.decoder_dim = 16;
cfg.encoder_depth = 1;
cfg.decoder_depth = 1;
cfg.mask_ratio = 0.5;
cfg.validate().unwrap();
// sanity: the model's derived n_visible matches mask_csi_window's.
let m0 = mask_csi_window(windows[0].0.view(), windows[0].1.view(), cfg.mask_ratio, MaskStrategy::Random, 1).unwrap();
assert_eq!(TokenLayout::from_window(windows[0].0.view()).n_tokens as i64, n_tokens);
let vs = nn::VarStore::new(Device::Cpu);
let model = CsiMae::new(&vs.root(), &cfg, n_tokens);
assert_eq!(model.n_visible, m0.visible_idx.len() as i64);
assert_eq!(model.n_masked, m0.masked_idx.len() as i64);
let mut opt = nn::Adam::default().build(&vs, 1e-2).unwrap();
let batch = MaeBatch::from_windows(&windows, &cfg, 1, MaskStrategy::Random, Device::Cpu).unwrap();
let l0 = pretrain_step(&model, &mut opt, &batch);
let mut last = l0;
for _ in 0..60 {
last = pretrain_step(&model, &mut opt, &batch);
}
assert!(l0.is_finite() && last.is_finite(), "loss must be finite (l0={l0}, last={last})");
assert!(last < 0.5 * l0, "overfitting one batch should cut loss in half: l0={l0}, last={last}");
}
}
}
// ---------------------------------------------------------------------------