feat(train): MERIDIAN-MAE — information-guided masking (iter 2a, #68)
csi_mae::mask_csi_window now dispatches on MaskStrategy:
- Random: uniform Fisher–Yates (as before).
- InfoGuided: CIG-MAE-style — preferentially mask high-information tokens.
A token's "information" = variance of its amplitude values +
variance of its phase values (token_information()); near-constant
tokens are trivially in-painted so masking them teaches less.
Selection is weighted-without-replacement (Efraimidis–Spirakis:
key_i = u_i^(1/w_i), ranked by ln(u_i)/w_i) — exact, and
deterministic given `seed` (the u_i come from SplitMix64).
Replaces the iteration-1 "InfoGuided falls back to Random with a warning" stub.
+3 unit tests (info-guided skews ≥7.5/10 toward high-info tokens; deterministic
in seed; token_information ≈ 0 for constant tokens). `cargo test -p
wifi-densepose-train --no-default-features` → 121 lib tests pass.
Still to do (iter 2b, next loop tick): the real csi_mae::model (tch encoder/
decoder + reconstruction_loss + pretrain_step), bin/pretrain_mae.rs, a gated
"loss decreases" smoke test.
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
603ad585b6
commit
48c7d03250
|
|
@ -91,6 +91,23 @@ fn shuffle<T>(xs: &mut [T], rng: &mut SplitMix64) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Per-token "information" score used by [`MaskStrategy::InfoGuided`]: the
|
||||
/// (population) variance of the token's amplitude values plus the variance of
|
||||
/// its phase values. Near-constant tokens (e.g. a quiet sub-carrier slice) score
|
||||
/// near zero, so they're less likely to be masked; structured tokens score
|
||||
/// higher. `amp`/`phase` are the flattened `[N, sub]` grids; `i` is the token row.
|
||||
fn token_information(amp: &Array2<f32>, phase: &Array2<f32>, i: usize) -> f64 {
|
||||
let var = |row: ndarray::ArrayView1<f32>| -> f64 {
|
||||
let m = row.len();
|
||||
if m == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let mean = row.iter().map(|&x| x as f64).sum::<f64>() / m as f64;
|
||||
row.iter().map(|&x| { let d = x as f64 - mean; d * d }).sum::<f64>() / m as f64
|
||||
};
|
||||
var(amp.row(i)) + var(phase.row(i))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Masking strategy
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -300,8 +317,11 @@ pub struct MaskedCsi {
|
|||
/// * `amplitude`, `phase` — `[T, tx, rx, sub]`, identical shapes.
|
||||
/// * `mask_ratio` — fraction hidden; clamped so at least one token is visible
|
||||
/// and at least one is masked.
|
||||
/// * `strategy` — [`MaskStrategy::InfoGuided`] currently falls back to
|
||||
/// [`MaskStrategy::Random`] with a warning (lands in iteration 2).
|
||||
/// * `strategy` — [`MaskStrategy::Random`] (uniform) or [`MaskStrategy::InfoGuided`]
|
||||
/// (CIG-MAE-style: preferentially mask high-information tokens, where a token's
|
||||
/// "information" is the variance of its amplitude + phase values — flat tokens
|
||||
/// are trivially in-painted, so masking them teaches less). Both are
|
||||
/// deterministic in `seed`.
|
||||
/// * `seed` — makes the choice reproducible. A good per-sample seed is
|
||||
/// `base_seed ^ (sample_index as u64).wrapping_mul(0x9E3779B97F4A7C15)`.
|
||||
///
|
||||
|
|
@ -332,9 +352,6 @@ pub fn mask_csi_window(
|
|||
reason: format!("must be in (0, 1), got {mask_ratio}"),
|
||||
});
|
||||
}
|
||||
if matches!(strategy, MaskStrategy::InfoGuided) {
|
||||
tracing::warn!("MaskStrategy::InfoGuided not yet implemented — falling back to Random");
|
||||
}
|
||||
|
||||
let layout = TokenLayout::from_window(amplitude);
|
||||
let n = layout.n_tokens;
|
||||
|
|
@ -354,13 +371,42 @@ pub fn mask_csi_window(
|
|||
n_mask = n - 1;
|
||||
}
|
||||
|
||||
// Random permutation of [0, n); first n_mask = masked, rest = visible.
|
||||
let amp_flat = TokenLayout::flatten(amplitude);
|
||||
let phase_flat = TokenLayout::flatten(phase);
|
||||
|
||||
// Pick the n_mask masked token indices according to the strategy.
|
||||
let mut rng = SplitMix64::new(seed);
|
||||
let mut perm: Vec<usize> = (0..n).collect();
|
||||
shuffle(&mut perm, &mut rng);
|
||||
let mut masked_idx: Vec<usize> = perm[..n_mask].to_vec();
|
||||
let mut visible_idx: Vec<usize> = perm[n_mask..].to_vec();
|
||||
let masked_set: Vec<usize> = match strategy {
|
||||
MaskStrategy::Random => {
|
||||
// Uniform: shuffle [0, n) and take the first n_mask.
|
||||
let mut perm: Vec<usize> = (0..n).collect();
|
||||
shuffle(&mut perm, &mut rng);
|
||||
perm[..n_mask].to_vec()
|
||||
}
|
||||
MaskStrategy::InfoGuided => {
|
||||
// Weighted-without-replacement by per-token information (variance of
|
||||
// amplitude+phase). Efraimidis–Spirakis: key_i = u_i^(1/w_i),
|
||||
// pick the n_mask largest keys. Deterministic given `seed`.
|
||||
let mut keyed: Vec<(f64, usize)> = (0..n)
|
||||
.map(|i| {
|
||||
let w = token_information(&_flat, &phase_flat, i) + 1e-6;
|
||||
// u in (0, 1]: avoid 0 so ln() is finite. key = u^(1/w);
|
||||
// rank by ln(key) = ln(u)/w (monotone, avoids tiny powers).
|
||||
let u = ((rng.next_u64() >> 11) as f64 + 1.0) / (((1u64 << 53) as f64) + 1.0);
|
||||
let key = u.ln() / w; // larger (closer to 0) ⇒ more likely chosen
|
||||
(key, i)
|
||||
})
|
||||
.collect();
|
||||
// Largest key = least-negative ln(u)/w ⇒ sort descending by key.
|
||||
keyed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
keyed[..n_mask].iter().map(|&(_, i)| i).collect()
|
||||
}
|
||||
};
|
||||
|
||||
let mut masked_idx = masked_set;
|
||||
masked_idx.sort_unstable();
|
||||
let masked_lookup: std::collections::HashSet<usize> = masked_idx.iter().copied().collect();
|
||||
let mut visible_idx: Vec<usize> = (0..n).filter(|i| !masked_lookup.contains(i)).collect();
|
||||
visible_idx.sort_unstable();
|
||||
|
||||
let mut mask = vec![false; n];
|
||||
|
|
@ -368,8 +414,6 @@ pub fn mask_csi_window(
|
|||
mask[i] = true;
|
||||
}
|
||||
|
||||
let amp_flat = TokenLayout::flatten(amplitude);
|
||||
let phase_flat = TokenLayout::flatten(phase);
|
||||
let gather = |src: &Array2<f32>, idx: &[usize]| -> Array2<f32> {
|
||||
let mut out = Array2::<f32>::zeros((idx.len(), layout.token_dim));
|
||||
for (row, &i) in idx.iter().enumerate() {
|
||||
|
|
@ -568,6 +612,67 @@ mod tests {
|
|||
assert_ne!(m1.masked_idx, m3.masked_idx); // different seed → different partition
|
||||
}
|
||||
|
||||
/// Build a window where the first half of the tokens are (near-)constant
|
||||
/// (low information) and the second half are noisy (high information).
|
||||
/// Returns `(amp, phase, n_tokens, n_low)`.
|
||||
fn split_info_window() -> (ndarray::Array4<f32>, ndarray::Array4<f32>, usize, usize) {
|
||||
// 20 frames, 1x1, 8 sub → 20 tokens; first 10 constant, last 10 noisy.
|
||||
let frames = 20;
|
||||
let sub = 8;
|
||||
let mut rng = SplitMix64::new(999);
|
||||
let amp = ndarray::Array4::<f32>::from_shape_fn((frames, 1, 1, sub), |(f, _, _, _)| {
|
||||
if f < 10 { 1.0 } else { (rng.next_u64() as f32) / (u64::MAX as f32) }
|
||||
});
|
||||
let phase = ndarray::Array4::<f32>::from_shape_fn((frames, 1, 1, sub), |(f, _, _, _)| {
|
||||
if f < 10 { 0.0 } else { (rng.next_u64() as f32) / (u64::MAX as f32) }
|
||||
});
|
||||
(amp, phase, frames, 10)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn info_guided_masking_prefers_high_information_tokens() {
|
||||
let (a, p, _n, n_low) = split_info_window();
|
||||
// Mask 50% (10 of 20). With info-guided selection the noisy tokens
|
||||
// (indices 10..20) should dominate the masked set far beyond chance.
|
||||
let mut high_count_total = 0usize;
|
||||
let trials = 8;
|
||||
for seed in 0..trials {
|
||||
let m = mask_csi_window(a.view(), p.view(), 0.5, MaskStrategy::InfoGuided, seed).unwrap();
|
||||
assert_eq!(m.masked_idx.len(), 10);
|
||||
let high = m.masked_idx.iter().filter(|&&i| i >= n_low).count();
|
||||
high_count_total += high;
|
||||
}
|
||||
// Random would average ~5/10 high per trial; info-guided should be ≥ ~8/10.
|
||||
let avg_high = high_count_total as f64 / trials as f64;
|
||||
assert!(avg_high >= 7.5, "info-guided avg high-info masked = {avg_high}, expected >= 7.5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn info_guided_masking_is_deterministic_in_seed() {
|
||||
let (a, p, _n, _) = split_info_window();
|
||||
let m1 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 5).unwrap();
|
||||
let m2 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 5).unwrap();
|
||||
let m3 = mask_csi_window(a.view(), p.view(), 0.4, MaskStrategy::InfoGuided, 6).unwrap();
|
||||
assert_eq!(m1.masked_idx, m2.masked_idx);
|
||||
assert_eq!(m1.target_amp, m2.target_amp);
|
||||
assert_ne!(m1.masked_idx, m3.masked_idx);
|
||||
// still a valid exhaustive/disjoint partition
|
||||
let n = m1.layout.n_tokens;
|
||||
assert_eq!(m1.visible_idx.len() + m1.masked_idx.len(), n);
|
||||
let mut all: Vec<usize> = m1.visible_idx.iter().chain(m1.masked_idx.iter()).copied().collect();
|
||||
all.sort_unstable();
|
||||
assert_eq!(all, (0..n).collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_information_is_zero_for_constant_and_positive_for_varied() {
|
||||
let (a, p, _n, _) = split_info_window();
|
||||
let amp_flat = TokenLayout::flatten(a.view());
|
||||
let ph_flat = TokenLayout::flatten(p.view());
|
||||
assert!(token_information(&_flat, &ph_flat, 0) < 1e-9); // constant token
|
||||
assert!(token_information(&_flat, &ph_flat, 15) > 1e-6); // noisy token
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn masking_clamps_extreme_ratios() {
|
||||
let (a, p) = synth_window(4, 1, 1, 8, 9);
|
||||
|
|
|
|||
Loading…
Reference in New Issue