diff --git a/CHECKLIST.md b/CHECKLIST.md index 9a8cee2f..5cafa9f3 100644 --- a/CHECKLIST.md +++ b/CHECKLIST.md @@ -69,6 +69,15 @@ each with explicit reason) listed at the bottom. (`csi_cfg/target_ip` + `target_port`) without USB; recovered both nodes after Mac IP move TP-Link → .103 +### Pose model + +- [x] **ADR-116** WiFlow-v1 supervised pose loader (Rust) — `--wiflow-model + data/models/ruview/wiflow-v1/wiflow-v1.json` flips + `pose_estimation: true`; per-tick TCN forward yields 17 COCO + keypoints on `/api/v1/pose/current` and WS `pose_data`. Output + quality requires per-deployment fine-tune (LoRA adapters or + re-train, see Pack E). + ### Tests / fixtures - [x] **ADR-114** `tests/fixtures/replay_idle.jsonl` + diff --git a/docs/adr/ADR-116-wiflow-v1-supervised-pose-loader.md b/docs/adr/ADR-116-wiflow-v1-supervised-pose-loader.md new file mode 100644 index 00000000..dc87292f --- /dev/null +++ b/docs/adr/ADR-116-wiflow-v1-supervised-pose-loader.md @@ -0,0 +1,224 @@ +# ADR-116 — WiFlow-v1 Supervised Pose Loader (Rust) + +**Status**: Accepted (integration), needs fine-tune (output quality) +**Date**: 2026-05-17 +**Scope**: `v2/crates/wifi-densepose-sensing-server/src/wiflow_v1.rs` (new, +~430 lines incl. tests), `src/main.rs` (CLI flag + load + 5 tick-site hooks + +`pose_current` keypoint path), `src/lib.rs` (module export). + +## Context + +Until this ADR `/api/v1/pose/*` always returned an empty `persons` array +(ADR-105 — no synthetic fallback when no real model is loaded). HuggingFace +`ruv/ruview/wiflow-v1/wiflow-v1.json` is the project's official supervised +pose model (Apache-2.0, 974 KB, 92.9 % PCK@20 on its training set). It just +sat on disk because there was no Rust loader — the only reference impl is +`scripts/train-wiflow-supervised.js` (JS, training script, not deployment). + +This ADR ports the JS inference path to Rust so sensing-server can serve +real 17-keypoint COCO skeletons in production. + +## What was wrong in the model file (and how this ADR works around it) + +The HuggingFace JSON has an `architecture` field that **lies**: + +```json +"architecture": { + "tcnChannels": [35, 256, 256, 192, 128], + "tcnKernel": 7, + "tcnDilations": [1, 2, 4, 8], + "fcDims": [2560, 2048, 34] +} +``` + +That's the `full` scale (~7.7 M params). The file is actually the **lite** +scale (186,946 params — confirmed by `totalParams` field). The exporter at +`train-wiflow-supervised.js:1599` hardcodes the full-scale dict for every +scale. The loader trusts `totalParams` and ignores `architecture`. + +Lite topology (recovered from `SCALE.lite` at `train-wiflow-supervised.js:135` +and verified by exact param count = 186,946): + +* 2 TCN blocks (NOT 4), kernel = 3 (NOT 7), dilations [1, 2] (NOT [1,2,4,8]) +* TCN channels: 35 → 32 → 32 +* Per block: causal_conv → BN → ReLU → causal_conv → BN + residual → ReLU + (1×1 projection on residual when in_ch ≠ out_ch, only block 0) +* Flatten 32 × 20 = 640 → fc1 (640→256) → ReLU → fc2 (256→34) +* Sigmoid on final 34-dim → 17 (x, y) keypoints in [0, 1] + +## Decisions + +### D1 — Pure-Rust forward pass, no new crates + +`wiflow_v1.rs` is self-contained: Vec math by hand, inline base64 +decoder (50 LoC), no `ndarray`, no `candle`, no `base64` crate added. The +inference is small enough (~250 K flops/forward) that hand-written Vec +loops are clearer than pulling a tensor framework for one model. + +### D2 — Weight stream order matches `collectParams()` in the JS trainer + +``` +for each TCN block: + conv1.weight (in_ch * k * out_ch f32s) + conv1.bias (out_ch) + bn1.gamma (out_ch) + bn1.beta (out_ch) + conv2.weight, conv2.bias, bn2.gamma, bn2.beta + (if in_ch != out_ch: res.weight, res.bias) +fc1.weight, fc1.bias, fc2.weight, fc2.bias +``` + +Loader asserts the stream is fully consumed (`Cursor::remaining() == 0`) +after fc2 — catches silent topology mismatches. Param count check +(`totalParams == 186_946`) catches scale mismatch before unpacking. + +### D3 — BatchNorm uses per-window mean/var (matches JS impl) + +`train-wiflow-supervised.js:770` computes mean/var across the T axis at +inference time, ignoring `runMean/runVar` accumulated during training. +Loader skips running stats entirely (only 2 params per channel stored: +gamma + beta). This is unusual but consistent — the network was trained +this way, so we infer this way. + +### D4 — Input prep: top-35 subcarriers by NBVI, raw amplitudes + +`build_input_from_history` (in `wiflow_v1.rs`): + +1. Take last 20 frames from any node's `AmpState.nbvi_history` (Vec>). +2. Rank subcarriers by NBVI score (`α·σ/μ² + (1−α)·σ/μ`, α = 0.5) — same + formula the classifier uses, but pick K = 35 (model input), not K = 12 + (classifier). +3. Apply 25th-percentile dead-zone gate to skip guard tones / null bins. +4. Build flat `[35 * 20]` row-major tensor of raw amplitudes (no z-score — + training data wasn't normalised either, BN handles it). + +If fewer than 20 frames or all subcarriers gated out → return `None`, +inference skipped this tick, `pose_keypoints: None` in SensingUpdate. + +### D5 — Per-tick inference, longest-history node + +`run_wiflow_inference()` at every `broadcast_tick_task` step (5 sites total +in `main.rs`): + +* Picks the node with longest `nbvi_history` (ties broken by smallest + node_id — deterministic). +* Cost: ~250 K flops on the lite scale (BN + 2 small convs + 2 FCs). + Measured 0.4 ms on the Mac M1 — well under the 100 ms tick budget. +* Returns `Vec<[f64; 4]>` of length 17 (`[x, y, z=0, conf=1]`). + +### D6 — `pose_current` reads `pose_keypoints` directly + +Pre-ADR: `/api/v1/pose/current` read `latest_update.persons`. The tracker +populated `persons` from `derive_pose_from_sensing` (signal-derived, +synthetic) regardless of `model_loaded`. Loader-output `pose_keypoints` +was only read by the WS broadcaster. + +This ADR makes `pose_current` prefer `pose_keypoints` when 17-len and +present, building a single `PersonDetection` with COCO joint names. Falls +back to tracker `persons` only when `pose_keypoints` is `None` (cold +start). Keeps the ADR-105 honesty gate: empty array if `model_loaded = +false`. + +### D7 — Honest about output quality + +The loaded model produces **17 keypoints**, but the **numerical values +are saturated** (most x/y near 0 or 1) — sigmoid extremes meaning the +network has no learned response to our specific deployment's CSI +distribution. This is expected: the model was trained on a different +ESP32 setup, different room, different person, with camera ground truth +we don't have here. **The integration is correct; the model needs +deployment-specific fine-tune to produce useful keypoints.** + +Two paths to usable output, left as follow-ups (Pack E): + +1. **Apply `node-1.json` / `node-2.json` LoRA adapters** (ADR-117 candidate) + — they're shipped alongside `wiflow-v1.json` in the same HuggingFace + repo, rank=8, alpha=16, target the encoder + task heads. Loader stub + + forward fold ~2 h. +2. **Re-train via `scripts/train-wiflow-supervised.js` with new ground- + truth capture** (~30 min capture + 19 min training per the model card). + Operator-side work. + +## Files Touched + +``` +v2/crates/wifi-densepose-sensing-server/src/wiflow_v1.rs (new, ~430 LoC) +v2/crates/wifi-densepose-sensing-server/src/lib.rs (+ pub mod) +v2/crates/wifi-densepose-sensing-server/src/main.rs: + + use wiflow_v1::{self, WiflowModel} + + Args.wiflow_model: Option + + static WIFLOW_MODEL: OnceLock> + + main() — load before existing --model/--load-rvf path + + fn run_wiflow_inference() -> Option> (right after csi_keepalive_task) + + 5 × `pose_keypoints: run_wiflow_inference()` at SensingUpdate sites + + pose_current — prefer pose_keypoints when 17-len; fall back to persons +docs/adr/ADR-116-wiflow-v1-supervised-pose-loader.md (this) +``` + +Binary size delta: 3.0 MB → 3.1 MB. + +## Verified Acceptance + +Live test on the operator's TP-Link deployment (.103, both nodes +192.168.0.100/.101): + +``` +$ ./target/release/sensing-server --source esp32 --csi-keepalive-pps 25 \ + --wiflow-model data/models/ruview/wiflow-v1/wiflow-v1.json + ... + ADR-116 wiflow-v1 loaded from data/models/ruview/wiflow-v1/wiflow-v1.json + (lite scale, 186946 params) + keepalive: learned address for node 2 = 192.168.0.100:63940 + keepalive: learned address for node 1 = 192.168.0.101:63844 + +$ curl :8080/api/v1/info → "pose_estimation": true +$ curl :8080/api/v1/pose/stats → "model_loaded": true, frames_processed: 2699 +$ curl :8080/api/v1/pose/current + { persons: [{id: 1, keypoints: [17 × {name, x, y, z, confidence}], ...}], + total_persons: 1, model_loaded: true } +``` + +End-to-end: model on disk → loader → forward pass → 17 keypoints → REST & +WS payload. UI's pose canvas (un-gated by ADR-105 D4) now draws what the +model emits. + +## Cargo tests + +`wiflow_v1` ships 3 unit tests covering the most-likely-to-rot bits: + +* `base64_round_trip_alphabet` — alphabet, padding, whitespace tolerance +* `sigmoid_bounds` — numerical stability at ±10 inputs +* `build_input_zero_history` — empty-history early return + +`cargo test -p wifi-densepose-sensing-server wiflow_v1` → 3 passed. + +## Open Items + +* **Pack E.1 — LoRA adapter loader.** `node-1.json` / `node-2.json` rank-8 + adapters from the same HF repo, ~21 KB each. The trainer encodes them + in the same custom format as `wiflow-v1.json` (different `format` tag), + so the loader plumbing is small. ~2 h. +* **Pack E.2 — Camera-supervised retraining for this room.** Run + `scripts/collect-ground-truth.py` against this Mac's webcam + + TP-Link/.100/.101 CSI for 5 min, then `scripts/train-wiflow- + supervised.js --scale lite`. Should drop sigmoid saturation and produce + spatially-coherent keypoints. ~1 h operator + 19 min train. +* **Inference rate-limiting.** Currently runs every tick (10 fps). If + multiple WS clients connect, each tick computes once and the result is + reused — fine. If model size grows to small/medium scale (~200K/800K + params), should cache the result per tick instead of computing per-client. +* **Per-node pose tracks.** Right now a single virtual person is emitted; + the broadcaster places it in `zone_1` with a fixed bbox. If/when LoRA + adapters disambiguate per-node viewpoints, fan out to one + `PersonDetection` per node (left/right of the room). + +## References + +* `scripts/train-wiflow-supervised.js` — JS reference implementation +* HuggingFace `ruv/ruview` — model file + LoRA adapters (Apache-2.0) +* ADR-079 — camera ground-truth training pipeline (the trainer this + loader was built against) +* ADR-105 — "no synthetic data in production runtime"; this ADR keeps + the gate but feeds it real model output +* ADR-115 — `/ota/set-target` (the prerequisite that got the CSI stream + flowing again so this loader has data to consume) diff --git a/v2/crates/wifi-densepose-sensing-server/src/lib.rs b/v2/crates/wifi-densepose-sensing-server/src/lib.rs index c9f9445e..f8c2a8f9 100644 --- a/v2/crates/wifi-densepose-sensing-server/src/lib.rs +++ b/v2/crates/wifi-densepose-sensing-server/src/lib.rs @@ -19,3 +19,5 @@ pub mod sona; pub mod sparse_inference; #[allow(dead_code)] pub mod embedding; +/// ADR-116: WiFlow-v1 supervised pose model loader + Rust forward pass. +pub mod wiflow_v1; diff --git a/v2/crates/wifi-densepose-sensing-server/src/main.rs b/v2/crates/wifi-densepose-sensing-server/src/main.rs index 1ed2d9fe..99519292 100644 --- a/v2/crates/wifi-densepose-sensing-server/src/main.rs +++ b/v2/crates/wifi-densepose-sensing-server/src/main.rs @@ -24,6 +24,9 @@ mod vital_signs; // Training pipeline modules (exposed via lib.rs) use wifi_densepose_sensing_server::{graph_transformer, trainer, dataset, embedding}; +// ADR-116: WiFlow-v1 supervised pose inference. +use wifi_densepose_sensing_server::wiflow_v1::{self, WiflowModel}; + use std::collections::{HashMap, VecDeque}; use ruvector_mincut::{DynamicMinCut, MinCutBuilder}; use std::net::SocketAddr; @@ -1163,8 +1166,21 @@ struct Args { /// Start field model calibration on boot (empty room required) #[arg(long)] calibrate: bool, + + /// ADR-116: Load WiFlow-v1 supervised pose model JSON + /// (`v2/data/models/ruview/wiflow-v1/wiflow-v1.json`). When loaded, + /// `pose_estimation` flips to true and `/api/v1/pose/*` returns + /// real 17-keypoint COCO skeletons instead of empty arrays. + /// Independent from `--model` (RVF container) and `--load-rvf`. + #[arg(long, value_name = "PATH")] + wiflow_model: Option, } +/// ADR-116: globally-shared WiFlow-v1 model. Loaded once at startup if +/// `--wiflow-model` was passed; consumed by `run_wiflow_inference()` on +/// every tick. None ⇒ pose endpoints stay gated per ADR-105. +static WIFLOW_MODEL: OnceLock> = OnceLock::new(); + // ── Data types ─────────────────────────────────────────────────────────────── /// ADR-018 ESP32 CSI binary frame header (20 bytes) @@ -3047,7 +3063,7 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) { signal_quality_score: sig_quality_score, quality_verdict: verdict_str, bssid_count: bssid_n, - pose_keypoints: None, + pose_keypoints: run_wiflow_inference(), model_status: None, persons: None, estimated_persons: if est_persons > 0 { Some(est_persons) } else { None }, @@ -3191,7 +3207,7 @@ async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) { signal_quality_score: None, quality_verdict: None, bssid_count: None, - pose_keypoints: None, + pose_keypoints: run_wiflow_inference(), model_status: None, persons: None, estimated_persons: if est_persons > 0 { Some(est_persons) } else { None }, @@ -4142,12 +4158,40 @@ async fn api_info(State(state): State) -> Json { async fn pose_current(State(state): State) -> Json { let s = state.read().await; - // ADR-105: only return persons when a trained pose model is loaded. - // Without a model we used to synthesise placeholder 17-keypoint - // skeletons from `derive_pose_from_sensing` so the UI looked alive; - // that's a lie about capability. Empty array now if no model. + // ADR-105 / ADR-116: when a trained pose model is loaded, prefer the + // WiFlow-v1 keypoints stamped onto the latest SensingUpdate + // (`pose_keypoints` Vec<[x,y,z,conf]>). Falls back to the tracker's + // `persons` only if no fresh model output is present. Without a model + // the endpoint stays empty per ADR-105 ("no synthetic data in + // production runtime"). let persons = if s.model_loaded { - s.latest_update.as_ref().and_then(|u| u.persons.clone()).unwrap_or_default() + let from_model = s.latest_update.as_ref() + .and_then(|u| u.pose_keypoints.as_ref()) + .filter(|kps| kps.len() == 17) + .map(|kps| { + let kp_names = [ + "nose","left_eye","right_eye","left_ear","right_ear", + "left_shoulder","right_shoulder","left_elbow","right_elbow", + "left_wrist","right_wrist","left_hip","right_hip", + "left_knee","right_knee","left_ankle","right_ankle", + ]; + let keypoints: Vec = kps.iter().enumerate() + .map(|(i, kp)| PoseKeypoint { + name: kp_names.get(i).unwrap_or(&"unknown").to_string(), + x: kp[0], y: kp[1], z: kp[2], confidence: kp[3], + }) + .collect(); + vec![PersonDetection { + id: 1, + confidence: s.latest_update.as_ref() + .map(|u| u.classification.confidence).unwrap_or(0.0), + bbox: BoundingBox { x: 260.0, y: 150.0, width: 120.0, height: 220.0 }, + keypoints, + zone: "zone_1".into(), + }] + }); + from_model.unwrap_or_else(|| + s.latest_update.as_ref().and_then(|u| u.persons.clone()).unwrap_or_default()) } else { Vec::new() }; @@ -5133,6 +5177,46 @@ async fn csi_keepalive_task(pps: u32) { } } +/// ADR-116: run one WiFlow-v1 forward pass over the best-available node's +/// most recent 20 amplitude frames. Returns 17 keypoints in the WS-payload +/// shape `[x, y, z, confidence]` (z=0, confidence=1.0 — the model emits +/// 2-D coords only, no per-keypoint uncertainty in this scale). +/// +/// Picks the node with the longest nbvi_history (any node id from +/// `AMP_HIST`); ties broken by smallest id (deterministic). Returns +/// `None` when: +/// * `--wiflow-model` was not passed at startup (`WIFLOW_MODEL = None`) +/// * no node has accumulated ≥ 20 frames yet (cold start) +/// * `build_input_from_history` rejects (all-zero subcarriers) +fn run_wiflow_inference() -> Option> { + let model = WIFLOW_MODEL.get().and_then(|m| m.as_ref())?; + // Snapshot the per-node history under the lock — keep critical section + // tiny so we don't stall the UDP receiver / classifier path. + let history = { + let map = amp_hist_init().lock().unwrap(); + let mut best: Option<(u8, std::collections::VecDeque>)> = None; + for (nid, st) in map.iter() { + let len = st.nbvi_history.len(); + if len < 20 { continue; } + match &best { + None => best = Some((*nid, st.nbvi_history.clone())), + Some((bid, bh)) => { + if len > bh.len() || (len == bh.len() && *nid < *bid) { + best = Some((*nid, st.nbvi_history.clone())); + } + } + } + } + best?.1 + }; + let input = wiflow_v1::build_input_from_history(&history)?; + let kp = model.forward(&input); + let out: Vec<[f64; 4]> = kp.iter() + .map(|(x, y)| [*x as f64, *y as f64, 0.0f64, 1.0f64]) + .collect(); + Some(out) +} + /// ADR-107: capture an empty-room baseline from the live WS stream /// and persist it to disk. Mirrors what `scripts/record-baseline.py` /// does, but runs in-process so the REST endpoint and the auto- @@ -5872,7 +5956,7 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) { signal_quality_score: None, quality_verdict: None, bssid_count: None, - pose_keypoints: None, + pose_keypoints: run_wiflow_inference(), model_status: None, persons: None, estimated_persons: if total_persons > 0 { Some(total_persons) } else { None }, @@ -6210,7 +6294,7 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) { signal_quality_score: None, quality_verdict: None, bssid_count: None, - pose_keypoints: None, + pose_keypoints: run_wiflow_inference(), model_status: None, persons: None, estimated_persons: if total_persons > 0 { Some(total_persons) } else { None }, @@ -6346,7 +6430,7 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) { signal_quality_score: None, quality_verdict: None, bssid_count: None, - pose_keypoints: None, + pose_keypoints: run_wiflow_inference(), model_status: if s.model_loaded { Some(serde_json::json!({ "loaded": true, @@ -6934,10 +7018,31 @@ async fn main() { None }; + // ADR-116: Load WiFlow-v1 supervised pose model if --wiflow-model was passed. + let wiflow_loaded = match args.wiflow_model.as_ref() { + Some(path) => match WiflowModel::load_from_json(path) { + Ok(m) => { + info!("ADR-116 wiflow-v1 loaded from {} (lite scale, 186946 params)", + path.display()); + let _ = WIFLOW_MODEL.set(Some(m)); + true + } + Err(e) => { + error!("ADR-116 wiflow-v1 load failed from {}: {}", path.display(), e); + let _ = WIFLOW_MODEL.set(None); + false + } + }, + None => { + let _ = WIFLOW_MODEL.set(None); + false + } + }; + // Load trained model via --model (uses progressive loading if --progressive set) let model_path = args.model.as_ref().or(args.load_rvf.as_ref()); let mut progressive_loader: Option = None; - let mut model_loaded = false; + let mut model_loaded = wiflow_loaded; if let Some(mp) = model_path { if args.progressive || args.model.is_some() { info!("Loading trained model (progressive) from {}", mp.display()); diff --git a/v2/crates/wifi-densepose-sensing-server/src/wiflow_v1.rs b/v2/crates/wifi-densepose-sensing-server/src/wiflow_v1.rs new file mode 100644 index 00000000..d4822f74 --- /dev/null +++ b/v2/crates/wifi-densepose-sensing-server/src/wiflow_v1.rs @@ -0,0 +1,466 @@ +//! ADR-116: WiFlow-v1 supervised pose model loader + inference. +//! +//! Ports `scripts/train-wiflow-supervised.js` inference path to Rust so +//! sensing-server can serve real keypoints on `/api/v1/pose/*` instead of +//! returning empty arrays per ADR-105 gate. +//! +//! The model on HuggingFace (`ruv/ruview/wiflow-v1/wiflow-v1.json`) is the +//! **lite scale** (186,946 params), NOT the `architecture` field that the +//! exporter hardcodes (which describes the `full` scale). We trust +//! `totalParams` to disambiguate. +//! +//! Topology (lite): +//! * 2 TCN blocks, kernel=3, dilations=[1,2] +//! * Per block: causal_conv1 → bn1 → relu → causal_conv2 → bn2 +//! + residual (1×1 projection if in_ch ≠ out_ch) → relu +//! * tcnChannels: 35 → 32 → 32 +//! * Flatten (32 × 20 = 640) → fc1 (640→256) → relu → fc2 (256→34) +//! * Sigmoid on final 34-dim vector → 17 (x,y) keypoints in [0, 1] +//! +//! Weight order (collectParams in train script): +//! for each tcn block: +//! conv1.weight, conv1.bias, bn1.gamma, bn1.beta, +//! conv2.weight, conv2.bias, bn2.gamma, bn2.beta, +//! (if in_ch ≠ out_ch: res.weight, res.bias) +//! fc1.weight, fc1.bias, fc2.weight, fc2.bias +//! +//! All weights are f32 little-endian, base64-encoded in `weightsBase64`. + +use std::path::Path; + +const TIME_STEPS: usize = 20; +const INPUT_DIM: usize = 35; +const NUM_KP: usize = 17; +const OUT_DIM: usize = NUM_KP * 2; // 34 +const TCN_CH: [usize; 3] = [INPUT_DIM, 32, 32]; // chain: 35 → 32 → 32 +const TCN_K: usize = 3; +const TCN_DIL: [usize; 2] = [1, 2]; +const HIDDEN: usize = 256; +const FLAT_DIM: usize = 32 * TIME_STEPS; // 640 + +/// CausalConv1d weights: `weight[oc*(in_ch*k) + ic*k + tap]`, bias `[oc]`. +#[derive(Debug, Clone)] +struct Conv1d { + in_ch: usize, + out_ch: usize, + kernel: usize, + dilation: usize, + weight: Vec, + bias: Vec, +} + +/// BatchNorm1d: 2 params per channel (gamma, beta). Running stats are NOT +/// serialized — JS impl re-computes mean/var per window at inference time. +#[derive(Debug, Clone)] +struct BatchNorm { + channels: usize, + gamma: Vec, + beta: Vec, +} + +#[derive(Debug, Clone)] +struct TcnBlock { + conv1: Conv1d, + bn1: BatchNorm, + conv2: Conv1d, + bn2: BatchNorm, + res: Option, // 1×1 projection when in_ch ≠ out_ch +} + +#[derive(Debug, Clone)] +struct Linear { + in_dim: usize, + out_dim: usize, + /// Row-major `[in_dim, out_dim]` — matches JS `weight[i*outDim + j]`. + weight: Vec, + bias: Vec, +} + +#[derive(Debug, Clone)] +pub struct WiflowModel { + blocks: [TcnBlock; 2], + fc1: Linear, + fc2: Linear, +} + +#[derive(Debug)] +pub struct LoadError(pub String); + +impl std::fmt::Display for LoadError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "wiflow_v1 load: {}", self.0) + } +} + +impl std::error::Error for LoadError {} + +impl WiflowModel { + pub fn load_from_json(path: &Path) -> Result { + let raw = std::fs::read_to_string(path) + .map_err(|e| LoadError(format!("read {}: {e}", path.display())))?; + let v: serde_json::Value = serde_json::from_str(&raw) + .map_err(|e| LoadError(format!("json parse: {e}")))?; + + let total = v.get("totalParams").and_then(|x| x.as_u64()).unwrap_or(0) as usize; + if total != 186_946 { + return Err(LoadError(format!( + "totalParams={total}, expected 186946 (lite scale). The exporter \ + hardcodes the `architecture` field to the full scale; \ + totalParams is the only reliable signal." + ))); + } + + let b64 = v.get("weightsBase64").and_then(|x| x.as_str()) + .ok_or_else(|| LoadError("missing weightsBase64".into()))?; + let bytes = base64_decode(b64) + .map_err(|e| LoadError(format!("base64: {e}")))?; + if bytes.len() != total * 4 { + return Err(LoadError(format!( + "bytes={}, expected {} (totalParams*4)", bytes.len(), total * 4))); + } + let floats: Vec = bytes.chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + + let mut cur = Cursor::new(&floats); + let block0 = TcnBlock::take(&mut cur, TCN_CH[0], TCN_CH[1], TCN_K, TCN_DIL[0])?; + let block1 = TcnBlock::take(&mut cur, TCN_CH[1], TCN_CH[2], TCN_K, TCN_DIL[1])?; + let fc1 = Linear::take(&mut cur, FLAT_DIM, HIDDEN)?; + let fc2 = Linear::take(&mut cur, HIDDEN, OUT_DIM)?; + if cur.remaining() != 0 { + return Err(LoadError(format!( + "weight stream has {} unread floats after fc2 — topology mismatch", + cur.remaining() + ))); + } + + Ok(Self { blocks: [block0, block1], fc1, fc2 }) + } + + /// Forward pass. + /// `input` is `[INPUT_DIM × TIME_STEPS]` row-major (channel-major): + /// `input[c * TIME_STEPS + t]`. + /// Returns 17 keypoints as (x, y) in [0, 1]. + pub fn forward(&self, input: &[f32]) -> [(f32, f32); NUM_KP] { + debug_assert_eq!(input.len(), INPUT_DIM * TIME_STEPS); + let mut x: Vec = input.to_vec(); + // TCN blocks + x = self.blocks[0].forward(&x, TIME_STEPS); + x = self.blocks[1].forward(&x, TIME_STEPS); + // Flatten — channels-major matches JS `c * T + t` linearisation. + debug_assert_eq!(x.len(), FLAT_DIM); + // fc1 + relu + let mut h = self.fc1.forward(&x); + for v in h.iter_mut() { if *v < 0.0 { *v = 0.0; } } + // fc2 + let out = self.fc2.forward(&h); + // sigmoid → 17 (x, y) + let mut kp = [(0.0f32, 0.0f32); NUM_KP]; + for i in 0..NUM_KP { + kp[i].0 = sigmoid(out[i * 2]); + kp[i].1 = sigmoid(out[i * 2 + 1]); + } + kp + } +} + +// ── Internal layer impls ───────────────────────────────────────────────────── + +struct Cursor<'a> { + data: &'a [f32], + offset: usize, +} + +impl<'a> Cursor<'a> { + fn new(d: &'a [f32]) -> Self { Self { data: d, offset: 0 } } + fn take(&mut self, n: usize) -> Result, LoadError> { + if self.offset + n > self.data.len() { + return Err(LoadError(format!( + "weight underrun: need {}, have {}", n, self.data.len() - self.offset))); + } + let out = self.data[self.offset..self.offset + n].to_vec(); + self.offset += n; + Ok(out) + } + fn remaining(&self) -> usize { self.data.len() - self.offset } +} + +impl Conv1d { + fn take(c: &mut Cursor<'_>, in_ch: usize, out_ch: usize, k: usize, dil: usize) + -> Result + { + let weight = c.take(in_ch * k * out_ch)?; + let bias = c.take(out_ch)?; + Ok(Self { in_ch, out_ch, kernel: k, dilation: dil, weight, bias }) + } + + /// Causal conv with left padding. Input layout: `[in_ch * T]` row-major. + fn forward(&self, input: &[f32], t_steps: usize) -> Vec { + let eff_k = self.kernel + (self.kernel - 1) * (self.dilation - 1); + let pad_left = eff_k - 1; + let mut out = vec![0.0f32; self.out_ch * t_steps]; + for oc in 0..self.out_ch { + for t in 0..t_steps { + let mut sum = self.bias[oc]; + for ic in 0..self.in_ch { + for k in 0..self.kernel { + let t_idx_signed = t as isize + pad_left as isize + - (k * self.dilation) as isize; + // Left-pad with zeros: only contribute when t_idx_signed - pad_left >= 0 + let t_src = t_idx_signed - pad_left as isize; + if t_src < 0 || t_src >= t_steps as isize { continue; } + let w_idx = oc * (self.in_ch * self.kernel) + ic * self.kernel + k; + sum += self.weight[w_idx] * input[ic * t_steps + t_src as usize]; + } + } + out[oc * t_steps + t] = sum; + } + } + out + } +} + +impl BatchNorm { + fn take(c: &mut Cursor<'_>, channels: usize) -> Result { + let gamma = c.take(channels)?; + let beta = c.take(channels)?; + Ok(Self { channels, gamma, beta }) + } + + /// Per-window normalisation matching JS impl: mean/var computed across + /// the T axis at inference time (not from saved running stats). + fn forward(&self, x: &mut [f32], t_steps: usize) { + let eps = 1e-5f32; + for c in 0..self.channels { + let base = c * t_steps; + let mut mean = 0.0f32; + for t in 0..t_steps { mean += x[base + t]; } + mean /= t_steps as f32; + let mut var = 0.0f32; + for t in 0..t_steps { + let d = x[base + t] - mean; + var += d * d; + } + var /= t_steps as f32; + let inv_std = 1.0f32 / (var + eps).sqrt(); + let g = self.gamma[c]; + let b = self.beta[c]; + for t in 0..t_steps { + x[base + t] = g * (x[base + t] - mean) * inv_std + b; + } + } + } +} + +impl TcnBlock { + fn take(c: &mut Cursor<'_>, in_ch: usize, out_ch: usize, k: usize, dil: usize) + -> Result + { + let conv1 = Conv1d::take(c, in_ch, out_ch, k, dil)?; + let bn1 = BatchNorm::take(c, out_ch)?; + let conv2 = Conv1d::take(c, out_ch, out_ch, k, dil)?; + let bn2 = BatchNorm::take(c, out_ch)?; + let res = if in_ch != out_ch { + Some(Conv1d::take(c, in_ch, out_ch, 1, 1)?) + } else { None }; + Ok(Self { conv1, bn1, conv2, bn2, res }) + } + + fn forward(&self, input: &[f32], t_steps: usize) -> Vec { + let mut x = self.conv1.forward(input, t_steps); + self.bn1.forward(&mut x, t_steps); + for v in x.iter_mut() { if *v < 0.0 { *v = 0.0; } } // relu + + let mut y = self.conv2.forward(&x, t_steps); + self.bn2.forward(&mut y, t_steps); + + // Residual + let res: Vec = if let Some(r) = &self.res { + r.forward(input, t_steps) + } else { + input.to_vec() + }; + debug_assert_eq!(y.len(), res.len()); + for (yv, rv) in y.iter_mut().zip(res.iter()) { *yv += *rv; } + for v in y.iter_mut() { if *v < 0.0 { *v = 0.0; } } // relu after residual + y + } +} + +impl Linear { + fn take(c: &mut Cursor<'_>, in_dim: usize, out_dim: usize) -> Result { + let weight = c.take(in_dim * out_dim)?; + let bias = c.take(out_dim)?; + Ok(Self { in_dim, out_dim, weight, bias }) + } + + fn forward(&self, input: &[f32]) -> Vec { + let mut out = vec![0.0f32; self.out_dim]; + for j in 0..self.out_dim { + let mut s = self.bias[j]; + for i in 0..self.in_dim { + s += input[i] * self.weight[i * self.out_dim + j]; + } + out[j] = s; + } + out + } +} + +fn sigmoid(x: f32) -> f32 { + if x >= 0.0 { + let e = (-x).exp(); + 1.0 / (1.0 + e) + } else { + let e = x.exp(); + e / (1.0 + e) + } +} + +// ── Inline base64 decoder ──────────────────────────────────────────────────── +// +// Standard alphabet (A–Z, a–z, 0–9, +, /). Padding `=` tolerated. Whitespace +// (including newlines) ignored — JSON.stringify can wrap base64 across lines +// in some exporters. Avoids pulling the `base64` crate just for one decode. + +fn base64_decode(s: &str) -> Result, String> { + let mut out = Vec::with_capacity(s.len() * 3 / 4 + 4); + let mut buf: u32 = 0; + let mut bits: u32 = 0; + for ch in s.bytes() { + let v: u32 = match ch { + b'A'..=b'Z' => (ch - b'A') as u32, + b'a'..=b'z' => (ch - b'a' + 26) as u32, + b'0'..=b'9' => (ch - b'0' + 52) as u32, + b'+' => 62, + b'/' => 63, + b'=' => break, + b' ' | b'\n' | b'\r' | b'\t' => continue, + _ => return Err(format!("invalid base64 char {:#x}", ch)), + }; + buf = (buf << 6) | v; + bits += 6; + if bits >= 8 { + bits -= 8; + out.push((buf >> bits) as u8); + buf &= (1 << bits) - 1; + } + } + Ok(out) +} + +// ── Convenience input helpers ──────────────────────────────────────────────── + +/// Build the `[INPUT_DIM × TIME_STEPS]` input tensor from the most recent +/// `TIME_STEPS` per-frame amplitude vectors of a single node. Picks the +/// `INPUT_DIM` (35) subcarriers with smallest NBVI score (most useful), using +/// the same per-subcarrier `α·σ/μ² + (1−α)·σ/μ` formula the classifier uses, +/// but with K=35 instead of NBVI_TOP_K=12 — model expects 35 channels. +/// +/// Returns `None` if the history has fewer than `TIME_STEPS` frames or all +/// subcarriers are zero / unusable. +pub fn build_input_from_history( + history: &std::collections::VecDeque>, +) -> Option> { + let n = history.len(); + if n < TIME_STEPS { return None; } + // Take the last 20 frames. + let recent: Vec<&Vec> = history.iter().rev().take(TIME_STEPS).collect(); + // recent is reverse-chronological; we want chronological for forward pass. + let recent: Vec<&Vec> = recent.into_iter().rev().collect(); + let n_sub = recent[0].len(); + if n_sub == 0 { return None; } + + // Per-subcarrier mean and std over the 20 frames. + let mut score: Vec<(usize, f64)> = (0..n_sub).map(|k| { + let mut sum = 0.0f64; + for f in &recent { sum += f.get(k).copied().unwrap_or(0.0); } + let mu = sum / TIME_STEPS as f64; + if mu.abs() < 1e-9 { return (k, f64::INFINITY); } + let mut var = 0.0f64; + for f in &recent { + let d = f.get(k).copied().unwrap_or(0.0) - mu; + var += d * d; + } + let sigma = (var / TIME_STEPS as f64).sqrt(); + // NBVI (α = 0.5): 0.5 * (σ/μ²) + 0.5 * (σ/μ) + let mu2 = mu * mu; + let nbvi = 0.5 * (sigma / mu2) + 0.5 * (sigma / mu.abs()); + (k, nbvi) + }).collect(); + + // 25th-percentile dead-zone gate (drop subcarriers with mean amplitude + // below the lower quartile). + let mut means: Vec = (0..n_sub).map(|k| { + let mut s = 0.0f64; + for f in &recent { s += f.get(k).copied().unwrap_or(0.0); } + s / TIME_STEPS as f64 + }).collect(); + means.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let q25_idx = (n_sub as f64 * 0.25) as usize; + let dead_thresh = means.get(q25_idx).copied().unwrap_or(0.0); + for (k, s) in score.iter_mut() { + // Re-compute mean for this k to gate (means above is sorted, indices lost). + let mut sum = 0.0f64; + for f in &recent { sum += f.get(*k).copied().unwrap_or(0.0); } + let mu = sum / TIME_STEPS as f64; + if mu < dead_thresh { *s = f64::INFINITY; } + } + + score.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + if score.is_empty() || !score[0].1.is_finite() { return None; } + + // Pick top-INPUT_DIM (35) by lowest NBVI. If fewer than 35 are finite, + // pad with whichever finite ones we have and zero the rest — model still + // runs, it just has dead channels. + let mut picks: Vec = score.iter() + .filter(|(_, s)| s.is_finite()) + .take(INPUT_DIM) + .map(|(k, _)| *k) + .collect(); + if picks.is_empty() { return None; } + while picks.len() < INPUT_DIM { picks.push(0); } // pad with subcarrier 0 + + // Raw amplitudes pass-through. Training script (`scripts/train-wiflow- + // supervised.js::loadJsonl`) feeds raw values; the two TCN BatchNorm + // layers normalise per-channel per-window at inference time so absolute + // scale (5–50 ESP32 amplitude range) is handled by the network itself. + let mut out = vec![0.0f32; INPUT_DIM * TIME_STEPS]; + for (ci, k) in picks.iter().enumerate() { + for (t, f) in recent.iter().enumerate() { + out[ci * TIME_STEPS + t] = f.get(*k).copied().unwrap_or(0.0) as f32; + } + } + Some(out) +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn base64_round_trip_alphabet() { + // "Man" -> "TWFu" + assert_eq!(base64_decode("TWFu").unwrap(), b"Man"); + // padding + assert_eq!(base64_decode("TWE=").unwrap(), b"Ma"); + assert_eq!(base64_decode("TQ==").unwrap(), b"M"); + // whitespace tolerated + assert_eq!(base64_decode("T W\nF u").unwrap(), b"Man"); + } + + #[test] + fn sigmoid_bounds() { + assert!((sigmoid(0.0) - 0.5).abs() < 1e-6); + assert!(sigmoid(10.0) > 0.999); + assert!(sigmoid(-10.0) < 0.001); + } + + #[test] + fn build_input_zero_history() { + let h = std::collections::VecDeque::new(); + assert!(build_input_from_history(&h).is_none()); + } +}