fix(train): leak-free subject-disjoint split + synthetic-val disclosure (ADR-155 §Tier-1.2)
MM-Fi windows are stride-1 (~99% overlap), so an index-level split leaks; and
bin/train.rs validated real training against a SYNTHETIC val set, making any
printed PCK meaningless on two counts.
- MmFiDataset::subject_disjoint_split partitions whole subjects -> the two views
share no subject and no window (leak-free by construction, deterministic per
seed). assert_split_leak_free verifies subject- AND window-disjointness and is
called inside the split so a leaky split is never handed out.
- bin/train.rs now prefers the real split; the synthetic path is a labelled
run_smoke_test ("[SMOKE-TEST] DO NOT REPORT") reachable only as a fallback.
- New DatasetError::InvalidSplit.
Tests prove disjointness, determinism, single-subject/bad-fraction rejection,
and that the validator catches an injected subject leak.
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
50b657459f
commit
2a2a2c5b06
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
use clap::Parser;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use wifi_densepose_train::{
|
||||
config::TrainingConfig,
|
||||
|
|
@ -170,8 +170,13 @@ fn main() {
|
|||
train_ds.len(),
|
||||
val_ds.len()
|
||||
);
|
||||
warn!(
|
||||
"[SMOKE-TEST ONLY] --dry-run trains and validates on SYNTHETIC data. \
|
||||
Any val_pck/val_oks is a pipeline smoke-test and MUST NOT be reported \
|
||||
as accuracy (ADR-155 §Tier-1.2)."
|
||||
);
|
||||
|
||||
run_training(config, &train_ds, &val_ds);
|
||||
run_smoke_test(config, &train_ds, &val_ds);
|
||||
} else {
|
||||
info!("Loading MM-Fi dataset from {}", data_dir.display());
|
||||
|
||||
|
|
@ -199,22 +204,47 @@ fn main() {
|
|||
|
||||
info!("Dataset: {} samples", train_ds.len());
|
||||
|
||||
// Use a small synthetic validation set when running without a split.
|
||||
let val_syn_cfg = SyntheticConfig {
|
||||
num_subcarriers: config.num_subcarriers,
|
||||
num_antennas_tx: config.num_antennas_tx,
|
||||
num_antennas_rx: config.num_antennas_rx,
|
||||
window_frames: config.window_frames,
|
||||
num_keypoints: config.num_keypoints,
|
||||
signal_frequency_hz: 2.4e9,
|
||||
};
|
||||
let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg);
|
||||
info!(
|
||||
"Using synthetic validation set ({} samples) for pipeline verification",
|
||||
val_ds.len()
|
||||
);
|
||||
|
||||
run_training(config, &train_ds, &val_ds);
|
||||
// ADR-155 §Tier-1.2: prefer a REAL, leak-free, subject-disjoint split so
|
||||
// any reported PCK/OKS is honest. MM-Fi windows are stride-1 (≈99%
|
||||
// overlap), so an index-level split would leak; a synthetic val set
|
||||
// makes the metric meaningless. Split at the subject level when the
|
||||
// dataset has ≥2 subjects.
|
||||
match train_ds.subject_disjoint_split(0.2, config.seed) {
|
||||
Ok((train_view, val_view)) => {
|
||||
info!(
|
||||
"Leak-free subject-disjoint split: {} train windows (subjects {:?}) / \
|
||||
{} val windows (subjects {:?})",
|
||||
train_view.len(),
|
||||
train_view.subjects(),
|
||||
val_view.len(),
|
||||
val_view.subjects(),
|
||||
);
|
||||
run_training(config, &train_view, &val_view);
|
||||
}
|
||||
Err(e) => {
|
||||
// Cannot form a real split (e.g. a single subject). Fall back to
|
||||
// a SYNTHETIC val set, but make it UNMISTAKABLE that this is a
|
||||
// smoke-test only — its metric is NOT a reportable number.
|
||||
warn!("Cannot build a leak-free subject-disjoint split: {e}");
|
||||
warn!(
|
||||
"[SMOKE-TEST ONLY] Falling back to a SYNTHETIC validation set. \
|
||||
ANY val_pck/val_oks printed below is a PIPELINE SMOKE-TEST on \
|
||||
synthetic data and MUST NOT be reported or claimed as accuracy \
|
||||
(ADR-155 §Tier-1.2). Provide a multi-subject dataset for a real \
|
||||
measurement."
|
||||
);
|
||||
let val_syn_cfg = SyntheticConfig {
|
||||
num_subcarriers: config.num_subcarriers,
|
||||
num_antennas_tx: config.num_antennas_tx,
|
||||
num_antennas_rx: config.num_antennas_rx,
|
||||
window_frames: config.window_frames,
|
||||
num_keypoints: config.num_keypoints,
|
||||
signal_frequency_hz: 2.4e9,
|
||||
};
|
||||
let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg);
|
||||
run_smoke_test(config, &train_ds, &val_ds);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -265,6 +295,55 @@ fn run_training(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn
|
|||
info!("Config and dataset infrastructure: OK");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// run_smoke_test — synthetic-validation path (NOT a reportable metric)
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// ADR-155 §Tier-1.2: identical to `run_training` but every metric it surfaces
|
||||
// is prefixed/labelled as a SMOKE-TEST so a synthetic-val PCK can never be
|
||||
// mistaken for a measured accuracy number.
|
||||
|
||||
#[cfg(feature = "tch-backend")]
|
||||
fn run_smoke_test(config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) {
|
||||
use wifi_densepose_train::trainer::Trainer;
|
||||
|
||||
warn!(
|
||||
"[SMOKE-TEST] Starting SYNTHETIC-validation run: {} train / {} val samples. \
|
||||
Reported PCK/OKS below are NOT measurements.",
|
||||
train_ds.len(),
|
||||
val_ds.len()
|
||||
);
|
||||
|
||||
let mut trainer = Trainer::new(config);
|
||||
match trainer.train(train_ds, val_ds) {
|
||||
Ok(result) => {
|
||||
warn!("[SMOKE-TEST] Pipeline ran end-to-end (no crash). Metrics are synthetic:");
|
||||
warn!(
|
||||
"[SMOKE-TEST] (DO NOT REPORT) best_pck@0.2={:.4} @ epoch {} — synthetic val",
|
||||
result.best_pck, result.best_epoch
|
||||
);
|
||||
info!(
|
||||
"[SMOKE-TEST] Final train loss: {:.6}",
|
||||
result.final_train_loss
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("[SMOKE-TEST] Pipeline failed: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tch-backend"))]
|
||||
fn run_smoke_test(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) {
|
||||
warn!(
|
||||
"[SMOKE-TEST] Pipeline verification only: {} train / {} synthetic-val samples loaded. \
|
||||
No metric is produced; build with --features tch-backend to run the pipeline.",
|
||||
train_ds.len(),
|
||||
val_ds.len()
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -519,6 +519,233 @@ impl CsiDataset for MmFiDataset {
|
|||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Leak-free train/test split (ADR-155 §Tier-1.2)
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// Why this exists: MM-Fi windows are extracted with stride 1
|
||||
// (`MmFiEntry::num_windows` = `num_frames − window_frames + 1`), so adjacent
|
||||
// windows overlap by `window_frames − 1` frames. A naive index-level random
|
||||
// split therefore puts near-identical windows on both sides of the boundary —
|
||||
// up to ~99% information leakage — and any PCK it reports is meaningless. The
|
||||
// leak-free discipline (mirrored from `occupancy_bench::EvalSplit`) is to split
|
||||
// at the **subject** level: a subject's clips (and thus all of its windows) go
|
||||
// entirely to train or entirely to test. Disjoint subjects ⇒ no shared window,
|
||||
// and no temporally-adjacent window can straddle the boundary.
|
||||
|
||||
/// A borrowed, read-only view over a contiguous-by-subject subset of a parent
|
||||
/// [`MmFiDataset`]'s windows. Implements [`CsiDataset`] so it can be passed
|
||||
/// straight to the trainer. Produced only by
|
||||
/// [`MmFiDataset::subject_disjoint_split`], which guarantees the two returned
|
||||
/// views are subject- and window-disjoint.
|
||||
pub struct MmFiSplitView<'a> {
|
||||
parent: &'a MmFiDataset,
|
||||
/// Global parent window indices owned by this view (sorted, unique).
|
||||
global_indices: Vec<usize>,
|
||||
/// Subject ids present in this view (for leak validation / reporting).
|
||||
subjects: std::collections::BTreeSet<u32>,
|
||||
name: &'static str,
|
||||
}
|
||||
|
||||
impl<'a> MmFiSplitView<'a> {
|
||||
/// Subject ids covered by this view.
|
||||
pub fn subjects(&self) -> &std::collections::BTreeSet<u32> {
|
||||
&self.subjects
|
||||
}
|
||||
|
||||
/// Global parent window indices owned by this view.
|
||||
pub fn global_indices(&self) -> &[usize] {
|
||||
&self.global_indices
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> CsiDataset for MmFiSplitView<'a> {
|
||||
fn len(&self) -> usize {
|
||||
self.global_indices.len()
|
||||
}
|
||||
|
||||
fn get(&self, idx: usize) -> Result<CsiSample, DatasetError> {
|
||||
let g = *self
|
||||
.global_indices
|
||||
.get(idx)
|
||||
.ok_or(DatasetError::IndexOutOfBounds {
|
||||
idx,
|
||||
len: self.global_indices.len(),
|
||||
})?;
|
||||
self.parent.get(g)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
}
|
||||
|
||||
impl MmFiDataset {
|
||||
/// All subject ids present in the scanned dataset (sorted, unique).
|
||||
pub fn subjects(&self) -> Vec<u32> {
|
||||
let set: std::collections::BTreeSet<u32> =
|
||||
self.entries.iter().map(|e| e.subject_id).collect();
|
||||
set.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Split into **subject-disjoint** train / test views (ADR-155 §Tier-1.2).
|
||||
///
|
||||
/// Subjects are assigned wholesale to one side: roughly
|
||||
/// `test_subject_fraction` of the distinct subjects (at least one, and at
|
||||
/// least one left for train) go to the test view, the rest to train. Because
|
||||
/// every window of a subject travels with that subject, the two views share
|
||||
/// **no subject and no window** — the split is leak-free by construction.
|
||||
///
|
||||
/// Assignment is deterministic for a given `seed` (seeded Fisher-Yates over
|
||||
/// the sorted subject list), so runs are reproducible.
|
||||
///
|
||||
/// # Errors
|
||||
/// [`DatasetError::InvalidSplit`] when there are fewer than 2 subjects, when
|
||||
/// `test_subject_fraction` is not in `(0, 1)`, or when either side would be
|
||||
/// empty.
|
||||
pub fn subject_disjoint_split(
|
||||
&self,
|
||||
test_subject_fraction: f64,
|
||||
seed: u64,
|
||||
) -> Result<(MmFiSplitView<'_>, MmFiSplitView<'_>), DatasetError> {
|
||||
if !(test_subject_fraction > 0.0 && test_subject_fraction < 1.0) {
|
||||
return Err(DatasetError::InvalidSplit(format!(
|
||||
"test_subject_fraction must be in (0,1), got {test_subject_fraction}"
|
||||
)));
|
||||
}
|
||||
let mut subjects = self.subjects();
|
||||
if subjects.len() < 2 {
|
||||
return Err(DatasetError::InvalidSplit(format!(
|
||||
"need >= 2 distinct subjects for a subject-disjoint split, got {}",
|
||||
subjects.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Deterministic shuffle of the sorted subject list.
|
||||
xorshift_shuffle_u32(&mut subjects, seed);
|
||||
let n_test = ((subjects.len() as f64 * test_subject_fraction).round() as usize)
|
||||
.clamp(1, subjects.len() - 1);
|
||||
let test_subjects: std::collections::BTreeSet<u32> =
|
||||
subjects[..n_test].iter().copied().collect();
|
||||
let train_subjects: std::collections::BTreeSet<u32> =
|
||||
subjects[n_test..].iter().copied().collect();
|
||||
|
||||
// Partition global window indices by the owning entry's subject.
|
||||
let mut train_idx = Vec::new();
|
||||
let mut test_idx = Vec::new();
|
||||
for (entry_i, entry) in self.entries.iter().enumerate() {
|
||||
let start = self.cumulative[entry_i];
|
||||
let end = self.cumulative[entry_i + 1];
|
||||
if test_subjects.contains(&entry.subject_id) {
|
||||
test_idx.extend(start..end);
|
||||
} else {
|
||||
train_idx.extend(start..end);
|
||||
}
|
||||
}
|
||||
|
||||
if train_idx.is_empty() || test_idx.is_empty() {
|
||||
return Err(DatasetError::InvalidSplit(
|
||||
"split produced an empty partition (a subject set has no windows)".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let train = MmFiSplitView {
|
||||
parent: self,
|
||||
global_indices: train_idx,
|
||||
subjects: train_subjects,
|
||||
name: "MmFiDataset[train]",
|
||||
};
|
||||
let test = MmFiSplitView {
|
||||
parent: self,
|
||||
global_indices: test_idx,
|
||||
subjects: test_subjects,
|
||||
name: "MmFiDataset[test]",
|
||||
};
|
||||
|
||||
// Self-check: never hand out a leaky split.
|
||||
assert_split_leak_free(&train, &test)?;
|
||||
Ok((train, test))
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify a train/test split is leak-free: subject-disjoint **and**
|
||||
/// window-disjoint, with both sides non-empty (ADR-155 §Tier-1.2).
|
||||
///
|
||||
/// Returns [`DatasetError::InvalidSplit`] describing the first violation found.
|
||||
pub fn assert_split_leak_free(
|
||||
train: &MmFiSplitView<'_>,
|
||||
test: &MmFiSplitView<'_>,
|
||||
) -> Result<(), DatasetError> {
|
||||
if train.global_indices.is_empty() || test.global_indices.is_empty() {
|
||||
return Err(DatasetError::InvalidSplit("a partition is empty".into()));
|
||||
}
|
||||
// Subject disjointness.
|
||||
if let Some(shared) = train.subjects.intersection(&test.subjects).next() {
|
||||
return Err(DatasetError::InvalidSplit(format!(
|
||||
"subject {shared} appears in both train and test (subject leakage)"
|
||||
)));
|
||||
}
|
||||
// Window disjointness (guards against any index bug in the partitioner).
|
||||
let train_set: std::collections::BTreeSet<usize> =
|
||||
train.global_indices.iter().copied().collect();
|
||||
if let Some(shared) = test.global_indices.iter().find(|i| train_set.contains(i)) {
|
||||
return Err(DatasetError::InvalidSplit(format!(
|
||||
"window {shared} appears in both train and test (window leakage)"
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl MmFiDataset {
|
||||
/// Build a metadata-only `MmFiDataset` for split tests: fabricated entries
|
||||
/// with given `(subject_id, action_id, num_frames)` and a window size. No
|
||||
/// files are touched — only the split / leak-check logic (which reads
|
||||
/// `subject_id` + window counts, never `get()`) is exercised.
|
||||
fn from_entries_for_test(clips: &[(u32, u32, usize)], window_frames: usize) -> Self {
|
||||
let entries: Vec<MmFiEntry> = clips
|
||||
.iter()
|
||||
.map(|&(subject_id, action_id, num_frames)| MmFiEntry {
|
||||
subject_id,
|
||||
action_id,
|
||||
amp_path: PathBuf::from("/nonexistent/wifi_csi.npy"),
|
||||
phase_path: PathBuf::from("/nonexistent/wifi_csi_phase.npy"),
|
||||
kp_path: PathBuf::from("/nonexistent/gt_keypoints.npy"),
|
||||
num_frames,
|
||||
window_frames,
|
||||
})
|
||||
.collect();
|
||||
let mut cumulative = vec![0usize; entries.len() + 1];
|
||||
for (i, e) in entries.iter().enumerate() {
|
||||
cumulative[i + 1] = cumulative[i] + e.num_windows();
|
||||
}
|
||||
MmFiDataset {
|
||||
entries,
|
||||
cumulative,
|
||||
window_frames,
|
||||
target_subcarriers: 56,
|
||||
num_keypoints: 17,
|
||||
root: PathBuf::from("/nonexistent"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deterministic Fisher-Yates shuffle of a `u32` slice (seeded Xorshift64).
|
||||
fn xorshift_shuffle_u32(items: &mut [u32], seed: u64) {
|
||||
let n = items.len();
|
||||
if n <= 1 {
|
||||
return;
|
||||
}
|
||||
let mut state = if seed == 0 { 0x853c49e6748fea9b } else { seed };
|
||||
for i in (1..n).rev() {
|
||||
state ^= state << 13;
|
||||
state ^= state >> 7;
|
||||
state ^= state << 17;
|
||||
let j = (state % (i as u64 + 1)) as usize;
|
||||
items.swap(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CompressedCsiBuffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -1019,6 +1246,91 @@ mod tests {
|
|||
assert_abs_diff_eq!(s0a.keypoints[[5, 0]], s0b.keypoints[[5, 0]], epsilon = 1e-7);
|
||||
}
|
||||
|
||||
// ----- Leak-free subject-disjoint split (ADR-155 §Tier-1.2) -----------
|
||||
|
||||
fn split_fixture() -> MmFiDataset {
|
||||
// 6 subjects × 2 clips each, 50 frames per clip, window 10 ⇒ 41
|
||||
// overlapping windows per clip. A leaky index-split would put adjacent
|
||||
// (near-identical) windows on both sides; the subject split cannot.
|
||||
let mut clips = Vec::new();
|
||||
for s in 1..=6u32 {
|
||||
for a in 1..=2u32 {
|
||||
clips.push((s, a, 50usize));
|
||||
}
|
||||
}
|
||||
MmFiDataset::from_entries_for_test(&clips, 10)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subject_split_is_subject_and_window_disjoint() {
|
||||
let ds = split_fixture();
|
||||
let (train, test) = ds.subject_disjoint_split(0.34, 42).unwrap();
|
||||
|
||||
// No subject is shared.
|
||||
assert!(train.subjects().is_disjoint(test.subjects()));
|
||||
// assert_split_leak_free agrees (subject + window disjoint, non-empty).
|
||||
assert_split_leak_free(&train, &test).expect("split must be leak-free");
|
||||
|
||||
// No global window index is shared.
|
||||
let train_set: std::collections::BTreeSet<usize> =
|
||||
train.global_indices().iter().copied().collect();
|
||||
for g in test.global_indices() {
|
||||
assert!(!train_set.contains(g), "window {g} leaked across the split");
|
||||
}
|
||||
|
||||
// Every window is accounted for exactly once (partition, not sample).
|
||||
assert_eq!(train.len() + test.len(), ds.len());
|
||||
assert!(train.len() > 0 && test.len() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subject_split_is_deterministic_for_seed() {
|
||||
let ds = split_fixture();
|
||||
let (tr1, te1) = ds.subject_disjoint_split(0.34, 7).unwrap();
|
||||
let (tr2, te2) = ds.subject_disjoint_split(0.34, 7).unwrap();
|
||||
assert_eq!(tr1.subjects(), tr2.subjects());
|
||||
assert_eq!(te1.subjects(), te2.subjects());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subject_split_rejects_single_subject() {
|
||||
// Only one subject ⇒ a subject-disjoint split is impossible.
|
||||
let ds = MmFiDataset::from_entries_for_test(&[(1, 1, 50), (1, 2, 50)], 10);
|
||||
assert!(matches!(
|
||||
ds.subject_disjoint_split(0.3, 1),
|
||||
Err(DatasetError::InvalidSplit(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subject_split_rejects_bad_fraction() {
|
||||
let ds = split_fixture();
|
||||
assert!(ds.subject_disjoint_split(0.0, 1).is_err());
|
||||
assert!(ds.subject_disjoint_split(1.0, 1).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assert_leak_free_detects_injected_subject_leak() {
|
||||
// Build two views that deliberately share subject 3 and prove the
|
||||
// validator catches it (a guard against future partitioner bugs).
|
||||
let ds = split_fixture();
|
||||
let (train, _test) = ds.subject_disjoint_split(0.34, 42).unwrap();
|
||||
// Fabricate a "test" view overlapping train's subjects.
|
||||
let mut shared_subjects = std::collections::BTreeSet::new();
|
||||
let leaked = *train.subjects().iter().next().unwrap();
|
||||
shared_subjects.insert(leaked);
|
||||
let bad_test = MmFiSplitView {
|
||||
parent: &ds,
|
||||
global_indices: train.global_indices().to_vec(),
|
||||
subjects: shared_subjects,
|
||||
name: "bad",
|
||||
};
|
||||
assert!(matches!(
|
||||
assert_split_leak_free(&train, &bad_test),
|
||||
Err(DatasetError::InvalidSplit(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn synthetic_different_indices_differ() {
|
||||
let cfg = SyntheticConfig::default();
|
||||
|
|
|
|||
|
|
@ -280,6 +280,12 @@ pub enum DatasetError {
|
|||
/// An I/O error that carries no path context.
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// A train/test split is invalid — it leaks information across the boundary
|
||||
/// (a subject appears in both partitions, or a window is shared) or is
|
||||
/// degenerate (an empty partition). ADR-155 §Tier-1.2.
|
||||
#[error("Invalid split: {0}")]
|
||||
InvalidSplit(String),
|
||||
}
|
||||
|
||||
impl DatasetError {
|
||||
|
|
|
|||
Loading…
Reference in New Issue