564 lines
20 KiB
Rust
564 lines
20 KiB
Rust
//! Training configuration for WiFi-DensePose.
|
||
//!
|
||
//! [`TrainingConfig`] is the single source of truth for all hyper-parameters,
|
||
//! dataset shapes, loss weights, and infrastructure settings used throughout
|
||
//! the training pipeline. It is serializable via [`serde`] so it can be stored
|
||
//! to / restored from JSON checkpoint files.
|
||
//!
|
||
//! # Example
|
||
//!
|
||
//! ```rust
|
||
//! use wifi_densepose_train::config::TrainingConfig;
|
||
//!
|
||
//! let cfg = TrainingConfig::default();
|
||
//! cfg.validate().expect("default config is valid");
|
||
//!
|
||
//! assert_eq!(cfg.num_subcarriers, 56);
|
||
//! assert_eq!(cfg.num_keypoints, 17);
|
||
//!
|
||
//! // Adapt for a non-MM-Fi source — e.g. an ESP32 HT40 capture (~192 raw
|
||
//! // subcarriers) or the ADR-078 multi-band mesh (168). The model still sees
|
||
//! // `num_subcarriers`; the loader resamples the native count down to it.
|
||
//! let ht40 = TrainingConfig::ht40_192();
|
||
//! assert_eq!(ht40.native_subcarriers, 192);
|
||
//! assert!(ht40.needs_subcarrier_interp());
|
||
//! let mesh = TrainingConfig::for_subcarriers(168, 56);
|
||
//! assert_eq!(mesh.native_subcarriers, 168);
|
||
//! ```
|
||
|
||
use serde::{Deserialize, Serialize};
|
||
use std::path::{Path, PathBuf};
|
||
|
||
use crate::error::ConfigError;
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// TrainingConfig
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Complete configuration for a WiFi-DensePose training run.
|
||
///
|
||
/// All fields have documented defaults that match the paper's experimental
|
||
/// setup. Use [`TrainingConfig::default()`] as a starting point, then override
|
||
/// individual fields as needed.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct TrainingConfig {
|
||
// -----------------------------------------------------------------------
|
||
// Data / Signal
|
||
// -----------------------------------------------------------------------
|
||
/// Number of subcarriers after interpolation (the *model's* input width).
|
||
///
|
||
/// The model always sees this many subcarriers regardless of the raw
|
||
/// hardware output; [`crate::subcarrier::interpolate_subcarriers`] resamples
|
||
/// `native_subcarriers` → `num_subcarriers` when they differ. Default: **56**.
|
||
pub num_subcarriers: usize,
|
||
|
||
/// Number of subcarriers in the *raw* dataset, before interpolation.
|
||
///
|
||
/// Common sources: MM-Fi = 114, ESP32 HT20 = 56, ESP32 HT40 ≈ 192 (or 114),
|
||
/// multi-band mesh = 168 (ADR-078). When it equals [`Self::num_subcarriers`]
|
||
/// no interpolation happens ([`Self::needs_subcarrier_interp`]). For the
|
||
/// non-MM-Fi shapes prefer the preset constructors
|
||
/// ([`Self::for_subcarriers`], [`Self::ht40_192`], [`Self::multiband_168`])
|
||
/// over overriding both fields by hand. Default: **114**.
|
||
///
|
||
/// **Multi-NIC note:** a 2–3-node CSI mesh currently maps onto the existing
|
||
/// `[T, n_tx, n_rx, n_sc]` layout by treating the nodes' receive chains as
|
||
/// extra `n_rx` (i.e. `num_antennas_rx = nodes × per_node_rx`); a dedicated
|
||
/// node dimension is a separate dataset-loader change.
|
||
pub native_subcarriers: usize,
|
||
|
||
/// Number of transmit antennas. Default: **3**.
|
||
pub num_antennas_tx: usize,
|
||
|
||
/// Number of receive antennas. Default: **3**.
|
||
pub num_antennas_rx: usize,
|
||
|
||
/// Temporal sliding-window length in frames. Default: **100**.
|
||
pub window_frames: usize,
|
||
|
||
/// Side length of the square keypoint heatmap output (H = W). Default: **56**.
|
||
pub heatmap_size: usize,
|
||
|
||
// -----------------------------------------------------------------------
|
||
// Model
|
||
// -----------------------------------------------------------------------
|
||
/// Number of body keypoints (COCO 17-joint skeleton). Default: **17**.
|
||
pub num_keypoints: usize,
|
||
|
||
/// Number of DensePose body-part classes. Default: **24**.
|
||
pub num_body_parts: usize,
|
||
|
||
/// Number of feature-map channels in the backbone encoder. Default: **256**.
|
||
pub backbone_channels: usize,
|
||
|
||
// -----------------------------------------------------------------------
|
||
// Optimisation
|
||
// -----------------------------------------------------------------------
|
||
/// Mini-batch size. Default: **8**.
|
||
pub batch_size: usize,
|
||
|
||
/// Initial learning rate for the Adam / AdamW optimiser. Default: **1e-3**.
|
||
pub learning_rate: f64,
|
||
|
||
/// L2 weight-decay regularisation coefficient. Default: **1e-4**.
|
||
pub weight_decay: f64,
|
||
|
||
/// Total number of training epochs. Default: **50**.
|
||
pub num_epochs: usize,
|
||
|
||
/// Number of linear-warmup epochs at the start. Default: **5**.
|
||
pub warmup_epochs: usize,
|
||
|
||
/// Epochs at which the learning rate is multiplied by `lr_gamma`.
|
||
///
|
||
/// Default: **[30, 45]** (multi-step scheduler).
|
||
pub lr_milestones: Vec<usize>,
|
||
|
||
/// Multiplicative factor applied at each LR milestone. Default: **0.1**.
|
||
pub lr_gamma: f64,
|
||
|
||
/// Maximum gradient L2 norm for gradient clipping. Default: **1.0**.
|
||
pub grad_clip_norm: f64,
|
||
|
||
// -----------------------------------------------------------------------
|
||
// Loss weights
|
||
// -----------------------------------------------------------------------
|
||
/// Weight for the keypoint heatmap loss term. Default: **0.3**.
|
||
pub lambda_kp: f64,
|
||
|
||
/// Weight for the DensePose body-part / UV-coordinate loss. Default: **0.6**.
|
||
pub lambda_dp: f64,
|
||
|
||
/// Weight for the cross-modal transfer / domain-alignment loss. Default: **0.1**.
|
||
pub lambda_tr: f64,
|
||
|
||
// -----------------------------------------------------------------------
|
||
// Validation and checkpointing
|
||
// -----------------------------------------------------------------------
|
||
/// Run validation every N epochs. Default: **1**.
|
||
pub val_every_epochs: usize,
|
||
|
||
/// Stop training if validation loss does not improve for this many
|
||
/// consecutive validation rounds. Default: **10**.
|
||
pub early_stopping_patience: usize,
|
||
|
||
/// Directory where model checkpoints are saved.
|
||
pub checkpoint_dir: PathBuf,
|
||
|
||
/// Directory where TensorBoard / CSV logs are written.
|
||
pub log_dir: PathBuf,
|
||
|
||
/// Keep only the top-K best checkpoints by validation metric. Default: **3**.
|
||
pub save_top_k: usize,
|
||
|
||
// -----------------------------------------------------------------------
|
||
// Device
|
||
// -----------------------------------------------------------------------
|
||
/// Use a CUDA GPU for training when available. Default: **false**.
|
||
pub use_gpu: bool,
|
||
|
||
/// CUDA device index when `use_gpu` is `true`. Default: **0**.
|
||
pub gpu_device_id: i64,
|
||
|
||
/// Number of background data-loading threads. Default: **4**.
|
||
pub num_workers: usize,
|
||
|
||
// -----------------------------------------------------------------------
|
||
// Reproducibility
|
||
// -----------------------------------------------------------------------
|
||
/// Global random seed for all RNG sources in the training pipeline.
|
||
///
|
||
/// This seed is applied to the dataset shuffler, model parameter
|
||
/// initialisation, and any stochastic augmentation. Default: **42**.
|
||
pub seed: u64,
|
||
}
|
||
|
||
impl Default for TrainingConfig {
|
||
fn default() -> Self {
|
||
TrainingConfig {
|
||
// Data
|
||
num_subcarriers: 56,
|
||
native_subcarriers: 114,
|
||
num_antennas_tx: 3,
|
||
num_antennas_rx: 3,
|
||
window_frames: 100,
|
||
heatmap_size: 56,
|
||
// Model
|
||
num_keypoints: 17,
|
||
num_body_parts: 24,
|
||
backbone_channels: 256,
|
||
// Optimisation
|
||
batch_size: 8,
|
||
learning_rate: 1e-3,
|
||
weight_decay: 1e-4,
|
||
num_epochs: 50,
|
||
warmup_epochs: 5,
|
||
lr_milestones: vec![30, 45],
|
||
lr_gamma: 0.1,
|
||
grad_clip_norm: 1.0,
|
||
// Loss weights
|
||
lambda_kp: 0.3,
|
||
lambda_dp: 0.6,
|
||
lambda_tr: 0.1,
|
||
// Validation / checkpointing
|
||
val_every_epochs: 1,
|
||
early_stopping_patience: 10,
|
||
checkpoint_dir: PathBuf::from("checkpoints"),
|
||
log_dir: PathBuf::from("logs"),
|
||
save_top_k: 3,
|
||
// Device
|
||
use_gpu: false,
|
||
gpu_device_id: 0,
|
||
num_workers: 4,
|
||
// Reproducibility
|
||
seed: 42,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl TrainingConfig {
|
||
/// Load a [`TrainingConfig`] from a JSON file at `path`.
|
||
///
|
||
/// # Errors
|
||
///
|
||
/// Returns [`ConfigError::FileRead`] if the file cannot be opened and
|
||
/// [`ConfigError::InvalidValue`] if the JSON is malformed.
|
||
pub fn from_json(path: &Path) -> Result<Self, ConfigError> {
|
||
let contents = std::fs::read_to_string(path).map_err(|source| ConfigError::FileRead {
|
||
path: path.to_path_buf(),
|
||
source,
|
||
})?;
|
||
let cfg: TrainingConfig = serde_json::from_str(&contents)
|
||
.map_err(|e| ConfigError::invalid_value("(file)", e.to_string()))?;
|
||
cfg.validate()?;
|
||
Ok(cfg)
|
||
}
|
||
|
||
/// Serialize this configuration to pretty-printed JSON and write it to
|
||
/// `path`, creating parent directories if necessary.
|
||
///
|
||
/// # Errors
|
||
///
|
||
/// Returns [`ConfigError::FileRead`] if the directory cannot be created or
|
||
/// the file cannot be written.
|
||
pub fn to_json(&self, path: &Path) -> Result<(), ConfigError> {
|
||
if let Some(parent) = path.parent() {
|
||
std::fs::create_dir_all(parent).map_err(|source| ConfigError::FileRead {
|
||
path: parent.to_path_buf(),
|
||
source,
|
||
})?;
|
||
}
|
||
let json = serde_json::to_string_pretty(self)
|
||
.map_err(|e| ConfigError::invalid_value("(serialization)", e.to_string()))?;
|
||
std::fs::write(path, json).map_err(|source| ConfigError::FileRead {
|
||
path: path.to_path_buf(),
|
||
source,
|
||
})?;
|
||
Ok(())
|
||
}
|
||
|
||
/// Build a config for a dataset whose raw CSI has `native` subcarriers,
|
||
/// resampling to `target` (the model's input width) before training.
|
||
///
|
||
/// All other fields take their [`Default`] values. Prefer this over
|
||
/// overriding `native_subcarriers` / `num_subcarriers` directly so the
|
||
/// relationship between the dataset's shape and the model's is explicit.
|
||
#[must_use]
|
||
pub fn for_subcarriers(native: usize, target: usize) -> Self {
|
||
Self {
|
||
native_subcarriers: native,
|
||
num_subcarriers: target,
|
||
..Self::default()
|
||
}
|
||
}
|
||
|
||
/// Preset for the MM-Fi dataset (114 raw subcarriers → 56). Identical to
|
||
/// [`Self::default()`]; provided as a named counterpart to the other
|
||
/// presets.
|
||
#[must_use]
|
||
pub fn mmfi() -> Self {
|
||
Self::default()
|
||
}
|
||
|
||
/// Preset for ESP32 HT40 captures (≈192 raw subcarriers → 56). Use
|
||
/// [`Self::for_subcarriers`] if your capture reports a different native
|
||
/// count (some HT40 firmwares yield 114).
|
||
#[must_use]
|
||
pub fn ht40_192() -> Self {
|
||
Self::for_subcarriers(192, 56)
|
||
}
|
||
|
||
/// Preset for the ADR-078 multi-band mesh (168 raw subcarriers → 56).
|
||
#[must_use]
|
||
pub fn multiband_168() -> Self {
|
||
Self::for_subcarriers(168, 56)
|
||
}
|
||
|
||
/// Returns `true` when the native dataset subcarrier count differs from the
|
||
/// model's target count and interpolation is therefore required.
|
||
pub fn needs_subcarrier_interp(&self) -> bool {
|
||
self.native_subcarriers != self.num_subcarriers
|
||
}
|
||
|
||
/// Validate all fields and return an error describing the first problem
|
||
/// found, or `Ok(())` if the configuration is coherent.
|
||
///
|
||
/// # Validated invariants
|
||
///
|
||
/// - Subcarrier counts must be non-zero.
|
||
/// - Antenna counts must be non-zero.
|
||
/// - `window_frames` must be at least 1.
|
||
/// - `batch_size` must be at least 1.
|
||
/// - `learning_rate` must be strictly positive.
|
||
/// - `weight_decay` must be non-negative.
|
||
/// - Loss weights must be non-negative and sum to a positive value.
|
||
/// - `num_epochs` must be greater than `warmup_epochs`.
|
||
/// - All `lr_milestones` must be within `[1, num_epochs]` and strictly
|
||
/// increasing.
|
||
/// - `save_top_k` must be at least 1.
|
||
/// - `val_every_epochs` must be at least 1.
|
||
pub fn validate(&self) -> Result<(), ConfigError> {
|
||
// Subcarrier counts
|
||
if self.num_subcarriers == 0 {
|
||
return Err(ConfigError::invalid_value("num_subcarriers", "must be > 0"));
|
||
}
|
||
if self.native_subcarriers == 0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"native_subcarriers",
|
||
"must be > 0",
|
||
));
|
||
}
|
||
|
||
// Antenna counts
|
||
if self.num_antennas_tx == 0 {
|
||
return Err(ConfigError::invalid_value("num_antennas_tx", "must be > 0"));
|
||
}
|
||
if self.num_antennas_rx == 0 {
|
||
return Err(ConfigError::invalid_value("num_antennas_rx", "must be > 0"));
|
||
}
|
||
|
||
// Temporal window
|
||
if self.window_frames == 0 {
|
||
return Err(ConfigError::invalid_value("window_frames", "must be > 0"));
|
||
}
|
||
|
||
// Heatmap
|
||
if self.heatmap_size == 0 {
|
||
return Err(ConfigError::invalid_value("heatmap_size", "must be > 0"));
|
||
}
|
||
|
||
// Model dims
|
||
if self.num_keypoints == 0 {
|
||
return Err(ConfigError::invalid_value("num_keypoints", "must be > 0"));
|
||
}
|
||
if self.num_body_parts == 0 {
|
||
return Err(ConfigError::invalid_value("num_body_parts", "must be > 0"));
|
||
}
|
||
if self.backbone_channels == 0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"backbone_channels",
|
||
"must be > 0",
|
||
));
|
||
}
|
||
|
||
// Optimisation
|
||
if self.batch_size == 0 {
|
||
return Err(ConfigError::invalid_value("batch_size", "must be > 0"));
|
||
}
|
||
if self.learning_rate <= 0.0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"learning_rate",
|
||
"must be > 0.0",
|
||
));
|
||
}
|
||
if self.weight_decay < 0.0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"weight_decay",
|
||
"must be >= 0.0",
|
||
));
|
||
}
|
||
if self.grad_clip_norm <= 0.0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"grad_clip_norm",
|
||
"must be > 0.0",
|
||
));
|
||
}
|
||
|
||
// Epochs
|
||
if self.num_epochs == 0 {
|
||
return Err(ConfigError::invalid_value("num_epochs", "must be > 0"));
|
||
}
|
||
if self.warmup_epochs >= self.num_epochs {
|
||
return Err(ConfigError::invalid_value(
|
||
"warmup_epochs",
|
||
"must be < num_epochs",
|
||
));
|
||
}
|
||
|
||
// LR milestones: must be strictly increasing and within bounds
|
||
let mut prev = 0usize;
|
||
for &m in &self.lr_milestones {
|
||
if m == 0 || m > self.num_epochs {
|
||
return Err(ConfigError::invalid_value(
|
||
"lr_milestones",
|
||
"each milestone must be in [1, num_epochs]",
|
||
));
|
||
}
|
||
if m <= prev {
|
||
return Err(ConfigError::invalid_value(
|
||
"lr_milestones",
|
||
"milestones must be strictly increasing",
|
||
));
|
||
}
|
||
prev = m;
|
||
}
|
||
|
||
if self.lr_gamma <= 0.0 || self.lr_gamma >= 1.0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"lr_gamma",
|
||
"must be in (0.0, 1.0)",
|
||
));
|
||
}
|
||
|
||
// Loss weights
|
||
if self.lambda_kp < 0.0 {
|
||
return Err(ConfigError::invalid_value("lambda_kp", "must be >= 0.0"));
|
||
}
|
||
if self.lambda_dp < 0.0 {
|
||
return Err(ConfigError::invalid_value("lambda_dp", "must be >= 0.0"));
|
||
}
|
||
if self.lambda_tr < 0.0 {
|
||
return Err(ConfigError::invalid_value("lambda_tr", "must be >= 0.0"));
|
||
}
|
||
let total_weight = self.lambda_kp + self.lambda_dp + self.lambda_tr;
|
||
if total_weight <= 0.0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"lambda_kp / lambda_dp / lambda_tr",
|
||
"at least one loss weight must be > 0.0",
|
||
));
|
||
}
|
||
|
||
// Validation / checkpoint
|
||
if self.val_every_epochs == 0 {
|
||
return Err(ConfigError::invalid_value(
|
||
"val_every_epochs",
|
||
"must be > 0",
|
||
));
|
||
}
|
||
if self.save_top_k == 0 {
|
||
return Err(ConfigError::invalid_value("save_top_k", "must be > 0"));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use tempfile::tempdir;
|
||
|
||
#[test]
|
||
fn default_config_is_valid() {
|
||
let cfg = TrainingConfig::default();
|
||
cfg.validate().expect("default config should be valid");
|
||
}
|
||
|
||
#[test]
|
||
fn json_round_trip() {
|
||
let tmp = tempdir().unwrap();
|
||
let path = tmp.path().join("config.json");
|
||
|
||
let original = TrainingConfig::default();
|
||
original.to_json(&path).expect("serialization should succeed");
|
||
|
||
let loaded = TrainingConfig::from_json(&path).expect("deserialization should succeed");
|
||
assert_eq!(loaded.num_subcarriers, original.num_subcarriers);
|
||
assert_eq!(loaded.batch_size, original.batch_size);
|
||
assert_eq!(loaded.seed, original.seed);
|
||
assert_eq!(loaded.lr_milestones, original.lr_milestones);
|
||
}
|
||
|
||
#[test]
|
||
fn zero_subcarriers_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.num_subcarriers = 0;
|
||
assert!(cfg.validate().is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn negative_learning_rate_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.learning_rate = -0.001;
|
||
assert!(cfg.validate().is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn warmup_equal_to_epochs_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.warmup_epochs = cfg.num_epochs;
|
||
assert!(cfg.validate().is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn non_increasing_milestones_are_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.lr_milestones = vec![30, 20]; // wrong order
|
||
assert!(cfg.validate().is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn milestone_beyond_epochs_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.lr_milestones = vec![30, cfg.num_epochs + 1];
|
||
assert!(cfg.validate().is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn all_zero_loss_weights_are_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.lambda_kp = 0.0;
|
||
cfg.lambda_dp = 0.0;
|
||
cfg.lambda_tr = 0.0;
|
||
assert!(cfg.validate().is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn needs_subcarrier_interp_when_counts_differ() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.num_subcarriers = 56;
|
||
cfg.native_subcarriers = 114;
|
||
assert!(cfg.needs_subcarrier_interp());
|
||
|
||
cfg.native_subcarriers = 56;
|
||
assert!(!cfg.needs_subcarrier_interp());
|
||
}
|
||
|
||
#[test]
|
||
fn config_fields_have_expected_defaults() {
|
||
let cfg = TrainingConfig::default();
|
||
assert_eq!(cfg.num_subcarriers, 56);
|
||
assert_eq!(cfg.native_subcarriers, 114);
|
||
assert_eq!(cfg.num_antennas_tx, 3);
|
||
assert_eq!(cfg.num_antennas_rx, 3);
|
||
assert_eq!(cfg.window_frames, 100);
|
||
assert_eq!(cfg.heatmap_size, 56);
|
||
assert_eq!(cfg.num_keypoints, 17);
|
||
assert_eq!(cfg.num_body_parts, 24);
|
||
assert_eq!(cfg.batch_size, 8);
|
||
assert!((cfg.learning_rate - 1e-3).abs() < 1e-10);
|
||
assert_eq!(cfg.num_epochs, 50);
|
||
assert_eq!(cfg.warmup_epochs, 5);
|
||
assert_eq!(cfg.lr_milestones, vec![30, 45]);
|
||
assert!((cfg.lr_gamma - 0.1).abs() < 1e-10);
|
||
assert!((cfg.lambda_kp - 0.3).abs() < 1e-10);
|
||
assert!((cfg.lambda_dp - 0.6).abs() < 1e-10);
|
||
assert!((cfg.lambda_tr - 0.1).abs() < 1e-10);
|
||
assert_eq!(cfg.seed, 42);
|
||
}
|
||
}
|