feat(train): MERIDIAN-MAE — csi_mae::model + pretrain loop + pretrain-mae bin (iter 2b, #68)
Real CSI masked-autoencoder behind feature `tch-backend` (ADR-027 §2.0):
- CsiMae: dual-stream per-token amp+phase embed → fuse → residual-MLP encoder
over the visible tokens → flatten-to-latent bottleneck → learned per-position
query + broadcast latent → residual-MLP decoder → dec_amp_head / dec_ph_head
→ index_select the masked positions. (MLP-based v0; self-attention transformer
blocks are iter 3.)
- CsiMae::reconstruction_loss(pred_amp, pred_phase, tgt_amp, tgt_phase, phase_w)
= MSE(amp) + phase_w * MSE(phase).
- MaeBatch::from_windows — partition computed once from window 0 and reused
across the batch (the bottleneck fixes n_tokens), ndarray → tch conversion.
- pretrain_step(model, opt, batch) -> f64 — one Adam step, returns the loss.
- src/bin/pretrain_mae.rs — synthetic-data pre-train driver (required-features
= ["tch-backend"]); clap args for epochs/batch/samples/lr/mask-ratio/save.
- #[cfg(feature="tch-backend")] smoke test: loss halves when overfitting one
batch over 60 steps; also asserts model.n_visible/n_masked match
mask_csi_window's clamping.
v0 limits (documented in the module): fixed n_tokens; batch-shared masking;
MSE on unwrapped phase (vs a circular loss). The dev box has no LibTorch, so the
tch path is CI-verified (`--features tch-backend`), not locally. The default
`cargo test -p wifi-densepose-train --no-default-features` stays green (121 lib
tests) — the model module and the bin are both feature-gated.
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
48c7d03250
commit
7d26b15eef
|
|
@ -20,6 +20,11 @@ name = "verify-training"
|
|||
path = "src/bin/verify_training.rs"
|
||||
required-features = ["tch-backend"]
|
||||
|
||||
[[bin]]
|
||||
name = "pretrain-mae"
|
||||
path = "src/bin/pretrain_mae.rs"
|
||||
required-features = ["tch-backend"]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tch-backend = ["tch"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
//! `pretrain-mae` — drive the MERIDIAN CSI masked-autoencoder pre-train on a
|
||||
//! deterministic `SyntheticCsiDataset` (ADR-027 §2.0, prototype iteration 2).
|
||||
//!
|
||||
//! This is the *prototype* driver — it exercises the full pre-train loop
|
||||
//! (mask → encode visible → reconstruct masked amplitude+phase → optimiser
|
||||
//! step) end-to-end on synthetic CSI. Real cross-domain pre-training (iter 3+)
|
||||
//! ingests heterogeneous capture — MM-Fi / Wi-Pose / `data/recordings/` /
|
||||
//! multi-band virtual sub-carriers — and runs on GPU (`scripts/gcloud-train.sh`
|
||||
//! / the cognitum project).
|
||||
//!
|
||||
//! ```text
|
||||
//! cargo run -p wifi-densepose-train --features tch-backend --bin pretrain-mae -- --epochs 5
|
||||
//! ```
|
||||
//!
|
||||
//! Only compiled with `--features tch-backend` (see Cargo.toml `required-features`).
|
||||
|
||||
use clap::Parser;
|
||||
use tch::nn::OptimizerConfig;
|
||||
use tch::{nn, Device};
|
||||
|
||||
use wifi_densepose_train::csi_mae::model::{pretrain_step, CsiMae, MaeBatch};
|
||||
use wifi_densepose_train::csi_mae::{MaeConfig, MaskStrategy, TokenLayout};
|
||||
use wifi_densepose_train::dataset::{CsiDataset, SyntheticConfig, SyntheticCsiDataset};
|
||||
|
||||
/// MERIDIAN CSI masked-autoencoder pre-train (prototype, synthetic data).
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "pretrain-mae", version, about)]
|
||||
struct Cli {
|
||||
/// Number of epochs over the synthetic dataset.
|
||||
#[arg(long, default_value_t = 5)]
|
||||
epochs: usize,
|
||||
/// Mini-batch size (windows per optimiser step).
|
||||
#[arg(long, default_value_t = 8)]
|
||||
batch: usize,
|
||||
/// Number of synthetic samples to generate.
|
||||
#[arg(long, default_value_t = 256)]
|
||||
samples: usize,
|
||||
/// Adam learning rate.
|
||||
#[arg(long, default_value_t = 1e-3)]
|
||||
lr: f64,
|
||||
/// Fraction of tokens masked per window.
|
||||
#[arg(long, default_value_t = 0.75)]
|
||||
mask_ratio: f64,
|
||||
/// Optional path to save the pre-trained variable store (`.ot`).
|
||||
#[arg(long)]
|
||||
save: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let cli = Cli::parse();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let ds = SyntheticCsiDataset::new(cli.samples, SyntheticConfig::default());
|
||||
if ds.len() < cli.batch {
|
||||
anyhow::bail!("need at least --batch ({}) samples, have {}", cli.batch, ds.len());
|
||||
}
|
||||
let s0 = ds.get(0)?;
|
||||
let layout = TokenLayout::from_window(s0.amplitude.view());
|
||||
let n_tokens = layout.n_tokens as i64;
|
||||
|
||||
let mut cfg = MaeConfig::default();
|
||||
cfg.token_dim = layout.token_dim;
|
||||
cfg.mask_ratio = cli.mask_ratio;
|
||||
cfg.validate().map_err(anyhow::Error::msg)?;
|
||||
|
||||
let device = Device::cuda_if_available();
|
||||
let vs = nn::VarStore::new(device);
|
||||
let model = CsiMae::new(&vs.root(), &cfg, n_tokens);
|
||||
let mut opt = nn::Adam::default().build(&vs, cli.lr)?;
|
||||
|
||||
println!(
|
||||
"pretrain-mae: device={device:?} n_tokens={n_tokens} token_dim={} V={} M={} samples={} batch={} epochs={} lr={} mask_ratio={}",
|
||||
cfg.token_dim, model.n_visible, model.n_masked, cli.samples, cli.batch, cli.epochs, cli.lr, cli.mask_ratio
|
||||
);
|
||||
|
||||
let mut step: u64 = 0;
|
||||
for epoch in 0..cli.epochs {
|
||||
let mut epoch_loss = 0.0_f64;
|
||||
let mut nb = 0_usize;
|
||||
let mut i = 0_usize;
|
||||
while i + cli.batch <= ds.len() {
|
||||
let mut windows = Vec::with_capacity(cli.batch);
|
||||
for j in i..i + cli.batch {
|
||||
let s = ds.get(j)?;
|
||||
windows.push((s.amplitude, s.phase));
|
||||
}
|
||||
let seed = step.wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xC511_0027;
|
||||
let batch = MaeBatch::from_windows(&windows, &cfg, seed, MaskStrategy::InfoGuided, device)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let loss = pretrain_step(&model, &mut opt, &batch);
|
||||
if !loss.is_finite() {
|
||||
anyhow::bail!("non-finite loss at epoch {epoch} step {step}");
|
||||
}
|
||||
epoch_loss += loss;
|
||||
nb += 1;
|
||||
step += 1;
|
||||
i += cli.batch;
|
||||
}
|
||||
let avg = if nb > 0 { epoch_loss / nb as f64 } else { f64::NAN };
|
||||
println!("epoch {epoch}: avg reconstruction loss = {avg:.6} ({nb} batches)");
|
||||
}
|
||||
|
||||
if let Some(path) = cli.save {
|
||||
vs.save(&path)?;
|
||||
println!("saved pre-trained variable store → {path}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -494,43 +494,377 @@ pub fn reassemble_tokens(
|
|||
/// CSI masked-autoencoder networks (LibTorch / `tch`).
|
||||
///
|
||||
/// **Compiled only with `--features tch-backend`.** Not exercised by the default
|
||||
/// `cargo test --workspace --no-default-features` CI job, so compile-checking it
|
||||
/// requires a LibTorch toolchain.
|
||||
/// `cargo test --workspace --no-default-features` CI job — compile-/run-checking
|
||||
/// this submodule requires a LibTorch toolchain (`LIBTORCH` was unset on the dev
|
||||
/// box that wrote it, so it is CI-verified only; if a `tch` API call below has
|
||||
/// drifted, it's a localised fix).
|
||||
///
|
||||
/// Iteration-1 status: this is a v0 *skeleton* — the public interface
|
||||
/// ([`model::CsiMae`], `forward`, `reconstruction_loss`, `model::pretrain_step`)
|
||||
/// is fixed, but the encoder/decoder are placeholder MLPs and per-batch masking
|
||||
/// is assumed shared. Transformer blocks, per-sample masking, and the
|
||||
/// `pretrain-mae` binary land in iteration 2.
|
||||
/// # v0 design (iteration 2)
|
||||
///
|
||||
/// A deliberately small **dual-stream** MAE, MLP-based (no self-attention yet —
|
||||
/// transformer blocks are iteration 3):
|
||||
///
|
||||
/// ```text
|
||||
/// visible amplitude [B, V, sub] ─► amp_embed ─┐
|
||||
/// ├─ cat ─► tok_fuse ─► relu ─► enc_blocks(residual MLP) ─► [B, V, enc]
|
||||
/// visible phase [B, V, sub] ─► ph_embed ─┘ │
|
||||
/// reshape [B, V·enc] │
|
||||
/// to_latent│
|
||||
/// ▼
|
||||
/// latent [B, enc]
|
||||
/// from_latent│
|
||||
/// ▼
|
||||
/// learned per-position query pos_query [N, dec] + ─► relu ─► dec_blocks(residual MLP) ─► [B, N, dec]
|
||||
/// (broadcast latent over N positions) │
|
||||
/// ┌──────────────────────┤
|
||||
/// dec_amp_head dec_ph_head
|
||||
/// [B, N, sub] [B, N, sub]
|
||||
/// index_select(masked positions) ─► (pred_amp, pred_ph) [B, M, sub]
|
||||
/// ```
|
||||
///
|
||||
/// Limitations to lift later: (1) a *fixed* `n_tokens` (the bottleneck flattens
|
||||
/// all visible token embeddings, so V — hence N and `mask_ratio` — is baked in
|
||||
/// at `new()` time); (2) **batch-shared masking** (`MaeBatch::from_samples` masks
|
||||
/// every sample in a batch with the same seed, so `masked_pos` is shared) —
|
||||
/// per-sample masking via gather/scatter is iteration 3; (3) MSE on unwrapped
|
||||
/// phase rather than a circular loss.
|
||||
#[cfg(feature = "tch-backend")]
|
||||
pub mod model {
|
||||
use super::MaeConfig;
|
||||
use super::{mask_csi_window, MaeConfig, MaskStrategy};
|
||||
use ndarray::{Array2, Array4, Axis};
|
||||
use tch::{nn, nn::Module, Device, Kind, Reduction, Tensor};
|
||||
|
||||
/// Placeholder for the CSI masked autoencoder network.
|
||||
///
|
||||
/// Iteration 2 fills this in (dual-stream encoder over visible tokens →
|
||||
/// shared latent; narrow decoder over all positions with learned mask
|
||||
/// tokens; reconstruct amplitude + phase). For now it exists so downstream
|
||||
/// code can name the type and the build wiring is in place.
|
||||
/// A residual MLP block: `LayerNorm(x + relu(Linear(x)))`.
|
||||
#[derive(Debug)]
|
||||
pub struct CsiMae {
|
||||
_cfg: MaeConfig,
|
||||
struct ResidualMlp {
|
||||
lin: nn::Linear,
|
||||
ln: nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl CsiMae {
|
||||
/// Build a `CsiMae` under the given variable-store path with `cfg`.
|
||||
///
|
||||
/// Iteration-1 stub — see the module docs.
|
||||
pub fn new(_vs: &tch::nn::Path, cfg: &MaeConfig) -> Self {
|
||||
Self { _cfg: cfg.clone() }
|
||||
impl ResidualMlp {
|
||||
fn new(p: &nn::Path, dim: i64) -> Self {
|
||||
Self {
|
||||
lin: nn::linear(p / "lin", dim, dim, Default::default()),
|
||||
ln: nn::layer_norm(p / "ln", vec![dim], Default::default()),
|
||||
}
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Tensor {
|
||||
self.ln.forward(&(x + self.lin.forward(x).relu()))
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE (iteration 2): add
|
||||
// impl CsiMae { fn forward(&self, vis_amp, vis_phase, vis_idx, n_tokens) -> (pred_amp, pred_phase) }
|
||||
// fn reconstruction_loss(pred_amp, pred_phase, tgt_amp, tgt_phase, phase_w) -> Tensor
|
||||
// fn pretrain_step(model, opt, batch) -> f64
|
||||
// plus a `bin/pretrain_mae.rs` driving SyntheticCsiDataset / MmFiDataset.
|
||||
/// The CSI masked autoencoder. See the module docs for the v0 design.
|
||||
#[derive(Debug)]
|
||||
pub struct CsiMae {
|
||||
/// Hyper-parameters this model was built with.
|
||||
pub cfg: MaeConfig,
|
||||
/// Number of tokens per window (`T·tx·rx`) — fixed at construction.
|
||||
pub n_tokens: i64,
|
||||
/// Number of masked (target) tokens per window.
|
||||
pub n_masked: i64,
|
||||
/// Number of visible (encoder-input) tokens per window.
|
||||
pub n_visible: i64,
|
||||
device: Device,
|
||||
amp_embed: nn::Linear,
|
||||
ph_embed: nn::Linear,
|
||||
tok_fuse: nn::Linear,
|
||||
enc_blocks: Vec<ResidualMlp>,
|
||||
to_latent: nn::Linear,
|
||||
from_latent: nn::Linear,
|
||||
/// Learned per-position query, shape `[n_tokens, decoder_dim]`.
|
||||
pos_query: Tensor,
|
||||
dec_blocks: Vec<ResidualMlp>,
|
||||
dec_amp_head: nn::Linear,
|
||||
dec_ph_head: nn::Linear,
|
||||
}
|
||||
|
||||
impl CsiMae {
|
||||
/// Build a `CsiMae` under `vs` for windows of exactly `n_tokens` tokens.
|
||||
///
|
||||
/// `n_tokens` is fixed because the bottleneck flattens all visible token
|
||||
/// embeddings; it must equal `T·tx·rx` of the windows fed at train/eval
|
||||
/// time (e.g. `TokenLayout::from_window(sample.amplitude.view()).n_tokens`).
|
||||
pub fn new(vs: &nn::Path, cfg: &MaeConfig, n_tokens: i64) -> Self {
|
||||
assert!(n_tokens >= 2, "n_tokens must be >= 2");
|
||||
let td = cfg.token_dim as i64;
|
||||
let enc = cfg.encoder_dim as i64;
|
||||
let dec = cfg.decoder_dim as i64;
|
||||
// Mirror mask_csi_window's clamping so the shapes line up exactly.
|
||||
let mut n_mask = (cfg.mask_ratio * n_tokens as f64).round() as i64;
|
||||
if n_mask < 1 {
|
||||
n_mask = 1;
|
||||
}
|
||||
if n_mask >= n_tokens {
|
||||
n_mask = n_tokens - 1;
|
||||
}
|
||||
let n_vis = n_tokens - n_mask;
|
||||
|
||||
let enc_blocks = (0..cfg.encoder_depth)
|
||||
.map(|i| ResidualMlp::new(&(vs / "enc" / i), enc))
|
||||
.collect();
|
||||
let dec_blocks = (0..cfg.decoder_depth)
|
||||
.map(|i| ResidualMlp::new(&(vs / "dec" / i), dec))
|
||||
.collect();
|
||||
let pos_query = vs.var(
|
||||
"pos_query",
|
||||
&[n_tokens, dec],
|
||||
nn::Init::Randn { mean: 0.0, stdev: 0.02 },
|
||||
);
|
||||
|
||||
Self {
|
||||
cfg: cfg.clone(),
|
||||
n_tokens,
|
||||
n_masked: n_mask,
|
||||
n_visible: n_vis,
|
||||
device: vs.device(),
|
||||
amp_embed: nn::linear(vs / "amp_embed", td, enc, Default::default()),
|
||||
ph_embed: nn::linear(vs / "ph_embed", td, enc, Default::default()),
|
||||
tok_fuse: nn::linear(vs / "tok_fuse", 2 * enc, enc, Default::default()),
|
||||
enc_blocks,
|
||||
to_latent: nn::linear(vs / "to_latent", n_vis * enc, enc, Default::default()),
|
||||
from_latent: nn::linear(vs / "from_latent", enc, dec, Default::default()),
|
||||
pos_query,
|
||||
dec_blocks,
|
||||
dec_amp_head: nn::linear(vs / "dec_amp_head", dec, td, Default::default()),
|
||||
dec_ph_head: nn::linear(vs / "dec_ph_head", dec, td, Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconstruct the masked amplitude & phase tokens.
|
||||
///
|
||||
/// * `vis_amp`, `vis_phase` — `[B, n_visible, token_dim]`.
|
||||
/// * `masked_pos` — the `n_masked` masked token indices (shared across
|
||||
/// the batch in this v0; see the module docs).
|
||||
/// * returns `(pred_amp, pred_phase)`, each `[B, n_masked, token_dim]`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
vis_amp: &Tensor,
|
||||
vis_phase: &Tensor,
|
||||
masked_pos: &[i64],
|
||||
train: bool,
|
||||
) -> (Tensor, Tensor) {
|
||||
let _ = train; // dropout/layernorm-train hooks would go here in iter 3
|
||||
let enc = self.cfg.encoder_dim as i64;
|
||||
let b = vis_amp.size()[0];
|
||||
|
||||
// Per-token dual-stream embed → fuse.
|
||||
let a = self.amp_embed.forward(vis_amp); // [B, V, enc]
|
||||
let p = self.ph_embed.forward(vis_phase); // [B, V, enc]
|
||||
let mut t = self.tok_fuse.forward(&Tensor::cat(&[&a, &p], -1)).relu(); // [B, V, enc]
|
||||
for blk in &self.enc_blocks {
|
||||
t = blk.forward(&t);
|
||||
}
|
||||
|
||||
// Bottleneck: flatten visible token embeddings → latent [B, enc].
|
||||
let flat = t.reshape([b, self.n_visible * enc]);
|
||||
let latent = self.to_latent.forward(&flat).relu(); // [B, enc]
|
||||
|
||||
// Decoder: learned per-position query + broadcast latent context.
|
||||
let ctx = self.from_latent.forward(&latent).unsqueeze(1); // [B, 1, dec]
|
||||
let mut d = (self.pos_query.unsqueeze(0) + ctx).relu(); // [B, N, dec]
|
||||
for blk in &self.dec_blocks {
|
||||
d = blk.forward(&d);
|
||||
}
|
||||
|
||||
let all_amp = self.dec_amp_head.forward(&d); // [B, N, td]
|
||||
let all_ph = self.dec_ph_head.forward(&d); // [B, N, td]
|
||||
let idx = Tensor::from_slice(masked_pos).to_device(self.device); // [M] i64
|
||||
(all_amp.index_select(1, &idx), all_ph.index_select(1, &idx))
|
||||
}
|
||||
|
||||
/// Dual-stream reconstruction loss: `MSE(pred_amp, tgt_amp) + w·MSE(pred_phase, tgt_phase)`.
|
||||
pub fn reconstruction_loss(
|
||||
pred_amp: &Tensor,
|
||||
pred_phase: &Tensor,
|
||||
tgt_amp: &Tensor,
|
||||
tgt_phase: &Tensor,
|
||||
phase_w: f64,
|
||||
) -> Tensor {
|
||||
let amp_l = pred_amp.mse_loss(tgt_amp, Reduction::Mean);
|
||||
let ph_l = pred_phase.mse_loss(tgt_phase, Reduction::Mean);
|
||||
amp_l + ph_l * phase_w
|
||||
}
|
||||
}
|
||||
|
||||
/// One batch of masked CSI windows ready for [`pretrain_step`].
|
||||
///
|
||||
/// All windows in the batch are masked with the *same* seed (v0
|
||||
/// simplification), so `masked_pos` / `n_visible` / `n_masked` are shared.
|
||||
#[derive(Debug)]
|
||||
pub struct MaeBatch {
|
||||
/// Visible amplitude tokens, `[B, n_visible, token_dim]`.
|
||||
pub vis_amp: Tensor,
|
||||
/// Visible phase tokens, `[B, n_visible, token_dim]`.
|
||||
pub vis_phase: Tensor,
|
||||
/// Target (masked) amplitude tokens, `[B, n_masked, token_dim]`.
|
||||
pub tgt_amp: Tensor,
|
||||
/// Target (masked) phase tokens, `[B, n_masked, token_dim]`.
|
||||
pub tgt_phase: Tensor,
|
||||
/// Masked token indices (length `n_masked`), shared across the batch.
|
||||
pub masked_pos: Vec<i64>,
|
||||
/// `T·tx·rx` of every window in the batch.
|
||||
pub n_tokens: i64,
|
||||
}
|
||||
|
||||
impl MaeBatch {
|
||||
/// Build a batch from `(amplitude, phase)` windows (each `[T,tx,rx,sub]`).
|
||||
///
|
||||
/// The visible/masked token partition is computed once from the **first**
|
||||
/// window (via [`mask_csi_window`] with `strategy`/`seed`) and reused for
|
||||
/// every window in the batch, so `masked_pos` is shared — the
|
||||
/// fixed-`n_tokens` model requires it. Every window must have the same
|
||||
/// `[T,tx,rx,sub]` shape. Returns `Err` on a shape mismatch / empty batch.
|
||||
pub fn from_windows(
|
||||
windows: &[(Array4<f32>, Array4<f32>)],
|
||||
cfg: &MaeConfig,
|
||||
seed: u64,
|
||||
strategy: MaskStrategy,
|
||||
device: Device,
|
||||
) -> Result<MaeBatch, String> {
|
||||
if windows.is_empty() {
|
||||
return Err("MaeBatch::from_windows: empty batch".into());
|
||||
}
|
||||
let td = cfg.token_dim;
|
||||
|
||||
// Partition from window 0; reuse it for the rest of the batch.
|
||||
let m0 = mask_csi_window(windows[0].0.view(), windows[0].1.view(), cfg.mask_ratio, strategy, seed)
|
||||
.map_err(|e| format!("MaeBatch window 0: {e}"))?;
|
||||
if m0.layout.token_dim != td {
|
||||
return Err(format!("MaeBatch window 0: token_dim {} != cfg.token_dim {td}", m0.layout.token_dim));
|
||||
}
|
||||
let n_tokens = m0.layout.n_tokens as i64;
|
||||
let visible_idx = m0.visible_idx.clone();
|
||||
let masked_idx = m0.masked_idx.clone();
|
||||
let masked_pos: Vec<i64> = masked_idx.iter().map(|&x| x as i64).collect();
|
||||
|
||||
let gather = |grid: &Array2<f32>, idx: &[usize]| -> Array2<f32> {
|
||||
let mut out = Array2::<f32>::zeros((idx.len(), td));
|
||||
for (r, &i) in idx.iter().enumerate() {
|
||||
out.row_mut(r).assign(&grid.row(i));
|
||||
}
|
||||
out
|
||||
};
|
||||
|
||||
let mut vis_amp_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
|
||||
let mut vis_ph_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
|
||||
let mut tgt_amp_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
|
||||
let mut tgt_ph_rows: Vec<Array2<f32>> = Vec::with_capacity(windows.len());
|
||||
|
||||
for (i, (amp, ph)) in windows.iter().enumerate() {
|
||||
let layout = super::TokenLayout::from_window(amp.view());
|
||||
if layout.token_dim != td || layout.n_tokens as i64 != n_tokens {
|
||||
return Err(format!(
|
||||
"MaeBatch window {i}: shape {:?} incompatible with batch (n_tokens={n_tokens}, token_dim={td})",
|
||||
amp.shape()
|
||||
));
|
||||
}
|
||||
if amp.shape() != ph.shape() {
|
||||
return Err(format!("MaeBatch window {i}: amplitude/phase shape mismatch"));
|
||||
}
|
||||
let amp_flat = super::TokenLayout::flatten(amp.view());
|
||||
let ph_flat = super::TokenLayout::flatten(ph.view());
|
||||
vis_amp_rows.push(gather(&_flat, &visible_idx));
|
||||
vis_ph_rows.push(gather(&ph_flat, &visible_idx));
|
||||
tgt_amp_rows.push(gather(&_flat, &masked_idx));
|
||||
tgt_ph_rows.push(gather(&ph_flat, &masked_idx));
|
||||
}
|
||||
|
||||
let stack3 = |rows: &[Array2<f32>]| -> Tensor {
|
||||
let views: Vec<_> = rows.iter().map(|r| r.view()).collect();
|
||||
let a3 = ndarray::stack(Axis(0), &views).expect("uniform [k, td] rows stack");
|
||||
let (b, k, d) = a3.dim();
|
||||
let std = a3.as_standard_layout();
|
||||
Tensor::from_slice(std.as_slice().expect("contiguous"))
|
||||
.reshape([b as i64, k as i64, d as i64])
|
||||
.to_device(device)
|
||||
};
|
||||
|
||||
Ok(MaeBatch {
|
||||
vis_amp: stack3(&vis_amp_rows),
|
||||
vis_phase: stack3(&vis_ph_rows),
|
||||
tgt_amp: stack3(&tgt_amp_rows),
|
||||
tgt_phase: stack3(&tgt_ph_rows),
|
||||
masked_pos,
|
||||
n_tokens,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one optimiser step on `batch`. Returns the (scalar) reconstruction loss.
|
||||
pub fn pretrain_step(model: &CsiMae, opt: &mut nn::Optimizer, batch: &MaeBatch) -> f64 {
|
||||
let (pred_amp, pred_ph) = model.forward(&batch.vis_amp, &batch.vis_phase, &batch.masked_pos, true);
|
||||
let loss = CsiMae::reconstruction_loss(
|
||||
&pred_amp,
|
||||
&pred_ph,
|
||||
&batch.tgt_amp,
|
||||
&batch.tgt_phase,
|
||||
model.cfg.phase_loss_weight,
|
||||
);
|
||||
opt.backward_step(&loss);
|
||||
f64::try_from(&loss).unwrap_or(f64::NAN)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::csi_mae::{MaeConfig, MaskStrategy, TokenLayout};
|
||||
use tch::nn::OptimizerConfig;
|
||||
|
||||
/// Deterministic synthetic CSI window `[T, tx, rx, sub]` with structure.
|
||||
fn synth(seed: u64, frames: usize, tx: usize, rx: usize, sub: usize) -> (Array4<f32>, Array4<f32>) {
|
||||
let mut s = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xDEAD_BEEF;
|
||||
let mut next = || {
|
||||
s = s.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
let mut z = s;
|
||||
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
|
||||
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
|
||||
((z ^ (z >> 31)) as f64 / u64::MAX as f64) as f32
|
||||
};
|
||||
let amp = Array4::from_shape_fn((frames, tx, rx, sub), |(f, _, _, c)| {
|
||||
0.5 + 0.4 * ((f as f32 * 0.3 + c as f32 * 0.1).sin()) + 0.05 * next()
|
||||
});
|
||||
let ph = Array4::from_shape_fn((frames, tx, rx, sub), |(f, _, _, c)| {
|
||||
0.3 * ((f as f32 * 0.2 - c as f32 * 0.05).cos()) + 0.05 * next()
|
||||
});
|
||||
(amp, ph)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_decreases_when_overfitting_one_batch() {
|
||||
tch::manual_seed(7);
|
||||
let (frames, tx, rx, sub) = (6usize, 1usize, 1usize, 8usize);
|
||||
let n_tokens = (frames * tx * rx) as i64;
|
||||
let windows: Vec<_> = (0..3).map(|i| synth(i, frames, tx, rx, sub)).collect();
|
||||
|
||||
let mut cfg = MaeConfig::default();
|
||||
cfg.token_dim = sub;
|
||||
cfg.encoder_dim = 32;
|
||||
cfg.decoder_dim = 16;
|
||||
cfg.encoder_depth = 1;
|
||||
cfg.decoder_depth = 1;
|
||||
cfg.mask_ratio = 0.5;
|
||||
cfg.validate().unwrap();
|
||||
|
||||
// sanity: the model's derived n_visible matches mask_csi_window's.
|
||||
let m0 = mask_csi_window(windows[0].0.view(), windows[0].1.view(), cfg.mask_ratio, MaskStrategy::Random, 1).unwrap();
|
||||
assert_eq!(TokenLayout::from_window(windows[0].0.view()).n_tokens as i64, n_tokens);
|
||||
|
||||
let vs = nn::VarStore::new(Device::Cpu);
|
||||
let model = CsiMae::new(&vs.root(), &cfg, n_tokens);
|
||||
assert_eq!(model.n_visible, m0.visible_idx.len() as i64);
|
||||
assert_eq!(model.n_masked, m0.masked_idx.len() as i64);
|
||||
|
||||
let mut opt = nn::Adam::default().build(&vs, 1e-2).unwrap();
|
||||
let batch = MaeBatch::from_windows(&windows, &cfg, 1, MaskStrategy::Random, Device::Cpu).unwrap();
|
||||
|
||||
let l0 = pretrain_step(&model, &mut opt, &batch);
|
||||
let mut last = l0;
|
||||
for _ in 0..60 {
|
||||
last = pretrain_step(&model, &mut opt, &batch);
|
||||
}
|
||||
assert!(l0.is_finite() && last.is_finite(), "loss must be finite (l0={l0}, last={last})");
|
||||
assert!(last < 0.5 * l0, "overfitting one batch should cut loss in half: l0={l0}, last={last}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue