feat(train): MERIDIAN-MAE — information-guided masking (iter 2a, #68)

csi_mae::mask_csi_window now dispatches on MaskStrategy:
  - Random:      uniform Fisher–Yates (as before).
  - InfoGuided:  CIG-MAE-style — preferentially mask high-information tokens.
                 A token's "information" = variance of its amplitude values +
                 variance of its phase values (token_information()); near-constant
                 tokens are trivially in-painted so masking them teaches less.
                 Selection is weighted-without-replacement (Efraimidis–Spirakis:
                 key_i = u_i^(1/w_i), ranked by ln(u_i)/w_i) — exact, and
                 deterministic given `seed` (the u_i come from SplitMix64).

Replaces the iteration-1 "InfoGuided falls back to Random with a warning" stub.
+3 unit tests (info-guided skews ≥7.5/10 toward high-info tokens; deterministic
in seed; token_information ≈ 0 for constant tokens). `cargo test -p
wifi-densepose-train --no-default-features` → 121 lib tests pass.

Still to do (iter 2b, next loop tick): the real csi_mae::model (tch encoder/
decoder + reconstruction_loss + pretrain_step), bin/pretrain_mae.rs, a gated
"loss decreases" smoke test.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-05-11 12:57:42 -04:00
parent 603ad585b6
commit 48c7d03250
1 changed files with 117 additions and 12 deletions

View File

@ -91,6 +91,23 @@ fn shuffle<T>(xs: &mut [T], rng: &mut SplitMix64) {
}
}
/// Per-token "information" score used by [`MaskStrategy::InfoGuided`]: the
/// (population) variance of the token's amplitude values plus the variance of
/// its phase values. Near-constant tokens (e.g. a quiet sub-carrier slice) score
/// near zero, so they're less likely to be masked; structured tokens score
/// higher. `amp`/`phase` are the flattened `[N, sub]` grids; `i` is the token row.
fn token_information(amp: &Array2<f32>, phase: &Array2<f32>, i: usize) -> f64 {
let var = |row: ndarray::ArrayView1<f32>| -> f64 {
let m = row.len();
if m == 0 {
return 0.0;
}
let mean = row.iter().map(|&x| x as f64).sum::<f64>() / m as f64;
row.iter().map(|&x| { let d = x as f64 - mean; d * d }).sum::<f64>() / m as f64
};
var(amp.row(i)) + var(phase.row(i))
}
// ---------------------------------------------------------------------------
// Masking strategy
// ---------------------------------------------------------------------------
@ -300,8 +317,11 @@ pub struct MaskedCsi {
/// * `amplitude`, `phase` — `[T, tx, rx, sub]`, identical shapes.
/// * `mask_ratio` — fraction hidden; clamped so at least one token is visible
/// and at least one is masked.
/// * `strategy` — [`MaskStrategy::InfoGuided`] currently falls back to
/// [`MaskStrategy::Random`] with a warning (lands in iteration 2).
/// * `strategy` — [`MaskStrategy::Random`] (uniform) or [`MaskStrategy::InfoGuided`]
/// (CIG-MAE-style: preferentially mask high-information tokens, where a token's
/// "information" is the variance of its amplitude + phase values — flat tokens
/// are trivially in-painted, so masking them teaches less). Both are
/// deterministic in `seed`.
/// * `seed` — makes the choice reproducible. A good per-sample seed is
/// `base_seed ^ (sample_index as u64).wrapping_mul(0x9E3779B97F4A7C15)`.
///
@ -332,9 +352,6 @@ pub fn mask_csi_window(
reason: format!("must be in (0, 1), got {mask_ratio}"),
});
}
if matches!(strategy, MaskStrategy::InfoGuided) {
tracing::warn!("MaskStrategy::InfoGuided not yet implemented — falling back to Random");
}
let layout = TokenLayout::from_window(amplitude);
let n = layout.n_tokens;
@ -354,13 +371,42 @@ pub fn mask_csi_window(
n_mask = n - 1;
}
// Random permutation of [0, n); first n_mask = masked, rest = visible.
let amp_flat = TokenLayout::flatten(amplitude);
let phase_flat = TokenLayout::flatten(phase);
// Pick the n_mask masked token indices according to the strategy.
let mut rng = SplitMix64::new(seed);
let mut perm: Vec<usize> = (0..n).collect();
shuffle(&mut perm, &mut rng);
let mut masked_idx: Vec<usize> = perm[..n_mask].to_vec();
let mut visible_idx: Vec<usize> = perm[n_mask..].to_vec();
let masked_set: Vec<usize> = match strategy {
MaskStrategy::Random => {
// Uniform: shuffle [0, n) and take the first n_mask.
let mut perm: Vec<usize> = (0..n).collect();
shuffle(&mut perm, &mut rng);
perm[..n_mask].to_vec()
}
MaskStrategy::InfoGuided => {
// Weighted-without-replacement by per-token information (variance of
// amplitude+phase). EfraimidisSpirakis: key_i = u_i^(1/w_i),
// pick the n_mask largest keys. Deterministic given `seed`.
let mut keyed: Vec<(f64, usize)> = (0..n)
.map(|i| {
let w = token_information(&amp_flat, &phase_flat, i) + 1e-6;
// u in (0, 1]: avoid 0 so ln() is finite. key = u^(1/w);
// rank by ln(key) = ln(u)/w (monotone, avoids tiny powers).
let u = ((rng.next_u64() >> 11) as f64 + 1.0) / (((1u64 << 53) as f64) + 1.0);
let key = u.ln() / w; // larger (closer to 0) ⇒ more likely chosen
(key, i)
})
.collect();
// Largest key = least-negative ln(u)/w ⇒ sort descending by key.
keyed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
keyed[..n_mask].iter().map(|&(_, i)| i).collect()
}
};
let mut masked_idx = masked_set;
masked_idx.sort_unstable();
let masked_lookup: std::collections::HashSet<usize> = masked_idx.iter().copied().collect();
let mut visible_idx: Vec<usize> = (0..n).filter(|i| !masked_lookup.contains(i)).collect();
visible_idx.sort_unstable();
let mut mask = vec![false; n];
@ -368,8 +414,6 @@ pub fn mask_csi_window(
mask[i] = true;
}
let amp_flat = TokenLayout::flatten(amplitude);
let phase_flat = TokenLayout::flatten(phase);
let gather = |src: &Array2<f32>, idx: &[usize]| -> Array2<f32> {
let mut out = Array2::<f32>::zeros((idx.len(), layout.token_dim));
for (row, &i) in idx.iter().enumerate() {
@ -568,6 +612,67 @@ mod tests {
assert_ne!(m1.masked_idx, m3.masked_idx); // different seed → different partition
}
/// Build a window where the first half of the tokens are (near-)constant
/// (low information) and the second half are noisy (high information).
/// Returns `(amp, phase, n_tokens, n_low)`.
fn split_info_window() -> (ndarray::Array4<f32>, ndarray::Array4<f32>, usize, usize) {
// 20 frames, 1x1, 8 sub → 20 tokens; first 10 constant, last 10 noisy.
let frames = 20;
let sub = 8;
let mut rng = SplitMix64::new(999);
let amp = ndarray::Array4::<f32>::from_shape_fn((frames, 1, 1, sub), |(f, _, _, _)| {
if f < 10 { 1.0 } else { (rng.next_u64() as f32) / (u64::MAX as f32) }
});
let phase = ndarray::Array4::<f32>::from_shape_fn((frames, 1, 1, sub), |(f, _, _, _)| {
if f < 10 { 0.0 } else { (rng.next_u64() as f32) / (u64::MAX as f32) }
});
(amp, phase, frames, 10)
}
#[test]
fn info_guided_masking_prefers_high_information_tokens() {
let (a, p, _n, n_low) = split_info_window();
// Mask 50% (10 of 20). With info-guided selection the noisy tokens
// (indices 10..20) should dominate the masked set far beyond chance.
let mut high_count_total = 0usize;
let trials = 8;
for seed in 0..trials {
let m = mask_csi_window(a.view(), p.view(), 0.5, MaskStrategy::InfoGuided, seed).unwrap();
assert_eq!(m.masked_idx.len(), 10);
let high = m.masked_idx.iter().filter(|&&i| i >= n_low).count();
high_count_total += high;
}
// Random would average ~5/10 high per trial; info-guided should be ≥ ~8/10.
let avg_high = high_count_total as f64 / trials as f64;
assert!(avg_high >= 7.5, "info-guided avg high-info masked = {avg_high}, expected >= 7.5");
}
#[test]
fn info_guided_masking_is_deterministic_in_seed() {
let (a, p, _n, _) = split_info_window();
let m1 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 5).unwrap();
let m2 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 5).unwrap();
let m3 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 6).unwrap();
assert_eq!(m1.masked_idx, m2.masked_idx);
assert_eq!(m1.target_amp, m2.target_amp);
assert_ne!(m1.masked_idx, m3.masked_idx);
// still a valid exhaustive/disjoint partition
let n = m1.layout.n_tokens;
assert_eq!(m1.visible_idx.len() + m1.masked_idx.len(), n);
let mut all: Vec<usize> = m1.visible_idx.iter().chain(m1.masked_idx.iter()).copied().collect();
all.sort_unstable();
assert_eq!(all, (0..n).collect::<Vec<_>>());
}
#[test]
fn token_information_is_zero_for_constant_and_positive_for_varied() {
let (a, p, _n, _) = split_info_window();
let amp_flat = TokenLayout::flatten(a.view());
let ph_flat = TokenLayout::flatten(p.view());
assert!(token_information(&amp_flat, &ph_flat, 0) < 1e-9); // constant token
assert!(token_information(&amp_flat, &ph_flat, 15) > 1e-6); // noisy token
}
#[test]
fn masking_clamps_extreme_ratios() {
let (a, p) = synth_window(4, 1, 1, 8, 9);