1209 lines
50 KiB
Rust
1209 lines
50 KiB
Rust
//! `wifi-densepose calibrate-serve` — HTTP API around ADR-135 baseline calibration.
|
|
//!
|
|
//! Wraps the same [`wifi_densepose_signal::CalibrationRecorder`] used by the
|
|
//! `calibrate` subcommand in a small Axum server so a UI (or any client) can
|
|
//! drive an empty-room baseline capture remotely:
|
|
//!
|
|
//! | Method | Path | Purpose |
|
|
//! |--------|-----------------------------------|-------------------------------------------|
|
|
//! | GET | `/` | API descriptor (discovery) |
|
|
//! | GET | `/api/v1/calibration/health` | liveness + UDP ingest stats |
|
|
//! | POST | `/api/v1/calibration/start` | begin a baseline capture session |
|
|
//! | GET | `/api/v1/calibration/status` | live session progress (poll this for UI) |
|
|
//! | POST | `/api/v1/calibration/stop` | finalize the current session early |
|
|
//! | GET | `/api/v1/calibration/result` | summary of the last finalized baseline |
|
|
//! | GET | `/api/v1/calibration/baselines` | list persisted baseline files |
|
|
//!
|
|
//! A single background task owns the UDP socket (ESP32 `0xC511_0001` frames) and
|
|
//! the optional active recorder; the HTTP handlers communicate with it over an
|
|
//! mpsc command channel and read a shared status snapshot. This keeps the
|
|
//! `&mut` recorder lock-free and the API non-blocking. CORS is permissive so a
|
|
//! browser UI served from any origin can call it during development.
|
|
|
|
use std::collections::{HashMap, VecDeque};
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
|
|
|
use anyhow::Result;
|
|
use axum::{
|
|
extract::{Query, State},
|
|
http::StatusCode,
|
|
response::IntoResponse,
|
|
routing::{get, post},
|
|
Json, Router,
|
|
};
|
|
use clap::Args;
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::net::UdpSocket;
|
|
use tokio::sync::{mpsc, oneshot, RwLock};
|
|
use tower_http::cors::CorsLayer;
|
|
use wifi_densepose_calibration::extract::{AnchorFeature, Features};
|
|
use wifi_densepose_calibration::{
|
|
AnchorLabel, AnchorQualityGate, AnchorRecorder, MixtureOfSpecialists, NodeGeometry,
|
|
SpecialistBank,
|
|
};
|
|
use wifi_densepose_core::types::CsiFrame;
|
|
use wifi_densepose_signal::{BaselineCalibration, CalibrationRecorder};
|
|
|
|
use crate::calibrate::{parse_csi_packet, tier_config};
|
|
|
|
/// Rolling window of per-frame scalars (mean amplitude) for live `room-state`
|
|
/// inference. Maintained by the ingest task regardless of any baseline session.
|
|
const LIVE_WINDOW: usize = 256;
|
|
|
|
/// One scalar per frame: mean amplitude across subcarriers/streams.
|
|
fn frame_scalar(frame: &CsiFrame) -> f32 {
|
|
let a = &frame.amplitude;
|
|
if a.is_empty() {
|
|
0.0
|
|
} else {
|
|
(a.sum() / a.len() as f64) as f32
|
|
}
|
|
}
|
|
|
|
const RECV_BUF: usize = 2048;
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// CLI arguments
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Arguments for the `calibrate-serve` subcommand.
|
|
#[derive(Args, Debug, Clone)]
|
|
pub struct CalibrateServeArgs {
|
|
/// TCP port for the HTTP API.
|
|
#[arg(long, default_value_t = 8090)]
|
|
pub http_port: u16,
|
|
|
|
/// Bind address for the HTTP API. Default 127.0.0.1 (localhost only);
|
|
/// use 0.0.0.0 to expose the API to the LAN for a remote UI.
|
|
#[arg(long, default_value = "127.0.0.1")]
|
|
pub http_bind: String,
|
|
|
|
/// UDP port to receive CSI frames from the ESP32 (must match provisioned target-port).
|
|
#[arg(long, default_value_t = 5005)]
|
|
pub udp_port: u16,
|
|
|
|
/// Bind address for the UDP CSI socket.
|
|
#[arg(long, default_value = "0.0.0.0")]
|
|
pub udp_bind: String,
|
|
|
|
/// Default PHY tier when a start request omits one (ht20 / ht40 / he20 / he40).
|
|
#[arg(long, default_value = "ht20")]
|
|
pub tier: String,
|
|
|
|
/// Directory where finalized baseline `.bin` files are written.
|
|
#[arg(long, default_value = "./baselines")]
|
|
pub output_dir: String,
|
|
|
|
/// Require `Authorization: Bearer <token>` on every API request. Strongly
|
|
/// recommended before binding to anything other than 127.0.0.1.
|
|
#[arg(long, env = "CALIBRATE_TOKEN")]
|
|
pub token: Option<String>,
|
|
}
|
|
|
|
/// Sanitize a client-supplied `room_id` for use in a filename (defends the
|
|
/// baseline write path against `../` / absolute-path traversal). Keeps only
|
|
/// `[A-Za-z0-9_-]`; empty result falls back to `default`.
|
|
fn sanitize_room_id(raw: &str) -> String {
|
|
let cleaned: String = raw
|
|
.chars()
|
|
.filter(|c| c.is_ascii_alphanumeric() || *c == '_' || *c == '-')
|
|
.take(64)
|
|
.collect();
|
|
if cleaned.is_empty() {
|
|
"default".into()
|
|
} else {
|
|
cleaned
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Wire types (request / response bodies)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Body for `POST /start`. All fields optional — sensible defaults applied.
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(default)]
|
|
pub struct StartParams {
|
|
/// PHY tier override (falls back to the server default).
|
|
pub tier: Option<String>,
|
|
/// Capture duration in seconds (also bounded by the tier's min-frame target).
|
|
pub duration_s: u32,
|
|
/// Optional room label, used in the persisted filename and status.
|
|
pub room_id: Option<String>,
|
|
/// Override the tier's minimum frame count (0 = use tier default).
|
|
pub min_frames: u32,
|
|
}
|
|
|
|
impl Default for StartParams {
|
|
fn default() -> Self {
|
|
Self { tier: None, duration_s: 30, room_id: None, min_frames: 0 }
|
|
}
|
|
}
|
|
|
|
/// Live per-session status snapshot returned by `GET /status`.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct SessionStatus {
|
|
/// `recording` | `finalizing` | `complete` | `aborted`.
|
|
pub state: String,
|
|
pub room_id: String,
|
|
pub tier: String,
|
|
pub frames_recorded: usize,
|
|
pub target_frames: usize,
|
|
/// 0.0..=1.0 capture progress.
|
|
pub progress: f32,
|
|
pub z_median: f32,
|
|
pub z_max: f32,
|
|
pub motion_flagged: bool,
|
|
pub elapsed_s: f32,
|
|
pub eta_s: f32,
|
|
/// Optional human-readable note (e.g. abort reason).
|
|
pub note: Option<String>,
|
|
}
|
|
|
|
/// Summary of a finalized baseline, returned by `GET /result` and `POST /stop`.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct ResultSummary {
|
|
pub calibration_id: String,
|
|
pub room_id: String,
|
|
pub tier: String,
|
|
pub frame_count: u64,
|
|
pub subcarriers: usize,
|
|
pub captured_at_unix_s: i64,
|
|
pub amp_mean_avg: f32,
|
|
pub amp_variance_avg: f32,
|
|
pub phase_dispersion_avg: f32,
|
|
pub output_path: String,
|
|
pub saved_bytes: usize,
|
|
}
|
|
|
|
/// Shared status the HTTP handlers read.
|
|
#[derive(Default)]
|
|
struct SharedStatus {
|
|
udp_port: u16,
|
|
default_tier: String,
|
|
output_dir: String,
|
|
frames_seen: u64,
|
|
last_frame_unix_ms: u64,
|
|
session: Option<SessionStatus>,
|
|
last_result: Option<ResultSummary>,
|
|
}
|
|
|
|
/// Commands sent from HTTP handlers to the ingest task.
|
|
enum CalCommand {
|
|
Start { params: StartParams, reply: oneshot::Sender<Result<SessionStatus, String>> },
|
|
Stop { reply: oneshot::Sender<Result<ResultSummary, String>> },
|
|
EnrollAnchor {
|
|
room_id: String,
|
|
baseline_name: String,
|
|
label: AnchorLabel,
|
|
duration_s: u32,
|
|
reply: oneshot::Sender<Result<AnchorVerdict, String>>,
|
|
},
|
|
}
|
|
|
|
/// Accumulated in-server enrollment for one room (not persisted until train).
|
|
#[derive(Default)]
|
|
struct RoomEnroll {
|
|
baseline_id: String,
|
|
fs_hz: f32,
|
|
anchors: Vec<AnchorFeature>,
|
|
/// Transceiver geometry recorded via `POST /enroll/geometry` (ADR-152
|
|
/// §2.1.1); latest recording wins. Snapshotted into the bank at train time.
|
|
geometry: Vec<NodeGeometry>,
|
|
}
|
|
|
|
/// Result of capturing one anchor (`POST /enroll/anchor`).
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct AnchorVerdict {
|
|
/// Anchor label (snake_case).
|
|
pub label: String,
|
|
/// Passed the quality gate.
|
|
pub accepted: bool,
|
|
/// Rejection reason, if any.
|
|
pub reason: Option<String>,
|
|
/// Mean amplitude z-score vs baseline.
|
|
pub presence_z: f32,
|
|
/// Fraction of frames flagged as motion.
|
|
pub motion_rate: f32,
|
|
/// Frames captured.
|
|
pub frames: u32,
|
|
/// Accepted anchors so far for this room.
|
|
pub accepted_count: usize,
|
|
/// Next anchor in the sequence, if any.
|
|
pub next: Option<String>,
|
|
}
|
|
|
|
/// In-flight anchor capture owned by the ingest task.
|
|
struct EnrollCapture {
|
|
recorder: AnchorRecorder,
|
|
baseline: BaselineCalibration,
|
|
label: AnchorLabel,
|
|
room_id: String,
|
|
baseline_id: String,
|
|
fs_hz: f32,
|
|
series: Vec<f32>,
|
|
deadline: Instant,
|
|
reply: Option<oneshot::Sender<Result<AnchorVerdict, String>>>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct ApiState {
|
|
cmd_tx: mpsc::Sender<CalCommand>,
|
|
status: Arc<RwLock<SharedStatus>>,
|
|
/// Rolling per-frame scalars for live `room-state` inference.
|
|
window: Arc<RwLock<VecDeque<f32>>>,
|
|
/// Default sample rate for periodicity extraction.
|
|
fs_hz: f32,
|
|
/// In-server enrollment accumulator, keyed by `room_id`.
|
|
enroll: Arc<RwLock<HashMap<String, RoomEnroll>>>,
|
|
}
|
|
|
|
/// Bearer-token gate (applied only when `--token` is set). Constant-time-ish
|
|
/// compare is unnecessary here (local appliance), but reject anything that
|
|
/// isn't an exact `Bearer <token>` match.
|
|
async fn require_bearer(
|
|
axum::extract::State(token): axum::extract::State<String>,
|
|
req: axum::extract::Request,
|
|
next: axum::middleware::Next,
|
|
) -> axum::response::Response {
|
|
let authorized = req
|
|
.headers()
|
|
.get(axum::http::header::AUTHORIZATION)
|
|
.and_then(|v| v.to_str().ok())
|
|
.and_then(|h| h.strip_prefix("Bearer "))
|
|
.map(|t| t == token)
|
|
.unwrap_or(false);
|
|
if authorized {
|
|
next.run(req).await
|
|
} else {
|
|
(
|
|
StatusCode::UNAUTHORIZED,
|
|
Json(serde_json::json!({"error": "missing or invalid bearer token"})),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Public entry point
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Build the API router (without the optional auth layer). Shared by `execute`
|
|
/// and the integration tests.
|
|
fn build_router(state: ApiState) -> Router {
|
|
Router::new()
|
|
.route("/", get(descriptor))
|
|
.route("/api/v1/calibration/health", get(health))
|
|
.route("/api/v1/calibration/start", post(start))
|
|
.route("/api/v1/calibration/status", get(status_handler))
|
|
.route("/api/v1/calibration/stop", post(stop))
|
|
.route("/api/v1/calibration/result", get(result))
|
|
.route("/api/v1/calibration/baselines", get(baselines))
|
|
.route("/api/v1/room/state", get(room_state))
|
|
.route("/api/v1/room/train", post(train_room))
|
|
.route("/api/v1/enroll/anchor", post(enroll_anchor))
|
|
.route("/api/v1/enroll/geometry", post(enroll_geometry))
|
|
.route("/api/v1/enroll/status", get(enroll_status))
|
|
.layer(CorsLayer::permissive())
|
|
.with_state(state)
|
|
}
|
|
|
|
/// Run the calibration HTTP API server (blocks until Ctrl-C).
|
|
pub async fn execute(args: CalibrateServeArgs) -> Result<()> {
|
|
std::fs::create_dir_all(&args.output_dir)
|
|
.map_err(|e| anyhow::anyhow!("cannot create output dir {}: {e}", args.output_dir))?;
|
|
|
|
let udp_addr = format!("{}:{}", args.udp_bind, args.udp_port);
|
|
let socket = UdpSocket::bind(&udp_addr)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("cannot bind UDP socket on {udp_addr}: {e}"))?;
|
|
eprintln!("[calibrate-serve] CSI ingest on udp://{udp_addr}");
|
|
|
|
let status = Arc::new(RwLock::new(SharedStatus {
|
|
udp_port: args.udp_port,
|
|
default_tier: args.tier.clone(),
|
|
output_dir: args.output_dir.clone(),
|
|
..Default::default()
|
|
}));
|
|
|
|
let (cmd_tx, cmd_rx) = mpsc::channel::<CalCommand>(8);
|
|
let window = Arc::new(RwLock::new(VecDeque::<f32>::with_capacity(LIVE_WINDOW)));
|
|
let enroll = Arc::new(RwLock::new(HashMap::<String, RoomEnroll>::new()));
|
|
|
|
// Background ingest task owns the socket + recorder.
|
|
{
|
|
let status = status.clone();
|
|
let default_tier = args.tier.clone();
|
|
let output_dir = args.output_dir.clone();
|
|
let window = window.clone();
|
|
let enroll = enroll.clone();
|
|
tokio::spawn(async move {
|
|
ingest_loop(socket, cmd_rx, status, default_tier, output_dir, window, enroll).await;
|
|
});
|
|
}
|
|
|
|
let state = ApiState { cmd_tx, status, window, fs_hz: 15.0, enroll };
|
|
let mut app = build_router(state);
|
|
|
|
// Optional bearer auth — required before any non-loopback exposure.
|
|
if let Some(token) = args.token.clone() {
|
|
app = app.layer(axum::middleware::from_fn_with_state(token, require_bearer));
|
|
eprintln!("[calibrate-serve] bearer auth ENABLED");
|
|
} else if args.http_bind != "127.0.0.1" && args.http_bind != "localhost" {
|
|
eprintln!(
|
|
"[calibrate-serve] WARNING: bound to {} with NO --token — anyone on the network can drive calibration",
|
|
args.http_bind
|
|
);
|
|
}
|
|
|
|
let http_addr = format!("{}:{}", args.http_bind, args.http_port);
|
|
let listener = tokio::net::TcpListener::bind(&http_addr)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("cannot bind HTTP listener on {http_addr}: {e}"))?;
|
|
eprintln!("[calibrate-serve] HTTP API on http://{http_addr} (GET / for the route list)");
|
|
|
|
axum::serve(listener, app)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("HTTP server error: {e}"))?;
|
|
Ok(())
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Ingest task — owns the UDP socket and the optional active recorder
|
|
// ---------------------------------------------------------------------------
|
|
|
|
struct ActiveSession {
|
|
recorder: CalibrationRecorder,
|
|
room_id: String,
|
|
tier: String,
|
|
started: Instant,
|
|
deadline: Instant,
|
|
target_frames: usize,
|
|
z_median: f32,
|
|
z_max: f32,
|
|
motion_flagged: bool,
|
|
}
|
|
|
|
async fn ingest_loop(
|
|
socket: UdpSocket,
|
|
mut cmd_rx: mpsc::Receiver<CalCommand>,
|
|
status: Arc<RwLock<SharedStatus>>,
|
|
default_tier: String,
|
|
output_dir: String,
|
|
window: Arc<RwLock<VecDeque<f32>>>,
|
|
enroll: Arc<RwLock<HashMap<String, RoomEnroll>>>,
|
|
) {
|
|
let mut buf = vec![0u8; RECV_BUF];
|
|
let mut active: Option<ActiveSession> = None;
|
|
let mut active_enroll: Option<EnrollCapture> = None;
|
|
let mut tick = tokio::time::interval(Duration::from_millis(200));
|
|
// Counters mirrored to shared status only on the 200 ms tick — avoids a lock
|
|
// + SessionStatus clone on every UDP frame (CPU starvation under flood).
|
|
let mut frames_seen: u64 = 0;
|
|
let mut last_frame_ms: u64 = 0;
|
|
// Live rolling window, flushed to the shared `window` on the tick.
|
|
let mut win_local: VecDeque<f32> = VecDeque::with_capacity(LIVE_WINDOW);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
// --- incoming command ---
|
|
Some(cmd) = cmd_rx.recv() => match cmd {
|
|
CalCommand::Start { params, reply } => {
|
|
if active.is_some() {
|
|
let _ = reply.send(Err("a calibration session is already running".into()));
|
|
continue;
|
|
}
|
|
let tier = params.tier.unwrap_or_else(|| default_tier.clone());
|
|
if !["ht20", "ht40", "he20", "he40"].contains(&tier.to_ascii_lowercase().as_str()) {
|
|
let _ = reply.send(Err(format!("invalid tier {tier:?}")));
|
|
continue;
|
|
}
|
|
let mut config = tier_config(&tier);
|
|
if params.min_frames > 0 {
|
|
config.min_frames = params.min_frames;
|
|
}
|
|
let target_frames = config.min_frames as usize;
|
|
let dur = params.duration_s.max(1) as u64;
|
|
// Sanitize: room_id is interpolated into the baseline write path.
|
|
let room_id = sanitize_room_id(¶ms.room_id.unwrap_or_else(|| "default".into()));
|
|
let sess = ActiveSession {
|
|
recorder: CalibrationRecorder::new(config),
|
|
room_id: room_id.clone(),
|
|
tier: tier.clone(),
|
|
started: Instant::now(),
|
|
deadline: Instant::now() + Duration::from_secs(dur),
|
|
target_frames,
|
|
z_median: 0.0,
|
|
z_max: 0.0,
|
|
motion_flagged: false,
|
|
};
|
|
let snap = session_snapshot(&sess, "recording", None);
|
|
active = Some(sess);
|
|
{
|
|
let mut s = status.write().await;
|
|
s.session = Some(snap.clone());
|
|
s.last_result = None;
|
|
}
|
|
eprintln!("[calibrate-serve] session start room={room_id} tier={tier} target={target_frames}");
|
|
let _ = reply.send(Ok(snap));
|
|
}
|
|
CalCommand::Stop { reply } => {
|
|
match active.take() {
|
|
Some(sess) => {
|
|
let res = finalize(sess, &output_dir, &status).await;
|
|
let _ = reply.send(res);
|
|
}
|
|
None => { let _ = reply.send(Err("no active calibration session".into())); }
|
|
}
|
|
}
|
|
CalCommand::EnrollAnchor { room_id, baseline_name, label, duration_s, reply } => {
|
|
if active.is_some() || active_enroll.is_some() {
|
|
let _ = reply.send(Err("a capture is already running".into()));
|
|
continue;
|
|
}
|
|
// Resolve the baseline as a sanitized name under output_dir.
|
|
let bname = sanitize_room_id(&baseline_name);
|
|
let bpath = format!("{output_dir}/{bname}.bin");
|
|
let baseline = match tokio::fs::read(&bpath).await {
|
|
Ok(bytes) => match BaselineCalibration::from_bytes(&bytes) {
|
|
Ok(b) => b,
|
|
Err(e) => { let _ = reply.send(Err(format!("invalid baseline {bname}: {e}"))); continue; }
|
|
},
|
|
Err(e) => { let _ = reply.send(Err(format!("baseline {bname} not found: {e}"))); continue; }
|
|
};
|
|
let baseline_id = baseline.calibration_uuid().to_string();
|
|
eprintln!("[calibrate-serve] enroll anchor room={room_id} label={} ({}s)", label.as_str(), duration_s);
|
|
active_enroll = Some(EnrollCapture {
|
|
recorder: AnchorRecorder::new(label),
|
|
baseline,
|
|
label,
|
|
room_id,
|
|
baseline_id,
|
|
fs_hz: 15.0,
|
|
series: Vec::new(),
|
|
deadline: Instant::now() + Duration::from_secs(duration_s.max(1) as u64),
|
|
reply: Some(reply),
|
|
});
|
|
}
|
|
},
|
|
|
|
// --- incoming CSI frame (no shared-status lock here; flushed on tick) ---
|
|
Ok(n) = socket.recv(&mut buf) => {
|
|
frames_seen += 1;
|
|
last_frame_ms = unix_ms();
|
|
let parse_tier = active.as_ref().map(|s| s.tier.clone()).unwrap_or_else(|| default_tier.clone());
|
|
if let Some(frame) = parse_csi_packet(&buf[..n], &parse_tier) {
|
|
// Always maintain the live window (drives /room/state).
|
|
win_local.push_back(frame_scalar(&frame));
|
|
while win_local.len() > LIVE_WINDOW {
|
|
win_local.pop_front();
|
|
}
|
|
if let Some(sess) = active.as_mut() {
|
|
if let Ok(score) = sess.recorder.record(&frame) {
|
|
sess.z_median = score.amplitude_z_median;
|
|
sess.z_max = score.amplitude_z_max;
|
|
sess.motion_flagged = score.motion_flagged;
|
|
}
|
|
if sess.recorder.frames_recorded() as usize >= sess.target_frames {
|
|
if let Some(done) = active.take() {
|
|
let _ = finalize(done, &output_dir, &status).await;
|
|
}
|
|
}
|
|
}
|
|
if let Some(ec) = active_enroll.as_mut() {
|
|
ec.recorder.record_frame(&ec.baseline, &frame);
|
|
ec.series.push(frame_scalar(&frame));
|
|
}
|
|
}
|
|
},
|
|
|
|
// --- 200 ms tick: flush counters + window + session snapshot, deadline check ---
|
|
_ = tick.tick() => {
|
|
{
|
|
let mut s = status.write().await;
|
|
s.frames_seen = frames_seen;
|
|
s.last_frame_unix_ms = last_frame_ms;
|
|
if let Some(sess) = active.as_ref() {
|
|
s.session = Some(session_snapshot(sess, "recording", None));
|
|
}
|
|
}
|
|
{
|
|
let mut w = window.write().await;
|
|
w.clear();
|
|
w.extend(win_local.iter().copied());
|
|
}
|
|
if let Some(sess) = active.as_ref() {
|
|
if Instant::now() >= sess.deadline {
|
|
let frames = sess.recorder.frames_recorded() as usize;
|
|
if frames >= 10 {
|
|
if let Some(done) = active.take() {
|
|
let _ = finalize(done, &output_dir, &status).await;
|
|
}
|
|
} else if let Some(mut done) = active.take() {
|
|
// not enough frames — abort honestly rather than emit a bad baseline
|
|
done.motion_flagged = false;
|
|
let note = format!(
|
|
"aborted: only {frames} frames in the time window (need >=10) — \
|
|
is the ESP32 streaming to udp:{}? ",
|
|
status.read().await.udp_port
|
|
);
|
|
let snap = session_snapshot(&done, "aborted", Some(note.clone()));
|
|
status.write().await.session = Some(snap);
|
|
eprintln!("[calibrate-serve] {note}");
|
|
}
|
|
}
|
|
}
|
|
// Enroll-anchor capture finished?
|
|
let enroll_done = active_enroll.as_ref().map(|ec| Instant::now() >= ec.deadline).unwrap_or(false);
|
|
if enroll_done {
|
|
if let Some(mut ec) = active_enroll.take() {
|
|
let gate = AnchorQualityGate::default();
|
|
let (anchor, reason) = ec.recorder.finalize(&gate, (unix_ms() / 1000) as i64);
|
|
let mut verdict = AnchorVerdict {
|
|
label: ec.label.as_str().into(),
|
|
accepted: anchor.quality.accepted,
|
|
reason,
|
|
presence_z: anchor.quality.presence_z,
|
|
motion_rate: anchor.quality.motion_rate,
|
|
frames: anchor.quality.frames,
|
|
accepted_count: 0,
|
|
next: None,
|
|
};
|
|
if anchor.quality.accepted {
|
|
let feat = AnchorFeature::from_series(&ec.room_id, ec.label, &ec.series, ec.fs_hz);
|
|
let mut map = enroll.write().await;
|
|
let re = map.entry(ec.room_id.clone()).or_insert_with(RoomEnroll::default);
|
|
if re.baseline_id.is_empty() {
|
|
re.baseline_id = ec.baseline_id.clone();
|
|
re.fs_hz = ec.fs_hz;
|
|
}
|
|
if let Some(slot) = re.anchors.iter_mut().find(|a| a.label == ec.label) {
|
|
*slot = feat;
|
|
} else {
|
|
re.anchors.push(feat);
|
|
}
|
|
verdict.accepted_count = re.anchors.len();
|
|
verdict.next = AnchorLabel::SEQUENCE.iter().copied()
|
|
.find(|l| !re.anchors.iter().any(|a| a.label == *l))
|
|
.map(|l| l.as_str().to_string());
|
|
} else {
|
|
verdict.accepted_count = enroll.read().await.get(&ec.room_id).map(|re| re.anchors.len()).unwrap_or(0);
|
|
}
|
|
eprintln!("[calibrate-serve] enroll anchor {} accepted={} ({} total)", verdict.label, verdict.accepted, verdict.accepted_count);
|
|
if let Some(tx) = ec.reply.take() {
|
|
let _ = tx.send(Ok(verdict));
|
|
}
|
|
}
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Finalize a session: persist the baseline and publish the result summary.
|
|
async fn finalize(
|
|
sess: ActiveSession,
|
|
output_dir: &str,
|
|
status: &Arc<RwLock<SharedStatus>>,
|
|
) -> Result<ResultSummary, String> {
|
|
let room_id = sess.room_id.clone();
|
|
let tier = sess.tier.clone();
|
|
// mark finalizing
|
|
{
|
|
let snap = session_snapshot(&sess, "finalizing", None);
|
|
status.write().await.session = Some(snap);
|
|
}
|
|
|
|
let baseline: BaselineCalibration = sess
|
|
.recorder
|
|
.finalize()
|
|
.map_err(|e| format!("finalize failed: {e}"))?;
|
|
|
|
let (amp_mean_avg, amp_var_avg, disp_avg) = baseline_averages(&baseline);
|
|
let uuid = baseline.calibration_uuid().to_string();
|
|
let path = format!("{output_dir}/{room_id}-{uuid}.bin");
|
|
let bytes = baseline.to_bytes();
|
|
// Async write — never block the ingest task's UDP/command path.
|
|
tokio::fs::write(&path, &bytes)
|
|
.await
|
|
.map_err(|e| format!("cannot write {path}: {e}"))?;
|
|
|
|
let summary = ResultSummary {
|
|
calibration_id: uuid,
|
|
room_id: room_id.clone(),
|
|
tier,
|
|
frame_count: baseline.frame_count,
|
|
subcarriers: baseline.subcarriers.len(),
|
|
captured_at_unix_s: baseline.captured_at_unix_s,
|
|
amp_mean_avg,
|
|
amp_variance_avg: amp_var_avg,
|
|
phase_dispersion_avg: disp_avg,
|
|
output_path: path.clone(),
|
|
saved_bytes: bytes.len(),
|
|
};
|
|
|
|
{
|
|
let mut s = status.write().await;
|
|
// reflect completion in the session snapshot, then store the result
|
|
if let Some(sess_status) = s.session.as_mut() {
|
|
sess_status.state = "complete".into();
|
|
sess_status.progress = 1.0;
|
|
}
|
|
s.last_result = Some(summary.clone());
|
|
}
|
|
eprintln!(
|
|
"[calibrate-serve] session complete room={room_id} frames={} -> {path} ({} bytes)",
|
|
summary.frame_count, summary.saved_bytes
|
|
);
|
|
Ok(summary)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// HTTP handlers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
async fn descriptor() -> impl IntoResponse {
|
|
Json(serde_json::json!({
|
|
"service": "wifi-densepose calibration API",
|
|
"adr": "ADR-135 (baseline) / ADR-151 (room calibration & training)",
|
|
"endpoints": {
|
|
"GET /api/v1/calibration/health": "liveness + UDP ingest stats",
|
|
"POST /api/v1/calibration/start": "{ tier?, duration_s?, room_id?, min_frames? }",
|
|
"GET /api/v1/calibration/status": "live session progress (poll for UI)",
|
|
"POST /api/v1/calibration/stop": "finalize current session early",
|
|
"GET /api/v1/calibration/result": "last finalized baseline summary",
|
|
"GET /api/v1/calibration/baselines": "list persisted baseline files",
|
|
"GET /api/v1/room/state?bank=<name>": "live mixture-of-specialists RoomState over the CSI window",
|
|
"POST /api/v1/room/train": "{ room_id, baseline_id, anchors[]?, geometry[]? } → train + persist a specialist bank (anchors[]/geometry[] optional if enrolled in-server)",
|
|
"POST /api/v1/enroll/anchor": "{ room_id, baseline, label, duration_s? } → capture one guided anchor (blocks for the capture)",
|
|
"POST /api/v1/enroll/geometry": "{ room_id, geometry: [NodeGeometry…] } → record transceiver geometry for the room (ADR-152 §2.1.1; latest wins)",
|
|
"GET /api/v1/enroll/status?room=<id>": "enrollment progress (accepted anchors, next, complete)"
|
|
}
|
|
}))
|
|
}
|
|
|
|
async fn health(State(st): State<ApiState>) -> impl IntoResponse {
|
|
let s = st.status.read().await;
|
|
let age = if s.last_frame_unix_ms == 0 { None } else { Some(unix_ms().saturating_sub(s.last_frame_unix_ms)) };
|
|
Json(serde_json::json!({
|
|
"status": "ok",
|
|
"udp_port": s.udp_port,
|
|
"frames_seen": s.frames_seen,
|
|
"last_frame_age_ms": age,
|
|
"streaming": age.map(|a| a < 2000).unwrap_or(false),
|
|
"default_tier": s.default_tier,
|
|
"output_dir": s.output_dir,
|
|
"session_active": s.session.as_ref().map(|x| x.state == "recording").unwrap_or(false),
|
|
}))
|
|
}
|
|
|
|
async fn start(State(st): State<ApiState>, Json(params): Json<StartParams>) -> impl IntoResponse {
|
|
let (tx, rx) = oneshot::channel();
|
|
if st.cmd_tx.send(CalCommand::Start { params, reply: tx }).await.is_err() {
|
|
return (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"ingest task unavailable"}))).into_response();
|
|
}
|
|
match rx.await {
|
|
Ok(Ok(snap)) => (StatusCode::ACCEPTED, Json(serde_json::to_value(snap).unwrap())).into_response(),
|
|
Ok(Err(e)) => (StatusCode::CONFLICT, Json(serde_json::json!({"error": e}))).into_response(),
|
|
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"no reply"}))).into_response(),
|
|
}
|
|
}
|
|
|
|
async fn status_handler(State(st): State<ApiState>) -> impl IntoResponse {
|
|
let s = st.status.read().await;
|
|
match &s.session {
|
|
Some(sess) => (StatusCode::OK, Json(serde_json::to_value(sess).unwrap())).into_response(),
|
|
None => (StatusCode::OK, Json(serde_json::json!({"state":"idle"}))).into_response(),
|
|
}
|
|
}
|
|
|
|
async fn stop(State(st): State<ApiState>) -> impl IntoResponse {
|
|
let (tx, rx) = oneshot::channel();
|
|
if st.cmd_tx.send(CalCommand::Stop { reply: tx }).await.is_err() {
|
|
return (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"ingest task unavailable"}))).into_response();
|
|
}
|
|
match rx.await {
|
|
Ok(Ok(summary)) => (StatusCode::OK, Json(serde_json::to_value(summary).unwrap())).into_response(),
|
|
Ok(Err(e)) => (StatusCode::CONFLICT, Json(serde_json::json!({"error": e}))).into_response(),
|
|
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"no reply"}))).into_response(),
|
|
}
|
|
}
|
|
|
|
async fn result(State(st): State<ApiState>) -> impl IntoResponse {
|
|
let s = st.status.read().await;
|
|
match &s.last_result {
|
|
Some(r) => (StatusCode::OK, Json(serde_json::to_value(r).unwrap())).into_response(),
|
|
None => (StatusCode::NOT_FOUND, Json(serde_json::json!({"error":"no finalized baseline yet"}))).into_response(),
|
|
}
|
|
}
|
|
|
|
/// Body for `POST /api/v1/room/train` — an enrollment (CLI `enroll` output or
|
|
/// any client that gathered labelled anchor features).
|
|
#[derive(Deserialize)]
|
|
struct TrainRequest {
|
|
room_id: String,
|
|
baseline_id: String,
|
|
#[serde(default)]
|
|
anchors: Vec<AnchorFeature>,
|
|
/// Optional transceiver geometry (ADR-152 §2.1.1). Falls back to the
|
|
/// geometry recorded in-server via `POST /enroll/geometry`; absent both,
|
|
/// the bank trains geometry-free (valid, but no geometry conditioning).
|
|
#[serde(default)]
|
|
geometry: Vec<NodeGeometry>,
|
|
}
|
|
|
|
/// Train a per-room specialist bank and persist it as `<output_dir>/<room_id>.json`
|
|
/// (the name `room-state` reads back). Uses the posted `anchors` if present, else
|
|
/// falls back to the in-server enrollment accumulated via `POST /enroll/anchor`.
|
|
/// The enrollment's transceiver-geometry snapshot (posted `geometry` or the
|
|
/// `POST /enroll/geometry` record) is threaded into the bank (ADR-152 §2.1.1).
|
|
async fn train_room(State(st): State<ApiState>, Json(req): Json<TrainRequest>) -> impl IntoResponse {
|
|
let (anchors, baseline_id) = if !req.anchors.is_empty() {
|
|
(req.anchors.clone(), req.baseline_id.clone())
|
|
} else {
|
|
match st.enroll.read().await.get(&req.room_id) {
|
|
Some(re) if !re.anchors.is_empty() => (re.anchors.clone(), re.baseline_id.clone()),
|
|
_ => {
|
|
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":"no anchors in request and none enrolled for this room"}))).into_response();
|
|
}
|
|
}
|
|
};
|
|
let geometry = if !req.geometry.is_empty() {
|
|
req.geometry.clone()
|
|
} else {
|
|
st.enroll.read().await.get(&req.room_id).map(|re| re.geometry.clone()).unwrap_or_default()
|
|
};
|
|
let at = (unix_ms() / 1000) as i64;
|
|
let bank = match SpecialistBank::train(&req.room_id, &baseline_id, &anchors, at) {
|
|
Ok(b) => b,
|
|
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("training failed: {e}")}))).into_response(),
|
|
};
|
|
let bank = if geometry.is_empty() {
|
|
eprintln!(
|
|
"[calibrate-serve] no transceiver geometry recorded for room '{}' — bank will not support geometry conditioning (ADR-152 §2.1.2)",
|
|
req.room_id
|
|
);
|
|
bank
|
|
} else {
|
|
bank.with_geometry(geometry)
|
|
};
|
|
let name = sanitize_room_id(&req.room_id);
|
|
let dir = { st.status.read().await.output_dir.clone() };
|
|
let path = format!("{dir}/{name}.json");
|
|
let json = match bank.to_json() {
|
|
Ok(j) => j,
|
|
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": format!("serialize: {e}")}))).into_response(),
|
|
};
|
|
if let Err(e) = tokio::fs::write(&path, json).await {
|
|
return (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": format!("cannot write {path}: {e}")}))).into_response();
|
|
}
|
|
let kinds: Vec<String> = bank.trained_kinds().iter().map(|k| format!("{k:?}")).collect();
|
|
(StatusCode::OK, Json(serde_json::json!({
|
|
"room_id": bank.room_id,
|
|
"bank": name, // pass as ?bank=<name> to /room/state
|
|
"anchor_count": bank.anchor_count,
|
|
"specialists": kinds,
|
|
"geometry_nodes": bank.geometry.len(),
|
|
"path": path,
|
|
}))).into_response()
|
|
}
|
|
|
|
/// Body for `POST /api/v1/enroll/geometry`.
|
|
#[derive(Deserialize)]
|
|
struct EnrollGeometryBody {
|
|
room_id: String,
|
|
/// Per-node transceiver geometry records (ADR-152 §2.1.1).
|
|
geometry: Vec<NodeGeometry>,
|
|
}
|
|
|
|
/// Record the room's transceiver geometry (ADR-152 §2.1.1) into the in-server
|
|
/// enrollment; the next `POST /room/train` snapshots it into the bank. A later
|
|
/// POST supersedes an earlier one (latest wins), mirroring
|
|
/// `EnrollmentSession::record_geometry`.
|
|
async fn enroll_geometry(State(st): State<ApiState>, Json(b): Json<EnrollGeometryBody>) -> impl IntoResponse {
|
|
if b.geometry.is_empty() {
|
|
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error":"geometry must be a non-empty array of NodeGeometry records"}))).into_response();
|
|
}
|
|
let nodes = b.geometry.len();
|
|
{
|
|
let mut map = st.enroll.write().await;
|
|
let re = map.entry(b.room_id.clone()).or_insert_with(RoomEnroll::default);
|
|
re.geometry = b.geometry;
|
|
}
|
|
eprintln!("[calibrate-serve] enroll geometry room={} nodes={nodes}", b.room_id);
|
|
(StatusCode::OK, Json(serde_json::json!({"room_id": b.room_id, "geometry_nodes": nodes}))).into_response()
|
|
}
|
|
|
|
/// Body for `POST /api/v1/enroll/anchor`.
|
|
#[derive(Deserialize)]
|
|
struct EnrollAnchorBody {
|
|
room_id: String,
|
|
/// Baseline name (without `.bin`), resolved under `output_dir`.
|
|
baseline: String,
|
|
/// Anchor label (snake_case, e.g. `stand_still`).
|
|
label: String,
|
|
/// Capture duration (s); defaults to the anchor's recommended length.
|
|
duration_s: Option<u32>,
|
|
}
|
|
|
|
/// Capture one guided anchor against a baseline. Blocks for the capture
|
|
/// duration, then returns the gate verdict (accept/reject + progress).
|
|
async fn enroll_anchor(State(st): State<ApiState>, Json(b): Json<EnrollAnchorBody>) -> impl IntoResponse {
|
|
let label = match AnchorLabel::from_str(&b.label) {
|
|
Some(l) => l,
|
|
None => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("unknown anchor label {:?}", b.label)}))).into_response(),
|
|
};
|
|
let duration_s = b.duration_s.unwrap_or_else(|| label.duration_s());
|
|
let (tx, rx) = oneshot::channel();
|
|
let cmd = CalCommand::EnrollAnchor {
|
|
room_id: b.room_id,
|
|
baseline_name: b.baseline,
|
|
label,
|
|
duration_s,
|
|
reply: tx,
|
|
};
|
|
if st.cmd_tx.send(cmd).await.is_err() {
|
|
return (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"ingest task unavailable"}))).into_response();
|
|
}
|
|
match rx.await {
|
|
Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::to_value(v).unwrap())).into_response(),
|
|
Ok(Err(e)) => (StatusCode::CONFLICT, Json(serde_json::json!({"error": e}))).into_response(),
|
|
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error":"no reply"}))).into_response(),
|
|
}
|
|
}
|
|
|
|
/// Query for `GET /api/v1/enroll/status`.
|
|
#[derive(Deserialize)]
|
|
struct EnrollStatusQuery {
|
|
room: String,
|
|
}
|
|
|
|
/// Enrollment progress for a room.
|
|
async fn enroll_status(State(st): State<ApiState>, Query(q): Query<EnrollStatusQuery>) -> impl IntoResponse {
|
|
let map = st.enroll.read().await;
|
|
let (accepted, baseline_id): (Vec<String>, String) = match map.get(&q.room) {
|
|
Some(re) => (
|
|
re.anchors.iter().map(|a| a.label.as_str().to_string()).collect(),
|
|
re.baseline_id.clone(),
|
|
),
|
|
None => (Vec::new(), String::new()),
|
|
};
|
|
let next = AnchorLabel::SEQUENCE
|
|
.iter()
|
|
.copied()
|
|
.find(|l| !accepted.iter().any(|a| a == l.as_str()))
|
|
.map(|l| l.as_str().to_string());
|
|
Json(serde_json::json!({
|
|
"room": q.room,
|
|
"baseline_id": baseline_id,
|
|
"accepted": accepted,
|
|
"count": accepted.len(),
|
|
"total": AnchorLabel::SEQUENCE.len(),
|
|
"next": next,
|
|
"complete": next.is_none() && !accepted.is_empty(),
|
|
}))
|
|
}
|
|
|
|
/// Query for `GET /api/v1/room/state`.
|
|
#[derive(Deserialize)]
|
|
struct RoomStateQuery {
|
|
/// Bank name (sanitized; resolved as `<output_dir>/<bank>.json`).
|
|
bank: String,
|
|
/// Sample rate override (Hz).
|
|
fs: Option<f32>,
|
|
}
|
|
|
|
/// Live mixture-of-specialists readout over the current CSI window.
|
|
async fn room_state(State(st): State<ApiState>, Query(q): Query<RoomStateQuery>) -> impl IntoResponse {
|
|
// Resolve the bank as a sanitized name under output_dir — no arbitrary file read.
|
|
let name = sanitize_room_id(&q.bank);
|
|
let dir = { st.status.read().await.output_dir.clone() };
|
|
let path = format!("{dir}/{name}.json");
|
|
let raw = match tokio::fs::read_to_string(&path).await {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
return (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": format!("bank '{name}' not found: {e}")}))).into_response();
|
|
}
|
|
};
|
|
let bank = match SpecialistBank::from_json(&raw) {
|
|
Ok(b) => b,
|
|
Err(e) => return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("invalid bank: {e}")}))).into_response(),
|
|
};
|
|
|
|
let series: Vec<f32> = { st.window.read().await.iter().copied().collect() };
|
|
if series.len() < 32 {
|
|
return (StatusCode::OK, Json(serde_json::json!({"state":"warming_up","frames":series.len()}))).into_response();
|
|
}
|
|
let fs = q.fs.unwrap_or(st.fs_hz);
|
|
let features = Features::from_series(&series, fs);
|
|
let baseline_id = bank.baseline_id.clone();
|
|
let mix = MixtureOfSpecialists::new(bank);
|
|
let room = mix.infer(&features, &baseline_id);
|
|
(StatusCode::OK, Json(serde_json::to_value(room).unwrap())).into_response()
|
|
}
|
|
|
|
async fn baselines(State(st): State<ApiState>) -> impl IntoResponse {
|
|
let dir = { st.status.read().await.output_dir.clone() };
|
|
let mut out = Vec::new();
|
|
if let Ok(rd) = std::fs::read_dir(&dir) {
|
|
for entry in rd.flatten() {
|
|
let path = entry.path();
|
|
if path.extension().and_then(|e| e.to_str()) == Some("bin") {
|
|
let bytes = entry.metadata().map(|m| m.len()).unwrap_or(0);
|
|
out.push(serde_json::json!({
|
|
"file": path.file_name().and_then(|n| n.to_str()).unwrap_or(""),
|
|
"path": path.to_string_lossy(),
|
|
"bytes": bytes,
|
|
}));
|
|
}
|
|
}
|
|
}
|
|
Json(serde_json::json!({ "dir": dir, "baselines": out }))
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
fn session_snapshot(sess: &ActiveSession, state: &str, note: Option<String>) -> SessionStatus {
|
|
let frames = sess.recorder.frames_recorded() as usize;
|
|
let progress = if sess.target_frames == 0 {
|
|
0.0
|
|
} else {
|
|
(frames as f32 / sess.target_frames as f32).clamp(0.0, 1.0)
|
|
};
|
|
let elapsed = sess.started.elapsed().as_secs_f32();
|
|
let eta = if frames == 0 {
|
|
sess.deadline.saturating_duration_since(Instant::now()).as_secs_f32()
|
|
} else {
|
|
let per = elapsed / frames as f32;
|
|
(per * (sess.target_frames.saturating_sub(frames)) as f32).max(0.0)
|
|
};
|
|
SessionStatus {
|
|
state: state.into(),
|
|
room_id: sess.room_id.clone(),
|
|
tier: sess.tier.clone(),
|
|
frames_recorded: frames,
|
|
target_frames: sess.target_frames,
|
|
progress,
|
|
z_median: sess.z_median,
|
|
z_max: sess.z_max,
|
|
motion_flagged: sess.motion_flagged,
|
|
elapsed_s: elapsed,
|
|
eta_s: eta,
|
|
note,
|
|
}
|
|
}
|
|
|
|
fn baseline_averages(b: &BaselineCalibration) -> (f32, f32, f32) {
|
|
let n = b.subcarriers.len().max(1) as f32;
|
|
let mut amp = 0.0f32;
|
|
let mut var = 0.0f32;
|
|
let mut disp = 0.0f32;
|
|
for s in &b.subcarriers {
|
|
amp += s.amp_mean;
|
|
var += s.amp_variance;
|
|
disp += s.phase_dispersion;
|
|
}
|
|
(amp / n, var / n, disp / n)
|
|
}
|
|
|
|
fn unix_ms() -> u64 {
|
|
SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.map(|d| d.as_millis() as u64)
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn start_params_defaults() {
|
|
let p = StartParams::default();
|
|
assert_eq!(p.duration_s, 30);
|
|
assert_eq!(p.min_frames, 0);
|
|
assert!(p.tier.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn start_params_partial_json() {
|
|
let p: StartParams = serde_json::from_str(r#"{"room_id":"living-room","tier":"he20"}"#).unwrap();
|
|
assert_eq!(p.room_id.as_deref(), Some("living-room"));
|
|
assert_eq!(p.tier.as_deref(), Some("he20"));
|
|
assert_eq!(p.duration_s, 30); // default applied
|
|
}
|
|
|
|
#[test]
|
|
fn args_defaults() {
|
|
let a = CalibrateServeArgs {
|
|
http_port: 8090,
|
|
http_bind: "127.0.0.1".into(),
|
|
udp_port: 5005,
|
|
udp_bind: "0.0.0.0".into(),
|
|
tier: "ht20".into(),
|
|
output_dir: "./baselines".into(),
|
|
token: None,
|
|
};
|
|
assert_eq!(a.http_port, 8090);
|
|
assert_eq!(a.udp_port, 5005);
|
|
}
|
|
|
|
#[test]
|
|
fn sanitize_blocks_path_traversal() {
|
|
assert_eq!(sanitize_room_id("../../etc/passwd"), "etcpasswd");
|
|
assert_eq!(sanitize_room_id("/abs/path"), "abspath");
|
|
assert_eq!(sanitize_room_id("living-room_1"), "living-room_1");
|
|
assert_eq!(sanitize_room_id(""), "default");
|
|
assert_eq!(sanitize_room_id("..\\..\\win"), "win");
|
|
assert!(!sanitize_room_id("a/b/c").contains('/'));
|
|
}
|
|
|
|
// ---- HTTP integration tests (router via tower oneshot, no network/ingest) ----
|
|
|
|
use axum::body::Body;
|
|
use axum::http::{Request, StatusCode};
|
|
use tower::ServiceExt; // for `oneshot`
|
|
|
|
fn test_state(dir: &str) -> ApiState {
|
|
let (cmd_tx, _rx) = mpsc::channel::<CalCommand>(8);
|
|
let status = Arc::new(RwLock::new(SharedStatus {
|
|
output_dir: dir.to_string(),
|
|
..Default::default()
|
|
}));
|
|
let window = Arc::new(RwLock::new(VecDeque::<f32>::new()));
|
|
let enroll = Arc::new(RwLock::new(HashMap::<String, RoomEnroll>::new()));
|
|
// Tested handlers never use cmd_tx; drop the receiver.
|
|
drop(_rx);
|
|
ApiState { cmd_tx, status, window, fs_hz: 15.0, enroll }
|
|
}
|
|
|
|
async fn req(app: Router, method: &str, uri: &str, body: Option<&str>) -> StatusCode {
|
|
let b = body.map(|s| Body::from(s.to_string())).unwrap_or_else(Body::empty);
|
|
let r = Request::builder()
|
|
.method(method)
|
|
.uri(uri)
|
|
.header("content-type", "application/json")
|
|
.body(b)
|
|
.unwrap();
|
|
app.oneshot(r).await.unwrap().status()
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn health_and_descriptor_ok() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let app = build_router(test_state(dir.path().to_str().unwrap()));
|
|
assert_eq!(req(app.clone(), "GET", "/", None).await, StatusCode::OK);
|
|
assert_eq!(req(app, "GET", "/api/v1/calibration/health", None).await, StatusCode::OK);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn train_then_state_and_traversal_defense() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let state = test_state(dir.path().to_str().unwrap());
|
|
// Fill the live window with a 0.3 Hz breathing sine.
|
|
{
|
|
let mut w = state.window.write().await;
|
|
for i in 0..200 {
|
|
w.push_back((2.0 * std::f32::consts::PI * 0.3 * i as f32 / 15.0).sin());
|
|
}
|
|
}
|
|
let app = build_router(state);
|
|
|
|
// POST /room/train with two anchors → bank persisted as t.json.
|
|
let body = r#"{"room_id":"t","baseline_id":"b","anchors":[
|
|
{"room_id":"t","label":"empty","features":{"mean":1.0,"variance":1.0,"motion":0.1,"breathing_score":0.0,"breathing_hz":0.0,"heart_score":0.0,"heart_hz":0.0}},
|
|
{"room_id":"t","label":"stand_still","features":{"mean":1.0,"variance":10.0,"motion":0.2,"breathing_score":0.0,"breathing_hz":0.0,"heart_score":0.0,"heart_hz":0.0}}
|
|
]}"#;
|
|
assert_eq!(req(app.clone(), "POST", "/api/v1/room/train", Some(body)).await, StatusCode::OK);
|
|
assert!(dir.path().join("t.json").exists(), "bank file written");
|
|
|
|
// GET /room/state?bank=t → 200 (trained bank loaded, window present).
|
|
assert_eq!(req(app.clone(), "GET", "/api/v1/room/state?bank=t", None).await, StatusCode::OK);
|
|
|
|
// Path-traversal: ?bank=../../etc/passwd is sanitized → NOT_FOUND, never reads outside dir.
|
|
assert_eq!(
|
|
req(app.clone(), "GET", "/api/v1/room/state?bank=../../etc/passwd", None).await,
|
|
StatusCode::NOT_FOUND
|
|
);
|
|
|
|
// Train with no anchors and nothing enrolled → 400.
|
|
assert_eq!(
|
|
req(app, "POST", "/api/v1/room/train", Some(r#"{"room_id":"none","baseline_id":"b","anchors":[]}"#)).await,
|
|
StatusCode::BAD_REQUEST
|
|
);
|
|
}
|
|
|
|
/// ADR-152 §2.1.1: geometry threads into the trained bank through both API
|
|
/// paths — inline in the train request, or recorded via /enroll/geometry —
|
|
/// and a geometry-free train still produces a valid (unconditioned) bank.
|
|
#[tokio::test]
|
|
async fn train_threads_geometry_into_bank() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let app = build_router(test_state(dir.path().to_str().unwrap()));
|
|
let anchors = r#"[
|
|
{"room_id":"g","label":"empty","features":{"mean":1.0,"variance":1.0,"motion":0.1,"breathing_score":0.0,"breathing_hz":0.0,"heart_score":0.0,"heart_hz":0.0}},
|
|
{"room_id":"g","label":"stand_still","features":{"mean":1.0,"variance":10.0,"motion":0.2,"breathing_score":0.0,"breathing_hz":0.0,"heart_score":0.0,"heart_hz":0.0}}
|
|
]"#;
|
|
let load_bank = |name: &str| {
|
|
let raw = std::fs::read_to_string(dir.path().join(format!("{name}.json"))).unwrap();
|
|
SpecialistBank::from_json(&raw).unwrap()
|
|
};
|
|
|
|
// (1) geometry inline in the train request.
|
|
let body = format!(
|
|
r#"{{"room_id":"g1","baseline_id":"b","anchors":{anchors},
|
|
"geometry":[{{"node_id":1,"position":{{"x_m":0.0,"y_m":0.0,"z_m":1.0}},"method":"tape-measure"}},{{"node_id":2}}]}}"#
|
|
);
|
|
assert_eq!(req(app.clone(), "POST", "/api/v1/room/train", Some(&body)).await, StatusCode::OK);
|
|
let bank = load_bank("g1");
|
|
assert_eq!(bank.geometry.len(), 2);
|
|
assert_eq!(bank.geometry[0].method, "tape-measure");
|
|
assert_eq!(bank.geometry[1].node_id, 2);
|
|
|
|
// (2) geometry recorded via /enroll/geometry; train body omits it.
|
|
assert_eq!(
|
|
req(app.clone(), "POST", "/api/v1/enroll/geometry",
|
|
Some(r#"{"room_id":"g2","geometry":[{"node_id":7,"method":"floor-plan"}]}"#)).await,
|
|
StatusCode::OK
|
|
);
|
|
let body2 = format!(r#"{{"room_id":"g2","baseline_id":"b","anchors":{anchors}}}"#);
|
|
assert_eq!(req(app.clone(), "POST", "/api/v1/room/train", Some(&body2)).await, StatusCode::OK);
|
|
let bank2 = load_bank("g2");
|
|
assert_eq!(bank2.geometry.len(), 1);
|
|
assert_eq!(bank2.geometry[0].node_id, 7);
|
|
|
|
// (3) no geometry anywhere → valid geometry-free bank (note logged).
|
|
let body3 = format!(r#"{{"room_id":"g3","baseline_id":"b","anchors":{anchors}}}"#);
|
|
assert_eq!(req(app.clone(), "POST", "/api/v1/room/train", Some(&body3)).await, StatusCode::OK);
|
|
let bank3 = load_bank("g3");
|
|
assert!(bank3.geometry.is_empty());
|
|
assert!(bank3.presence.is_some(), "bank still trains without geometry");
|
|
|
|
// (4) empty geometry array is rejected.
|
|
assert_eq!(
|
|
req(app, "POST", "/api/v1/enroll/geometry", Some(r#"{"room_id":"g4","geometry":[]}"#)).await,
|
|
StatusCode::BAD_REQUEST
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn enroll_status_empty_and_bad_label() {
|
|
let dir = tempfile::tempdir().unwrap();
|
|
let app = build_router(test_state(dir.path().to_str().unwrap()));
|
|
// No enrollment yet → 200 with next=empty.
|
|
assert_eq!(req(app.clone(), "GET", "/api/v1/enroll/status?room=x", None).await, StatusCode::OK);
|
|
// Unknown anchor label → 400.
|
|
assert_eq!(
|
|
req(app, "POST", "/api/v1/enroll/anchor", Some(r#"{"room_id":"x","baseline":"b","label":"nope"}"#)).await,
|
|
StatusCode::BAD_REQUEST
|
|
);
|
|
}
|
|
}
|