diff --git a/v2/crates/wifi-densepose-train/src/bin/train.rs b/v2/crates/wifi-densepose-train/src/bin/train.rs index 7126d24f..5fe7bab2 100644 --- a/v2/crates/wifi-densepose-train/src/bin/train.rs +++ b/v2/crates/wifi-densepose-train/src/bin/train.rs @@ -25,7 +25,7 @@ use clap::Parser; use std::path::PathBuf; -use tracing::{error, info}; +use tracing::{error, info, warn}; use wifi_densepose_train::{ config::TrainingConfig, @@ -170,8 +170,13 @@ fn main() { train_ds.len(), val_ds.len() ); + warn!( + "[SMOKE-TEST ONLY] --dry-run trains and validates on SYNTHETIC data. \ + Any val_pck/val_oks is a pipeline smoke-test and MUST NOT be reported \ + as accuracy (ADR-155 §Tier-1.2)." + ); - run_training(config, &train_ds, &val_ds); + run_smoke_test(config, &train_ds, &val_ds); } else { info!("Loading MM-Fi dataset from {}", data_dir.display()); @@ -199,22 +204,47 @@ fn main() { info!("Dataset: {} samples", train_ds.len()); - // Use a small synthetic validation set when running without a split. - let val_syn_cfg = SyntheticConfig { - num_subcarriers: config.num_subcarriers, - num_antennas_tx: config.num_antennas_tx, - num_antennas_rx: config.num_antennas_rx, - window_frames: config.window_frames, - num_keypoints: config.num_keypoints, - signal_frequency_hz: 2.4e9, - }; - let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg); - info!( - "Using synthetic validation set ({} samples) for pipeline verification", - val_ds.len() - ); - - run_training(config, &train_ds, &val_ds); + // ADR-155 §Tier-1.2: prefer a REAL, leak-free, subject-disjoint split so + // any reported PCK/OKS is honest. MM-Fi windows are stride-1 (≈99% + // overlap), so an index-level split would leak; a synthetic val set + // makes the metric meaningless. Split at the subject level when the + // dataset has ≥2 subjects. + match train_ds.subject_disjoint_split(0.2, config.seed) { + Ok((train_view, val_view)) => { + info!( + "Leak-free subject-disjoint split: {} train windows (subjects {:?}) / \ + {} val windows (subjects {:?})", + train_view.len(), + train_view.subjects(), + val_view.len(), + val_view.subjects(), + ); + run_training(config, &train_view, &val_view); + } + Err(e) => { + // Cannot form a real split (e.g. a single subject). Fall back to + // a SYNTHETIC val set, but make it UNMISTAKABLE that this is a + // smoke-test only — its metric is NOT a reportable number. + warn!("Cannot build a leak-free subject-disjoint split: {e}"); + warn!( + "[SMOKE-TEST ONLY] Falling back to a SYNTHETIC validation set. \ + ANY val_pck/val_oks printed below is a PIPELINE SMOKE-TEST on \ + synthetic data and MUST NOT be reported or claimed as accuracy \ + (ADR-155 §Tier-1.2). Provide a multi-subject dataset for a real \ + measurement." + ); + let val_syn_cfg = SyntheticConfig { + num_subcarriers: config.num_subcarriers, + num_antennas_tx: config.num_antennas_tx, + num_antennas_rx: config.num_antennas_rx, + window_frames: config.window_frames, + num_keypoints: config.num_keypoints, + signal_frequency_hz: 2.4e9, + }; + let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg); + run_smoke_test(config, &train_ds, &val_ds); + } + } } } @@ -265,6 +295,55 @@ fn run_training(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn info!("Config and dataset infrastructure: OK"); } +// --------------------------------------------------------------------------- +// run_smoke_test — synthetic-validation path (NOT a reportable metric) +// --------------------------------------------------------------------------- +// +// ADR-155 §Tier-1.2: identical to `run_training` but every metric it surfaces +// is prefixed/labelled as a SMOKE-TEST so a synthetic-val PCK can never be +// mistaken for a measured accuracy number. + +#[cfg(feature = "tch-backend")] +fn run_smoke_test(config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) { + use wifi_densepose_train::trainer::Trainer; + + warn!( + "[SMOKE-TEST] Starting SYNTHETIC-validation run: {} train / {} val samples. \ + Reported PCK/OKS below are NOT measurements.", + train_ds.len(), + val_ds.len() + ); + + let mut trainer = Trainer::new(config); + match trainer.train(train_ds, val_ds) { + Ok(result) => { + warn!("[SMOKE-TEST] Pipeline ran end-to-end (no crash). Metrics are synthetic:"); + warn!( + "[SMOKE-TEST] (DO NOT REPORT) best_pck@0.2={:.4} @ epoch {} — synthetic val", + result.best_pck, result.best_epoch + ); + info!( + "[SMOKE-TEST] Final train loss: {:.6}", + result.final_train_loss + ); + } + Err(e) => { + error!("[SMOKE-TEST] Pipeline failed: {e}"); + std::process::exit(1); + } + } +} + +#[cfg(not(feature = "tch-backend"))] +fn run_smoke_test(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) { + warn!( + "[SMOKE-TEST] Pipeline verification only: {} train / {} synthetic-val samples loaded. \ + No metric is produced; build with --features tch-backend to run the pipeline.", + train_ds.len(), + val_ds.len() + ); +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- diff --git a/v2/crates/wifi-densepose-train/src/dataset.rs b/v2/crates/wifi-densepose-train/src/dataset.rs index 78d77ae2..d13e8329 100644 --- a/v2/crates/wifi-densepose-train/src/dataset.rs +++ b/v2/crates/wifi-densepose-train/src/dataset.rs @@ -519,6 +519,233 @@ impl CsiDataset for MmFiDataset { } } +// --------------------------------------------------------------------------- +// Leak-free train/test split (ADR-155 §Tier-1.2) +// --------------------------------------------------------------------------- +// +// Why this exists: MM-Fi windows are extracted with stride 1 +// (`MmFiEntry::num_windows` = `num_frames − window_frames + 1`), so adjacent +// windows overlap by `window_frames − 1` frames. A naive index-level random +// split therefore puts near-identical windows on both sides of the boundary — +// up to ~99% information leakage — and any PCK it reports is meaningless. The +// leak-free discipline (mirrored from `occupancy_bench::EvalSplit`) is to split +// at the **subject** level: a subject's clips (and thus all of its windows) go +// entirely to train or entirely to test. Disjoint subjects ⇒ no shared window, +// and no temporally-adjacent window can straddle the boundary. + +/// A borrowed, read-only view over a contiguous-by-subject subset of a parent +/// [`MmFiDataset`]'s windows. Implements [`CsiDataset`] so it can be passed +/// straight to the trainer. Produced only by +/// [`MmFiDataset::subject_disjoint_split`], which guarantees the two returned +/// views are subject- and window-disjoint. +pub struct MmFiSplitView<'a> { + parent: &'a MmFiDataset, + /// Global parent window indices owned by this view (sorted, unique). + global_indices: Vec, + /// Subject ids present in this view (for leak validation / reporting). + subjects: std::collections::BTreeSet, + name: &'static str, +} + +impl<'a> MmFiSplitView<'a> { + /// Subject ids covered by this view. + pub fn subjects(&self) -> &std::collections::BTreeSet { + &self.subjects + } + + /// Global parent window indices owned by this view. + pub fn global_indices(&self) -> &[usize] { + &self.global_indices + } +} + +impl<'a> CsiDataset for MmFiSplitView<'a> { + fn len(&self) -> usize { + self.global_indices.len() + } + + fn get(&self, idx: usize) -> Result { + let g = *self + .global_indices + .get(idx) + .ok_or(DatasetError::IndexOutOfBounds { + idx, + len: self.global_indices.len(), + })?; + self.parent.get(g) + } + + fn name(&self) -> &str { + self.name + } +} + +impl MmFiDataset { + /// All subject ids present in the scanned dataset (sorted, unique). + pub fn subjects(&self) -> Vec { + let set: std::collections::BTreeSet = + self.entries.iter().map(|e| e.subject_id).collect(); + set.into_iter().collect() + } + + /// Split into **subject-disjoint** train / test views (ADR-155 §Tier-1.2). + /// + /// Subjects are assigned wholesale to one side: roughly + /// `test_subject_fraction` of the distinct subjects (at least one, and at + /// least one left for train) go to the test view, the rest to train. Because + /// every window of a subject travels with that subject, the two views share + /// **no subject and no window** — the split is leak-free by construction. + /// + /// Assignment is deterministic for a given `seed` (seeded Fisher-Yates over + /// the sorted subject list), so runs are reproducible. + /// + /// # Errors + /// [`DatasetError::InvalidSplit`] when there are fewer than 2 subjects, when + /// `test_subject_fraction` is not in `(0, 1)`, or when either side would be + /// empty. + pub fn subject_disjoint_split( + &self, + test_subject_fraction: f64, + seed: u64, + ) -> Result<(MmFiSplitView<'_>, MmFiSplitView<'_>), DatasetError> { + if !(test_subject_fraction > 0.0 && test_subject_fraction < 1.0) { + return Err(DatasetError::InvalidSplit(format!( + "test_subject_fraction must be in (0,1), got {test_subject_fraction}" + ))); + } + let mut subjects = self.subjects(); + if subjects.len() < 2 { + return Err(DatasetError::InvalidSplit(format!( + "need >= 2 distinct subjects for a subject-disjoint split, got {}", + subjects.len() + ))); + } + + // Deterministic shuffle of the sorted subject list. + xorshift_shuffle_u32(&mut subjects, seed); + let n_test = ((subjects.len() as f64 * test_subject_fraction).round() as usize) + .clamp(1, subjects.len() - 1); + let test_subjects: std::collections::BTreeSet = + subjects[..n_test].iter().copied().collect(); + let train_subjects: std::collections::BTreeSet = + subjects[n_test..].iter().copied().collect(); + + // Partition global window indices by the owning entry's subject. + let mut train_idx = Vec::new(); + let mut test_idx = Vec::new(); + for (entry_i, entry) in self.entries.iter().enumerate() { + let start = self.cumulative[entry_i]; + let end = self.cumulative[entry_i + 1]; + if test_subjects.contains(&entry.subject_id) { + test_idx.extend(start..end); + } else { + train_idx.extend(start..end); + } + } + + if train_idx.is_empty() || test_idx.is_empty() { + return Err(DatasetError::InvalidSplit( + "split produced an empty partition (a subject set has no windows)".into(), + )); + } + + let train = MmFiSplitView { + parent: self, + global_indices: train_idx, + subjects: train_subjects, + name: "MmFiDataset[train]", + }; + let test = MmFiSplitView { + parent: self, + global_indices: test_idx, + subjects: test_subjects, + name: "MmFiDataset[test]", + }; + + // Self-check: never hand out a leaky split. + assert_split_leak_free(&train, &test)?; + Ok((train, test)) + } +} + +/// Verify a train/test split is leak-free: subject-disjoint **and** +/// window-disjoint, with both sides non-empty (ADR-155 §Tier-1.2). +/// +/// Returns [`DatasetError::InvalidSplit`] describing the first violation found. +pub fn assert_split_leak_free( + train: &MmFiSplitView<'_>, + test: &MmFiSplitView<'_>, +) -> Result<(), DatasetError> { + if train.global_indices.is_empty() || test.global_indices.is_empty() { + return Err(DatasetError::InvalidSplit("a partition is empty".into())); + } + // Subject disjointness. + if let Some(shared) = train.subjects.intersection(&test.subjects).next() { + return Err(DatasetError::InvalidSplit(format!( + "subject {shared} appears in both train and test (subject leakage)" + ))); + } + // Window disjointness (guards against any index bug in the partitioner). + let train_set: std::collections::BTreeSet = + train.global_indices.iter().copied().collect(); + if let Some(shared) = test.global_indices.iter().find(|i| train_set.contains(i)) { + return Err(DatasetError::InvalidSplit(format!( + "window {shared} appears in both train and test (window leakage)" + ))); + } + Ok(()) +} + +#[cfg(test)] +impl MmFiDataset { + /// Build a metadata-only `MmFiDataset` for split tests: fabricated entries + /// with given `(subject_id, action_id, num_frames)` and a window size. No + /// files are touched — only the split / leak-check logic (which reads + /// `subject_id` + window counts, never `get()`) is exercised. + fn from_entries_for_test(clips: &[(u32, u32, usize)], window_frames: usize) -> Self { + let entries: Vec = clips + .iter() + .map(|&(subject_id, action_id, num_frames)| MmFiEntry { + subject_id, + action_id, + amp_path: PathBuf::from("/nonexistent/wifi_csi.npy"), + phase_path: PathBuf::from("/nonexistent/wifi_csi_phase.npy"), + kp_path: PathBuf::from("/nonexistent/gt_keypoints.npy"), + num_frames, + window_frames, + }) + .collect(); + let mut cumulative = vec![0usize; entries.len() + 1]; + for (i, e) in entries.iter().enumerate() { + cumulative[i + 1] = cumulative[i] + e.num_windows(); + } + MmFiDataset { + entries, + cumulative, + window_frames, + target_subcarriers: 56, + num_keypoints: 17, + root: PathBuf::from("/nonexistent"), + } + } +} + +/// Deterministic Fisher-Yates shuffle of a `u32` slice (seeded Xorshift64). +fn xorshift_shuffle_u32(items: &mut [u32], seed: u64) { + let n = items.len(); + if n <= 1 { + return; + } + let mut state = if seed == 0 { 0x853c49e6748fea9b } else { seed }; + for i in (1..n).rev() { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + let j = (state % (i as u64 + 1)) as usize; + items.swap(i, j); + } +} + // --------------------------------------------------------------------------- // CompressedCsiBuffer // --------------------------------------------------------------------------- @@ -1019,6 +1246,91 @@ mod tests { assert_abs_diff_eq!(s0a.keypoints[[5, 0]], s0b.keypoints[[5, 0]], epsilon = 1e-7); } + // ----- Leak-free subject-disjoint split (ADR-155 §Tier-1.2) ----------- + + fn split_fixture() -> MmFiDataset { + // 6 subjects × 2 clips each, 50 frames per clip, window 10 ⇒ 41 + // overlapping windows per clip. A leaky index-split would put adjacent + // (near-identical) windows on both sides; the subject split cannot. + let mut clips = Vec::new(); + for s in 1..=6u32 { + for a in 1..=2u32 { + clips.push((s, a, 50usize)); + } + } + MmFiDataset::from_entries_for_test(&clips, 10) + } + + #[test] + fn subject_split_is_subject_and_window_disjoint() { + let ds = split_fixture(); + let (train, test) = ds.subject_disjoint_split(0.34, 42).unwrap(); + + // No subject is shared. + assert!(train.subjects().is_disjoint(test.subjects())); + // assert_split_leak_free agrees (subject + window disjoint, non-empty). + assert_split_leak_free(&train, &test).expect("split must be leak-free"); + + // No global window index is shared. + let train_set: std::collections::BTreeSet = + train.global_indices().iter().copied().collect(); + for g in test.global_indices() { + assert!(!train_set.contains(g), "window {g} leaked across the split"); + } + + // Every window is accounted for exactly once (partition, not sample). + assert_eq!(train.len() + test.len(), ds.len()); + assert!(train.len() > 0 && test.len() > 0); + } + + #[test] + fn subject_split_is_deterministic_for_seed() { + let ds = split_fixture(); + let (tr1, te1) = ds.subject_disjoint_split(0.34, 7).unwrap(); + let (tr2, te2) = ds.subject_disjoint_split(0.34, 7).unwrap(); + assert_eq!(tr1.subjects(), tr2.subjects()); + assert_eq!(te1.subjects(), te2.subjects()); + } + + #[test] + fn subject_split_rejects_single_subject() { + // Only one subject ⇒ a subject-disjoint split is impossible. + let ds = MmFiDataset::from_entries_for_test(&[(1, 1, 50), (1, 2, 50)], 10); + assert!(matches!( + ds.subject_disjoint_split(0.3, 1), + Err(DatasetError::InvalidSplit(_)) + )); + } + + #[test] + fn subject_split_rejects_bad_fraction() { + let ds = split_fixture(); + assert!(ds.subject_disjoint_split(0.0, 1).is_err()); + assert!(ds.subject_disjoint_split(1.0, 1).is_err()); + } + + #[test] + fn assert_leak_free_detects_injected_subject_leak() { + // Build two views that deliberately share subject 3 and prove the + // validator catches it (a guard against future partitioner bugs). + let ds = split_fixture(); + let (train, _test) = ds.subject_disjoint_split(0.34, 42).unwrap(); + // Fabricate a "test" view overlapping train's subjects. + let mut shared_subjects = std::collections::BTreeSet::new(); + let leaked = *train.subjects().iter().next().unwrap(); + shared_subjects.insert(leaked); + let bad_test = MmFiSplitView { + parent: &ds, + global_indices: train.global_indices().to_vec(), + subjects: shared_subjects, + name: "bad", + }; + assert!(matches!( + assert_split_leak_free(&train, &bad_test), + Err(DatasetError::InvalidSplit(_)) + )); + } + #[test] fn synthetic_different_indices_differ() { let cfg = SyntheticConfig::default(); diff --git a/v2/crates/wifi-densepose-train/src/error.rs b/v2/crates/wifi-densepose-train/src/error.rs index 3d2c2fcd..2a4f824c 100644 --- a/v2/crates/wifi-densepose-train/src/error.rs +++ b/v2/crates/wifi-densepose-train/src/error.rs @@ -280,6 +280,12 @@ pub enum DatasetError { /// An I/O error that carries no path context. #[error("IO error: {0}")] Io(#[from] std::io::Error), + + /// A train/test split is invalid — it leaks information across the boundary + /// (a subject appears in both partitions, or a window is shared) or is + /// degenerate (an empty partition). ADR-155 §Tier-1.2. + #[error("Invalid split: {0}")] + InvalidSplit(String), } impl DatasetError {