diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml index ee3ce0be..a76e6f1c 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml @@ -43,5 +43,8 @@ clap = { workspace = true } # Multi-BSSID WiFi scanning pipeline (ADR-022 Phase 3) wifi-densepose-wifiscan = { version = "0.3.0", path = "../wifi-densepose-wifiscan" } +# Signal processing with RuvSense pose tracker (accuracy sprint) +wifi-densepose-signal = { version = "0.3.0", path = "../wifi-densepose-signal" } + [dev-dependencies] tempfile = "3.10" diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/field_bridge.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/field_bridge.rs new file mode 100644 index 00000000..001f933c --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/field_bridge.rs @@ -0,0 +1,142 @@ +//! Bridge between sensing-server frame data and signal crate FieldModel +//! for eigenvalue-based person counting. +//! +//! The FieldModel decomposes CSI observations into environmental drift and +//! body perturbation via SVD eigenmodes. When calibrated, perturbation energy +//! provides a physics-grounded occupancy estimate that supplements the +//! score-based heuristic in `score_to_person_count`. + +use std::collections::VecDeque; +use wifi_densepose_signal::ruvsense::field_model::{CalibrationStatus, FieldModel}; + +use super::score_to_person_count; + +/// Number of recent frames to feed into perturbation extraction. +const OCCUPANCY_WINDOW: usize = 50; + +/// Perturbation energy threshold for detecting a second person. +const ENERGY_THRESH_2: f64 = 12.0; +/// Perturbation energy threshold for detecting a third person. +const ENERGY_THRESH_3: f64 = 25.0; + +/// Estimate occupancy using the FieldModel when calibrated, falling back +/// to the score-based heuristic otherwise. +/// +/// When the field model is `Fresh` or `Stale`, we extract body perturbation +/// from the most recent frames and map total energy to a person count. +/// On any error or when uncalibrated, we fall through to `score_to_person_count`. +pub fn occupancy_or_fallback( + field: &FieldModel, + frame_history: &VecDeque>, + smoothed_score: f64, + prev_count: usize, +) -> usize { + match field.status() { + CalibrationStatus::Fresh | CalibrationStatus::Stale => { + let frames: Vec> = frame_history + .iter() + .rev() + .take(OCCUPANCY_WINDOW) + .cloned() + .collect(); + + if frames.is_empty() { + return score_to_person_count(smoothed_score, prev_count); + } + + // Use the most recent frame as the observation for perturbation + // extraction. The FieldModel expects [n_links][n_subcarriers], + // so we wrap the single frame as a single-link observation. + let observation = vec![frames[0].clone()]; + match field.extract_perturbation(&observation) { + Ok(perturbation) => { + if perturbation.total_energy > ENERGY_THRESH_3 { + 3 + } else if perturbation.total_energy > ENERGY_THRESH_2 { + 2 + } else { + 1 + } + } + Err(e) => { + tracing::warn!("FieldModel perturbation failed, using fallback: {e}"); + score_to_person_count(smoothed_score, prev_count) + } + } + } + _ => score_to_person_count(smoothed_score, prev_count), + } +} + +/// Feed the latest frame to the FieldModel during calibration collection. +/// +/// Only acts when the model status is `Collecting`. Wraps the latest frame +/// as a single-link observation and feeds it; errors are logged and ignored. +pub fn maybe_feed_calibration(field: &mut FieldModel, frame_history: &VecDeque>) { + if field.status() != CalibrationStatus::Collecting { + return; + } + if let Some(latest) = frame_history.back() { + let observations = vec![latest.clone()]; + if let Err(e) = field.feed_calibration(&observations) { + tracing::warn!("FieldModel calibration feed error: {e}"); + } + } +} + +/// Parse node positions from a semicolon-delimited string. +/// +/// Format: `"x,y,z;x,y,z;..."` where each coordinate is an `f32`. +/// Entries that fail to parse are silently skipped. +pub fn parse_node_positions(input: &str) -> Vec<[f32; 3]> { + if input.is_empty() { + return Vec::new(); + } + input + .split(';') + .filter_map(|triplet| { + let parts: Vec<&str> = triplet.split(',').collect(); + if parts.len() != 3 { + return None; + } + let x = parts[0].parse::().ok()?; + let y = parts[1].parse::().ok()?; + let z = parts[2].parse::().ok()?; + Some([x, y, z]) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_node_positions() { + let positions = parse_node_positions("0,0,1.5;3,0,1.5;1.5,3,1.5"); + assert_eq!(positions.len(), 3); + assert_eq!(positions[0], [0.0, 0.0, 1.5]); + assert_eq!(positions[1], [3.0, 0.0, 1.5]); + assert_eq!(positions[2], [1.5, 3.0, 1.5]); + } + + #[test] + fn test_parse_node_positions_empty() { + let positions = parse_node_positions(""); + assert!(positions.is_empty()); + } + + #[test] + fn test_parse_node_positions_invalid() { + let positions = parse_node_positions("abc;1,2,3"); + assert_eq!(positions.len(), 1); + assert_eq!(positions[0], [1.0, 2.0, 3.0]); + } + + #[test] + fn test_parse_node_positions_partial_triplet() { + let positions = parse_node_positions("1,2;3,4,5"); + assert_eq!(positions.len(), 1); + assert_eq!(positions[0], [3.0, 4.0, 5.0]); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs index f4835fc3..e323fb46 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs @@ -9,8 +9,11 @@ //! Replaces both ws_server.py and the Python HTTP server. mod adaptive_classifier; +mod field_bridge; +mod multistatic_bridge; mod rvf_container; mod rvf_pipeline; +mod tracker_bridge; mod vital_signs; // Training pipeline modules (exposed via lib.rs) @@ -52,6 +55,11 @@ use wifi_densepose_wifiscan::{ }; use wifi_densepose_wifiscan::parse_netsh_output as parse_netsh_bssid_output; +// Accuracy sprint: Kalman tracker, multistatic fusion, field model +use wifi_densepose_signal::ruvsense::pose_tracker::PoseTracker; +use wifi_densepose_signal::ruvsense::multistatic::{MultistaticFuser, MultistaticConfig}; +use wifi_densepose_signal::ruvsense::field_model::{FieldModel, FieldModelConfig, CalibrationStatus}; + // ── CLI ────────────────────────────────────────────────────────────────────── #[derive(Parser, Debug)] @@ -144,6 +152,14 @@ struct Args { /// Build fingerprint index from embeddings (env|activity|temporal|person) #[arg(long, value_name = "TYPE")] build_index: Option, + + /// Node positions for multistatic fusion (format: "x,y,z;x,y,z;...") + #[arg(long, env = "SENSING_NODE_POSITIONS")] + node_positions: Option, + + /// Start field model calibration on boot (empty room required) + #[arg(long)] + calibrate: bool, } // ── Data types ─────────────────────────────────────────────────────────────── @@ -282,9 +298,9 @@ struct BoundingBox { /// Each ESP32 node gets its own frame history, smoothing buffers, and vital /// sign detector so that data from different nodes is never mixed. struct NodeState { - frame_history: VecDeque>, + pub(crate) frame_history: VecDeque>, smoothed_person_score: f64, - prev_person_count: usize, + pub(crate) prev_person_count: usize, smoothed_motion: f64, current_motion_level: String, debounce_counter: u32, @@ -300,12 +316,12 @@ struct NodeState { rssi_history: VecDeque, vital_detector: VitalSignDetector, latest_vitals: VitalSigns, - last_frame_time: Option, + pub(crate) last_frame_time: Option, edge_vitals: Option, } impl NodeState { - fn new() -> Self { + pub(crate) fn new() -> Self { Self { frame_history: VecDeque::new(), smoothed_person_score: 0.0, @@ -436,6 +452,15 @@ struct AppStateInner { /// Per-node sensing state for multi-node deployments. /// Keyed by `node_id` from the ESP32 frame header. node_states: HashMap, + // ── Accuracy sprint: Kalman tracker, multistatic fusion, eigenvalue counting ── + /// Global Kalman-based pose tracker for stable person IDs and smoothed keypoints. + pose_tracker: PoseTracker, + /// Instant of last tracker update (for computing dt). + last_tracker_instant: Option, + /// Attention-weighted multi-node CSI fusion engine. + multistatic_fuser: MultistaticFuser, + /// SVD-based room field model for eigenvalue person counting (None until calibration). + field_model: Option, } /// If no ESP32 frame arrives within this duration, source reverts to offline. @@ -445,6 +470,31 @@ impl AppStateInner { /// Return the effective data source, accounting for ESP32 frame timeout. /// If the source is "esp32" but no frame has arrived in 5 seconds, returns /// "esp32:offline" so the UI can distinguish active vs stale connections. + /// Person count: eigenvalue-based if field model is calibrated, else heuristic. + /// Uses global frame_history if populated, otherwise the freshest per-node history. + fn person_count(&self) -> usize { + match self.field_model.as_ref() { + Some(fm) => { + // Prefer global frame_history (populated by wifi/simulate paths). + // Fall back to freshest per-node history (populated by ESP32 paths). + let history = if !self.frame_history.is_empty() { + &self.frame_history + } else { + // Find the node with the most recent frame + self.node_states.values() + .filter(|ns| !ns.frame_history.is_empty()) + .max_by_key(|ns| ns.last_frame_time) + .map(|ns| &ns.frame_history) + .unwrap_or(&self.frame_history) + }; + field_bridge::occupancy_or_fallback( + fm, history, self.smoothed_person_score, self.prev_person_count, + ) + } + None => score_to_person_count(self.smoothed_person_score, self.prev_person_count), + } + } + fn effective_source(&self) -> String { if self.source == "esp32" { if let Some(last) = self.last_esp32_frame { @@ -1435,7 +1485,7 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) { let raw_score = compute_person_score(&features); s.smoothed_person_score = s.smoothed_person_score * 0.90 + raw_score * 0.10; let est_persons = if classification.presence { - let count = score_to_person_count(s.smoothed_person_score, s.prev_person_count); + let count = s.person_count(); s.prev_person_count = count; count } else { @@ -1475,10 +1525,13 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) { node_features: None, }; - // Populate persons from the sensing update. - let persons = derive_pose_from_sensing(&update); - if !persons.is_empty() { - update.persons = Some(persons); + // Populate persons from the sensing update (Kalman-smoothed via tracker). + let raw_persons = derive_pose_from_sensing(&update); + let tracked = tracker_bridge::tracker_update( + &mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons, + ); + if !tracked.is_empty() { + update.persons = Some(tracked); } if let Ok(json) = serde_json::to_string(&update) { @@ -1569,7 +1622,7 @@ async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) { let raw_score = compute_person_score(&features); s.smoothed_person_score = s.smoothed_person_score * 0.90 + raw_score * 0.10; let est_persons = if classification.presence { - let count = score_to_person_count(s.smoothed_person_score, s.prev_person_count); + let count = s.person_count(); s.prev_person_count = count; count } else { @@ -1609,9 +1662,12 @@ async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) { node_features: None, }; - let persons = derive_pose_from_sensing(&update); - if !persons.is_empty() { - update.persons = Some(persons); + let raw_persons = derive_pose_from_sensing(&update); + let tracked = tracker_bridge::tracker_update( + &mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons, + ); + if !tracked.is_empty() { + update.persons = Some(tracked); } if let Ok(json) = serde_json::to_string(&update) { @@ -1788,9 +1844,13 @@ async fn handle_ws_pose_client(mut socket: WebSocket, state: SharedState) { keypoints, zone: "zone_1".into(), }] - }).unwrap_or_else(|| derive_pose_from_sensing(&sensing)) + }).unwrap_or_else(|| { + // Prefer tracked persons from broadcast if available + sensing.persons.clone().unwrap_or_else(|| derive_pose_from_sensing(&sensing)) + }) } else { - derive_pose_from_sensing(&sensing) + // Prefer tracked persons from broadcast if available + sensing.persons.clone().unwrap_or_else(|| derive_pose_from_sensing(&sensing)) }; let pose_msg = serde_json::json!({ @@ -2229,7 +2289,7 @@ async fn api_info(State(state): State) -> Json { async fn pose_current(State(state): State) -> Json { let s = state.read().await; let persons = match &s.latest_update { - Some(update) => derive_pose_from_sensing(update), + Some(update) => update.persons.clone().unwrap_or_else(|| derive_pose_from_sensing(update)), None => vec![], }; Json(serde_json::json!({ @@ -2780,6 +2840,79 @@ async fn adaptive_unload(State(state): State) -> Json) -> Json { + let mut s = state.write().await; + // Guard: don't discard an in-progress calibration + if let Some(ref fm) = s.field_model { + if fm.status() == CalibrationStatus::Collecting { + return Json(serde_json::json!({ + "success": false, + "error": "Calibration already in progress. Call /calibration/stop first.", + "frame_count": fm.calibration_frame_count(), + })); + } + } + match FieldModel::new(FieldModelConfig::default()) { + Ok(fm) => { + s.field_model = Some(fm); + Json(serde_json::json!({ + "success": true, + "message": "Calibration started — keep room empty while frames accumulate.", + })) + } + Err(e) => Json(serde_json::json!({ + "success": false, + "error": format!("{e}"), + })), + } +} + +async fn calibration_stop(State(state): State) -> Json { + let mut s = state.write().await; + if let Some(ref mut fm) = s.field_model { + let ts = chrono::Utc::now().timestamp_micros() as u64; + match fm.finalize_calibration(ts, 0) { + Ok(modes) => { + let baseline = modes.baseline_eigenvalue_count; + let variance_explained = modes.variance_explained; + info!("Field model calibrated: baseline_eigenvalues={baseline}, variance_explained={variance_explained:.2}"); + Json(serde_json::json!({ + "success": true, + "baseline_eigenvalue_count": baseline, + "variance_explained": variance_explained, + "frame_count": fm.calibration_frame_count(), + })) + } + Err(e) => Json(serde_json::json!({ + "success": false, + "error": format!("{e}"), + })), + } + } else { + Json(serde_json::json!({ + "success": false, + "error": "No field model active — call /calibration/start first.", + })) + } +} + +async fn calibration_status(State(state): State) -> Json { + let s = state.read().await; + match s.field_model.as_ref() { + Some(fm) => Json(serde_json::json!({ + "active": true, + "status": format!("{:?}", fm.status()), + "frame_count": fm.calibration_frame_count(), + })), + None => Json(serde_json::json!({ + "active": false, + "status": "none", + })), + } +} + /// Generate a simple timestamp string (epoch seconds) for recording IDs. fn chrono_timestamp() -> u64 { std::time::SystemTime::now() @@ -3045,12 +3178,30 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) { else if vitals.presence { 0.3 } else { 0.05 }; - // Aggregate person count across all active nodes. + // Aggregate person count: attention-weighted fusion or max-per-node fallback. let now = std::time::Instant::now(); - let total_persons: usize = s.node_states.values() - .filter(|n| n.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10)) - .map(|n| n.prev_person_count) - .sum(); + let total_persons = { + let (fused, fallback_count) = multistatic_bridge::fuse_or_fallback( + &s.multistatic_fuser, &s.node_states, + ); + match fused { + Some(ref f) => { + let score = multistatic_bridge::compute_person_score_from_amplitudes(&f.fused_amplitude); + s.smoothed_person_score = s.smoothed_person_score * 0.90 + score * 0.10; + let count = s.person_count(); + s.prev_person_count = count; + count + } + None => fallback_count, + } + }; + + // Feed field model calibration if active (use per-node history for ESP32). + if let Some(ref mut fm) = s.field_model { + if let Some(ns) = s.node_states.get(&node_id) { + field_bridge::maybe_feed_calibration(fm, &ns.frame_history); + } + } // Build nodes array with all active nodes. let active_nodes: Vec = s.node_states.iter() @@ -3112,9 +3263,12 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) { node_features: None, }; - let persons = derive_pose_from_sensing(&update); - if !persons.is_empty() { - update.persons = Some(persons); + let raw_persons = derive_pose_from_sensing(&update); + let tracked = tracker_bridge::tracker_update( + &mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons, + ); + if !tracked.is_empty() { + update.persons = Some(tracked); } if let Ok(json) = serde_json::to_string(&update) { @@ -3244,12 +3398,30 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) { else if classification.motion_level == "present_still" { 0.3 } else { 0.05 }; - // Aggregate person count across all active nodes. + // Aggregate person count: attention-weighted fusion or naive sum fallback. let now = std::time::Instant::now(); - let total_persons: usize = s.node_states.values() - .filter(|n| n.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10)) - .map(|n| n.prev_person_count) - .sum(); + let total_persons = { + let (fused, fallback_count) = multistatic_bridge::fuse_or_fallback( + &s.multistatic_fuser, &s.node_states, + ); + match fused { + Some(ref f) => { + let score = multistatic_bridge::compute_person_score_from_amplitudes(&f.fused_amplitude); + s.smoothed_person_score = s.smoothed_person_score * 0.90 + score * 0.10; + let count = s.person_count(); + s.prev_person_count = count; + count + } + None => fallback_count, + } + }; + + // Feed field model calibration if active (use per-node history for ESP32). + if let Some(ref mut fm) = s.field_model { + if let Some(ns) = s.node_states.get(&node_id) { + field_bridge::maybe_feed_calibration(fm, &ns.frame_history); + } + } // Build nodes array with all active nodes. let active_nodes: Vec = s.node_states.iter() @@ -3291,9 +3463,12 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) { node_features: None, }; - let persons = derive_pose_from_sensing(&update); - if !persons.is_empty() { - update.persons = Some(persons); + let raw_persons = derive_pose_from_sensing(&update); + let tracked = tracker_bridge::tracker_update( + &mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons, + ); + if !tracked.is_empty() { + update.persons = Some(tracked); } if let Ok(json) = serde_json::to_string(&update) { @@ -3360,7 +3535,7 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) { let raw_score = compute_person_score(&features); s.smoothed_person_score = s.smoothed_person_score * 0.90 + raw_score * 0.10; let est_persons = if classification.presence { - let count = score_to_person_count(s.smoothed_person_score, s.prev_person_count); + let count = s.person_count(); s.prev_person_count = count; count } else { @@ -3410,10 +3585,13 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) { node_features: None, }; - // Populate persons from the sensing update. - let persons = derive_pose_from_sensing(&update); - if !persons.is_empty() { - update.persons = Some(persons); + // Populate persons from the sensing update (Kalman-smoothed via tracker). + let raw_persons = derive_pose_from_sensing(&update); + let tracked = tracker_bridge::tracker_update( + &mut s.pose_tracker, &mut s.last_tracker_instant, raw_persons, + ); + if !tracked.is_empty() { + update.persons = Some(tracked); } if update.classification.presence { @@ -4042,6 +4220,29 @@ async fn main() { m }), node_states: HashMap::new(), + // Accuracy sprint + pose_tracker: PoseTracker::new(), + last_tracker_instant: None, + multistatic_fuser: { + let mut fuser = MultistaticFuser::with_config(MultistaticConfig { + min_nodes: 1, // single-node passthrough + ..Default::default() + }); + if let Some(ref pos_str) = args.node_positions { + let positions = field_bridge::parse_node_positions(pos_str); + if !positions.is_empty() { + info!("Configured {} node positions for multistatic fusion", positions.len()); + fuser.set_node_positions(positions); + } + } + fuser + }, + field_model: if args.calibrate { + info!("Field model calibration enabled — room should be empty during startup"); + FieldModel::new(FieldModelConfig::default()).ok() + } else { + None + }, })); // Start background tasks based on source @@ -4138,6 +4339,10 @@ async fn main() { .route("/api/v1/adaptive/train", post(adaptive_train)) .route("/api/v1/adaptive/status", get(adaptive_status)) .route("/api/v1/adaptive/unload", post(adaptive_unload)) + // Field model calibration (eigenvalue-based person counting) + .route("/api/v1/calibration/start", post(calibration_start)) + .route("/api/v1/calibration/stop", post(calibration_stop)) + .route("/api/v1/calibration/status", get(calibration_status)) // Static UI files .nest_service("/ui", ServeDir::new(&ui_path)) .layer(SetResponseHeaderLayer::overriding( diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/multistatic_bridge.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/multistatic_bridge.rs new file mode 100644 index 00000000..98d89dae --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/multistatic_bridge.rs @@ -0,0 +1,263 @@ +//! Bridge between sensing-server per-node state and the signal crate's +//! `MultistaticFuser` for attention-weighted CSI fusion across ESP32 nodes. +//! +//! This module converts the server's `NodeState` (f64 amplitude history) into +//! `MultiBandCsiFrame`s that the multistatic fusion pipeline expects, then +//! drives `MultistaticFuser::fuse` with a graceful fallback when fusion fails +//! (e.g. insufficient nodes or timestamp spread). + +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use wifi_densepose_signal::hardware_norm::{CanonicalCsiFrame, HardwareType}; +use wifi_densepose_signal::ruvsense::multiband::MultiBandCsiFrame; +use wifi_densepose_signal::ruvsense::multistatic::{FusedSensingFrame, MultistaticFuser}; + +use super::NodeState; + +/// Maximum age for a node frame to be considered active (10 seconds). +const STALE_THRESHOLD: Duration = Duration::from_secs(10); + +/// Default WiFi channel frequency (MHz) used for single-channel frames. +const DEFAULT_FREQ_MHZ: u32 = 2437; // Channel 6 + +/// Convert a single `NodeState` into a `MultiBandCsiFrame` suitable for +/// multistatic fusion. +/// +/// Returns `None` when the node has no frame history or no recorded +/// `last_frame_time`. +pub fn node_frame_from_state(node_id: u8, ns: &NodeState) -> Option { + let last_time = ns.last_frame_time.as_ref()?; + let latest = ns.frame_history.back()?; + if latest.is_empty() { + return None; + } + + let amplitude: Vec = latest.iter().map(|&v| v as f32).collect(); + let n_sub = amplitude.len(); + let phase = vec![0.0_f32; n_sub]; + + // Derive a monotonic timestamp: use wall-clock time minus elapsed since + // last frame to approximate when the frame was actually received. + let wall_us = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_micros() as u64) + .unwrap_or(0); + let age_us = last_time.elapsed().as_micros() as u64; + let timestamp_us = wall_us.saturating_sub(age_us); + + let canonical = CanonicalCsiFrame { + amplitude, + phase, + hardware_type: HardwareType::Esp32S3, + }; + + Some(MultiBandCsiFrame { + node_id, + timestamp_us, + channel_frames: vec![canonical], + frequencies_mhz: vec![DEFAULT_FREQ_MHZ], + coherence: 1.0, // single-channel, perfect self-coherence + }) +} + +/// Collect `MultiBandCsiFrame`s from all active nodes. +/// +/// A node is considered active if its `last_frame_time` is within +/// [`STALE_THRESHOLD`] of `now`. +pub fn node_frames_from_states(node_states: &HashMap) -> Vec { + let now = Instant::now(); + let mut frames = Vec::with_capacity(node_states.len()); + + for (&node_id, ns) in node_states { + // Skip stale nodes + if let Some(ref t) = ns.last_frame_time { + if now.duration_since(*t) > STALE_THRESHOLD { + continue; + } + } else { + continue; + } + + if let Some(frame) = node_frame_from_state(node_id, ns) { + frames.push(frame); + } + } + + frames +} + +/// Attempt multistatic fusion; fall back to max per-node person count on failure. +/// +/// Returns `(fused_frame, fallback_person_count)`. When fusion succeeds, the +/// caller should compute person count from the fused amplitudes (the returned +/// fallback count is 0 as a sentinel). On failure, returns the maximum +/// per-node count (not the sum, to avoid double-counting overlapping coverage). +pub fn fuse_or_fallback( + fuser: &MultistaticFuser, + node_states: &HashMap, +) -> (Option, usize) { + let frames = node_frames_from_states(node_states); + if frames.is_empty() { + return (None, 0); + } + + match fuser.fuse(&frames) { + Ok(fused) => { + // Return 0 as sentinel — caller must compute count from fused amplitudes. + (Some(fused), 0) + } + Err(e) => { + tracing::debug!("Multistatic fusion failed ({e}), using per-node max fallback"); + // Use max (not sum) to avoid double-counting when nodes have overlapping coverage. + let max_count: usize = node_states + .values() + .filter(|ns| { + ns.last_frame_time + .map(|t| t.elapsed() <= STALE_THRESHOLD) + .unwrap_or(false) + }) + .map(|ns| ns.prev_person_count) + .max() + .unwrap_or(0); + (None, max_count) + } + } +} + +/// Compute a person-presence score from fused amplitude data. +/// +/// Uses the squared coefficient of variation (variance / mean^2) as a +/// lightweight proxy for body-induced CSI perturbation. A flat amplitude +/// vector (no person) yields a score near zero; a vector with high variance +/// relative to its mean (person moving) yields a score approaching 1.0. +pub fn compute_person_score_from_amplitudes(amplitudes: &[f32]) -> f64 { + if amplitudes.is_empty() { + return 0.0; + } + + let n = amplitudes.len() as f64; + let sum: f64 = amplitudes.iter().map(|&a| a as f64).sum(); + let mean = sum / n; + + let variance: f64 = amplitudes.iter().map(|&a| { + let diff = (a as f64) - mean; + diff * diff + }).sum::() / n; + + let score = variance / (mean * mean + 1e-10); + score.clamp(0.0, 1.0) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::VecDeque; + + /// Helper: build a minimal NodeState for testing. Uses `NodeState::new()` + /// then mutates the `pub(crate)` fields the bridge needs. + fn make_node_state( + frame_history: VecDeque>, + last_frame_time: Option, + prev_person_count: usize, + ) -> NodeState { + let mut ns = NodeState::new(); + ns.frame_history = frame_history; + ns.last_frame_time = last_frame_time; + ns.prev_person_count = prev_person_count; + ns + } + + #[test] + fn test_node_frame_from_empty_state() { + let ns = make_node_state(VecDeque::new(), Some(Instant::now()), 0); + assert!(node_frame_from_state(1, &ns).is_none()); + } + + #[test] + fn test_node_frame_from_state_no_time() { + let mut history = VecDeque::new(); + history.push_back(vec![1.0, 2.0, 3.0]); + let ns = make_node_state(history, None, 0); + assert!(node_frame_from_state(1, &ns).is_none()); + } + + #[test] + fn test_node_frame_conversion() { + let mut history = VecDeque::new(); + history.push_back(vec![10.0, 20.0, 30.5]); + let ns = make_node_state(history, Some(Instant::now()), 0); + + let frame = node_frame_from_state(42, &ns).expect("should produce a frame"); + assert_eq!(frame.node_id, 42); + assert_eq!(frame.channel_frames.len(), 1); + + let ch = &frame.channel_frames[0]; + assert_eq!(ch.amplitude.len(), 3); + assert!((ch.amplitude[0] - 10.0_f32).abs() < f32::EPSILON); + assert!((ch.amplitude[1] - 20.0_f32).abs() < f32::EPSILON); + assert!((ch.amplitude[2] - 30.5_f32).abs() < f32::EPSILON); + // Phase should be all zeros + assert!(ch.phase.iter().all(|&p| p == 0.0)); + assert_eq!(ch.hardware_type, HardwareType::Esp32S3); + } + + #[test] + fn test_stale_node_excluded() { + let mut states: HashMap = HashMap::new(); + + // Active node: frame just received + let mut active_history = VecDeque::new(); + active_history.push_back(vec![1.0, 2.0]); + states.insert(1, make_node_state(active_history, Some(Instant::now()), 1)); + + // Stale node: frame 20 seconds ago + let mut stale_history = VecDeque::new(); + stale_history.push_back(vec![3.0, 4.0]); + let stale_time = Instant::now() - Duration::from_secs(20); + states.insert(2, make_node_state(stale_history, Some(stale_time), 1)); + + let frames = node_frames_from_states(&states); + assert_eq!(frames.len(), 1, "stale node should be excluded"); + assert_eq!(frames[0].node_id, 1); + } + + #[test] + fn test_compute_person_score_empty() { + assert!((compute_person_score_from_amplitudes(&[]) - 0.0).abs() < f64::EPSILON); + } + + #[test] + fn test_compute_person_score_flat() { + // Constant amplitude => variance = 0 => score ~ 0 + let flat = vec![5.0_f32; 64]; + let score = compute_person_score_from_amplitudes(&flat); + assert!(score < 0.001, "flat signal should have near-zero score, got {score}"); + } + + #[test] + fn test_compute_person_score_varied() { + // High variance relative to mean should produce a positive score + let varied: Vec = (0..64).map(|i| if i % 2 == 0 { 1.0 } else { 10.0 }).collect(); + let score = compute_person_score_from_amplitudes(&varied); + assert!(score > 0.1, "varied signal should have positive score, got {score}"); + assert!(score <= 1.0, "score should be clamped to 1.0, got {score}"); + } + + #[test] + fn test_compute_person_score_clamped() { + // Near-zero mean with non-zero variance => would blow up without clamp + let vals = vec![0.0_f32, 0.0, 0.0, 0.001]; + let score = compute_person_score_from_amplitudes(&vals); + assert!(score <= 1.0, "score must be clamped to 1.0"); + } + + #[test] + fn test_fuse_or_fallback_empty() { + let fuser = MultistaticFuser::new(); + let states: HashMap = HashMap::new(); + let (fused, count) = fuse_or_fallback(&fuser, &states); + assert!(fused.is_none()); + assert_eq!(count, 0); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/tracker_bridge.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/tracker_bridge.rs new file mode 100644 index 00000000..cdddc043 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/tracker_bridge.rs @@ -0,0 +1,397 @@ +//! Bridge between sensing-server PersonDetection types and signal crate PoseTracker. +//! +//! The sensing server uses f64 types (PersonDetection, PoseKeypoint, BoundingBox) +//! while the signal crate's PoseTracker operates on f32 Kalman states. This module +//! provides conversion functions and a single `tracker_update` entry point that +//! accepts server-side detections and returns tracker-smoothed results. + +use std::time::Instant; +use wifi_densepose_signal::ruvsense::{ + self, KeypointState, PoseTrack, TrackLifecycleState, TrackId, NUM_KEYPOINTS, +}; +use wifi_densepose_signal::ruvsense::pose_tracker::PoseTracker; + +use super::{BoundingBox, PersonDetection, PoseKeypoint}; + +/// COCO-17 keypoint names in index order. +const COCO_NAMES: [&str; 17] = [ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", +]; + +/// Map a lowercase keypoint name to its COCO-17 index. +fn keypoint_name_to_coco_index(name: &str) -> Option { + COCO_NAMES.iter().position(|&n| n.eq_ignore_ascii_case(name)) +} + +/// Convert server-side PersonDetection slices into tracker-compatible keypoint arrays. +/// +/// For each person, maps named keypoints to COCO-17 positions. Unmapped slots are +/// filled with the centroid of the mapped keypoints so the Kalman filter has a +/// reasonable initial value rather than zeros. +fn detections_to_tracker_keypoints(persons: &[PersonDetection]) -> Vec<[[f32; 3]; 17]> { + persons + .iter() + .map(|person| { + let mut kps = [[0.0_f32; 3]; 17]; + let mut mapped_count = 0u32; + let mut cx = 0.0_f32; + let mut cy = 0.0_f32; + let mut cz = 0.0_f32; + + // First pass: place mapped keypoints and accumulate centroid + for kp in &person.keypoints { + if let Some(idx) = keypoint_name_to_coco_index(&kp.name) { + kps[idx] = [kp.x as f32, kp.y as f32, kp.z as f32]; + cx += kp.x as f32; + cy += kp.y as f32; + cz += kp.z as f32; + mapped_count += 1; + } + } + + // Compute centroid of mapped keypoints + let centroid = if mapped_count > 0 { + let n = mapped_count as f32; + [cx / n, cy / n, cz / n] + } else { + [0.0, 0.0, 0.0] + }; + + // Second pass: fill unmapped slots with centroid + // Build a set of mapped indices + let mut mapped = [false; 17]; + for kp in &person.keypoints { + if let Some(idx) = keypoint_name_to_coco_index(&kp.name) { + mapped[idx] = true; + } + } + for i in 0..17 { + if !mapped[i] { + kps[i] = centroid; + } + } + + kps + }) + .collect() +} + +/// Convert active PoseTracker tracks back into server-side PersonDetection values. +/// +/// Only tracks whose lifecycle `is_alive()` are included. +pub fn tracker_to_person_detections(tracker: &PoseTracker) -> Vec { + tracker + .active_tracks() + .into_iter() + .map(|track| { + let id = track.id.0 as u32; + + let confidence = match track.lifecycle { + TrackLifecycleState::Active => 0.9, + TrackLifecycleState::Tentative => 0.5, + TrackLifecycleState::Lost => 0.3, + TrackLifecycleState::Terminated => 0.0, + }; + + // Build keypoints from Kalman state + let keypoints: Vec = (0..NUM_KEYPOINTS) + .map(|i| { + let pos = track.keypoints[i].position(); + PoseKeypoint { + name: COCO_NAMES[i].to_string(), + x: pos[0] as f64, + y: pos[1] as f64, + z: pos[2] as f64, + confidence: track.keypoints[i].confidence as f64, + } + }) + .collect(); + + // Compute bounding box from keypoint min/max + let mut min_x = f64::MAX; + let mut min_y = f64::MAX; + let mut max_x = f64::MIN; + let mut max_y = f64::MIN; + for kp in &keypoints { + if kp.x < min_x { min_x = kp.x; } + if kp.y < min_y { min_y = kp.y; } + if kp.x > max_x { max_x = kp.x; } + if kp.y > max_y { max_y = kp.y; } + } + + let bbox = BoundingBox { + x: min_x, + y: min_y, + width: max_x - min_x, + height: max_y - min_y, + }; + + PersonDetection { + id, + confidence, + keypoints, + bbox, + zone: "tracked".to_string(), + } + }) + .collect() +} + +/// Run one tracker cycle: predict, match detections, update, prune. +/// +/// This is the main entry point called each sensing frame. It: +/// 1. Computes dt from the previous call instant +/// 2. Predicts all existing tracks forward +/// 3. Greedily assigns detections to tracks by Mahalanobis cost +/// 4. Updates matched tracks, creates new tracks for unmatched detections +/// 5. Prunes terminated tracks +/// 6. Returns smoothed PersonDetection values from the tracker state +pub fn tracker_update( + tracker: &mut PoseTracker, + last_instant: &mut Option, + persons: Vec, +) -> Vec { + let now = Instant::now(); + let dt = last_instant.map_or(0.1_f32, |prev| now.duration_since(prev).as_secs_f32()); + *last_instant = Some(now); + + // Predict all tracks forward + tracker.predict_all(dt); + + if persons.is_empty() { + tracker.prune_terminated(); + return tracker_to_person_detections(tracker); + } + + // Convert detections to f32 keypoint arrays + let all_keypoints = detections_to_tracker_keypoints(&persons); + + // Compute centroids for each detection + let centroids: Vec<[f32; 3]> = all_keypoints + .iter() + .map(|kps| { + let mut c = [0.0_f32; 3]; + for kp in kps { + c[0] += kp[0]; + c[1] += kp[1]; + c[2] += kp[2]; + } + let n = NUM_KEYPOINTS as f32; + c[0] /= n; + c[1] /= n; + c[2] /= n; + c + }) + .collect(); + + // Greedy assignment: for each detection, find the best matching active track. + // Collect tracks once to avoid re-borrowing tracker per detection. + let active: Vec<(TrackId, [f32; 3])> = tracker.active_tracks().iter().map(|t| { + let centroid = { + let mut c = [0.0_f32; 3]; + for kp in &t.keypoints { + let p = kp.position(); + c[0] += p[0]; c[1] += p[1]; c[2] += p[2]; + } + let n = NUM_KEYPOINTS as f32; + [c[0] / n, c[1] / n, c[2] / n] + }; + (t.id, centroid) + }).collect(); + + let mut used_tracks: Vec = vec![false; active.len()]; + let mut matched: Vec> = vec![None; persons.len()]; + + for det_idx in 0..persons.len() { + let mut best_cost = f32::MAX; + let mut best_track_idx = None; + + let active_refs = tracker.active_tracks(); + for (track_idx, track) in active_refs.iter().enumerate() { + if used_tracks[track_idx] { + continue; + } + let cost = tracker.assignment_cost(track, ¢roids[det_idx], &[]); + if cost < best_cost { + best_cost = cost; + best_track_idx = Some(track_idx); + } + } + + // Mahalanobis gate: 9.0 (default TrackerConfig) + if best_cost < 9.0 { + if let Some(tidx) = best_track_idx { + matched[det_idx] = Some(active[tidx].0); + used_tracks[tidx] = true; + } + } + } + + // Timestamp for new/updated tracks (microseconds since UNIX epoch) + let timestamp_us = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_micros() as u64) + .unwrap_or(0); + + // Update matched tracks (uses update_keypoints for proper lifecycle transitions) + for (det_idx, track_id_opt) in matched.iter().enumerate() { + if let Some(track_id) = track_id_opt { + if let Some(track) = tracker.find_track_mut(*track_id) { + track.update_keypoints(&all_keypoints[det_idx], 0.08, 1.0, timestamp_us); + } + } + } + + // Create new tracks for unmatched detections + for (det_idx, track_id_opt) in matched.iter().enumerate() { + if track_id_opt.is_none() { + tracker.create_track(&all_keypoints[det_idx], timestamp_us); + } + } + + tracker.prune_terminated(); + tracker_to_person_detections(tracker) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_keypoint(name: &str, x: f64, y: f64, z: f64) -> PoseKeypoint { + PoseKeypoint { + name: name.to_string(), + x, + y, + z, + confidence: 0.9, + } + } + + fn make_person(id: u32, keypoints: Vec) -> PersonDetection { + PersonDetection { + id, + confidence: 0.8, + keypoints, + bbox: BoundingBox { + x: 0.0, + y: 0.0, + width: 1.0, + height: 1.0, + }, + zone: "test".to_string(), + } + } + + #[test] + fn test_keypoint_name_to_coco_index() { + assert_eq!(keypoint_name_to_coco_index("nose"), Some(0)); + assert_eq!(keypoint_name_to_coco_index("left_eye"), Some(1)); + assert_eq!(keypoint_name_to_coco_index("right_eye"), Some(2)); + assert_eq!(keypoint_name_to_coco_index("left_ear"), Some(3)); + assert_eq!(keypoint_name_to_coco_index("right_ear"), Some(4)); + assert_eq!(keypoint_name_to_coco_index("left_shoulder"), Some(5)); + assert_eq!(keypoint_name_to_coco_index("right_shoulder"), Some(6)); + assert_eq!(keypoint_name_to_coco_index("left_elbow"), Some(7)); + assert_eq!(keypoint_name_to_coco_index("right_elbow"), Some(8)); + assert_eq!(keypoint_name_to_coco_index("left_wrist"), Some(9)); + assert_eq!(keypoint_name_to_coco_index("right_wrist"), Some(10)); + assert_eq!(keypoint_name_to_coco_index("left_hip"), Some(11)); + assert_eq!(keypoint_name_to_coco_index("right_hip"), Some(12)); + assert_eq!(keypoint_name_to_coco_index("left_knee"), Some(13)); + assert_eq!(keypoint_name_to_coco_index("right_knee"), Some(14)); + assert_eq!(keypoint_name_to_coco_index("left_ankle"), Some(15)); + assert_eq!(keypoint_name_to_coco_index("right_ankle"), Some(16)); + assert_eq!(keypoint_name_to_coco_index("unknown"), None); + // Case insensitive + assert_eq!(keypoint_name_to_coco_index("NOSE"), Some(0)); + assert_eq!(keypoint_name_to_coco_index("Left_Eye"), Some(1)); + } + + #[test] + fn test_detections_to_tracker_keypoints() { + let person = make_person( + 1, + vec![ + make_keypoint("nose", 1.0, 2.0, 0.5), + make_keypoint("left_shoulder", 0.8, 2.5, 0.4), + make_keypoint("right_shoulder", 1.2, 2.5, 0.6), + ], + ); + + let result = detections_to_tracker_keypoints(&[person]); + assert_eq!(result.len(), 1); + + let kps = &result[0]; + + // Mapped keypoints should have correct values + assert!((kps[0][0] - 1.0).abs() < 1e-5); // nose x + assert!((kps[0][1] - 2.0).abs() < 1e-5); // nose y + assert!((kps[0][2] - 0.5).abs() < 1e-5); // nose z + + assert!((kps[5][0] - 0.8).abs() < 1e-5); // left_shoulder x + assert!((kps[6][0] - 1.2).abs() < 1e-5); // right_shoulder x + + // Unmapped keypoints should be at centroid of mapped keypoints + // centroid = ((1.0+0.8+1.2)/3, (2.0+2.5+2.5)/3, (0.5+0.4+0.6)/3) + let cx = (1.0 + 0.8 + 1.2) / 3.0; + let cy = (2.0 + 2.5 + 2.5) / 3.0; + let cz = (0.5 + 0.4 + 0.6) / 3.0; + + // left_eye (index 1) should be at centroid + assert!((kps[1][0] - cx).abs() < 1e-4); + assert!((kps[1][1] - cy).abs() < 1e-4); + assert!((kps[1][2] - cz).abs() < 1e-4); + } + + #[test] + fn test_tracker_update_stable_ids() { + let mut tracker = PoseTracker::new(); + let mut last_instant: Option = None; + + let person = make_person( + 0, + vec![ + make_keypoint("nose", 1.0, 2.0, 0.0), + make_keypoint("left_shoulder", 0.8, 2.5, 0.0), + make_keypoint("right_shoulder", 1.2, 2.5, 0.0), + make_keypoint("left_hip", 0.9, 3.5, 0.0), + make_keypoint("right_hip", 1.1, 3.5, 0.0), + ], + ); + + // First update: creates a new track + let result1 = tracker_update(&mut tracker, &mut last_instant, vec![person.clone()]); + assert_eq!(result1.len(), 1); + let id1 = result1[0].id; + + // Second update: should match the existing track + let result2 = tracker_update(&mut tracker, &mut last_instant, vec![person.clone()]); + assert_eq!(result2.len(), 1); + let id2 = result2[0].id; + + // Third update: same track ID should persist + let result3 = tracker_update(&mut tracker, &mut last_instant, vec![person.clone()]); + assert_eq!(result3.len(), 1); + let id3 = result3[0].id; + + // All three updates should return the same track ID + assert_eq!(id1, id2, "Track ID should be stable across updates"); + assert_eq!(id2, id3, "Track ID should be stable across updates"); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/Cargo.toml b/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/Cargo.toml index 11114e9b..782392ff 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/Cargo.toml +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/Cargo.toml @@ -20,6 +20,7 @@ chrono = { version = "0.4", features = ["serde"] } # Signal processing ndarray = { workspace = true } +ndarray-linalg = { workspace = true } rustfft.workspace = true num-complex.workspace = true num-traits.workspace = true diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/src/ruvsense/field_model.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/src/ruvsense/field_model.rs index 7494235e..5bd3cadf 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/src/ruvsense/field_model.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-signal/src/ruvsense/field_model.rs @@ -17,6 +17,10 @@ //! of Squares and Products." Technometrics. //! - ADR-030: RuvSense Persistent Field Model +use ndarray::Array2; +use ndarray_linalg::Eigh; +use ndarray_linalg::UPLO; + // --------------------------------------------------------------------------- // Error types // --------------------------------------------------------------------------- @@ -47,6 +51,14 @@ pub enum FieldModelError { /// Invalid configuration parameter. #[error("Invalid configuration: {0}")] InvalidConfig(String), + + /// Model has not been calibrated yet. + #[error("Field model not calibrated")] + NotCalibrated, + + /// Not enough data for the requested operation. + #[error("Insufficient data: need {need}, have {have}")] + InsufficientData { need: usize, have: usize }, } // --------------------------------------------------------------------------- @@ -260,6 +272,8 @@ pub struct FieldNormalMode { pub calibrated_at_us: u64, /// Hash of mesh geometry at calibration time. pub geometry_hash: u64, + /// Baseline eigenvalue count above Marcenko-Pastur threshold (empty-room). + pub baseline_eigenvalue_count: usize, } /// Body perturbation extracted from a CSI observation. @@ -310,6 +324,60 @@ pub struct FieldModel { status: CalibrationStatus, /// Timestamp of last calibration completion (microseconds). last_calibration_us: u64, + /// Running outer-product sum for full covariance SVD: [n_sub x n_sub]. + covariance_sum: Option>, + /// Number of frames accumulated into covariance_sum. + covariance_count: u64, +} + +/// Diagonal variance fallback for when full covariance SVD is unavailable. +/// +/// Returns `(mode_energies, environmental_modes, baseline_eigenvalue_count)`. +fn diagonal_fallback( + link_stats: &[LinkBaselineStats], + n_sc: usize, + n_modes: usize, +) -> (Vec, Vec>, usize) { + // Average variance across links (diagonal approximation) + let mut avg_variance = vec![0.0_f64; n_sc]; + for ls in link_stats { + let var = ls.variance_vector(); + for (i, v) in var.iter().enumerate() { + avg_variance[i] += v; + } + } + let n_links_f = link_stats.len() as f64; + if n_links_f > 0.0 { + for v in avg_variance.iter_mut() { + *v /= n_links_f; + } + } + + // Sort subcarrier indices by variance (descending) to pick top-K modes + let mut indices: Vec = (0..n_sc).collect(); + indices.sort_by(|&a, &b| { + avg_variance[b] + .partial_cmp(&avg_variance[a]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let mut environmental_modes = Vec::with_capacity(n_modes); + let mut mode_energies = Vec::with_capacity(n_modes); + + for k in 0..n_modes.min(n_sc) { + let idx = indices[k]; + let mut mode = vec![0.0_f64; n_sc]; + mode[idx] = 1.0; + mode_energies.push(avg_variance[idx]); + environmental_modes.push(mode); + } + + // For diagonal fallback, estimate baseline eigenvalue count from variance + let total_var: f64 = avg_variance.iter().sum(); + let mean_var = if n_sc > 0 { total_var / n_sc as f64 } else { 0.0 }; + let baseline_count = avg_variance.iter().filter(|&&v| v > mean_var * 2.0).count(); + + (mode_energies, environmental_modes, baseline_count) } impl FieldModel { @@ -339,6 +407,8 @@ impl FieldModel { modes: None, status: CalibrationStatus::Uncalibrated, last_calibration_us: 0, + covariance_sum: None, + covariance_count: 0, }) } @@ -375,6 +445,30 @@ impl FieldModel { if self.status == CalibrationStatus::Uncalibrated { self.status = CalibrationStatus::Collecting; } + + // Accumulate raw outer products for SVD covariance (no centering here — + // mean subtraction is deferred to finalize_calibration to avoid bias). + // We average across links so covariance_count tracks frames, not links. + let n = self.config.n_subcarriers; + let cov = self.covariance_sum.get_or_insert_with(|| Array2::zeros((n, n))); + let n_links = observations.len(); + for obs in observations { + if obs.len() >= n { + // Rank-1 update: cov += obs * obs^T (raw, un-centered) + for i in 0..n { + for j in i..n { + let val = obs[i] * obs[j]; + cov[[i, j]] += val; + if i != j { + cov[[j, i]] += val; + } + } + } + } + } + // Count once per frame (not per link) for correct MP ratio + self.covariance_count += 1; + Ok(()) } @@ -396,58 +490,117 @@ impl FieldModel { }); } - // Build covariance matrix from per-link variance data. - // We average the variance vectors across all links to get the - // covariance diagonal, then compute eigenmodes via power iteration. let n_sc = self.config.n_subcarriers; let n_modes = self.config.n_modes.min(n_sc); // Collect per-link baselines let baseline: Vec> = self.link_stats.iter().map(|ls| ls.mean_vector()).collect(); - // Average covariance across links (diagonal approximation) - let mut avg_variance = vec![0.0_f64; n_sc]; - for ls in &self.link_stats { - let var = ls.variance_vector(); - for (i, v) in var.iter().enumerate() { - avg_variance[i] += v; + // --- True eigenvalue decomposition (with diagonal fallback) --- + let (mode_energies, environmental_modes, baseline_eig_count) = + if let Some(ref cov_sum) = self.covariance_sum { + if self.covariance_count > 1 { + // Compute sample covariance from raw outer products: + // cov = (sum_xx / N - mean * mean^T) * N / (N-1) + // where sum_xx accumulated obs * obs^T across all links per frame. + // We average per-link means for centering. + let n_frames = self.covariance_count as f64; + let n_links = self.config.n_links as f64; + // Average mean across all links + let mut avg_mean = vec![0.0f64; n_sc]; + for ls in &self.link_stats { + let m = ls.mean_vector(); + for i in 0..n_sc { avg_mean[i] += m[i]; } + } + for i in 0..n_sc { avg_mean[i] /= n_links; } + // cov = sum_xx / (N * n_links) - mean * mean^T, then Bessel correction + let total_obs = n_frames * n_links; + let mut covariance = cov_sum / total_obs; + for i in 0..n_sc { + for j in 0..n_sc { + covariance[[i, j]] -= avg_mean[i] * avg_mean[j]; + } + } + // Bessel's correction: multiply by N/(N-1) where N = total observations + let bessel = total_obs / (total_obs - 1.0); + covariance *= bessel; + + // Symmetric eigendecomposition + match covariance.eigh(UPLO::Upper) { + Ok((eigenvalues, eigenvectors)) => { + // eigenvalues are in ascending order from ndarray-linalg + // Reverse to get descending + let len = eigenvalues.len(); + let mut sorted_indices: Vec = (0..len).collect(); + sorted_indices.sort_by(|&a, &b| { + eigenvalues[b] + .partial_cmp(&eigenvalues[a]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Extract top n_modes + let modes: Vec> = sorted_indices + .iter() + .take(n_modes) + .map(|&idx| eigenvectors.column(idx).to_vec()) + .collect(); + let energies: Vec = sorted_indices + .iter() + .take(n_modes) + .map(|&idx| eigenvalues[idx].max(0.0)) + .collect(); + + // Marcenko-Pastur threshold for baseline eigenvalue count. + // Use median of bottom half as robust noise estimate + // (consistent with estimate_occupancy). + let noise_var = { + let mut sorted_eigs: Vec = eigenvalues + .iter().copied().map(|e| e.max(0.0)).collect(); + sorted_eigs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let half = sorted_eigs.len() / 2; + if half > 0 { + sorted_eigs[..half].iter().sum::() / half as f64 + } else { + sorted_eigs.iter().sum::() / sorted_eigs.len().max(1) as f64 + } + }; + let ratio = n_sc as f64 / self.covariance_count as f64; + let mp_threshold = noise_var * (1.0 + ratio.sqrt()).powi(2); + let baseline_count = eigenvalues + .iter() + .filter(|&&ev| ev > mp_threshold) + .count(); + + (energies, modes, baseline_count) + } + Err(_) => { + // Fallback to diagonal approximation on SVD failure + diagonal_fallback(&self.link_stats, n_sc, n_modes) + } + } + } else { + diagonal_fallback(&self.link_stats, n_sc, n_modes) + } + } else { + diagonal_fallback(&self.link_stats, n_sc, n_modes) + }; + + // Compute variance explained + let total_energy: f64 = mode_energies.iter().sum(); + // For variance_explained, we need total variance across all subcarriers. + // Use the sum of all eigenvalues (== trace of covariance == total variance). + let total_variance = if let Some(ref cov_sum) = self.covariance_sum { + if self.covariance_count > 1 { + let scale = 1.0 / (self.covariance_count as f64 - 1.0); + (0..n_sc).map(|i| (cov_sum[[i, i]] * scale).max(0.0)).sum::() + } else { + total_energy } - } - let n_links_f = self.config.n_links as f64; - for v in avg_variance.iter_mut() { - *v /= n_links_f; - } - - // Extract modes via simplified power iteration on the diagonal - // covariance. Since we use a diagonal approximation, the eigenmodes - // are aligned with the standard basis, sorted by variance. - let total_variance: f64 = avg_variance.iter().sum(); - - // Sort subcarrier indices by variance (descending) to pick top-K modes - let mut indices: Vec = (0..n_sc).collect(); - indices.sort_by(|&a, &b| { - avg_variance[b] - .partial_cmp(&avg_variance[a]) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - let mut environmental_modes = Vec::with_capacity(n_modes); - let mut mode_energies = Vec::with_capacity(n_modes); - let mut explained = 0.0_f64; - - for k in 0..n_modes { - let idx = indices[k]; - // Create a unit vector along the highest-variance subcarrier - let mut mode = vec![0.0_f64; n_sc]; - mode[idx] = 1.0; - let energy = avg_variance[idx]; - environmental_modes.push(mode); - mode_energies.push(energy); - explained += energy; - } - + } else { + total_energy + }; let variance_explained = if total_variance > 1e-15 { - explained / total_variance + total_energy / total_variance } else { 0.0 }; @@ -459,6 +612,7 @@ impl FieldModel { variance_explained, calibrated_at_us: timestamp_us, geometry_hash, + baseline_eigenvalue_count: baseline_eig_count, }; self.modes = Some(field_mode); @@ -541,6 +695,84 @@ impl FieldModel { }) } + /// Estimate room occupancy from eigenvalue analysis of recent CSI frames. + /// + /// `recent_frames`: sliding window of amplitude vectors (recommend 50 frames + /// ~ 2.5s at 20 Hz). Returns estimated person count (0 = empty room). + pub fn estimate_occupancy(&self, recent_frames: &[Vec]) -> Result { + let modes = self.modes.as_ref().ok_or(FieldModelError::NotCalibrated)?; + + let n = self.config.n_subcarriers; + if recent_frames.len() < 10 { + return Err(FieldModelError::InsufficientData { + need: 10, + have: recent_frames.len(), + }); + } + + // Build covariance matrix from recent frames + let mut mean = vec![0.0f64; n]; + let mut count = 0usize; + for frame in recent_frames { + if frame.len() >= n { + for i in 0..n { + mean[i] += frame[i]; + } + count += 1; + } + } + if count < 2 { + return Ok(0); + } + for m in &mut mean { + *m /= count as f64; + } + + let mut cov = Array2::::zeros((n, n)); + for frame in recent_frames { + if frame.len() >= n { + for i in 0..n { + let ci = frame[i] - mean[i]; + for j in i..n { + let val = ci * (frame[j] - mean[j]); + cov[[i, j]] += val; + if i != j { + cov[[j, i]] += val; + } + } + } + } + } + let scale = 1.0 / (count as f64 - 1.0); + cov *= scale; + + // Eigendecompose + let eigenvalues = match cov.eigh(UPLO::Upper) { + Ok((evals, _)) => evals, + Err(_) => return Ok(0), // SVD failure = can't estimate + }; + + // Marcenko-Pastur noise threshold + let noise_var = { + let mut sorted: Vec = eigenvalues.iter().copied().collect(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + // Median of bottom half as robust noise estimate + let half = sorted.len() / 2; + if half > 0 { + sorted[..half].iter().sum::() / half as f64 + } else { + 1.0 + } + }; + let ratio = n as f64 / count as f64; + let mp_threshold = noise_var * (1.0 + ratio.sqrt()).powi(2); + + let significant = eigenvalues.iter().filter(|&&ev| ev > mp_threshold).count(); + let occupancy = significant.saturating_sub(modes.baseline_eigenvalue_count); + + Ok(occupancy.min(10)) // Cap at 10 persons + } + /// Check calibration freshness against a given timestamp. pub fn check_freshness(&self, current_us: u64) -> CalibrationStatus { if self.modes.is_none() { @@ -563,6 +795,8 @@ impl FieldModel { .collect(); self.modes = None; self.status = CalibrationStatus::Uncalibrated; + self.covariance_sum = None; + self.covariance_count = 0; } } @@ -873,6 +1107,179 @@ mod tests { } } + #[test] + fn test_covariance_accumulation() { + let config = make_config(2, 4, 5); + let mut model = FieldModel::new(config).unwrap(); + + // Feed calibration data + for i in 0..10 { + let obs = make_observations(2, 4, 1.0 + 0.1 * i as f64); + model.feed_calibration(&obs).unwrap(); + } + + // covariance_sum should be populated + assert!(model.covariance_sum.is_some()); + assert!(model.covariance_count > 0); + let cov = model.covariance_sum.as_ref().unwrap(); + assert_eq!(cov.shape(), &[4, 4]); + // Diagonal entries should be non-negative (sum of squares) + for i in 0..4 { + assert!(cov[[i, i]] >= 0.0, "Diagonal covariance entry must be >= 0"); + } + // Matrix should be symmetric + for i in 0..4 { + for j in 0..4 { + assert!( + (cov[[i, j]] - cov[[j, i]]).abs() < 1e-10, + "Covariance matrix must be symmetric" + ); + } + } + } + + #[test] + fn test_svd_finalize_produces_orthonormal_modes() { + let config = FieldModelConfig { + n_links: 1, + n_subcarriers: 8, + n_modes: 3, + min_calibration_frames: 20, + baseline_expiry_s: 86_400.0, + }; + let mut model = FieldModel::new(config).unwrap(); + + // Feed frames with correlated subcarrier patterns to produce + // non-trivial eigenmodes + for i in 0..50 { + let t = i as f64 * 0.1; + let obs = vec![vec![ + 1.0 + t.sin(), + 2.0 + t.cos(), + 3.0 + 0.5 * t.sin(), + 4.0 + 0.3 * t.cos(), + 5.0 + 0.1 * t, + 6.0, + 7.0 + 0.2 * (2.0 * t).sin(), + 8.0 + 0.1 * (2.0 * t).cos(), + ]]; + model.feed_calibration(&obs).unwrap(); + } + model.finalize_calibration(1_000_000, 0).unwrap(); + + let modes = model.modes().unwrap(); + // Each mode should be approximately unit length + for (k, mode) in modes.environmental_modes.iter().enumerate() { + let norm: f64 = mode.iter().map(|x| x * x).sum::().sqrt(); + assert!( + (norm - 1.0).abs() < 0.01, + "Mode {} has norm {} (expected ~1.0)", + k, + norm + ); + } + // Modes should be approximately orthogonal + for i in 0..modes.environmental_modes.len() { + for j in (i + 1)..modes.environmental_modes.len() { + let dot: f64 = modes.environmental_modes[i] + .iter() + .zip(modes.environmental_modes[j].iter()) + .map(|(a, b)| a * b) + .sum(); + assert!( + dot.abs() < 0.05, + "Modes {} and {} have dot product {} (expected ~0)", + i, + j, + dot + ); + } + } + } + + #[test] + fn test_estimate_occupancy_noise_only() { + let config = FieldModelConfig { + n_links: 1, + n_subcarriers: 8, + n_modes: 3, + min_calibration_frames: 20, + baseline_expiry_s: 86_400.0, + }; + let mut model = FieldModel::new(config).unwrap(); + + // Calibrate with some deterministic noise-like pattern + for i in 0..50 { + let t = i as f64 * 0.1; + let obs = vec![vec![ + 1.0 + 0.01 * t.sin(), + 2.0 + 0.01 * t.cos(), + 3.0 + 0.01 * (2.0 * t).sin(), + 4.0 + 0.01 * (2.0 * t).cos(), + 5.0 + 0.01 * (3.0 * t).sin(), + 6.0 + 0.01 * (3.0 * t).cos(), + 7.0 + 0.01 * (4.0 * t).sin(), + 8.0 + 0.01 * (4.0 * t).cos(), + ]]; + model.feed_calibration(&obs).unwrap(); + } + model.finalize_calibration(1_000_000, 0).unwrap(); + + // Estimate occupancy with similar noise-only frames + let frames: Vec> = (0..20) + .map(|i| { + let t = (i + 50) as f64 * 0.1; + vec![ + 1.0 + 0.01 * t.sin(), + 2.0 + 0.01 * t.cos(), + 3.0 + 0.01 * (2.0 * t).sin(), + 4.0 + 0.01 * (2.0 * t).cos(), + 5.0 + 0.01 * (3.0 * t).sin(), + 6.0 + 0.01 * (3.0 * t).cos(), + 7.0 + 0.01 * (4.0 * t).sin(), + 8.0 + 0.01 * (4.0 * t).cos(), + ] + }) + .collect(); + let occupancy = model.estimate_occupancy(&frames).unwrap(); + assert_eq!(occupancy, 0, "Noise-only frames should yield 0 occupancy"); + } + + #[test] + fn test_baseline_eigenvalue_count_stored() { + let config = FieldModelConfig { + n_links: 1, + n_subcarriers: 8, + n_modes: 3, + min_calibration_frames: 20, + baseline_expiry_s: 86_400.0, + }; + let mut model = FieldModel::new(config).unwrap(); + + // Feed frames with structured variance so eigenvalues are meaningful + for i in 0..50 { + let t = i as f64 * 0.1; + let obs = vec![vec![ + 1.0 + t.sin(), + 2.0 + t.cos(), + 3.0 + 0.5 * t.sin(), + 4.0 + 0.3 * t.cos(), + 5.0 + 0.1 * t, + 6.0, + 7.0, + 8.0, + ]]; + model.feed_calibration(&obs).unwrap(); + } + let modes = model.finalize_calibration(1_000_000, 0).unwrap(); + // baseline_eigenvalue_count should exist and be a reasonable value + // (at least 0, at most n_subcarriers) + assert!( + modes.baseline_eigenvalue_count <= 8, + "baseline_eigenvalue_count should be <= n_subcarriers" + ); + } + #[test] fn test_environmental_projection_removes_drift() { let config = make_config(1, 4, 10);