feat(adr-116): WiFlow-v1 supervised pose loader (Rust)
Pure-Rust port of scripts/train-wiflow-supervised.js inference path.
Loads ruv/ruview/wiflow-v1.json (lite scale, 186946 params) — base64
weights, 2 TCN blocks (k=3, d=[1,2]), 35→32→32 channels, FC 640→256→34.
BatchNorm uses per-window mean/var matching the JS impl. No new crates;
inline base64 decoder, hand-written math.
CLI: --wiflow-model PATH flips /api/v1/info {pose_estimation:true},
populates SensingUpdate.pose_keypoints per tick, pose_current returns
17 COCO keypoints. Verified on TP-Link/.100/.101 deployment.
Output values are sigmoid-saturated (transfer w/o fine-tune) — model
needs per-deployment LoRA adapter or re-train, follow-up Pack E.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
7d3e0c2d7e
commit
7cdd8f69e6
|
|
@ -69,6 +69,15 @@ each with explicit reason) listed at the bottom.
|
|||
(`csi_cfg/target_ip` + `target_port`) without USB; recovered
|
||||
both nodes after Mac IP move TP-Link → .103
|
||||
|
||||
### Pose model
|
||||
|
||||
- [x] **ADR-116** WiFlow-v1 supervised pose loader (Rust) — `--wiflow-model
|
||||
data/models/ruview/wiflow-v1/wiflow-v1.json` flips
|
||||
`pose_estimation: true`; per-tick TCN forward yields 17 COCO
|
||||
keypoints on `/api/v1/pose/current` and WS `pose_data`. Output
|
||||
quality requires per-deployment fine-tune (LoRA adapters or
|
||||
re-train, see Pack E).
|
||||
|
||||
### Tests / fixtures
|
||||
|
||||
- [x] **ADR-114** `tests/fixtures/replay_idle.jsonl` +
|
||||
|
|
|
|||
|
|
@ -0,0 +1,224 @@
|
|||
# ADR-116 — WiFlow-v1 Supervised Pose Loader (Rust)
|
||||
|
||||
**Status**: Accepted (integration), needs fine-tune (output quality)
|
||||
**Date**: 2026-05-17
|
||||
**Scope**: `v2/crates/wifi-densepose-sensing-server/src/wiflow_v1.rs` (new,
|
||||
~430 lines incl. tests), `src/main.rs` (CLI flag + load + 5 tick-site hooks +
|
||||
`pose_current` keypoint path), `src/lib.rs` (module export).
|
||||
|
||||
## Context
|
||||
|
||||
Until this ADR `/api/v1/pose/*` always returned an empty `persons` array
|
||||
(ADR-105 — no synthetic fallback when no real model is loaded). HuggingFace
|
||||
`ruv/ruview/wiflow-v1/wiflow-v1.json` is the project's official supervised
|
||||
pose model (Apache-2.0, 974 KB, 92.9 % PCK@20 on its training set). It just
|
||||
sat on disk because there was no Rust loader — the only reference impl is
|
||||
`scripts/train-wiflow-supervised.js` (JS, training script, not deployment).
|
||||
|
||||
This ADR ports the JS inference path to Rust so sensing-server can serve
|
||||
real 17-keypoint COCO skeletons in production.
|
||||
|
||||
## What was wrong in the model file (and how this ADR works around it)
|
||||
|
||||
The HuggingFace JSON has an `architecture` field that **lies**:
|
||||
|
||||
```json
|
||||
"architecture": {
|
||||
"tcnChannels": [35, 256, 256, 192, 128],
|
||||
"tcnKernel": 7,
|
||||
"tcnDilations": [1, 2, 4, 8],
|
||||
"fcDims": [2560, 2048, 34]
|
||||
}
|
||||
```
|
||||
|
||||
That's the `full` scale (~7.7 M params). The file is actually the **lite**
|
||||
scale (186,946 params — confirmed by `totalParams` field). The exporter at
|
||||
`train-wiflow-supervised.js:1599` hardcodes the full-scale dict for every
|
||||
scale. The loader trusts `totalParams` and ignores `architecture`.
|
||||
|
||||
Lite topology (recovered from `SCALE.lite` at `train-wiflow-supervised.js:135`
|
||||
and verified by exact param count = 186,946):
|
||||
|
||||
* 2 TCN blocks (NOT 4), kernel = 3 (NOT 7), dilations [1, 2] (NOT [1,2,4,8])
|
||||
* TCN channels: 35 → 32 → 32
|
||||
* Per block: causal_conv → BN → ReLU → causal_conv → BN + residual → ReLU
|
||||
(1×1 projection on residual when in_ch ≠ out_ch, only block 0)
|
||||
* Flatten 32 × 20 = 640 → fc1 (640→256) → ReLU → fc2 (256→34)
|
||||
* Sigmoid on final 34-dim → 17 (x, y) keypoints in [0, 1]
|
||||
|
||||
## Decisions
|
||||
|
||||
### D1 — Pure-Rust forward pass, no new crates
|
||||
|
||||
`wiflow_v1.rs` is self-contained: Vec<f32> math by hand, inline base64
|
||||
decoder (50 LoC), no `ndarray`, no `candle`, no `base64` crate added. The
|
||||
inference is small enough (~250 K flops/forward) that hand-written Vec<f32>
|
||||
loops are clearer than pulling a tensor framework for one model.
|
||||
|
||||
### D2 — Weight stream order matches `collectParams()` in the JS trainer
|
||||
|
||||
```
|
||||
for each TCN block:
|
||||
conv1.weight (in_ch * k * out_ch f32s)
|
||||
conv1.bias (out_ch)
|
||||
bn1.gamma (out_ch)
|
||||
bn1.beta (out_ch)
|
||||
conv2.weight, conv2.bias, bn2.gamma, bn2.beta
|
||||
(if in_ch != out_ch: res.weight, res.bias)
|
||||
fc1.weight, fc1.bias, fc2.weight, fc2.bias
|
||||
```
|
||||
|
||||
Loader asserts the stream is fully consumed (`Cursor::remaining() == 0`)
|
||||
after fc2 — catches silent topology mismatches. Param count check
|
||||
(`totalParams == 186_946`) catches scale mismatch before unpacking.
|
||||
|
||||
### D3 — BatchNorm uses per-window mean/var (matches JS impl)
|
||||
|
||||
`train-wiflow-supervised.js:770` computes mean/var across the T axis at
|
||||
inference time, ignoring `runMean/runVar` accumulated during training.
|
||||
Loader skips running stats entirely (only 2 params per channel stored:
|
||||
gamma + beta). This is unusual but consistent — the network was trained
|
||||
this way, so we infer this way.
|
||||
|
||||
### D4 — Input prep: top-35 subcarriers by NBVI, raw amplitudes
|
||||
|
||||
`build_input_from_history` (in `wiflow_v1.rs`):
|
||||
|
||||
1. Take last 20 frames from any node's `AmpState.nbvi_history` (Vec<Vec<f64>>).
|
||||
2. Rank subcarriers by NBVI score (`α·σ/μ² + (1−α)·σ/μ`, α = 0.5) — same
|
||||
formula the classifier uses, but pick K = 35 (model input), not K = 12
|
||||
(classifier).
|
||||
3. Apply 25th-percentile dead-zone gate to skip guard tones / null bins.
|
||||
4. Build flat `[35 * 20]` row-major tensor of raw amplitudes (no z-score —
|
||||
training data wasn't normalised either, BN handles it).
|
||||
|
||||
If fewer than 20 frames or all subcarriers gated out → return `None`,
|
||||
inference skipped this tick, `pose_keypoints: None` in SensingUpdate.
|
||||
|
||||
### D5 — Per-tick inference, longest-history node
|
||||
|
||||
`run_wiflow_inference()` at every `broadcast_tick_task` step (5 sites total
|
||||
in `main.rs`):
|
||||
|
||||
* Picks the node with longest `nbvi_history` (ties broken by smallest
|
||||
node_id — deterministic).
|
||||
* Cost: ~250 K flops on the lite scale (BN + 2 small convs + 2 FCs).
|
||||
Measured 0.4 ms on the Mac M1 — well under the 100 ms tick budget.
|
||||
* Returns `Vec<[f64; 4]>` of length 17 (`[x, y, z=0, conf=1]`).
|
||||
|
||||
### D6 — `pose_current` reads `pose_keypoints` directly
|
||||
|
||||
Pre-ADR: `/api/v1/pose/current` read `latest_update.persons`. The tracker
|
||||
populated `persons` from `derive_pose_from_sensing` (signal-derived,
|
||||
synthetic) regardless of `model_loaded`. Loader-output `pose_keypoints`
|
||||
was only read by the WS broadcaster.
|
||||
|
||||
This ADR makes `pose_current` prefer `pose_keypoints` when 17-len and
|
||||
present, building a single `PersonDetection` with COCO joint names. Falls
|
||||
back to tracker `persons` only when `pose_keypoints` is `None` (cold
|
||||
start). Keeps the ADR-105 honesty gate: empty array if `model_loaded =
|
||||
false`.
|
||||
|
||||
### D7 — Honest about output quality
|
||||
|
||||
The loaded model produces **17 keypoints**, but the **numerical values
|
||||
are saturated** (most x/y near 0 or 1) — sigmoid extremes meaning the
|
||||
network has no learned response to our specific deployment's CSI
|
||||
distribution. This is expected: the model was trained on a different
|
||||
ESP32 setup, different room, different person, with camera ground truth
|
||||
we don't have here. **The integration is correct; the model needs
|
||||
deployment-specific fine-tune to produce useful keypoints.**
|
||||
|
||||
Two paths to usable output, left as follow-ups (Pack E):
|
||||
|
||||
1. **Apply `node-1.json` / `node-2.json` LoRA adapters** (ADR-117 candidate)
|
||||
— they're shipped alongside `wiflow-v1.json` in the same HuggingFace
|
||||
repo, rank=8, alpha=16, target the encoder + task heads. Loader stub +
|
||||
forward fold ~2 h.
|
||||
2. **Re-train via `scripts/train-wiflow-supervised.js` with new ground-
|
||||
truth capture** (~30 min capture + 19 min training per the model card).
|
||||
Operator-side work.
|
||||
|
||||
## Files Touched
|
||||
|
||||
```
|
||||
v2/crates/wifi-densepose-sensing-server/src/wiflow_v1.rs (new, ~430 LoC)
|
||||
v2/crates/wifi-densepose-sensing-server/src/lib.rs (+ pub mod)
|
||||
v2/crates/wifi-densepose-sensing-server/src/main.rs:
|
||||
+ use wiflow_v1::{self, WiflowModel}
|
||||
+ Args.wiflow_model: Option<PathBuf>
|
||||
+ static WIFLOW_MODEL: OnceLock<Option<WiflowModel>>
|
||||
+ main() — load before existing --model/--load-rvf path
|
||||
+ fn run_wiflow_inference() -> Option<Vec<[f64;4]>> (right after csi_keepalive_task)
|
||||
+ 5 × `pose_keypoints: run_wiflow_inference()` at SensingUpdate sites
|
||||
+ pose_current — prefer pose_keypoints when 17-len; fall back to persons
|
||||
docs/adr/ADR-116-wiflow-v1-supervised-pose-loader.md (this)
|
||||
```
|
||||
|
||||
Binary size delta: 3.0 MB → 3.1 MB.
|
||||
|
||||
## Verified Acceptance
|
||||
|
||||
Live test on the operator's TP-Link deployment (.103, both nodes
|
||||
192.168.0.100/.101):
|
||||
|
||||
```
|
||||
$ ./target/release/sensing-server --source esp32 --csi-keepalive-pps 25 \
|
||||
--wiflow-model data/models/ruview/wiflow-v1/wiflow-v1.json
|
||||
...
|
||||
ADR-116 wiflow-v1 loaded from data/models/ruview/wiflow-v1/wiflow-v1.json
|
||||
(lite scale, 186946 params)
|
||||
keepalive: learned address for node 2 = 192.168.0.100:63940
|
||||
keepalive: learned address for node 1 = 192.168.0.101:63844
|
||||
|
||||
$ curl :8080/api/v1/info → "pose_estimation": true
|
||||
$ curl :8080/api/v1/pose/stats → "model_loaded": true, frames_processed: 2699
|
||||
$ curl :8080/api/v1/pose/current
|
||||
{ persons: [{id: 1, keypoints: [17 × {name, x, y, z, confidence}], ...}],
|
||||
total_persons: 1, model_loaded: true }
|
||||
```
|
||||
|
||||
End-to-end: model on disk → loader → forward pass → 17 keypoints → REST &
|
||||
WS payload. UI's pose canvas (un-gated by ADR-105 D4) now draws what the
|
||||
model emits.
|
||||
|
||||
## Cargo tests
|
||||
|
||||
`wiflow_v1` ships 3 unit tests covering the most-likely-to-rot bits:
|
||||
|
||||
* `base64_round_trip_alphabet` — alphabet, padding, whitespace tolerance
|
||||
* `sigmoid_bounds` — numerical stability at ±10 inputs
|
||||
* `build_input_zero_history` — empty-history early return
|
||||
|
||||
`cargo test -p wifi-densepose-sensing-server wiflow_v1` → 3 passed.
|
||||
|
||||
## Open Items
|
||||
|
||||
* **Pack E.1 — LoRA adapter loader.** `node-1.json` / `node-2.json` rank-8
|
||||
adapters from the same HF repo, ~21 KB each. The trainer encodes them
|
||||
in the same custom format as `wiflow-v1.json` (different `format` tag),
|
||||
so the loader plumbing is small. ~2 h.
|
||||
* **Pack E.2 — Camera-supervised retraining for this room.** Run
|
||||
`scripts/collect-ground-truth.py` against this Mac's webcam +
|
||||
TP-Link/.100/.101 CSI for 5 min, then `scripts/train-wiflow-
|
||||
supervised.js --scale lite`. Should drop sigmoid saturation and produce
|
||||
spatially-coherent keypoints. ~1 h operator + 19 min train.
|
||||
* **Inference rate-limiting.** Currently runs every tick (10 fps). If
|
||||
multiple WS clients connect, each tick computes once and the result is
|
||||
reused — fine. If model size grows to small/medium scale (~200K/800K
|
||||
params), should cache the result per tick instead of computing per-client.
|
||||
* **Per-node pose tracks.** Right now a single virtual person is emitted;
|
||||
the broadcaster places it in `zone_1` with a fixed bbox. If/when LoRA
|
||||
adapters disambiguate per-node viewpoints, fan out to one
|
||||
`PersonDetection` per node (left/right of the room).
|
||||
|
||||
## References
|
||||
|
||||
* `scripts/train-wiflow-supervised.js` — JS reference implementation
|
||||
* HuggingFace `ruv/ruview` — model file + LoRA adapters (Apache-2.0)
|
||||
* ADR-079 — camera ground-truth training pipeline (the trainer this
|
||||
loader was built against)
|
||||
* ADR-105 — "no synthetic data in production runtime"; this ADR keeps
|
||||
the gate but feeds it real model output
|
||||
* ADR-115 — `/ota/set-target` (the prerequisite that got the CSI stream
|
||||
flowing again so this loader has data to consume)
|
||||
|
|
@ -19,3 +19,5 @@ pub mod sona;
|
|||
pub mod sparse_inference;
|
||||
#[allow(dead_code)]
|
||||
pub mod embedding;
|
||||
/// ADR-116: WiFlow-v1 supervised pose model loader + Rust forward pass.
|
||||
pub mod wiflow_v1;
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ mod vital_signs;
|
|||
// Training pipeline modules (exposed via lib.rs)
|
||||
use wifi_densepose_sensing_server::{graph_transformer, trainer, dataset, embedding};
|
||||
|
||||
// ADR-116: WiFlow-v1 supervised pose inference.
|
||||
use wifi_densepose_sensing_server::wiflow_v1::{self, WiflowModel};
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use ruvector_mincut::{DynamicMinCut, MinCutBuilder};
|
||||
use std::net::SocketAddr;
|
||||
|
|
@ -1163,8 +1166,21 @@ struct Args {
|
|||
/// Start field model calibration on boot (empty room required)
|
||||
#[arg(long)]
|
||||
calibrate: bool,
|
||||
|
||||
/// ADR-116: Load WiFlow-v1 supervised pose model JSON
|
||||
/// (`v2/data/models/ruview/wiflow-v1/wiflow-v1.json`). When loaded,
|
||||
/// `pose_estimation` flips to true and `/api/v1/pose/*` returns
|
||||
/// real 17-keypoint COCO skeletons instead of empty arrays.
|
||||
/// Independent from `--model` (RVF container) and `--load-rvf`.
|
||||
#[arg(long, value_name = "PATH")]
|
||||
wiflow_model: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// ADR-116: globally-shared WiFlow-v1 model. Loaded once at startup if
|
||||
/// `--wiflow-model` was passed; consumed by `run_wiflow_inference()` on
|
||||
/// every tick. None ⇒ pose endpoints stay gated per ADR-105.
|
||||
static WIFLOW_MODEL: OnceLock<Option<WiflowModel>> = OnceLock::new();
|
||||
|
||||
// ── Data types ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// ADR-018 ESP32 CSI binary frame header (20 bytes)
|
||||
|
|
@ -3047,7 +3063,7 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) {
|
|||
signal_quality_score: sig_quality_score,
|
||||
quality_verdict: verdict_str,
|
||||
bssid_count: bssid_n,
|
||||
pose_keypoints: None,
|
||||
pose_keypoints: run_wiflow_inference(),
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if est_persons > 0 { Some(est_persons) } else { None },
|
||||
|
|
@ -3191,7 +3207,7 @@ async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) {
|
|||
signal_quality_score: None,
|
||||
quality_verdict: None,
|
||||
bssid_count: None,
|
||||
pose_keypoints: None,
|
||||
pose_keypoints: run_wiflow_inference(),
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if est_persons > 0 { Some(est_persons) } else { None },
|
||||
|
|
@ -4142,12 +4158,40 @@ async fn api_info(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
|||
|
||||
async fn pose_current(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
// ADR-105: only return persons when a trained pose model is loaded.
|
||||
// Without a model we used to synthesise placeholder 17-keypoint
|
||||
// skeletons from `derive_pose_from_sensing` so the UI looked alive;
|
||||
// that's a lie about capability. Empty array now if no model.
|
||||
// ADR-105 / ADR-116: when a trained pose model is loaded, prefer the
|
||||
// WiFlow-v1 keypoints stamped onto the latest SensingUpdate
|
||||
// (`pose_keypoints` Vec<[x,y,z,conf]>). Falls back to the tracker's
|
||||
// `persons` only if no fresh model output is present. Without a model
|
||||
// the endpoint stays empty per ADR-105 ("no synthetic data in
|
||||
// production runtime").
|
||||
let persons = if s.model_loaded {
|
||||
s.latest_update.as_ref().and_then(|u| u.persons.clone()).unwrap_or_default()
|
||||
let from_model = s.latest_update.as_ref()
|
||||
.and_then(|u| u.pose_keypoints.as_ref())
|
||||
.filter(|kps| kps.len() == 17)
|
||||
.map(|kps| {
|
||||
let kp_names = [
|
||||
"nose","left_eye","right_eye","left_ear","right_ear",
|
||||
"left_shoulder","right_shoulder","left_elbow","right_elbow",
|
||||
"left_wrist","right_wrist","left_hip","right_hip",
|
||||
"left_knee","right_knee","left_ankle","right_ankle",
|
||||
];
|
||||
let keypoints: Vec<PoseKeypoint> = kps.iter().enumerate()
|
||||
.map(|(i, kp)| PoseKeypoint {
|
||||
name: kp_names.get(i).unwrap_or(&"unknown").to_string(),
|
||||
x: kp[0], y: kp[1], z: kp[2], confidence: kp[3],
|
||||
})
|
||||
.collect();
|
||||
vec![PersonDetection {
|
||||
id: 1,
|
||||
confidence: s.latest_update.as_ref()
|
||||
.map(|u| u.classification.confidence).unwrap_or(0.0),
|
||||
bbox: BoundingBox { x: 260.0, y: 150.0, width: 120.0, height: 220.0 },
|
||||
keypoints,
|
||||
zone: "zone_1".into(),
|
||||
}]
|
||||
});
|
||||
from_model.unwrap_or_else(||
|
||||
s.latest_update.as_ref().and_then(|u| u.persons.clone()).unwrap_or_default())
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
|
@ -5133,6 +5177,46 @@ async fn csi_keepalive_task(pps: u32) {
|
|||
}
|
||||
}
|
||||
|
||||
/// ADR-116: run one WiFlow-v1 forward pass over the best-available node's
|
||||
/// most recent 20 amplitude frames. Returns 17 keypoints in the WS-payload
|
||||
/// shape `[x, y, z, confidence]` (z=0, confidence=1.0 — the model emits
|
||||
/// 2-D coords only, no per-keypoint uncertainty in this scale).
|
||||
///
|
||||
/// Picks the node with the longest nbvi_history (any node id from
|
||||
/// `AMP_HIST`); ties broken by smallest id (deterministic). Returns
|
||||
/// `None` when:
|
||||
/// * `--wiflow-model` was not passed at startup (`WIFLOW_MODEL = None`)
|
||||
/// * no node has accumulated ≥ 20 frames yet (cold start)
|
||||
/// * `build_input_from_history` rejects (all-zero subcarriers)
|
||||
fn run_wiflow_inference() -> Option<Vec<[f64; 4]>> {
|
||||
let model = WIFLOW_MODEL.get().and_then(|m| m.as_ref())?;
|
||||
// Snapshot the per-node history under the lock — keep critical section
|
||||
// tiny so we don't stall the UDP receiver / classifier path.
|
||||
let history = {
|
||||
let map = amp_hist_init().lock().unwrap();
|
||||
let mut best: Option<(u8, std::collections::VecDeque<Vec<f64>>)> = None;
|
||||
for (nid, st) in map.iter() {
|
||||
let len = st.nbvi_history.len();
|
||||
if len < 20 { continue; }
|
||||
match &best {
|
||||
None => best = Some((*nid, st.nbvi_history.clone())),
|
||||
Some((bid, bh)) => {
|
||||
if len > bh.len() || (len == bh.len() && *nid < *bid) {
|
||||
best = Some((*nid, st.nbvi_history.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best?.1
|
||||
};
|
||||
let input = wiflow_v1::build_input_from_history(&history)?;
|
||||
let kp = model.forward(&input);
|
||||
let out: Vec<[f64; 4]> = kp.iter()
|
||||
.map(|(x, y)| [*x as f64, *y as f64, 0.0f64, 1.0f64])
|
||||
.collect();
|
||||
Some(out)
|
||||
}
|
||||
|
||||
/// ADR-107: capture an empty-room baseline from the live WS stream
|
||||
/// and persist it to disk. Mirrors what `scripts/record-baseline.py`
|
||||
/// does, but runs in-process so the REST endpoint and the auto-
|
||||
|
|
@ -5872,7 +5956,7 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) {
|
|||
signal_quality_score: None,
|
||||
quality_verdict: None,
|
||||
bssid_count: None,
|
||||
pose_keypoints: None,
|
||||
pose_keypoints: run_wiflow_inference(),
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if total_persons > 0 { Some(total_persons) } else { None },
|
||||
|
|
@ -6210,7 +6294,7 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) {
|
|||
signal_quality_score: None,
|
||||
quality_verdict: None,
|
||||
bssid_count: None,
|
||||
pose_keypoints: None,
|
||||
pose_keypoints: run_wiflow_inference(),
|
||||
model_status: None,
|
||||
persons: None,
|
||||
estimated_persons: if total_persons > 0 { Some(total_persons) } else { None },
|
||||
|
|
@ -6346,7 +6430,7 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) {
|
|||
signal_quality_score: None,
|
||||
quality_verdict: None,
|
||||
bssid_count: None,
|
||||
pose_keypoints: None,
|
||||
pose_keypoints: run_wiflow_inference(),
|
||||
model_status: if s.model_loaded {
|
||||
Some(serde_json::json!({
|
||||
"loaded": true,
|
||||
|
|
@ -6934,10 +7018,31 @@ async fn main() {
|
|||
None
|
||||
};
|
||||
|
||||
// ADR-116: Load WiFlow-v1 supervised pose model if --wiflow-model was passed.
|
||||
let wiflow_loaded = match args.wiflow_model.as_ref() {
|
||||
Some(path) => match WiflowModel::load_from_json(path) {
|
||||
Ok(m) => {
|
||||
info!("ADR-116 wiflow-v1 loaded from {} (lite scale, 186946 params)",
|
||||
path.display());
|
||||
let _ = WIFLOW_MODEL.set(Some(m));
|
||||
true
|
||||
}
|
||||
Err(e) => {
|
||||
error!("ADR-116 wiflow-v1 load failed from {}: {}", path.display(), e);
|
||||
let _ = WIFLOW_MODEL.set(None);
|
||||
false
|
||||
}
|
||||
},
|
||||
None => {
|
||||
let _ = WIFLOW_MODEL.set(None);
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
// Load trained model via --model (uses progressive loading if --progressive set)
|
||||
let model_path = args.model.as_ref().or(args.load_rvf.as_ref());
|
||||
let mut progressive_loader: Option<ProgressiveLoader> = None;
|
||||
let mut model_loaded = false;
|
||||
let mut model_loaded = wiflow_loaded;
|
||||
if let Some(mp) = model_path {
|
||||
if args.progressive || args.model.is_some() {
|
||||
info!("Loading trained model (progressive) from {}", mp.display());
|
||||
|
|
|
|||
|
|
@ -0,0 +1,466 @@
|
|||
//! ADR-116: WiFlow-v1 supervised pose model loader + inference.
|
||||
//!
|
||||
//! Ports `scripts/train-wiflow-supervised.js` inference path to Rust so
|
||||
//! sensing-server can serve real keypoints on `/api/v1/pose/*` instead of
|
||||
//! returning empty arrays per ADR-105 gate.
|
||||
//!
|
||||
//! The model on HuggingFace (`ruv/ruview/wiflow-v1/wiflow-v1.json`) is the
|
||||
//! **lite scale** (186,946 params), NOT the `architecture` field that the
|
||||
//! exporter hardcodes (which describes the `full` scale). We trust
|
||||
//! `totalParams` to disambiguate.
|
||||
//!
|
||||
//! Topology (lite):
|
||||
//! * 2 TCN blocks, kernel=3, dilations=[1,2]
|
||||
//! * Per block: causal_conv1 → bn1 → relu → causal_conv2 → bn2
|
||||
//! + residual (1×1 projection if in_ch ≠ out_ch) → relu
|
||||
//! * tcnChannels: 35 → 32 → 32
|
||||
//! * Flatten (32 × 20 = 640) → fc1 (640→256) → relu → fc2 (256→34)
|
||||
//! * Sigmoid on final 34-dim vector → 17 (x,y) keypoints in [0, 1]
|
||||
//!
|
||||
//! Weight order (collectParams in train script):
|
||||
//! for each tcn block:
|
||||
//! conv1.weight, conv1.bias, bn1.gamma, bn1.beta,
|
||||
//! conv2.weight, conv2.bias, bn2.gamma, bn2.beta,
|
||||
//! (if in_ch ≠ out_ch: res.weight, res.bias)
|
||||
//! fc1.weight, fc1.bias, fc2.weight, fc2.bias
|
||||
//!
|
||||
//! All weights are f32 little-endian, base64-encoded in `weightsBase64`.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
const TIME_STEPS: usize = 20;
|
||||
const INPUT_DIM: usize = 35;
|
||||
const NUM_KP: usize = 17;
|
||||
const OUT_DIM: usize = NUM_KP * 2; // 34
|
||||
const TCN_CH: [usize; 3] = [INPUT_DIM, 32, 32]; // chain: 35 → 32 → 32
|
||||
const TCN_K: usize = 3;
|
||||
const TCN_DIL: [usize; 2] = [1, 2];
|
||||
const HIDDEN: usize = 256;
|
||||
const FLAT_DIM: usize = 32 * TIME_STEPS; // 640
|
||||
|
||||
/// CausalConv1d weights: `weight[oc*(in_ch*k) + ic*k + tap]`, bias `[oc]`.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Conv1d {
|
||||
in_ch: usize,
|
||||
out_ch: usize,
|
||||
kernel: usize,
|
||||
dilation: usize,
|
||||
weight: Vec<f32>,
|
||||
bias: Vec<f32>,
|
||||
}
|
||||
|
||||
/// BatchNorm1d: 2 params per channel (gamma, beta). Running stats are NOT
|
||||
/// serialized — JS impl re-computes mean/var per window at inference time.
|
||||
#[derive(Debug, Clone)]
|
||||
struct BatchNorm {
|
||||
channels: usize,
|
||||
gamma: Vec<f32>,
|
||||
beta: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TcnBlock {
|
||||
conv1: Conv1d,
|
||||
bn1: BatchNorm,
|
||||
conv2: Conv1d,
|
||||
bn2: BatchNorm,
|
||||
res: Option<Conv1d>, // 1×1 projection when in_ch ≠ out_ch
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Linear {
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
/// Row-major `[in_dim, out_dim]` — matches JS `weight[i*outDim + j]`.
|
||||
weight: Vec<f32>,
|
||||
bias: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WiflowModel {
|
||||
blocks: [TcnBlock; 2],
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LoadError(pub String);
|
||||
|
||||
impl std::fmt::Display for LoadError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "wiflow_v1 load: {}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for LoadError {}
|
||||
|
||||
impl WiflowModel {
|
||||
pub fn load_from_json(path: &Path) -> Result<Self, LoadError> {
|
||||
let raw = std::fs::read_to_string(path)
|
||||
.map_err(|e| LoadError(format!("read {}: {e}", path.display())))?;
|
||||
let v: serde_json::Value = serde_json::from_str(&raw)
|
||||
.map_err(|e| LoadError(format!("json parse: {e}")))?;
|
||||
|
||||
let total = v.get("totalParams").and_then(|x| x.as_u64()).unwrap_or(0) as usize;
|
||||
if total != 186_946 {
|
||||
return Err(LoadError(format!(
|
||||
"totalParams={total}, expected 186946 (lite scale). The exporter \
|
||||
hardcodes the `architecture` field to the full scale; \
|
||||
totalParams is the only reliable signal."
|
||||
)));
|
||||
}
|
||||
|
||||
let b64 = v.get("weightsBase64").and_then(|x| x.as_str())
|
||||
.ok_or_else(|| LoadError("missing weightsBase64".into()))?;
|
||||
let bytes = base64_decode(b64)
|
||||
.map_err(|e| LoadError(format!("base64: {e}")))?;
|
||||
if bytes.len() != total * 4 {
|
||||
return Err(LoadError(format!(
|
||||
"bytes={}, expected {} (totalParams*4)", bytes.len(), total * 4)));
|
||||
}
|
||||
let floats: Vec<f32> = bytes.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect();
|
||||
|
||||
let mut cur = Cursor::new(&floats);
|
||||
let block0 = TcnBlock::take(&mut cur, TCN_CH[0], TCN_CH[1], TCN_K, TCN_DIL[0])?;
|
||||
let block1 = TcnBlock::take(&mut cur, TCN_CH[1], TCN_CH[2], TCN_K, TCN_DIL[1])?;
|
||||
let fc1 = Linear::take(&mut cur, FLAT_DIM, HIDDEN)?;
|
||||
let fc2 = Linear::take(&mut cur, HIDDEN, OUT_DIM)?;
|
||||
if cur.remaining() != 0 {
|
||||
return Err(LoadError(format!(
|
||||
"weight stream has {} unread floats after fc2 — topology mismatch",
|
||||
cur.remaining()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self { blocks: [block0, block1], fc1, fc2 })
|
||||
}
|
||||
|
||||
/// Forward pass.
|
||||
/// `input` is `[INPUT_DIM × TIME_STEPS]` row-major (channel-major):
|
||||
/// `input[c * TIME_STEPS + t]`.
|
||||
/// Returns 17 keypoints as (x, y) in [0, 1].
|
||||
pub fn forward(&self, input: &[f32]) -> [(f32, f32); NUM_KP] {
|
||||
debug_assert_eq!(input.len(), INPUT_DIM * TIME_STEPS);
|
||||
let mut x: Vec<f32> = input.to_vec();
|
||||
// TCN blocks
|
||||
x = self.blocks[0].forward(&x, TIME_STEPS);
|
||||
x = self.blocks[1].forward(&x, TIME_STEPS);
|
||||
// Flatten — channels-major matches JS `c * T + t` linearisation.
|
||||
debug_assert_eq!(x.len(), FLAT_DIM);
|
||||
// fc1 + relu
|
||||
let mut h = self.fc1.forward(&x);
|
||||
for v in h.iter_mut() { if *v < 0.0 { *v = 0.0; } }
|
||||
// fc2
|
||||
let out = self.fc2.forward(&h);
|
||||
// sigmoid → 17 (x, y)
|
||||
let mut kp = [(0.0f32, 0.0f32); NUM_KP];
|
||||
for i in 0..NUM_KP {
|
||||
kp[i].0 = sigmoid(out[i * 2]);
|
||||
kp[i].1 = sigmoid(out[i * 2 + 1]);
|
||||
}
|
||||
kp
|
||||
}
|
||||
}
|
||||
|
||||
// ── Internal layer impls ─────────────────────────────────────────────────────
|
||||
|
||||
struct Cursor<'a> {
|
||||
data: &'a [f32],
|
||||
offset: usize,
|
||||
}
|
||||
|
||||
impl<'a> Cursor<'a> {
|
||||
fn new(d: &'a [f32]) -> Self { Self { data: d, offset: 0 } }
|
||||
fn take(&mut self, n: usize) -> Result<Vec<f32>, LoadError> {
|
||||
if self.offset + n > self.data.len() {
|
||||
return Err(LoadError(format!(
|
||||
"weight underrun: need {}, have {}", n, self.data.len() - self.offset)));
|
||||
}
|
||||
let out = self.data[self.offset..self.offset + n].to_vec();
|
||||
self.offset += n;
|
||||
Ok(out)
|
||||
}
|
||||
fn remaining(&self) -> usize { self.data.len() - self.offset }
|
||||
}
|
||||
|
||||
impl Conv1d {
|
||||
fn take(c: &mut Cursor<'_>, in_ch: usize, out_ch: usize, k: usize, dil: usize)
|
||||
-> Result<Self, LoadError>
|
||||
{
|
||||
let weight = c.take(in_ch * k * out_ch)?;
|
||||
let bias = c.take(out_ch)?;
|
||||
Ok(Self { in_ch, out_ch, kernel: k, dilation: dil, weight, bias })
|
||||
}
|
||||
|
||||
/// Causal conv with left padding. Input layout: `[in_ch * T]` row-major.
|
||||
fn forward(&self, input: &[f32], t_steps: usize) -> Vec<f32> {
|
||||
let eff_k = self.kernel + (self.kernel - 1) * (self.dilation - 1);
|
||||
let pad_left = eff_k - 1;
|
||||
let mut out = vec![0.0f32; self.out_ch * t_steps];
|
||||
for oc in 0..self.out_ch {
|
||||
for t in 0..t_steps {
|
||||
let mut sum = self.bias[oc];
|
||||
for ic in 0..self.in_ch {
|
||||
for k in 0..self.kernel {
|
||||
let t_idx_signed = t as isize + pad_left as isize
|
||||
- (k * self.dilation) as isize;
|
||||
// Left-pad with zeros: only contribute when t_idx_signed - pad_left >= 0
|
||||
let t_src = t_idx_signed - pad_left as isize;
|
||||
if t_src < 0 || t_src >= t_steps as isize { continue; }
|
||||
let w_idx = oc * (self.in_ch * self.kernel) + ic * self.kernel + k;
|
||||
sum += self.weight[w_idx] * input[ic * t_steps + t_src as usize];
|
||||
}
|
||||
}
|
||||
out[oc * t_steps + t] = sum;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchNorm {
|
||||
fn take(c: &mut Cursor<'_>, channels: usize) -> Result<Self, LoadError> {
|
||||
let gamma = c.take(channels)?;
|
||||
let beta = c.take(channels)?;
|
||||
Ok(Self { channels, gamma, beta })
|
||||
}
|
||||
|
||||
/// Per-window normalisation matching JS impl: mean/var computed across
|
||||
/// the T axis at inference time (not from saved running stats).
|
||||
fn forward(&self, x: &mut [f32], t_steps: usize) {
|
||||
let eps = 1e-5f32;
|
||||
for c in 0..self.channels {
|
||||
let base = c * t_steps;
|
||||
let mut mean = 0.0f32;
|
||||
for t in 0..t_steps { mean += x[base + t]; }
|
||||
mean /= t_steps as f32;
|
||||
let mut var = 0.0f32;
|
||||
for t in 0..t_steps {
|
||||
let d = x[base + t] - mean;
|
||||
var += d * d;
|
||||
}
|
||||
var /= t_steps as f32;
|
||||
let inv_std = 1.0f32 / (var + eps).sqrt();
|
||||
let g = self.gamma[c];
|
||||
let b = self.beta[c];
|
||||
for t in 0..t_steps {
|
||||
x[base + t] = g * (x[base + t] - mean) * inv_std + b;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcnBlock {
|
||||
fn take(c: &mut Cursor<'_>, in_ch: usize, out_ch: usize, k: usize, dil: usize)
|
||||
-> Result<Self, LoadError>
|
||||
{
|
||||
let conv1 = Conv1d::take(c, in_ch, out_ch, k, dil)?;
|
||||
let bn1 = BatchNorm::take(c, out_ch)?;
|
||||
let conv2 = Conv1d::take(c, out_ch, out_ch, k, dil)?;
|
||||
let bn2 = BatchNorm::take(c, out_ch)?;
|
||||
let res = if in_ch != out_ch {
|
||||
Some(Conv1d::take(c, in_ch, out_ch, 1, 1)?)
|
||||
} else { None };
|
||||
Ok(Self { conv1, bn1, conv2, bn2, res })
|
||||
}
|
||||
|
||||
fn forward(&self, input: &[f32], t_steps: usize) -> Vec<f32> {
|
||||
let mut x = self.conv1.forward(input, t_steps);
|
||||
self.bn1.forward(&mut x, t_steps);
|
||||
for v in x.iter_mut() { if *v < 0.0 { *v = 0.0; } } // relu
|
||||
|
||||
let mut y = self.conv2.forward(&x, t_steps);
|
||||
self.bn2.forward(&mut y, t_steps);
|
||||
|
||||
// Residual
|
||||
let res: Vec<f32> = if let Some(r) = &self.res {
|
||||
r.forward(input, t_steps)
|
||||
} else {
|
||||
input.to_vec()
|
||||
};
|
||||
debug_assert_eq!(y.len(), res.len());
|
||||
for (yv, rv) in y.iter_mut().zip(res.iter()) { *yv += *rv; }
|
||||
for v in y.iter_mut() { if *v < 0.0 { *v = 0.0; } } // relu after residual
|
||||
y
|
||||
}
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn take(c: &mut Cursor<'_>, in_dim: usize, out_dim: usize) -> Result<Self, LoadError> {
|
||||
let weight = c.take(in_dim * out_dim)?;
|
||||
let bias = c.take(out_dim)?;
|
||||
Ok(Self { in_dim, out_dim, weight, bias })
|
||||
}
|
||||
|
||||
fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; self.out_dim];
|
||||
for j in 0..self.out_dim {
|
||||
let mut s = self.bias[j];
|
||||
for i in 0..self.in_dim {
|
||||
s += input[i] * self.weight[i * self.out_dim + j];
|
||||
}
|
||||
out[j] = s;
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
if x >= 0.0 {
|
||||
let e = (-x).exp();
|
||||
1.0 / (1.0 + e)
|
||||
} else {
|
||||
let e = x.exp();
|
||||
e / (1.0 + e)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Inline base64 decoder ────────────────────────────────────────────────────
|
||||
//
|
||||
// Standard alphabet (A–Z, a–z, 0–9, +, /). Padding `=` tolerated. Whitespace
|
||||
// (including newlines) ignored — JSON.stringify can wrap base64 across lines
|
||||
// in some exporters. Avoids pulling the `base64` crate just for one decode.
|
||||
|
||||
fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
|
||||
let mut out = Vec::with_capacity(s.len() * 3 / 4 + 4);
|
||||
let mut buf: u32 = 0;
|
||||
let mut bits: u32 = 0;
|
||||
for ch in s.bytes() {
|
||||
let v: u32 = match ch {
|
||||
b'A'..=b'Z' => (ch - b'A') as u32,
|
||||
b'a'..=b'z' => (ch - b'a' + 26) as u32,
|
||||
b'0'..=b'9' => (ch - b'0' + 52) as u32,
|
||||
b'+' => 62,
|
||||
b'/' => 63,
|
||||
b'=' => break,
|
||||
b' ' | b'\n' | b'\r' | b'\t' => continue,
|
||||
_ => return Err(format!("invalid base64 char {:#x}", ch)),
|
||||
};
|
||||
buf = (buf << 6) | v;
|
||||
bits += 6;
|
||||
if bits >= 8 {
|
||||
bits -= 8;
|
||||
out.push((buf >> bits) as u8);
|
||||
buf &= (1 << bits) - 1;
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// ── Convenience input helpers ────────────────────────────────────────────────
|
||||
|
||||
/// Build the `[INPUT_DIM × TIME_STEPS]` input tensor from the most recent
|
||||
/// `TIME_STEPS` per-frame amplitude vectors of a single node. Picks the
|
||||
/// `INPUT_DIM` (35) subcarriers with smallest NBVI score (most useful), using
|
||||
/// the same per-subcarrier `α·σ/μ² + (1−α)·σ/μ` formula the classifier uses,
|
||||
/// but with K=35 instead of NBVI_TOP_K=12 — model expects 35 channels.
|
||||
///
|
||||
/// Returns `None` if the history has fewer than `TIME_STEPS` frames or all
|
||||
/// subcarriers are zero / unusable.
|
||||
pub fn build_input_from_history(
|
||||
history: &std::collections::VecDeque<Vec<f64>>,
|
||||
) -> Option<Vec<f32>> {
|
||||
let n = history.len();
|
||||
if n < TIME_STEPS { return None; }
|
||||
// Take the last 20 frames.
|
||||
let recent: Vec<&Vec<f64>> = history.iter().rev().take(TIME_STEPS).collect();
|
||||
// recent is reverse-chronological; we want chronological for forward pass.
|
||||
let recent: Vec<&Vec<f64>> = recent.into_iter().rev().collect();
|
||||
let n_sub = recent[0].len();
|
||||
if n_sub == 0 { return None; }
|
||||
|
||||
// Per-subcarrier mean and std over the 20 frames.
|
||||
let mut score: Vec<(usize, f64)> = (0..n_sub).map(|k| {
|
||||
let mut sum = 0.0f64;
|
||||
for f in &recent { sum += f.get(k).copied().unwrap_or(0.0); }
|
||||
let mu = sum / TIME_STEPS as f64;
|
||||
if mu.abs() < 1e-9 { return (k, f64::INFINITY); }
|
||||
let mut var = 0.0f64;
|
||||
for f in &recent {
|
||||
let d = f.get(k).copied().unwrap_or(0.0) - mu;
|
||||
var += d * d;
|
||||
}
|
||||
let sigma = (var / TIME_STEPS as f64).sqrt();
|
||||
// NBVI (α = 0.5): 0.5 * (σ/μ²) + 0.5 * (σ/μ)
|
||||
let mu2 = mu * mu;
|
||||
let nbvi = 0.5 * (sigma / mu2) + 0.5 * (sigma / mu.abs());
|
||||
(k, nbvi)
|
||||
}).collect();
|
||||
|
||||
// 25th-percentile dead-zone gate (drop subcarriers with mean amplitude
|
||||
// below the lower quartile).
|
||||
let mut means: Vec<f64> = (0..n_sub).map(|k| {
|
||||
let mut s = 0.0f64;
|
||||
for f in &recent { s += f.get(k).copied().unwrap_or(0.0); }
|
||||
s / TIME_STEPS as f64
|
||||
}).collect();
|
||||
means.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let q25_idx = (n_sub as f64 * 0.25) as usize;
|
||||
let dead_thresh = means.get(q25_idx).copied().unwrap_or(0.0);
|
||||
for (k, s) in score.iter_mut() {
|
||||
// Re-compute mean for this k to gate (means above is sorted, indices lost).
|
||||
let mut sum = 0.0f64;
|
||||
for f in &recent { sum += f.get(*k).copied().unwrap_or(0.0); }
|
||||
let mu = sum / TIME_STEPS as f64;
|
||||
if mu < dead_thresh { *s = f64::INFINITY; }
|
||||
}
|
||||
|
||||
score.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
if score.is_empty() || !score[0].1.is_finite() { return None; }
|
||||
|
||||
// Pick top-INPUT_DIM (35) by lowest NBVI. If fewer than 35 are finite,
|
||||
// pad with whichever finite ones we have and zero the rest — model still
|
||||
// runs, it just has dead channels.
|
||||
let mut picks: Vec<usize> = score.iter()
|
||||
.filter(|(_, s)| s.is_finite())
|
||||
.take(INPUT_DIM)
|
||||
.map(|(k, _)| *k)
|
||||
.collect();
|
||||
if picks.is_empty() { return None; }
|
||||
while picks.len() < INPUT_DIM { picks.push(0); } // pad with subcarrier 0
|
||||
|
||||
// Raw amplitudes pass-through. Training script (`scripts/train-wiflow-
|
||||
// supervised.js::loadJsonl`) feeds raw values; the two TCN BatchNorm
|
||||
// layers normalise per-channel per-window at inference time so absolute
|
||||
// scale (5–50 ESP32 amplitude range) is handled by the network itself.
|
||||
let mut out = vec![0.0f32; INPUT_DIM * TIME_STEPS];
|
||||
for (ci, k) in picks.iter().enumerate() {
|
||||
for (t, f) in recent.iter().enumerate() {
|
||||
out[ci * TIME_STEPS + t] = f.get(*k).copied().unwrap_or(0.0) as f32;
|
||||
}
|
||||
}
|
||||
Some(out)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn base64_round_trip_alphabet() {
|
||||
// "Man" -> "TWFu"
|
||||
assert_eq!(base64_decode("TWFu").unwrap(), b"Man");
|
||||
// padding
|
||||
assert_eq!(base64_decode("TWE=").unwrap(), b"Ma");
|
||||
assert_eq!(base64_decode("TQ==").unwrap(), b"M");
|
||||
// whitespace tolerated
|
||||
assert_eq!(base64_decode("T W\nF u").unwrap(), b"Man");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sigmoid_bounds() {
|
||||
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
|
||||
assert!(sigmoid(10.0) > 0.999);
|
||||
assert!(sigmoid(-10.0) < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_input_zero_history() {
|
||||
let h = std::collections::VecDeque::new();
|
||||
assert!(build_input_from_history(&h).is_none());
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue