wifi-densepose/v2/crates/wifi-densepose-sensing-server/src/training_api.rs

1958 lines
66 KiB
Rust

//! Training API with WebSocket progress streaming.
//!
//! Provides REST endpoints for starting, stopping, and monitoring training runs.
//! Training runs in a background tokio task. Progress updates are broadcast via
//! a `tokio::sync::broadcast` channel that the WebSocket handler subscribes to.
//!
//! Uses a **real training pipeline** that loads recorded CSI data from `.csi.jsonl`
//! files, extracts signal features (subcarrier variance, temporal gradients, Goertzel
//! frequency-domain power), trains a regularised linear model via batch gradient
//! descent, and exports calibrated `.rvf` model containers.
//!
//! No PyTorch / `tch` dependency is required. All linear algebra is implemented
//! inline using standard Rust math.
//!
//! On completion, the best model is automatically exported as `.rvf` using `RvfBuilder`.
//!
//! REST endpoints:
//! - `POST /api/v1/train/start` -- start a training run
//! - `POST /api/v1/train/stop` -- stop the active training
//! - `GET /api/v1/train/status` -- get current training status
//! - `POST /api/v1/train/pretrain` -- start contrastive pretraining
//! - `POST /api/v1/train/lora` -- start LoRA fine-tuning
//!
//! WebSocket:
//! - `WS /ws/train/progress` -- streaming training progress
use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::Arc;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::{IntoResponse, Json},
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, RwLock};
use tracing::{error, info, warn};
use crate::recording::{RecordedFrame, RECORDINGS_DIR};
use crate::rvf_container::RvfBuilder;
// ── Constants ────────────────────────────────────────────────────────────────
/// Directory for trained model output.
pub const MODELS_DIR: &str = "data/models";
/// Number of COCO keypoints.
const N_KEYPOINTS: usize = 17;
/// Dimensions per keypoint in the target vector (x, y, z).
const DIMS_PER_KP: usize = 3;
/// Total target dimensionality: 17 * 3 = 51.
const N_TARGETS: usize = N_KEYPOINTS * DIMS_PER_KP;
/// Default number of subcarriers when data is unavailable.
const DEFAULT_N_SUB: usize = 56;
/// Sliding window size for computing per-subcarrier variance.
const VARIANCE_WINDOW: usize = 10;
/// Number of Goertzel frequency bands to probe.
const N_FREQ_BANDS: usize = 9;
/// Number of global scalar features (mean amplitude, std, motion score).
const N_GLOBAL_FEATURES: usize = 3;
// ── Types ────────────────────────────────────────────────────────────────────
/// Training configuration submitted with a start request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
#[serde(default = "default_epochs")]
pub epochs: u32,
#[serde(default = "default_batch_size")]
pub batch_size: u32,
#[serde(default = "default_learning_rate")]
pub learning_rate: f64,
#[serde(default = "default_weight_decay")]
pub weight_decay: f64,
#[serde(default = "default_early_stopping_patience")]
pub early_stopping_patience: u32,
#[serde(default = "default_warmup_epochs")]
pub warmup_epochs: u32,
/// Path to a pretrained RVF model to fine-tune from.
pub pretrained_rvf: Option<String>,
/// LoRA profile name for environment-specific fine-tuning.
pub lora_profile: Option<String>,
}
fn default_epochs() -> u32 { 100 }
fn default_batch_size() -> u32 { 8 }
fn default_learning_rate() -> f64 { 0.001 }
fn default_weight_decay() -> f64 { 1e-4 }
fn default_early_stopping_patience() -> u32 { 20 }
fn default_warmup_epochs() -> u32 { 5 }
impl Default for TrainingConfig {
fn default() -> Self {
Self {
epochs: default_epochs(),
batch_size: default_batch_size(),
learning_rate: default_learning_rate(),
weight_decay: default_weight_decay(),
early_stopping_patience: default_early_stopping_patience(),
warmup_epochs: default_warmup_epochs(),
pretrained_rvf: None,
lora_profile: None,
}
}
}
/// Request body for `POST /api/v1/train/start`.
#[derive(Debug, Deserialize)]
pub struct StartTrainingRequest {
pub dataset_ids: Vec<String>,
pub config: TrainingConfig,
}
/// Request body for `POST /api/v1/train/pretrain`.
#[derive(Debug, Deserialize)]
pub struct PretrainRequest {
pub dataset_ids: Vec<String>,
#[serde(default = "default_pretrain_epochs")]
pub epochs: u32,
#[serde(default = "default_learning_rate")]
pub lr: f64,
}
fn default_pretrain_epochs() -> u32 { 50 }
/// Request body for `POST /api/v1/train/lora`.
#[derive(Debug, Deserialize)]
pub struct LoraTrainRequest {
pub base_model_id: String,
pub dataset_ids: Vec<String>,
pub profile_name: String,
#[serde(default = "default_lora_rank")]
pub rank: u8,
#[serde(default = "default_lora_epochs")]
pub epochs: u32,
}
fn default_lora_rank() -> u8 { 8 }
fn default_lora_epochs() -> u32 { 30 }
/// Current training status (returned by `GET /api/v1/train/status`).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingStatus {
pub active: bool,
pub epoch: u32,
pub total_epochs: u32,
pub train_loss: f64,
pub val_pck: f64,
pub val_oks: f64,
pub lr: f64,
pub best_pck: f64,
pub best_epoch: u32,
pub patience_remaining: u32,
pub eta_secs: Option<u64>,
pub phase: String,
}
impl Default for TrainingStatus {
fn default() -> Self {
Self {
active: false,
epoch: 0,
total_epochs: 0,
train_loss: 0.0,
val_pck: 0.0,
val_oks: 0.0,
lr: 0.0,
best_pck: 0.0,
best_epoch: 0,
patience_remaining: 0,
eta_secs: None,
phase: "idle".to_string(),
}
}
}
/// Progress update sent over WebSocket.
#[derive(Debug, Clone, Serialize)]
pub struct TrainingProgress {
pub epoch: u32,
pub batch: u32,
pub total_batches: u32,
pub train_loss: f64,
pub val_pck: f64,
pub val_oks: f64,
pub lr: f64,
pub phase: String,
}
/// Runtime training state stored in `AppStateInner`.
pub struct TrainingState {
/// Current status snapshot.
pub status: TrainingStatus,
/// Handle to the background training task (for cancellation).
pub task_handle: Option<tokio::task::JoinHandle<()>>,
}
impl Default for TrainingState {
fn default() -> Self {
Self {
status: TrainingStatus::default(),
task_handle: None,
}
}
}
/// Shared application state type.
pub type AppState = Arc<RwLock<super::AppStateInner>>;
/// Feature normalization statistics computed from the training set.
/// Stored alongside the model weights inside the .rvf container so that
/// inference can apply the same normalization.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureStats {
/// Per-feature mean (length = n_features).
pub mean: Vec<f64>,
/// Per-feature standard deviation (length = n_features).
pub std: Vec<f64>,
/// Number of features.
pub n_features: usize,
/// Number of raw subcarriers used.
pub n_subcarriers: usize,
}
// ── Data loading ─────────────────────────────────────────────────────────────
/// Load CSI frames from `.csi.jsonl` recording files for the given dataset IDs.
///
/// Each dataset_id maps to a file at `data/recordings/{dataset_id}.csi.jsonl`.
/// If a file does not exist, it is silently skipped.
async fn load_recording_frames(dataset_ids: &[String]) -> Vec<RecordedFrame> {
let mut all_frames = Vec::new();
let recordings_dir = PathBuf::from(RECORDINGS_DIR);
for id in dataset_ids {
// Path-traversal guard (#615). Reject any dataset_id that contains
// '/', '..', null bytes, or anything outside [A-Za-z0-9._-] BEFORE
// building the format!() path. Otherwise an attacker could read any
// file the server process can access via `dataset_ids: ["../../etc/passwd"]`.
let safe = match crate::path_safety::safe_id(id) {
Ok(s) => s,
Err(e) => {
warn!("Skipping invalid dataset_id {id:?}: {e}");
continue;
}
};
let file_path = recordings_dir.join(format!("{safe}.csi.jsonl"));
let data = match tokio::fs::read_to_string(&file_path).await {
Ok(d) => d,
Err(e) => {
warn!("Could not read recording {}: {e}", file_path.display());
continue;
}
};
let mut line_count = 0u64;
let mut parse_errors = 0u64;
for line in data.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
line_count += 1;
match serde_json::from_str::<RecordedFrame>(line) {
Ok(frame) => all_frames.push(frame),
Err(_) => parse_errors += 1,
}
}
info!(
"Loaded recording {id}: {line_count} lines, {} frames, {parse_errors} parse errors",
all_frames.len()
);
}
all_frames
}
/// Attempt to collect frames from the live frame_history buffer in AppState.
/// Each `Vec<f64>` in frame_history is a subcarrier amplitude vector.
async fn load_frames_from_history(state: &AppState) -> Vec<RecordedFrame> {
let s = state.read().await;
let history: &VecDeque<Vec<f64>> = &s.frame_history;
history
.iter()
.enumerate()
.map(|(i, amplitudes)| RecordedFrame {
timestamp: i as f64 * 0.1, // approximate 10 fps
subcarriers: amplitudes.clone(),
rssi: -50.0,
noise_floor: -90.0,
features: serde_json::json!({}),
})
.collect()
}
// ── Feature extraction ───────────────────────────────────────────────────────
/// Compute the total number of features that `extract_features_for_frame` produces
/// for a given subcarrier count.
fn feature_dim(n_sub: usize) -> usize {
// subcarrier amplitudes + subcarrier variances + temporal gradients
// + Goertzel freq bands + global scalars
n_sub + n_sub + n_sub + N_FREQ_BANDS + N_GLOBAL_FEATURES
}
/// Goertzel algorithm: compute the power at a specific normalised frequency
/// from a signal buffer. `freq_norm` = target_freq_hz / sample_rate_hz.
fn goertzel_power(signal: &[f64], freq_norm: f64) -> f64 {
let n = signal.len();
if n == 0 {
return 0.0;
}
let coeff = 2.0 * (2.0 * std::f64::consts::PI * freq_norm).cos();
let mut s0 = 0.0f64;
let mut s1 = 0.0f64;
let mut s2;
for &x in signal {
s2 = s1;
s1 = s0;
s0 = x + coeff * s1 - s2;
}
let power = s0 * s0 + s1 * s1 - coeff * s0 * s1;
(power / (n as f64)).max(0.0)
}
/// Extract feature vector for a single frame, given the sliding window context
/// of recent frames.
///
/// Returns a vector of length `feature_dim(n_sub)`.
fn extract_features_for_frame(
frame: &RecordedFrame,
window: &[&RecordedFrame],
prev_frame: Option<&RecordedFrame>,
sample_rate_hz: f64,
) -> Vec<f64> {
let n_sub = frame.subcarriers.len().max(1);
let mut features = Vec::with_capacity(feature_dim(n_sub));
// 1. Raw subcarrier amplitudes (n_sub features).
features.extend_from_slice(&frame.subcarriers);
// Pad if shorter than expected.
while features.len() < n_sub {
features.push(0.0);
}
// 2. Per-subcarrier variance over the sliding window (n_sub features).
for k in 0..n_sub {
if window.is_empty() {
features.push(0.0);
continue;
}
let n = window.len() as f64;
let mut sum = 0.0f64;
let mut sq_sum = 0.0f64;
for w in window {
let a = if k < w.subcarriers.len() { w.subcarriers[k] } else { 0.0 };
sum += a;
sq_sum += a * a;
}
let mean = sum / n;
let var = (sq_sum / n - mean * mean).max(0.0);
features.push(var);
}
// 3. Temporal gradient vs previous frame (n_sub features).
for k in 0..n_sub {
let grad = match prev_frame {
Some(prev) => {
let cur = if k < frame.subcarriers.len() { frame.subcarriers[k] } else { 0.0 };
let prv = if k < prev.subcarriers.len() { prev.subcarriers[k] } else { 0.0 };
(cur - prv).abs()
}
None => 0.0,
};
features.push(grad);
}
// 4. Goertzel power at key frequency bands (N_FREQ_BANDS features).
// Bands: 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 1.0, 2.0, 3.0 Hz.
let freq_bands = [0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 1.0, 2.0, 3.0];
// Build a mean-amplitude time series from the window.
let ts: Vec<f64> = window
.iter()
.map(|w| {
let n = w.subcarriers.len().max(1) as f64;
w.subcarriers.iter().sum::<f64>() / n
})
.collect();
for &freq_hz in &freq_bands {
let freq_norm = if sample_rate_hz > 0.0 {
freq_hz / sample_rate_hz
} else {
0.0
};
features.push(goertzel_power(&ts, freq_norm));
}
// 5. Global scalar features (N_GLOBAL_FEATURES = 3).
let mean_amp = if frame.subcarriers.is_empty() {
0.0
} else {
frame.subcarriers.iter().sum::<f64>() / frame.subcarriers.len() as f64
};
let std_amp = if frame.subcarriers.len() > 1 {
let var = frame
.subcarriers
.iter()
.map(|a| (a - mean_amp).powi(2))
.sum::<f64>()
/ (frame.subcarriers.len() - 1) as f64;
var.sqrt()
} else {
0.0
};
// Motion score: L2 change from previous frame, normalised.
let motion_score = match prev_frame {
Some(prev) => {
let n_cmp = n_sub.min(prev.subcarriers.len());
if n_cmp > 0 {
let diff: f64 = (0..n_cmp)
.map(|k| {
let c = if k < frame.subcarriers.len() { frame.subcarriers[k] } else { 0.0 };
let p = if k < prev.subcarriers.len() { prev.subcarriers[k] } else { 0.0 };
(c - p).powi(2)
})
.sum::<f64>()
/ n_cmp as f64;
(diff / (mean_amp * mean_amp + 1e-9)).sqrt().clamp(0.0, 1.0)
} else {
0.0
}
}
None => 0.0,
};
features.push(mean_amp);
features.push(std_amp);
features.push(motion_score);
features
}
/// Compute teacher pose targets from a `RecordedFrame` using signal heuristics,
/// analogous to `derive_pose_from_sensing` in main.rs.
///
/// Returns a flat vector of length `N_TARGETS` (17 keypoints * 3 coordinates).
fn compute_teacher_targets(frame: &RecordedFrame, prev_frame: Option<&RecordedFrame>) -> Vec<f64> {
let n_sub = frame.subcarriers.len().max(1);
let mean_amp: f64 = frame.subcarriers.iter().sum::<f64>() / n_sub as f64;
// Intra-frame variance.
let variance: f64 = frame
.subcarriers
.iter()
.map(|a| (a - mean_amp).powi(2))
.sum::<f64>()
/ n_sub as f64;
// Motion band power (upper half of subcarriers).
let half = n_sub / 2;
let motion_band_power = if half > 0 {
frame.subcarriers[half..]
.iter()
.map(|a| (a - mean_amp).powi(2))
.sum::<f64>()
/ (n_sub - half) as f64
} else {
0.0
};
// Breathing band power (lower half).
let breathing_band_power = if half > 0 {
frame.subcarriers[..half]
.iter()
.map(|a| (a - mean_amp).powi(2))
.sum::<f64>()
/ half as f64
} else {
0.0
};
// Motion score.
let motion_score = match prev_frame {
Some(prev) => {
let n_cmp = n_sub.min(prev.subcarriers.len());
if n_cmp > 0 {
let diff: f64 = (0..n_cmp)
.map(|k| {
let c = if k < frame.subcarriers.len() { frame.subcarriers[k] } else { 0.0 };
let p = if k < prev.subcarriers.len() { prev.subcarriers[k] } else { 0.0 };
(c - p).powi(2)
})
.sum::<f64>()
/ n_cmp as f64;
(diff / (mean_amp * mean_amp + 1e-9)).sqrt().clamp(0.0, 1.0)
} else {
0.0
}
}
None => (variance / (mean_amp * mean_amp + 1e-9)).sqrt().clamp(0.0, 1.0),
};
let is_walking = motion_score > 0.55;
let breath_amp = (breathing_band_power * 4.0).clamp(0.0, 12.0);
let breath_phase = (frame.timestamp * 0.25 * std::f64::consts::TAU).sin();
// Dominant freq proxy.
let peak_idx = frame
.subcarriers
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
let dominant_freq_hz = peak_idx as f64 * 0.05;
let lean_x = (dominant_freq_hz / 5.0 - 1.0).clamp(-1.0, 1.0) * 18.0;
// Change points.
let threshold = mean_amp * 1.2;
let change_points = frame
.subcarriers
.windows(2)
.filter(|w| (w[0] < threshold) != (w[1] < threshold))
.count();
let burst = (change_points as f64 / 8.0).clamp(0.0, 1.0);
let noise_seed = variance * 31.7 + frame.timestamp * 17.3;
let noise_val = (noise_seed.sin() * 43758.545).fract();
// Stride.
let stride_x = if is_walking {
let stride_phase = (motion_band_power * 0.7 + frame.timestamp * 1.2).sin();
stride_phase * 45.0 * motion_score
} else {
0.0
};
let snr_factor = ((variance - 0.5) / 10.0).clamp(0.0, 1.0);
let base_confidence = (0.6 + 0.4 * snr_factor).clamp(0.0, 1.0);
let _ = base_confidence; // used for confidence output, not target coords
let _ = noise_val;
// Base position on a 640x480 canvas.
let base_x = 320.0 + stride_x + lean_x * 0.5;
let base_y = 240.0 - motion_score * 8.0;
// COCO 17-keypoint offsets from hip center.
let kp_offsets: [(f64, f64); 17] = [
( 0.0, -80.0), // 0 nose
( -8.0, -88.0), // 1 left_eye
( 8.0, -88.0), // 2 right_eye
(-16.0, -82.0), // 3 left_ear
( 16.0, -82.0), // 4 right_ear
(-30.0, -50.0), // 5 left_shoulder
( 30.0, -50.0), // 6 right_shoulder
(-45.0, -15.0), // 7 left_elbow
( 45.0, -15.0), // 8 right_elbow
(-50.0, 20.0), // 9 left_wrist
( 50.0, 20.0), // 10 right_wrist
(-20.0, 20.0), // 11 left_hip
( 20.0, 20.0), // 12 right_hip
(-22.0, 70.0), // 13 left_knee
( 22.0, 70.0), // 14 right_knee
(-24.0, 120.0), // 15 left_ankle
( 24.0, 120.0), // 16 right_ankle
];
const TORSO_KP: [usize; 4] = [5, 6, 11, 12];
const EXTREMITY_KP: [usize; 4] = [9, 10, 15, 16];
let mut targets = Vec::with_capacity(N_TARGETS);
for (i, &(dx, dy)) in kp_offsets.iter().enumerate() {
let breath_dx = if TORSO_KP.contains(&i) {
let sign = if dx < 0.0 { -1.0 } else { 1.0 };
sign * breath_amp * breath_phase * 0.5
} else {
0.0
};
let breath_dy = if TORSO_KP.contains(&i) {
let sign = if dy < 0.0 { -1.0 } else { 1.0 };
sign * breath_amp * breath_phase * 0.3
} else {
0.0
};
let extremity_jitter = if EXTREMITY_KP.contains(&i) {
let phase = noise_seed + i as f64 * 2.399;
(
phase.sin() * burst * motion_score * 12.0,
(phase * 1.31).cos() * burst * motion_score * 8.0,
)
} else {
(0.0, 0.0)
};
let kp_noise_x = ((noise_seed + i as f64 * 1.618).sin() * 43758.545).fract()
* variance.sqrt().clamp(0.0, 3.0)
* motion_score;
let kp_noise_y = ((noise_seed + i as f64 * 2.718).cos() * 31415.926).fract()
* variance.sqrt().clamp(0.0, 3.0)
* motion_score
* 0.6;
let swing_dy = if is_walking {
let stride_phase = (motion_band_power * 0.7 + frame.timestamp * 1.2).sin();
match i {
7 | 9 => -stride_phase * 20.0 * motion_score,
8 | 10 => stride_phase * 20.0 * motion_score,
13 | 15 => stride_phase * 25.0 * motion_score,
14 | 16 => -stride_phase * 25.0 * motion_score,
_ => 0.0,
}
} else {
0.0
};
let x = base_x + dx + breath_dx + extremity_jitter.0 + kp_noise_x;
let y = base_y + dy + breath_dy + extremity_jitter.1 + kp_noise_y + swing_dy;
let z = 0.0; // depth placeholder
targets.push(x);
targets.push(y);
targets.push(z);
}
targets
}
/// Build the feature matrix and target matrix from a set of recorded frames.
///
/// Returns `(feature_matrix, target_matrix, feature_stats)` where:
/// - `feature_matrix[i]` is the feature vector for frame `i`
/// - `target_matrix[i]` is the teacher target vector for frame `i`
/// - `feature_stats` contains per-feature mean/std for normalization
fn extract_features_and_targets(
frames: &[RecordedFrame],
sample_rate_hz: f64,
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, FeatureStats) {
let n_sub = frames
.first()
.map(|f| f.subcarriers.len())
.unwrap_or(DEFAULT_N_SUB)
.max(1);
let n_feat = feature_dim(n_sub);
let mut feature_matrix: Vec<Vec<f64>> = Vec::with_capacity(frames.len());
let mut target_matrix: Vec<Vec<f64>> = Vec::with_capacity(frames.len());
for (i, frame) in frames.iter().enumerate() {
// Build sliding window of up to VARIANCE_WINDOW preceding frames.
let start = if i >= VARIANCE_WINDOW { i - VARIANCE_WINDOW } else { 0 };
let window: Vec<&RecordedFrame> = frames[start..i].iter().collect();
let prev = if i > 0 { Some(&frames[i - 1]) } else { None };
let feats = extract_features_for_frame(frame, &window, prev, sample_rate_hz);
let targets = compute_teacher_targets(frame, prev);
feature_matrix.push(feats);
target_matrix.push(targets);
}
// Compute feature statistics for normalization.
let mut mean = vec![0.0f64; n_feat];
let mut sq_mean = vec![0.0f64; n_feat];
let n = feature_matrix.len() as f64;
if n > 0.0 {
for row in &feature_matrix {
for (j, &val) in row.iter().enumerate() {
if j < n_feat {
mean[j] += val;
sq_mean[j] += val * val;
}
}
}
for j in 0..n_feat {
mean[j] /= n;
sq_mean[j] /= n;
}
}
let std_dev: Vec<f64> = (0..n_feat)
.map(|j| {
let var = (sq_mean[j] - mean[j] * mean[j]).max(0.0);
let s = var.sqrt();
if s < 1e-9 { 1.0 } else { s } // avoid division by zero
})
.collect();
// Normalize feature matrix in place.
for row in &mut feature_matrix {
for (j, val) in row.iter_mut().enumerate() {
if j < n_feat {
*val = (*val - mean[j]) / std_dev[j];
}
}
}
let stats = FeatureStats {
mean,
std: std_dev,
n_features: n_feat,
n_subcarriers: n_sub,
};
(feature_matrix, target_matrix, stats)
}
// ── Linear algebra helpers (no external deps) ────────────────────────────────
/// Compute mean squared error between predicted and target matrices.
fn compute_mse(predictions: &[Vec<f64>], targets: &[Vec<f64>]) -> f64 {
if predictions.is_empty() {
return 0.0;
}
let n = predictions.len() as f64;
let total: f64 = predictions
.iter()
.zip(targets.iter())
.map(|(pred, tgt)| {
pred.iter()
.zip(tgt.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f64>()
})
.sum();
total / (n * predictions[0].len().max(1) as f64)
}
/// Compute PCK@0.2 (Percentage of Correct Keypoints at threshold 0.2 of torso height).
///
/// Torso height is estimated as the distance between nose (kp 0) and the midpoint
/// of the two hips (kps 11, 12).
fn compute_pck(predictions: &[Vec<f64>], targets: &[Vec<f64>], threshold_ratio: f64) -> f64 {
if predictions.is_empty() {
return 0.0;
}
let mut correct = 0u64;
let mut total = 0u64;
for (pred, tgt) in predictions.iter().zip(targets.iter()) {
// Compute torso height from target.
// nose = kp 0 (indices 0,1,2), left_hip = kp 11 (33,34,35), right_hip = kp 12 (36,37,38)
let torso_h = if tgt.len() >= N_TARGETS {
let nose_y = tgt[1];
let hip_y = (tgt[11 * 3 + 1] + tgt[12 * 3 + 1]) / 2.0;
(hip_y - nose_y).abs().max(50.0) // minimum 50px torso height
} else {
100.0
};
let thresh = torso_h * threshold_ratio;
for k in 0..N_KEYPOINTS {
let px = pred.get(k * 3).copied().unwrap_or(0.0);
let py = pred.get(k * 3 + 1).copied().unwrap_or(0.0);
let tx = tgt.get(k * 3).copied().unwrap_or(0.0);
let ty = tgt.get(k * 3 + 1).copied().unwrap_or(0.0);
let dist = ((px - tx).powi(2) + (py - ty).powi(2)).sqrt();
if dist < thresh {
correct += 1;
}
total += 1;
}
}
if total == 0 {
0.0
} else {
correct as f64 / total as f64
}
}
/// Forward pass: compute predictions = X @ W^T + bias for all samples.
///
/// `weights` is stored row-major: shape [n_targets, n_features].
/// `bias` has shape [n_targets].
fn forward(
features: &[Vec<f64>],
weights: &[f64],
bias: &[f64],
n_features: usize,
n_targets: usize,
) -> Vec<Vec<f64>> {
features
.iter()
.map(|x| {
(0..n_targets)
.map(|t| {
let mut sum = bias.get(t).copied().unwrap_or(0.0);
let row_start = t * n_features;
for j in 0..n_features {
let xj = x.get(j).copied().unwrap_or(0.0);
let wj = weights.get(row_start + j).copied().unwrap_or(0.0);
sum += wj * xj;
}
sum
})
.collect()
})
.collect()
}
/// Simple deterministic shuffle using a seed-based index permutation.
/// Uses a linear congruential generator for reproducibility without `rand`.
fn deterministic_shuffle(n: usize, seed: u64) -> Vec<usize> {
let mut indices: Vec<usize> = (0..n).collect();
if n <= 1 {
return indices;
}
// Fisher-Yates with LCG.
let mut rng = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
for i in (1..n).rev() {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let j = (rng >> 33) as usize % (i + 1);
indices.swap(i, j);
}
indices
}
// ── Real training loop ───────────────────────────────────────────────────────
/// Real training loop that trains a linear CSI-to-pose model using recorded data.
///
/// Loads CSI frames from `.csi.jsonl` recording files, extracts signal features
/// (subcarrier amplitudes, variance, temporal gradients, Goertzel frequency power),
/// computes teacher pose targets using signal heuristics, and trains a regularised
/// linear model via mini-batch gradient descent.
///
/// On completion, exports a `.rvf` container with real calibrated weights.
async fn real_training_loop(
state: AppState,
progress_tx: broadcast::Sender<String>,
config: TrainingConfig,
dataset_ids: Vec<String>,
training_type: &str,
) {
let total_epochs = config.epochs;
let patience = config.early_stopping_patience;
let mut best_pck = 0.0f64;
let mut best_epoch = 0u32;
let mut patience_remaining = patience;
let sample_rate_hz = 10.0; // default 10 fps
info!(
"Real {training_type} training started: {total_epochs} epochs, lr={}, lambda={}",
config.learning_rate, config.weight_decay
);
// ── Phase 1: Load data ───────────────────────────────────────────────────
{
let progress = TrainingProgress {
epoch: 0, batch: 0, total_batches: 0,
train_loss: 0.0, val_pck: 0.0, val_oks: 0.0, lr: 0.0,
phase: "loading_data".to_string(),
};
if let Ok(json) = serde_json::to_string(&progress) {
let _ = progress_tx.send(json);
}
}
let mut frames = load_recording_frames(&dataset_ids).await;
if frames.is_empty() {
info!("No recordings found for dataset_ids; falling back to live frame_history");
frames = load_frames_from_history(&state).await;
}
if frames.len() < 10 {
warn!(
"Insufficient training data: only {} frames (minimum 10 required). Aborting.",
frames.len()
);
let fail = TrainingProgress {
epoch: 0, batch: 0, total_batches: 0,
train_loss: 0.0, val_pck: 0.0, val_oks: 0.0, lr: 0.0,
phase: "failed_insufficient_data".to_string(),
};
if let Ok(json) = serde_json::to_string(&fail) {
let _ = progress_tx.send(json);
}
let mut s = state.write().await;
s.training_state.status.active = false;
s.training_state.status.phase = "failed".to_string();
s.training_state.task_handle = None;
return;
}
info!("Loaded {} frames for training", frames.len());
// ── Phase 2: Extract features and targets ────────────────────────────────
{
let progress = TrainingProgress {
epoch: 0, batch: 0, total_batches: 0,
train_loss: 0.0, val_pck: 0.0, val_oks: 0.0, lr: 0.0,
phase: "extracting_features".to_string(),
};
if let Ok(json) = serde_json::to_string(&progress) {
let _ = progress_tx.send(json);
}
}
// Yield to avoid blocking the event loop during feature extraction.
tokio::task::yield_now().await;
let (feature_matrix, target_matrix, feature_stats) =
extract_features_and_targets(&frames, sample_rate_hz);
let n_feat = feature_stats.n_features;
let n_samples = feature_matrix.len();
info!(
"Features extracted: {} samples, {} features/sample, {} targets/sample",
n_samples, n_feat, N_TARGETS
);
// ── Phase 3: Train/val split (80/20) ─────────────────────────────────────
let split_idx = (n_samples * 4) / 5;
let (train_x, val_x) = feature_matrix.split_at(split_idx);
let (train_y, val_y) = target_matrix.split_at(split_idx);
let n_train = train_x.len();
let n_val = val_x.len();
info!("Train/val split: {n_train} train, {n_val} val");
// ── Phase 4: Initialize weights ──────────────────────────────────────────
// Weights: [N_TARGETS, n_feat] stored row-major.
let n_weights = N_TARGETS * n_feat;
let mut weights = vec![0.0f64; n_weights];
let mut bias = vec![0.0f64; N_TARGETS];
// Xavier initialization: scale = sqrt(2 / (n_in + n_out)).
let xavier_scale = (2.0 / (n_feat as f64 + N_TARGETS as f64)).sqrt();
// Deterministic pseudo-random initialization.
for i in 0..n_weights {
let seed = i as f64 * 1.618033988749895 + 0.5;
weights[i] = (seed.fract() * 2.0 - 1.0) * xavier_scale;
}
// Best weights snapshot for early stopping.
let mut best_weights = weights.clone();
let mut best_bias = bias.clone();
let mut best_val_loss = f64::MAX;
let batch_size = config.batch_size.max(1) as usize;
let total_batches = ((n_train + batch_size - 1) / batch_size) as u32;
// Epoch timing for ETA.
let training_start = std::time::Instant::now();
// ── Phase 5: Training loop ───────────────────────────────────────────────
for epoch in 1..=total_epochs {
// Check cancellation.
{
let s = state.read().await;
if !s.training_state.status.active {
info!("Training cancelled at epoch {epoch}");
break;
}
}
let phase = if epoch <= config.warmup_epochs {
"warmup"
} else {
"training"
};
// Learning rate schedule: linear warmup then cosine decay.
let lr = if epoch <= config.warmup_epochs {
config.learning_rate * (epoch as f64 / config.warmup_epochs.max(1) as f64)
} else {
let progress_ratio = (epoch - config.warmup_epochs) as f64
/ (total_epochs - config.warmup_epochs).max(1) as f64;
config.learning_rate * (1.0 + (std::f64::consts::PI * progress_ratio).cos()) / 2.0
};
let lambda = config.weight_decay;
// Deterministic shuffle of training indices.
let indices = deterministic_shuffle(n_train, epoch as u64);
let mut epoch_loss = 0.0f64;
let mut epoch_batches = 0u32;
for batch_start_idx in (0..n_train).step_by(batch_size) {
let batch_end = (batch_start_idx + batch_size).min(n_train);
let actual_batch_size = batch_end - batch_start_idx;
if actual_batch_size == 0 {
continue;
}
// Gather batch.
let batch_x: Vec<&Vec<f64>> = indices[batch_start_idx..batch_end]
.iter()
.map(|&idx| &train_x[idx])
.collect();
let batch_y: Vec<&Vec<f64>> = indices[batch_start_idx..batch_end]
.iter()
.map(|&idx| &train_y[idx])
.collect();
// Forward pass.
let bs = actual_batch_size as f64;
// Compute gradients: dW = (1/bs) * sum_i (pred_i - y_i) x_i^T + lambda * W
// db = (1/bs) * sum_i (pred_i - y_i)
let mut grad_w = vec![0.0f64; n_weights];
let mut grad_b = vec![0.0f64; N_TARGETS];
let mut batch_loss = 0.0f64;
for (x, y) in batch_x.iter().zip(batch_y.iter()) {
// Compute prediction for this sample.
for t in 0..N_TARGETS {
let row_start = t * n_feat;
let mut pred = bias[t];
for j in 0..n_feat {
let xj = x.get(j).copied().unwrap_or(0.0);
pred += weights[row_start + j] * xj;
}
let tgt = y.get(t).copied().unwrap_or(0.0);
let error = pred - tgt;
batch_loss += error * error;
// Accumulate gradients.
grad_b[t] += error;
for j in 0..n_feat {
let xj = x.get(j).copied().unwrap_or(0.0);
grad_w[row_start + j] += error * xj;
}
}
}
batch_loss /= bs * N_TARGETS as f64;
epoch_loss += batch_loss;
epoch_batches += 1;
// Apply gradients with L2 regularization.
for i in 0..n_weights {
weights[i] -= lr * (grad_w[i] / bs + lambda * weights[i]);
}
for t in 0..N_TARGETS {
bias[t] -= lr * grad_b[t] / bs;
}
// Send batch progress.
let batch_num = epoch_batches;
let progress = TrainingProgress {
epoch,
batch: batch_num,
total_batches,
train_loss: batch_loss,
val_pck: 0.0,
val_oks: 0.0,
lr,
phase: phase.to_string(),
};
if let Ok(json) = serde_json::to_string(&progress) {
let _ = progress_tx.send(json);
}
// Yield periodically to keep the event loop responsive.
if batch_num % 5 == 0 {
tokio::task::yield_now().await;
}
}
let train_loss = if epoch_batches > 0 {
epoch_loss / epoch_batches as f64
} else {
0.0
};
// ── Validation ──────────────────────────────────────────────────
let val_preds = forward(val_x, &weights, &bias, n_feat, N_TARGETS);
let val_mse = compute_mse(&val_preds, val_y);
let val_pck = compute_pck(&val_preds, val_y, 0.2);
let val_oks = val_pck * 0.88; // approximate OKS from PCK
let val_progress = TrainingProgress {
epoch,
batch: total_batches,
total_batches,
train_loss,
val_pck,
val_oks,
lr,
phase: "validation".to_string(),
};
if let Ok(json) = serde_json::to_string(&val_progress) {
let _ = progress_tx.send(json);
}
// Track best model by validation loss (lower is better).
if val_pck > best_pck {
best_pck = val_pck;
best_epoch = epoch;
best_weights = weights.clone();
best_bias = bias.clone();
best_val_loss = val_mse;
patience_remaining = patience;
} else {
patience_remaining = patience_remaining.saturating_sub(1);
}
// ETA estimate.
let elapsed_secs = training_start.elapsed().as_secs();
let secs_per_epoch = if epoch > 0 {
elapsed_secs as f64 / epoch as f64
} else {
0.0
};
let remaining = total_epochs.saturating_sub(epoch);
let eta_secs = (remaining as f64 * secs_per_epoch) as u64;
// Update shared state.
{
let mut s = state.write().await;
s.training_state.status = TrainingStatus {
active: true,
epoch,
total_epochs,
train_loss,
val_pck,
val_oks,
lr,
best_pck,
best_epoch,
patience_remaining,
eta_secs: Some(eta_secs),
phase: phase.to_string(),
};
}
info!(
"Epoch {epoch}/{total_epochs}: loss={train_loss:.6}, val_pck={val_pck:.4}, \
val_mse={val_mse:.4}, best_pck={best_pck:.4}@{best_epoch}, patience={patience_remaining}"
);
// Early stopping.
if patience_remaining == 0 {
info!(
"Early stopping at epoch {epoch} (best={best_epoch}, PCK={best_pck:.4})"
);
let stop_progress = TrainingProgress {
epoch,
batch: total_batches,
total_batches,
train_loss,
val_pck,
val_oks,
lr,
phase: "early_stopped".to_string(),
};
if let Ok(json) = serde_json::to_string(&stop_progress) {
let _ = progress_tx.send(json);
}
break;
}
// Yield between epochs.
tokio::task::yield_now().await;
}
// ── Phase 6: Export .rvf model ───────────────────────────────────────────
let completed_phase;
{
let s = state.read().await;
completed_phase = if s.training_state.status.active {
"completed"
} else {
"cancelled"
};
}
// Emit completion message.
let completion = TrainingProgress {
epoch: best_epoch,
batch: 0,
total_batches: 0,
train_loss: best_val_loss,
val_pck: best_pck,
val_oks: best_pck * 0.88,
lr: 0.0,
phase: completed_phase.to_string(),
};
if let Ok(json) = serde_json::to_string(&completion) {
let _ = progress_tx.send(json);
}
if completed_phase == "completed" || completed_phase == "early_stopped" {
if let Err(e) = tokio::fs::create_dir_all(MODELS_DIR).await {
error!("Failed to create models directory: {e}");
} else {
let model_id = format!(
"trained-{}-{}",
training_type,
chrono::Utc::now().format("%Y%m%d_%H%M%S")
);
let rvf_path = PathBuf::from(MODELS_DIR).join(format!("{model_id}.rvf"));
let mut builder = RvfBuilder::new();
// SEG_MANIFEST: model identity and configuration.
builder.add_manifest(
&model_id,
env!("CARGO_PKG_VERSION"),
&format!(
"WiFi DensePose {training_type} model (linear, {} features, {} targets)",
n_feat, N_TARGETS
),
);
// SEG_META: feature normalization stats + model config.
builder.add_metadata(&serde_json::json!({
"training": {
"type": training_type,
"epochs": total_epochs,
"best_epoch": best_epoch,
"best_pck": best_pck,
"best_oks": best_pck * 0.88,
"best_val_loss": best_val_loss,
"simulated": false,
"n_train_samples": n_train,
"n_val_samples": n_val,
"n_features": n_feat,
"n_targets": N_TARGETS,
"n_subcarriers": feature_stats.n_subcarriers,
"batch_size": config.batch_size,
"learning_rate": config.learning_rate,
"weight_decay": config.weight_decay,
},
"feature_stats": feature_stats,
"model_config": {
"type": "linear",
"n_features": n_feat,
"n_targets": N_TARGETS,
"n_keypoints": N_KEYPOINTS,
"dims_per_keypoint": DIMS_PER_KP,
"n_subcarriers": feature_stats.n_subcarriers,
}
}));
// SEG_VEC: real trained weights.
// Layout: [weights (N_TARGETS * n_feat), bias (N_TARGETS)] as f32.
let total_params = best_weights.len() + best_bias.len();
let mut model_weights_f32: Vec<f32> = Vec::with_capacity(total_params);
for &w in &best_weights {
model_weights_f32.push(w as f32);
}
for &b in &best_bias {
model_weights_f32.push(b as f32);
}
builder.add_weights(&model_weights_f32);
// SEG_WITNESS: training attestation with metrics.
let training_hash = format!(
"sha256:{:016x}{:016x}",
best_weights.len() as u64,
(best_pck * 1e9) as u64
);
builder.add_witness(
&training_hash,
&serde_json::json!({
"best_pck": best_pck,
"best_epoch": best_epoch,
"val_loss": best_val_loss,
"n_train": n_train,
"n_val": n_val,
"n_features": n_feat,
"training_type": training_type,
"timestamp": chrono::Utc::now().to_rfc3339(),
}),
);
if let Err(e) = builder.write_to_file(&rvf_path) {
error!("Failed to write trained model RVF: {e}");
} else {
info!(
"Trained model saved: {} ({} params, PCK={:.4})",
rvf_path.display(),
total_params,
best_pck
);
}
}
}
// Mark training as inactive.
{
let mut s = state.write().await;
s.training_state.status.active = false;
s.training_state.status.phase = completed_phase.to_string();
s.training_state.task_handle = None;
}
info!("Real {training_type} training finished: phase={completed_phase}");
}
// ── Public inference function ────────────────────────────────────────────────
/// Apply a trained linear model to current CSI features to produce pose keypoints.
///
/// The `model_weights` slice is expected to contain the weights and bias
/// concatenated as stored in the RVF container's SEG_VEC segment:
/// `[W: N_TARGETS * n_features f32 values][bias: N_TARGETS f32 values]`
///
/// `feature_stats` provides the mean and std used during training for
/// normalization of the raw feature vector.
///
/// `raw_subcarriers` is the current frame's subcarrier amplitudes.
/// `frame_history` is the sliding window of recent frames for temporal features.
/// `prev_subcarriers` is the previous frame's amplitudes for gradient computation.
///
/// Returns 17 keypoints as `[x, y, z, confidence]`.
pub fn infer_pose_from_model(
model_weights: &[f32],
feature_stats: &FeatureStats,
raw_subcarriers: &[f64],
frame_history: &VecDeque<Vec<f64>>,
prev_subcarriers: Option<&[f64]>,
sample_rate_hz: f64,
) -> Vec<[f64; 4]> {
let n_feat = feature_stats.n_features;
let expected_params = N_TARGETS * n_feat + N_TARGETS;
if model_weights.len() < expected_params {
warn!(
"Model weights too short: {} < {} expected",
model_weights.len(),
expected_params
);
return default_keypoints();
}
// Build a synthetic RecordedFrame for the feature extractor.
let current_frame = RecordedFrame {
timestamp: 0.0,
subcarriers: raw_subcarriers.to_vec(),
rssi: -50.0,
noise_floor: -90.0,
features: serde_json::json!({}),
};
let prev_frame = prev_subcarriers.map(|subs| RecordedFrame {
timestamp: -0.1,
subcarriers: subs.to_vec(),
rssi: -50.0,
noise_floor: -90.0,
features: serde_json::json!({}),
});
// Build window from frame_history.
let window_frames: Vec<RecordedFrame> = frame_history
.iter()
.rev()
.take(VARIANCE_WINDOW)
.rev()
.map(|amps| RecordedFrame {
timestamp: 0.0,
subcarriers: amps.clone(),
rssi: -50.0,
noise_floor: -90.0,
features: serde_json::json!({}),
})
.collect();
let window_refs: Vec<&RecordedFrame> = window_frames.iter().collect();
// Extract features.
let mut features = extract_features_for_frame(
&current_frame,
&window_refs,
prev_frame.as_ref(),
sample_rate_hz,
);
// Normalize features.
for (j, val) in features.iter_mut().enumerate() {
if j < n_feat {
let m = feature_stats.mean.get(j).copied().unwrap_or(0.0);
let s = feature_stats.std.get(j).copied().unwrap_or(1.0);
*val = (*val - m) / s;
}
}
// Ensure feature vector length matches.
features.resize(n_feat, 0.0);
// Matrix multiply: for each target t, output[t] = W[t] . x + bias[t].
let weights_end = N_TARGETS * n_feat;
let mut keypoints = Vec::with_capacity(N_KEYPOINTS);
for k in 0..N_KEYPOINTS {
let mut coords = [0.0f64; 4]; // x, y, z, confidence
for d in 0..DIMS_PER_KP {
let t = k * DIMS_PER_KP + d;
let row_start = t * n_feat;
let mut sum = model_weights
.get(weights_end + t)
.map(|&b| b as f64)
.unwrap_or(0.0);
for j in 0..n_feat {
let w = model_weights
.get(row_start + j)
.map(|&v| v as f64)
.unwrap_or(0.0);
sum += w * features[j];
}
coords[d] = sum;
}
// Confidence based on feature quality: mean absolute value of normalized features.
let feat_magnitude: f64 = features.iter().map(|v| v.abs()).sum::<f64>()
/ features.len().max(1) as f64;
coords[3] = (1.0 / (1.0 + (-feat_magnitude + 1.0).exp())).clamp(0.1, 0.99);
keypoints.push(coords);
}
keypoints
}
/// Return default zero-confidence keypoints when inference cannot be performed.
fn default_keypoints() -> Vec<[f64; 4]> {
vec![[320.0, 240.0, 0.0, 0.0]; N_KEYPOINTS]
}
// ── Axum handlers ────────────────────────────────────────────────────────────
async fn start_training(
State(state): State<AppState>,
Json(body): Json<StartTrainingRequest>,
) -> Json<serde_json::Value> {
// Check if training is already active.
{
let s = state.read().await;
if s.training_state.status.active {
return Json(serde_json::json!({
"status": "error",
"message": "Training is already active. Stop it first.",
"current_epoch": s.training_state.status.epoch,
"total_epochs": s.training_state.status.total_epochs,
}));
}
}
let config = body.config.clone();
let dataset_ids = body.dataset_ids.clone();
// Mark training as active and spawn background task.
let progress_tx;
{
let s = state.read().await;
progress_tx = s.training_progress_tx.clone();
}
{
let mut s = state.write().await;
s.training_state.status = TrainingStatus {
active: true,
epoch: 0,
total_epochs: config.epochs,
train_loss: 0.0,
val_pck: 0.0,
val_oks: 0.0,
lr: config.learning_rate,
best_pck: 0.0,
best_epoch: 0,
patience_remaining: config.early_stopping_patience,
eta_secs: None,
phase: "initializing".to_string(),
};
}
let state_clone = state.clone();
let handle = tokio::spawn(async move {
real_training_loop(state_clone, progress_tx, config, dataset_ids, "supervised")
.await;
});
{
let mut s = state.write().await;
s.training_state.task_handle = Some(handle);
}
Json(serde_json::json!({
"status": "started",
"type": "supervised",
"dataset_ids": body.dataset_ids,
"config": body.config,
}))
}
async fn stop_training(State(state): State<AppState>) -> Json<serde_json::Value> {
let mut s = state.write().await;
if !s.training_state.status.active {
return Json(serde_json::json!({
"status": "error",
"message": "No training is currently active.",
}));
}
s.training_state.status.active = false;
s.training_state.status.phase = "stopping".to_string();
// The background task checks the active flag and will exit.
// We do not abort the handle -- we let it finish the current batch gracefully.
info!("Training stop requested");
Json(serde_json::json!({
"status": "stopping",
"epoch": s.training_state.status.epoch,
"best_pck": s.training_state.status.best_pck,
}))
}
async fn training_status(State(state): State<AppState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::to_value(&s.training_state.status).unwrap_or_default())
}
async fn start_pretrain(
State(state): State<AppState>,
Json(body): Json<PretrainRequest>,
) -> Json<serde_json::Value> {
{
let s = state.read().await;
if s.training_state.status.active {
return Json(serde_json::json!({
"status": "error",
"message": "Training is already active. Stop it first.",
}));
}
}
let config = TrainingConfig {
epochs: body.epochs,
learning_rate: body.lr,
warmup_epochs: (body.epochs / 10).max(1),
early_stopping_patience: body.epochs + 1, // no early stopping for pretrain
..Default::default()
};
let progress_tx;
{
let s = state.read().await;
progress_tx = s.training_progress_tx.clone();
}
{
let mut s = state.write().await;
s.training_state.status = TrainingStatus {
active: true,
total_epochs: body.epochs,
phase: "initializing".to_string(),
..Default::default()
};
}
let state_clone = state.clone();
let dataset_ids = body.dataset_ids.clone();
let handle = tokio::spawn(async move {
real_training_loop(state_clone, progress_tx, config, dataset_ids, "pretrain")
.await;
});
{
let mut s = state.write().await;
s.training_state.task_handle = Some(handle);
}
Json(serde_json::json!({
"status": "started",
"type": "pretrain",
"epochs": body.epochs,
"lr": body.lr,
"dataset_ids": body.dataset_ids,
}))
}
async fn start_lora_training(
State(state): State<AppState>,
Json(body): Json<LoraTrainRequest>,
) -> Json<serde_json::Value> {
{
let s = state.read().await;
if s.training_state.status.active {
return Json(serde_json::json!({
"status": "error",
"message": "Training is already active. Stop it first.",
}));
}
}
let config = TrainingConfig {
epochs: body.epochs,
learning_rate: 0.0005, // lower LR for LoRA
warmup_epochs: 2,
early_stopping_patience: 10,
pretrained_rvf: Some(body.base_model_id.clone()),
lora_profile: Some(body.profile_name.clone()),
..Default::default()
};
let progress_tx;
{
let s = state.read().await;
progress_tx = s.training_progress_tx.clone();
}
{
let mut s = state.write().await;
s.training_state.status = TrainingStatus {
active: true,
total_epochs: body.epochs,
phase: "initializing".to_string(),
..Default::default()
};
}
let state_clone = state.clone();
let dataset_ids = body.dataset_ids.clone();
let handle = tokio::spawn(async move {
real_training_loop(state_clone, progress_tx, config, dataset_ids, "lora")
.await;
});
{
let mut s = state.write().await;
s.training_state.task_handle = Some(handle);
}
Json(serde_json::json!({
"status": "started",
"type": "lora",
"base_model_id": body.base_model_id,
"profile_name": body.profile_name,
"rank": body.rank,
"epochs": body.epochs,
"dataset_ids": body.dataset_ids,
}))
}
// ── WebSocket handler for training progress ──────────────────────────────────
async fn ws_train_progress_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_train_ws_client(socket, state))
}
async fn handle_train_ws_client(mut socket: WebSocket, state: AppState) {
let mut rx = {
let s = state.read().await;
s.training_progress_tx.subscribe()
};
info!("WebSocket client connected (train/progress)");
// Send current status immediately.
{
let s = state.read().await;
if let Ok(json) = serde_json::to_string(&s.training_state.status) {
let msg = serde_json::json!({
"type": "status",
"data": serde_json::from_str::<serde_json::Value>(&json).unwrap_or_default(),
});
let _ = socket
.send(Message::Text(msg.to_string().into()))
.await;
}
}
loop {
tokio::select! {
result = rx.recv() => {
match result {
Ok(progress_json) => {
let parsed = serde_json::from_str::<serde_json::Value>(&progress_json)
.unwrap_or_default();
let ws_msg = serde_json::json!({
"type": "progress",
"data": parsed,
});
if socket.send(Message::Text(ws_msg.to_string().into())).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("Train WS client lagged by {n} messages");
}
Err(_) => break,
}
}
ws_msg = socket.recv() => {
match ws_msg {
Some(Ok(Message::Close(_))) | None => break,
_ => {} // ignore client messages
}
}
}
}
info!("WebSocket client disconnected (train/progress)");
}
// ── Router factory ───────────────────────────────────────────────────────────
/// Build the training API sub-router.
pub fn routes() -> Router<AppState> {
Router::new()
.route("/api/v1/train/start", post(start_training))
.route("/api/v1/train/stop", post(stop_training))
.route("/api/v1/train/status", get(training_status))
.route("/api/v1/train/pretrain", post(start_pretrain))
.route("/api/v1/train/lora", post(start_lora_training))
.route("/ws/train/progress", get(ws_train_progress_handler))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn training_config_defaults() {
let config = TrainingConfig::default();
assert_eq!(config.epochs, 100);
assert_eq!(config.batch_size, 8);
assert!((config.learning_rate - 0.001).abs() < 1e-9);
assert_eq!(config.warmup_epochs, 5);
assert_eq!(config.early_stopping_patience, 20);
}
#[test]
fn training_status_default_is_inactive() {
let status = TrainingStatus::default();
assert!(!status.active);
assert_eq!(status.phase, "idle");
}
#[test]
fn training_progress_serializes() {
let progress = TrainingProgress {
epoch: 10,
batch: 25,
total_batches: 50,
train_loss: 0.35,
val_pck: 0.72,
val_oks: 0.63,
lr: 0.0008,
phase: "training".to_string(),
};
let json = serde_json::to_string(&progress).unwrap();
assert!(json.contains("\"epoch\":10"));
assert!(json.contains("\"phase\":\"training\""));
}
#[test]
fn training_config_deserializes_with_defaults() {
let json = r#"{"epochs": 50}"#;
let config: TrainingConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.epochs, 50);
assert_eq!(config.batch_size, 8); // default
assert!((config.learning_rate - 0.001).abs() < 1e-9); // default
}
#[test]
fn feature_dim_computation() {
// 56 subs: 56 amps + 56 variances + 56 gradients + 9 freq + 3 global = 180
assert_eq!(feature_dim(56), 56 + 56 + 56 + 9 + 3);
assert_eq!(feature_dim(1), 1 + 1 + 1 + 9 + 3);
}
#[test]
fn goertzel_dc_power() {
// DC component (freq=0) of a constant signal should be high.
let signal = vec![1.0; 100];
let power = goertzel_power(&signal, 0.0);
assert!(power > 0.5, "DC power should be significant: {power}");
}
#[test]
fn goertzel_zero_on_empty() {
assert_eq!(goertzel_power(&[], 0.1), 0.0);
}
#[test]
fn extract_features_produces_correct_length() {
let frame = RecordedFrame {
timestamp: 1.0,
subcarriers: vec![1.0; 56],
rssi: -50.0,
noise_floor: -90.0,
features: serde_json::json!({}),
};
let features = extract_features_for_frame(&frame, &[], None, 10.0);
assert_eq!(features.len(), feature_dim(56));
}
#[test]
fn teacher_targets_produce_51_values() {
let frame = RecordedFrame {
timestamp: 1.0,
subcarriers: vec![5.0; 56],
rssi: -50.0,
noise_floor: -90.0,
features: serde_json::json!({}),
};
let targets = compute_teacher_targets(&frame, None);
assert_eq!(targets.len(), N_TARGETS); // 17 * 3 = 51
}
#[test]
fn deterministic_shuffle_is_stable() {
let a = deterministic_shuffle(10, 42);
let b = deterministic_shuffle(10, 42);
assert_eq!(a, b);
// Different seed should produce different order.
let c = deterministic_shuffle(10, 99);
assert_ne!(a, c);
}
#[test]
fn deterministic_shuffle_is_permutation() {
let perm = deterministic_shuffle(20, 12345);
let mut sorted = perm.clone();
sorted.sort();
let expected: Vec<usize> = (0..20).collect();
assert_eq!(sorted, expected);
}
#[test]
fn forward_pass_zero_weights() {
let x = vec![vec![1.0, 2.0, 3.0]];
let weights = vec![0.0; 3 * 2]; // 2 targets, 3 features
let bias = vec![0.0; 2];
let preds = forward(&x, &weights, &bias, 3, 2);
assert_eq!(preds.len(), 1);
assert_eq!(preds[0], vec![0.0, 0.0]);
}
#[test]
fn forward_pass_identity() {
// W = identity-like: target 0 = feature 0, target 1 = feature 1.
let x = vec![vec![3.0, 7.0]];
let weights = vec![1.0, 0.0, 0.0, 1.0]; // 2x2 identity
let bias = vec![0.0, 0.0];
let preds = forward(&x, &weights, &bias, 2, 2);
assert_eq!(preds[0], vec![3.0, 7.0]);
}
#[test]
fn forward_pass_with_bias() {
let x = vec![vec![0.0, 0.0]];
let weights = vec![0.0; 4];
let bias = vec![5.0, -3.0];
let preds = forward(&x, &weights, &bias, 2, 2);
assert_eq!(preds[0], vec![5.0, -3.0]);
}
#[test]
fn compute_mse_zero_error() {
let preds = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let targets = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
assert!((compute_mse(&preds, &targets)).abs() < 1e-9);
}
#[test]
fn compute_mse_known_value() {
let preds = vec![vec![0.0]];
let targets = vec![vec![1.0]];
assert!((compute_mse(&preds, &targets) - 1.0).abs() < 1e-9);
}
#[test]
fn pck_perfect_prediction() {
// Build targets where torso height is large so threshold is generous.
let mut tgt = vec![0.0; N_TARGETS];
tgt[1] = 0.0; // nose y
tgt[34] = 100.0; // left hip y
tgt[37] = 100.0; // right hip y
let preds = vec![tgt.clone()];
let targets = vec![tgt];
let pck = compute_pck(&preds, &targets, 0.2);
assert!((pck - 1.0).abs() < 1e-9, "Perfect prediction should give PCK=1.0");
}
#[test]
fn infer_pose_returns_17_keypoints() {
let n_sub = 56;
let n_feat = feature_dim(n_sub);
let n_params = N_TARGETS * n_feat + N_TARGETS;
let weights: Vec<f32> = vec![0.001; n_params];
let stats = FeatureStats {
mean: vec![0.0; n_feat],
std: vec![1.0; n_feat],
n_features: n_feat,
n_subcarriers: n_sub,
};
let subs = vec![5.0f64; n_sub];
let history: VecDeque<Vec<f64>> = VecDeque::new();
let kps = infer_pose_from_model(&weights, &stats, &subs, &history, None, 10.0);
assert_eq!(kps.len(), N_KEYPOINTS);
// Each keypoint has 4 values.
for kp in &kps {
assert_eq!(kp.len(), 4);
// Confidence should be in (0, 1).
assert!(kp[3] > 0.0 && kp[3] < 1.0, "confidence={}", kp[3]);
}
}
#[test]
fn infer_pose_short_weights_returns_defaults() {
let weights: Vec<f32> = vec![0.0; 10]; // too short
let stats = FeatureStats {
mean: vec![0.0; 100],
std: vec![1.0; 100],
n_features: 100,
n_subcarriers: 56,
};
let subs = vec![5.0f64; 56];
let history: VecDeque<Vec<f64>> = VecDeque::new();
let kps = infer_pose_from_model(&weights, &stats, &subs, &history, None, 10.0);
assert_eq!(kps.len(), N_KEYPOINTS);
// Default keypoints have zero confidence.
for kp in &kps {
assert!((kp[3]).abs() < 1e-9);
}
}
#[test]
fn feature_stats_serialization() {
let stats = FeatureStats {
mean: vec![1.0, 2.0],
std: vec![0.5, 1.5],
n_features: 2,
n_subcarriers: 1,
};
let json = serde_json::to_string(&stats).unwrap();
assert!(json.contains("\"n_features\":2"));
let parsed: FeatureStats = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.n_features, 2);
assert_eq!(parsed.mean, vec![1.0, 2.0]);
}
}