//! Integration + property tests for [`wifi_densepose_train::mae`] //! (ADR-152 §2.3 — UNSW MAE pretraining recipe). //! //! All deterministic tests use fixed seeds; property tests use `proptest` //! with its default deterministic-replay machinery. use proptest::prelude::*; use wifi_densepose_train::mae::{ patchify, random_mask, unpatchify, unpatchify_visible, MaePretrainConfig, }; use wifi_densepose_train::MaeError; /// Deterministic test window: value = t * 1000 + sc (every cell unique). fn window(time: usize, subc: usize) -> Vec { (0..time * subc) .map(|i| ((i / subc) * 1000 + i % subc) as f32) .collect() } // --------------------------------------------------------------------------- // Config defaults + validation // --------------------------------------------------------------------------- #[test] fn default_config_matches_unsw_recipe() { let cfg = MaePretrainConfig::default(); assert!((cfg.mask_ratio - 0.80).abs() < 1e-12); assert_eq!(cfg.patch_time, 30); assert_eq!(cfg.patch_subc, 3); assert_eq!(cfg.seed, 42); cfg.validate().expect("default recipe is valid"); } #[test] fn config_json_round_trip() { let cfg = MaePretrainConfig::default(); let json = serde_json::to_string(&cfg).unwrap(); let back: MaePretrainConfig = serde_json::from_str(&json).unwrap(); assert_eq!(back, cfg); } #[test] fn invalid_mask_ratio_rejected() { for ratio in [0.0, 1.0, -0.1, 1.5, f64::NAN] { let cfg = MaePretrainConfig { mask_ratio: ratio, ..MaePretrainConfig::default() }; assert!(cfg.validate().is_err(), "ratio {ratio} should be invalid"); } } #[test] fn zero_patch_dims_rejected() { let cfg = MaePretrainConfig { patch_time: 0, ..MaePretrainConfig::default() }; assert!(cfg.validate().is_err()); let cfg = MaePretrainConfig { patch_subc: 0, ..MaePretrainConfig::default() }; assert!(cfg.validate().is_err()); } // --------------------------------------------------------------------------- // Divisibility policy: error, never truncate // --------------------------------------------------------------------------- #[test] fn non_divisible_window_errors_with_crop_hint() { let cfg = MaePretrainConfig::default(); // (30, 3) // Default TrainingConfig window 100 × 56 is NOT divisible by (30, 3). let err = cfg.validate_for_window(100, 56).unwrap_err(); match err { MaeError::NotDivisible { axis, window, patch, remainder, crop, } => { assert_eq!(axis, "time"); assert_eq!(window, 100); assert_eq!(patch, 30); assert_eq!(remainder, 10); assert_eq!(crop, 90); } other => panic!("expected NotDivisible, got {other:?}"), } assert_eq!(cfg.cropped_window_shape(100, 56), (90, 54)); // The hinted crop validates cleanly. cfg.validate_for_window(90, 54).expect("crop is divisible"); assert_eq!(cfg.num_patches(90, 54).unwrap(), 3 * 18); } #[test] fn patch_larger_than_window_errors() { let cfg = MaePretrainConfig::default(); let err = cfg.validate_for_window(20, 3).unwrap_err(); assert!(matches!( err, MaeError::PatchExceedsWindow { axis: "time", .. } )); } #[test] fn window_length_mismatch_errors() { let cfg = MaePretrainConfig::default(); let buf = vec![0.0_f32; 89 * 54]; // declared 90 × 54 let err = patchify(&buf, 90, 54, &cfg).unwrap_err(); assert!(matches!(err, MaeError::WindowShapeMismatch { .. })); } // --------------------------------------------------------------------------- // NaN handling // --------------------------------------------------------------------------- #[test] fn nan_and_inf_input_rejected_with_location() { let cfg = MaePretrainConfig::default(); let mut buf = window(90, 54); buf[2 * 54 + 7] = f32::NAN; match patchify(&buf, 90, 54, &cfg).unwrap_err() { MaeError::NonFiniteValue { row, col, .. } => { assert_eq!((row, col), (2, 7)); } other => panic!("expected NonFiniteValue, got {other:?}"), } buf[2 * 54 + 7] = f32::INFINITY; assert!(matches!( patchify(&buf, 90, 54, &cfg), Err(MaeError::NonFiniteValue { .. }) )); } #[test] fn finite_input_is_nan_free_after_round_trip() { let cfg = MaePretrainConfig::default(); let buf = window(90, 54); let grid = patchify(&buf, 90, 54, &cfg).unwrap(); assert!(grid.patches.iter().flatten().all(|v| v.is_finite())); assert!(unpatchify(&grid).iter().all(|v| v.is_finite())); } // --------------------------------------------------------------------------- // Patchify / unpatchify round trip // --------------------------------------------------------------------------- #[test] fn patchify_unpatchify_identity_default_recipe() { let cfg = MaePretrainConfig::default(); let buf = window(90, 54); let grid = patchify(&buf, 90, 54, &cfg).unwrap(); assert_eq!(grid.n_patches(), 54); assert_eq!(grid.patch_len(), 90); assert_eq!(grid.window_shape(), (90, 54)); assert_eq!(unpatchify(&grid), buf); } #[test] fn patch_layout_is_time_major() { // 4 × 4 window, (2, 2) patches → patch 0 is rows 0–1 × cols 0–1. let cfg = MaePretrainConfig { patch_time: 2, patch_subc: 2, ..MaePretrainConfig::default() }; let buf = window(4, 4); let grid = patchify(&buf, 4, 4, &cfg).unwrap(); assert_eq!(grid.patches[0], vec![0.0, 1.0, 1000.0, 1001.0]); // Patch index 1 is the next subcarrier block on the same time rows. assert_eq!(grid.patches[1], vec![2.0, 3.0, 1002.0, 1003.0]); // Patch index n_patches_subc starts the second time row of patches. assert_eq!(grid.patches[2], vec![2000.0, 2001.0, 3000.0, 3001.0]); } #[test] fn unpatchify_visible_restores_visible_and_fills_masked() { let cfg = MaePretrainConfig::default(); let buf = window(90, 54); let (grid, mask) = cfg.mask_window(&buf, 90, 54).unwrap(); let fill = -1.0_f32; let recon = unpatchify_visible(&grid, &mask.visible, fill); // Visible patch regions are identical to the input; masked regions = fill. let full = unpatchify(&grid); assert_eq!(full, buf); let mut n_fill = 0usize; for (i, (&r, &orig)) in recon.iter().zip(buf.iter()).enumerate() { if r == fill && orig != fill { n_fill += 1; } else { assert_eq!(r, orig, "visible value at flat index {i} must round-trip"); } } assert_eq!(n_fill, mask.masked.len() * grid.patch_len()); } // --------------------------------------------------------------------------- // Random mask: exact count, determinism, disjointness // --------------------------------------------------------------------------- #[test] fn mask_count_is_exact_for_default_recipe() { // 54 patches @ 0.80 → round(43.2) = 43 masked, 11 visible. let cfg = MaePretrainConfig::default(); assert_eq!(cfg.num_masked(54), 43); let mask = random_mask(54, cfg.mask_ratio, cfg.seed).unwrap(); assert_eq!(mask.masked.len(), 43); assert_eq!(mask.visible.len(), 11); } #[test] fn same_seed_same_mask_different_seed_differs() { let a = random_mask(100, 0.80, 7).unwrap(); let b = random_mask(100, 0.80, 7).unwrap(); assert_eq!(a, b, "same (n, ratio, seed) must reproduce the mask"); let c = random_mask(100, 0.80, 8).unwrap(); assert_ne!(a.masked, c.masked, "different seeds must differ"); } #[test] fn random_mask_rejects_invalid_ratios() { // Error-not-silent: NaN must not silently mask 0 patches; ratios outside // (0, 1) must not degenerate to all-visible / all-masked grids. for ratio in [ f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 1.0, 1.5, 0.0, -0.1, ] { let err = random_mask(54, ratio, 42).unwrap_err(); assert!( matches!(err, MaeError::InvalidMaskRatio { .. }), "ratio {ratio} must be rejected, got {err:?}" ); } } #[test] fn mask_window_rejects_invalid_ratio_before_masking() { let cfg = MaePretrainConfig { mask_ratio: f64::NAN, ..MaePretrainConfig::default() }; let buf = window(90, 54); assert!(matches!( cfg.mask_window(&buf, 90, 54), Err(MaeError::InvalidMaskRatio { .. }) )); } proptest! { /// Exact count, sortedness, range, disjointness, and full coverage hold /// for arbitrary grid sizes, ratios, and seeds. #[test] fn prop_mask_invariants( n in 1usize..600, ratio in 0.01f64..0.99, seed in any::(), ) { let mask = random_mask(n, ratio, seed).unwrap(); let expected_masked = ((ratio * n as f64).round() as usize).min(n); prop_assert_eq!(mask.masked.len(), expected_masked); prop_assert_eq!(mask.masked.len() + mask.visible.len(), n); // In range, sorted, strictly increasing (no duplicates). for set in [&mask.masked, &mask.visible] { for w in set.windows(2) { prop_assert!(w[0] < w[1]); } if let Some(&last) = set.last() { prop_assert!(last < n); } } // Disjoint + complete: merged sets are exactly 0..n. let mut all: Vec = mask.masked.iter().chain(&mask.visible).copied().collect(); all.sort_unstable(); prop_assert_eq!(all, (0..n).collect::>()); } /// Determinism by seed for arbitrary inputs. #[test] fn prop_mask_deterministic(n in 1usize..400, seed in any::()) { prop_assert_eq!( random_mask(n, 0.80, seed).unwrap(), random_mask(n, 0.80, seed).unwrap() ); } /// Round-trip identity for arbitrary divisible window/patch geometries. #[test] fn prop_patchify_round_trip( pt in 1usize..8, ps in 1usize..8, nt in 1usize..6, ns in 1usize..6, seed in any::(), ) { let (time, subc) = (pt * nt, ps * ns); let cfg = MaePretrainConfig { patch_time: pt, patch_subc: ps, seed, ..MaePretrainConfig::default() }; let buf = window(time, subc); let grid = patchify(&buf, time, subc, &cfg).unwrap(); prop_assert_eq!(grid.n_patches(), nt * ns); prop_assert_eq!(unpatchify(&grid), buf); } }