From 48c7d03250b084da4662ff43f58e77329728e541 Mon Sep 17 00:00:00 2001 From: ruv Date: Mon, 11 May 2026 12:57:42 -0400 Subject: [PATCH] =?UTF-8?q?feat(train):=20MERIDIAN-MAE=20=E2=80=94=20infor?= =?UTF-8?q?mation-guided=20masking=20(iter=202a,=20#68)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit csi_mae::mask_csi_window now dispatches on MaskStrategy: - Random: uniform Fisher–Yates (as before). - InfoGuided: CIG-MAE-style — preferentially mask high-information tokens. A token's "information" = variance of its amplitude values + variance of its phase values (token_information()); near-constant tokens are trivially in-painted so masking them teaches less. Selection is weighted-without-replacement (Efraimidis–Spirakis: key_i = u_i^(1/w_i), ranked by ln(u_i)/w_i) — exact, and deterministic given `seed` (the u_i come from SplitMix64). Replaces the iteration-1 "InfoGuided falls back to Random with a warning" stub. +3 unit tests (info-guided skews ≥7.5/10 toward high-info tokens; deterministic in seed; token_information ≈ 0 for constant tokens). `cargo test -p wifi-densepose-train --no-default-features` → 121 lib tests pass. Still to do (iter 2b, next loop tick): the real csi_mae::model (tch encoder/ decoder + reconstruction_loss + pretrain_step), bin/pretrain_mae.rs, a gated "loss decreases" smoke test. Co-Authored-By: claude-flow --- v2/crates/wifi-densepose-train/src/csi_mae.rs | 129 ++++++++++++++++-- 1 file changed, 117 insertions(+), 12 deletions(-) diff --git a/v2/crates/wifi-densepose-train/src/csi_mae.rs b/v2/crates/wifi-densepose-train/src/csi_mae.rs index a4661776..254ca88d 100644 --- a/v2/crates/wifi-densepose-train/src/csi_mae.rs +++ b/v2/crates/wifi-densepose-train/src/csi_mae.rs @@ -91,6 +91,23 @@ fn shuffle(xs: &mut [T], rng: &mut SplitMix64) { } } +/// Per-token "information" score used by [`MaskStrategy::InfoGuided`]: the +/// (population) variance of the token's amplitude values plus the variance of +/// its phase values. Near-constant tokens (e.g. a quiet sub-carrier slice) score +/// near zero, so they're less likely to be masked; structured tokens score +/// higher. `amp`/`phase` are the flattened `[N, sub]` grids; `i` is the token row. +fn token_information(amp: &Array2, phase: &Array2, i: usize) -> f64 { + let var = |row: ndarray::ArrayView1| -> f64 { + let m = row.len(); + if m == 0 { + return 0.0; + } + let mean = row.iter().map(|&x| x as f64).sum::() / m as f64; + row.iter().map(|&x| { let d = x as f64 - mean; d * d }).sum::() / m as f64 + }; + var(amp.row(i)) + var(phase.row(i)) +} + // --------------------------------------------------------------------------- // Masking strategy // --------------------------------------------------------------------------- @@ -300,8 +317,11 @@ pub struct MaskedCsi { /// * `amplitude`, `phase` — `[T, tx, rx, sub]`, identical shapes. /// * `mask_ratio` — fraction hidden; clamped so at least one token is visible /// and at least one is masked. -/// * `strategy` — [`MaskStrategy::InfoGuided`] currently falls back to -/// [`MaskStrategy::Random`] with a warning (lands in iteration 2). +/// * `strategy` — [`MaskStrategy::Random`] (uniform) or [`MaskStrategy::InfoGuided`] +/// (CIG-MAE-style: preferentially mask high-information tokens, where a token's +/// "information" is the variance of its amplitude + phase values — flat tokens +/// are trivially in-painted, so masking them teaches less). Both are +/// deterministic in `seed`. /// * `seed` — makes the choice reproducible. A good per-sample seed is /// `base_seed ^ (sample_index as u64).wrapping_mul(0x9E3779B97F4A7C15)`. /// @@ -332,9 +352,6 @@ pub fn mask_csi_window( reason: format!("must be in (0, 1), got {mask_ratio}"), }); } - if matches!(strategy, MaskStrategy::InfoGuided) { - tracing::warn!("MaskStrategy::InfoGuided not yet implemented — falling back to Random"); - } let layout = TokenLayout::from_window(amplitude); let n = layout.n_tokens; @@ -354,13 +371,42 @@ pub fn mask_csi_window( n_mask = n - 1; } - // Random permutation of [0, n); first n_mask = masked, rest = visible. + let amp_flat = TokenLayout::flatten(amplitude); + let phase_flat = TokenLayout::flatten(phase); + + // Pick the n_mask masked token indices according to the strategy. let mut rng = SplitMix64::new(seed); - let mut perm: Vec = (0..n).collect(); - shuffle(&mut perm, &mut rng); - let mut masked_idx: Vec = perm[..n_mask].to_vec(); - let mut visible_idx: Vec = perm[n_mask..].to_vec(); + let masked_set: Vec = match strategy { + MaskStrategy::Random => { + // Uniform: shuffle [0, n) and take the first n_mask. + let mut perm: Vec = (0..n).collect(); + shuffle(&mut perm, &mut rng); + perm[..n_mask].to_vec() + } + MaskStrategy::InfoGuided => { + // Weighted-without-replacement by per-token information (variance of + // amplitude+phase). Efraimidis–Spirakis: key_i = u_i^(1/w_i), + // pick the n_mask largest keys. Deterministic given `seed`. + let mut keyed: Vec<(f64, usize)> = (0..n) + .map(|i| { + let w = token_information(&_flat, &phase_flat, i) + 1e-6; + // u in (0, 1]: avoid 0 so ln() is finite. key = u^(1/w); + // rank by ln(key) = ln(u)/w (monotone, avoids tiny powers). + let u = ((rng.next_u64() >> 11) as f64 + 1.0) / (((1u64 << 53) as f64) + 1.0); + let key = u.ln() / w; // larger (closer to 0) ⇒ more likely chosen + (key, i) + }) + .collect(); + // Largest key = least-negative ln(u)/w ⇒ sort descending by key. + keyed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + keyed[..n_mask].iter().map(|&(_, i)| i).collect() + } + }; + + let mut masked_idx = masked_set; masked_idx.sort_unstable(); + let masked_lookup: std::collections::HashSet = masked_idx.iter().copied().collect(); + let mut visible_idx: Vec = (0..n).filter(|i| !masked_lookup.contains(i)).collect(); visible_idx.sort_unstable(); let mut mask = vec![false; n]; @@ -368,8 +414,6 @@ pub fn mask_csi_window( mask[i] = true; } - let amp_flat = TokenLayout::flatten(amplitude); - let phase_flat = TokenLayout::flatten(phase); let gather = |src: &Array2, idx: &[usize]| -> Array2 { let mut out = Array2::::zeros((idx.len(), layout.token_dim)); for (row, &i) in idx.iter().enumerate() { @@ -568,6 +612,67 @@ mod tests { assert_ne!(m1.masked_idx, m3.masked_idx); // different seed → different partition } + /// Build a window where the first half of the tokens are (near-)constant + /// (low information) and the second half are noisy (high information). + /// Returns `(amp, phase, n_tokens, n_low)`. + fn split_info_window() -> (ndarray::Array4, ndarray::Array4, usize, usize) { + // 20 frames, 1x1, 8 sub → 20 tokens; first 10 constant, last 10 noisy. + let frames = 20; + let sub = 8; + let mut rng = SplitMix64::new(999); + let amp = ndarray::Array4::::from_shape_fn((frames, 1, 1, sub), |(f, _, _, _)| { + if f < 10 { 1.0 } else { (rng.next_u64() as f32) / (u64::MAX as f32) } + }); + let phase = ndarray::Array4::::from_shape_fn((frames, 1, 1, sub), |(f, _, _, _)| { + if f < 10 { 0.0 } else { (rng.next_u64() as f32) / (u64::MAX as f32) } + }); + (amp, phase, frames, 10) + } + + #[test] + fn info_guided_masking_prefers_high_information_tokens() { + let (a, p, _n, n_low) = split_info_window(); + // Mask 50% (10 of 20). With info-guided selection the noisy tokens + // (indices 10..20) should dominate the masked set far beyond chance. + let mut high_count_total = 0usize; + let trials = 8; + for seed in 0..trials { + let m = mask_csi_window(a.view(), p.view(), 0.5, MaskStrategy::InfoGuided, seed).unwrap(); + assert_eq!(m.masked_idx.len(), 10); + let high = m.masked_idx.iter().filter(|&&i| i >= n_low).count(); + high_count_total += high; + } + // Random would average ~5/10 high per trial; info-guided should be ≥ ~8/10. + let avg_high = high_count_total as f64 / trials as f64; + assert!(avg_high >= 7.5, "info-guided avg high-info masked = {avg_high}, expected >= 7.5"); + } + + #[test] + fn info_guided_masking_is_deterministic_in_seed() { + let (a, p, _n, _) = split_info_window(); + let m1 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 5).unwrap(); + let m2 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 5).unwrap(); + let m3 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 6).unwrap(); + assert_eq!(m1.masked_idx, m2.masked_idx); + assert_eq!(m1.target_amp, m2.target_amp); + assert_ne!(m1.masked_idx, m3.masked_idx); + // still a valid exhaustive/disjoint partition + let n = m1.layout.n_tokens; + assert_eq!(m1.visible_idx.len() + m1.masked_idx.len(), n); + let mut all: Vec = m1.visible_idx.iter().chain(m1.masked_idx.iter()).copied().collect(); + all.sort_unstable(); + assert_eq!(all, (0..n).collect::>()); + } + + #[test] + fn token_information_is_zero_for_constant_and_positive_for_varied() { + let (a, p, _n, _) = split_info_window(); + let amp_flat = TokenLayout::flatten(a.view()); + let ph_flat = TokenLayout::flatten(p.view()); + assert!(token_information(&_flat, &ph_flat, 0) < 1e-9); // constant token + assert!(token_information(&_flat, &ph_flat, 15) > 1e-6); // noisy token + } + #[test] fn masking_clamps_extreme_ratios() { let (a, p) = synth_window(4, 1, 1, 8, 9);