fix(train): leak-free subject-disjoint split + synthetic-val disclosure (ADR-155 §Tier-1.2)

MM-Fi windows are stride-1 (~99% overlap), so an index-level split leaks; and
bin/train.rs validated real training against a SYNTHETIC val set, making any
printed PCK meaningless on two counts.

- MmFiDataset::subject_disjoint_split partitions whole subjects -> the two views
  share no subject and no window (leak-free by construction, deterministic per
  seed). assert_split_leak_free verifies subject- AND window-disjointness and is
  called inside the split so a leaky split is never handed out.
- bin/train.rs now prefers the real split; the synthetic path is a labelled
  run_smoke_test ("[SMOKE-TEST] DO NOT REPORT") reachable only as a fallback.
- New DatasetError::InvalidSplit.

Tests prove disjointness, determinism, single-subject/bad-fraction rejection,
and that the validator catches an injected subject leak.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-06-11 19:56:57 -04:00
parent 50b657459f
commit 2a2a2c5b06
3 changed files with 415 additions and 18 deletions

View File

@ -25,7 +25,7 @@
use clap::Parser;
use std::path::PathBuf;
use tracing::{error, info};
use tracing::{error, info, warn};
use wifi_densepose_train::{
config::TrainingConfig,
@ -170,8 +170,13 @@ fn main() {
train_ds.len(),
val_ds.len()
);
warn!(
"[SMOKE-TEST ONLY] --dry-run trains and validates on SYNTHETIC data. \
Any val_pck/val_oks is a pipeline smoke-test and MUST NOT be reported \
as accuracy (ADR-155 §Tier-1.2)."
);
run_training(config, &train_ds, &val_ds);
run_smoke_test(config, &train_ds, &val_ds);
} else {
info!("Loading MM-Fi dataset from {}", data_dir.display());
@ -199,22 +204,47 @@ fn main() {
info!("Dataset: {} samples", train_ds.len());
// Use a small synthetic validation set when running without a split.
let val_syn_cfg = SyntheticConfig {
num_subcarriers: config.num_subcarriers,
num_antennas_tx: config.num_antennas_tx,
num_antennas_rx: config.num_antennas_rx,
window_frames: config.window_frames,
num_keypoints: config.num_keypoints,
signal_frequency_hz: 2.4e9,
};
let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg);
info!(
"Using synthetic validation set ({} samples) for pipeline verification",
val_ds.len()
);
run_training(config, &train_ds, &val_ds);
// ADR-155 §Tier-1.2: prefer a REAL, leak-free, subject-disjoint split so
// any reported PCK/OKS is honest. MM-Fi windows are stride-1 (≈99%
// overlap), so an index-level split would leak; a synthetic val set
// makes the metric meaningless. Split at the subject level when the
// dataset has ≥2 subjects.
match train_ds.subject_disjoint_split(0.2, config.seed) {
Ok((train_view, val_view)) => {
info!(
"Leak-free subject-disjoint split: {} train windows (subjects {:?}) / \
{} val windows (subjects {:?})",
train_view.len(),
train_view.subjects(),
val_view.len(),
val_view.subjects(),
);
run_training(config, &train_view, &val_view);
}
Err(e) => {
// Cannot form a real split (e.g. a single subject). Fall back to
// a SYNTHETIC val set, but make it UNMISTAKABLE that this is a
// smoke-test only — its metric is NOT a reportable number.
warn!("Cannot build a leak-free subject-disjoint split: {e}");
warn!(
"[SMOKE-TEST ONLY] Falling back to a SYNTHETIC validation set. \
ANY val_pck/val_oks printed below is a PIPELINE SMOKE-TEST on \
synthetic data and MUST NOT be reported or claimed as accuracy \
(ADR-155 §Tier-1.2). Provide a multi-subject dataset for a real \
measurement."
);
let val_syn_cfg = SyntheticConfig {
num_subcarriers: config.num_subcarriers,
num_antennas_tx: config.num_antennas_tx,
num_antennas_rx: config.num_antennas_rx,
window_frames: config.window_frames,
num_keypoints: config.num_keypoints,
signal_frequency_hz: 2.4e9,
};
let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg);
run_smoke_test(config, &train_ds, &val_ds);
}
}
}
}
@ -265,6 +295,55 @@ fn run_training(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn
info!("Config and dataset infrastructure: OK");
}
// ---------------------------------------------------------------------------
// run_smoke_test — synthetic-validation path (NOT a reportable metric)
// ---------------------------------------------------------------------------
//
// ADR-155 §Tier-1.2: identical to `run_training` but every metric it surfaces
// is prefixed/labelled as a SMOKE-TEST so a synthetic-val PCK can never be
// mistaken for a measured accuracy number.
#[cfg(feature = "tch-backend")]
fn run_smoke_test(config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) {
use wifi_densepose_train::trainer::Trainer;
warn!(
"[SMOKE-TEST] Starting SYNTHETIC-validation run: {} train / {} val samples. \
Reported PCK/OKS below are NOT measurements.",
train_ds.len(),
val_ds.len()
);
let mut trainer = Trainer::new(config);
match trainer.train(train_ds, val_ds) {
Ok(result) => {
warn!("[SMOKE-TEST] Pipeline ran end-to-end (no crash). Metrics are synthetic:");
warn!(
"[SMOKE-TEST] (DO NOT REPORT) best_pck@0.2={:.4} @ epoch {} — synthetic val",
result.best_pck, result.best_epoch
);
info!(
"[SMOKE-TEST] Final train loss: {:.6}",
result.final_train_loss
);
}
Err(e) => {
error!("[SMOKE-TEST] Pipeline failed: {e}");
std::process::exit(1);
}
}
}
#[cfg(not(feature = "tch-backend"))]
fn run_smoke_test(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) {
warn!(
"[SMOKE-TEST] Pipeline verification only: {} train / {} synthetic-val samples loaded. \
No metric is produced; build with --features tch-backend to run the pipeline.",
train_ds.len(),
val_ds.len()
);
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------

View File

@ -519,6 +519,233 @@ impl CsiDataset for MmFiDataset {
}
}
// ---------------------------------------------------------------------------
// Leak-free train/test split (ADR-155 §Tier-1.2)
// ---------------------------------------------------------------------------
//
// Why this exists: MM-Fi windows are extracted with stride 1
// (`MmFiEntry::num_windows` = `num_frames window_frames + 1`), so adjacent
// windows overlap by `window_frames 1` frames. A naive index-level random
// split therefore puts near-identical windows on both sides of the boundary —
// up to ~99% information leakage — and any PCK it reports is meaningless. The
// leak-free discipline (mirrored from `occupancy_bench::EvalSplit`) is to split
// at the **subject** level: a subject's clips (and thus all of its windows) go
// entirely to train or entirely to test. Disjoint subjects ⇒ no shared window,
// and no temporally-adjacent window can straddle the boundary.
/// A borrowed, read-only view over a contiguous-by-subject subset of a parent
/// [`MmFiDataset`]'s windows. Implements [`CsiDataset`] so it can be passed
/// straight to the trainer. Produced only by
/// [`MmFiDataset::subject_disjoint_split`], which guarantees the two returned
/// views are subject- and window-disjoint.
pub struct MmFiSplitView<'a> {
parent: &'a MmFiDataset,
/// Global parent window indices owned by this view (sorted, unique).
global_indices: Vec<usize>,
/// Subject ids present in this view (for leak validation / reporting).
subjects: std::collections::BTreeSet<u32>,
name: &'static str,
}
impl<'a> MmFiSplitView<'a> {
/// Subject ids covered by this view.
pub fn subjects(&self) -> &std::collections::BTreeSet<u32> {
&self.subjects
}
/// Global parent window indices owned by this view.
pub fn global_indices(&self) -> &[usize] {
&self.global_indices
}
}
impl<'a> CsiDataset for MmFiSplitView<'a> {
fn len(&self) -> usize {
self.global_indices.len()
}
fn get(&self, idx: usize) -> Result<CsiSample, DatasetError> {
let g = *self
.global_indices
.get(idx)
.ok_or(DatasetError::IndexOutOfBounds {
idx,
len: self.global_indices.len(),
})?;
self.parent.get(g)
}
fn name(&self) -> &str {
self.name
}
}
impl MmFiDataset {
/// All subject ids present in the scanned dataset (sorted, unique).
pub fn subjects(&self) -> Vec<u32> {
let set: std::collections::BTreeSet<u32> =
self.entries.iter().map(|e| e.subject_id).collect();
set.into_iter().collect()
}
/// Split into **subject-disjoint** train / test views (ADR-155 §Tier-1.2).
///
/// Subjects are assigned wholesale to one side: roughly
/// `test_subject_fraction` of the distinct subjects (at least one, and at
/// least one left for train) go to the test view, the rest to train. Because
/// every window of a subject travels with that subject, the two views share
/// **no subject and no window** — the split is leak-free by construction.
///
/// Assignment is deterministic for a given `seed` (seeded Fisher-Yates over
/// the sorted subject list), so runs are reproducible.
///
/// # Errors
/// [`DatasetError::InvalidSplit`] when there are fewer than 2 subjects, when
/// `test_subject_fraction` is not in `(0, 1)`, or when either side would be
/// empty.
pub fn subject_disjoint_split(
&self,
test_subject_fraction: f64,
seed: u64,
) -> Result<(MmFiSplitView<'_>, MmFiSplitView<'_>), DatasetError> {
if !(test_subject_fraction > 0.0 && test_subject_fraction < 1.0) {
return Err(DatasetError::InvalidSplit(format!(
"test_subject_fraction must be in (0,1), got {test_subject_fraction}"
)));
}
let mut subjects = self.subjects();
if subjects.len() < 2 {
return Err(DatasetError::InvalidSplit(format!(
"need >= 2 distinct subjects for a subject-disjoint split, got {}",
subjects.len()
)));
}
// Deterministic shuffle of the sorted subject list.
xorshift_shuffle_u32(&mut subjects, seed);
let n_test = ((subjects.len() as f64 * test_subject_fraction).round() as usize)
.clamp(1, subjects.len() - 1);
let test_subjects: std::collections::BTreeSet<u32> =
subjects[..n_test].iter().copied().collect();
let train_subjects: std::collections::BTreeSet<u32> =
subjects[n_test..].iter().copied().collect();
// Partition global window indices by the owning entry's subject.
let mut train_idx = Vec::new();
let mut test_idx = Vec::new();
for (entry_i, entry) in self.entries.iter().enumerate() {
let start = self.cumulative[entry_i];
let end = self.cumulative[entry_i + 1];
if test_subjects.contains(&entry.subject_id) {
test_idx.extend(start..end);
} else {
train_idx.extend(start..end);
}
}
if train_idx.is_empty() || test_idx.is_empty() {
return Err(DatasetError::InvalidSplit(
"split produced an empty partition (a subject set has no windows)".into(),
));
}
let train = MmFiSplitView {
parent: self,
global_indices: train_idx,
subjects: train_subjects,
name: "MmFiDataset[train]",
};
let test = MmFiSplitView {
parent: self,
global_indices: test_idx,
subjects: test_subjects,
name: "MmFiDataset[test]",
};
// Self-check: never hand out a leaky split.
assert_split_leak_free(&train, &test)?;
Ok((train, test))
}
}
/// Verify a train/test split is leak-free: subject-disjoint **and**
/// window-disjoint, with both sides non-empty (ADR-155 §Tier-1.2).
///
/// Returns [`DatasetError::InvalidSplit`] describing the first violation found.
pub fn assert_split_leak_free(
train: &MmFiSplitView<'_>,
test: &MmFiSplitView<'_>,
) -> Result<(), DatasetError> {
if train.global_indices.is_empty() || test.global_indices.is_empty() {
return Err(DatasetError::InvalidSplit("a partition is empty".into()));
}
// Subject disjointness.
if let Some(shared) = train.subjects.intersection(&test.subjects).next() {
return Err(DatasetError::InvalidSplit(format!(
"subject {shared} appears in both train and test (subject leakage)"
)));
}
// Window disjointness (guards against any index bug in the partitioner).
let train_set: std::collections::BTreeSet<usize> =
train.global_indices.iter().copied().collect();
if let Some(shared) = test.global_indices.iter().find(|i| train_set.contains(i)) {
return Err(DatasetError::InvalidSplit(format!(
"window {shared} appears in both train and test (window leakage)"
)));
}
Ok(())
}
#[cfg(test)]
impl MmFiDataset {
/// Build a metadata-only `MmFiDataset` for split tests: fabricated entries
/// with given `(subject_id, action_id, num_frames)` and a window size. No
/// files are touched — only the split / leak-check logic (which reads
/// `subject_id` + window counts, never `get()`) is exercised.
fn from_entries_for_test(clips: &[(u32, u32, usize)], window_frames: usize) -> Self {
let entries: Vec<MmFiEntry> = clips
.iter()
.map(|&(subject_id, action_id, num_frames)| MmFiEntry {
subject_id,
action_id,
amp_path: PathBuf::from("/nonexistent/wifi_csi.npy"),
phase_path: PathBuf::from("/nonexistent/wifi_csi_phase.npy"),
kp_path: PathBuf::from("/nonexistent/gt_keypoints.npy"),
num_frames,
window_frames,
})
.collect();
let mut cumulative = vec![0usize; entries.len() + 1];
for (i, e) in entries.iter().enumerate() {
cumulative[i + 1] = cumulative[i] + e.num_windows();
}
MmFiDataset {
entries,
cumulative,
window_frames,
target_subcarriers: 56,
num_keypoints: 17,
root: PathBuf::from("/nonexistent"),
}
}
}
/// Deterministic Fisher-Yates shuffle of a `u32` slice (seeded Xorshift64).
fn xorshift_shuffle_u32(items: &mut [u32], seed: u64) {
let n = items.len();
if n <= 1 {
return;
}
let mut state = if seed == 0 { 0x853c49e6748fea9b } else { seed };
for i in (1..n).rev() {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let j = (state % (i as u64 + 1)) as usize;
items.swap(i, j);
}
}
// ---------------------------------------------------------------------------
// CompressedCsiBuffer
// ---------------------------------------------------------------------------
@ -1019,6 +1246,91 @@ mod tests {
assert_abs_diff_eq!(s0a.keypoints[[5, 0]], s0b.keypoints[[5, 0]], epsilon = 1e-7);
}
// ----- Leak-free subject-disjoint split (ADR-155 §Tier-1.2) -----------
fn split_fixture() -> MmFiDataset {
// 6 subjects × 2 clips each, 50 frames per clip, window 10 ⇒ 41
// overlapping windows per clip. A leaky index-split would put adjacent
// (near-identical) windows on both sides; the subject split cannot.
let mut clips = Vec::new();
for s in 1..=6u32 {
for a in 1..=2u32 {
clips.push((s, a, 50usize));
}
}
MmFiDataset::from_entries_for_test(&clips, 10)
}
#[test]
fn subject_split_is_subject_and_window_disjoint() {
let ds = split_fixture();
let (train, test) = ds.subject_disjoint_split(0.34, 42).unwrap();
// No subject is shared.
assert!(train.subjects().is_disjoint(test.subjects()));
// assert_split_leak_free agrees (subject + window disjoint, non-empty).
assert_split_leak_free(&train, &test).expect("split must be leak-free");
// No global window index is shared.
let train_set: std::collections::BTreeSet<usize> =
train.global_indices().iter().copied().collect();
for g in test.global_indices() {
assert!(!train_set.contains(g), "window {g} leaked across the split");
}
// Every window is accounted for exactly once (partition, not sample).
assert_eq!(train.len() + test.len(), ds.len());
assert!(train.len() > 0 && test.len() > 0);
}
#[test]
fn subject_split_is_deterministic_for_seed() {
let ds = split_fixture();
let (tr1, te1) = ds.subject_disjoint_split(0.34, 7).unwrap();
let (tr2, te2) = ds.subject_disjoint_split(0.34, 7).unwrap();
assert_eq!(tr1.subjects(), tr2.subjects());
assert_eq!(te1.subjects(), te2.subjects());
}
#[test]
fn subject_split_rejects_single_subject() {
// Only one subject ⇒ a subject-disjoint split is impossible.
let ds = MmFiDataset::from_entries_for_test(&[(1, 1, 50), (1, 2, 50)], 10);
assert!(matches!(
ds.subject_disjoint_split(0.3, 1),
Err(DatasetError::InvalidSplit(_))
));
}
#[test]
fn subject_split_rejects_bad_fraction() {
let ds = split_fixture();
assert!(ds.subject_disjoint_split(0.0, 1).is_err());
assert!(ds.subject_disjoint_split(1.0, 1).is_err());
}
#[test]
fn assert_leak_free_detects_injected_subject_leak() {
// Build two views that deliberately share subject 3 and prove the
// validator catches it (a guard against future partitioner bugs).
let ds = split_fixture();
let (train, _test) = ds.subject_disjoint_split(0.34, 42).unwrap();
// Fabricate a "test" view overlapping train's subjects.
let mut shared_subjects = std::collections::BTreeSet::new();
let leaked = *train.subjects().iter().next().unwrap();
shared_subjects.insert(leaked);
let bad_test = MmFiSplitView {
parent: &ds,
global_indices: train.global_indices().to_vec(),
subjects: shared_subjects,
name: "bad",
};
assert!(matches!(
assert_split_leak_free(&train, &bad_test),
Err(DatasetError::InvalidSplit(_))
));
}
#[test]
fn synthetic_different_indices_differ() {
let cfg = SyntheticConfig::default();

View File

@ -280,6 +280,12 @@ pub enum DatasetError {
/// An I/O error that carries no path context.
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// A train/test split is invalid — it leaks information across the boundary
/// (a subject appears in both partitions, or a window is shared) or is
/// degenerate (an empty partition). ADR-155 §Tier-1.2.
#[error("Invalid split: {0}")]
InvalidSplit(String),
}
impl DatasetError {