fix(sensing-server): wire training_api::routes() into HTTP router
The proper supervised training pipeline in training_api.rs (with
real_training_loop using analytical gradients on a regularised
linear model + RecordedFrame deserialisation + .rvf export) was
implemented but never wired into the http_app router.
main.rs registered stub handlers at /api/v1/train/{start,status,stop}
that just flipped a String field. POST /train/start returned
"running" without invoking any training; GET /train/status returned
the stub string instead of a TrainingStatus struct.
This change:
- Declares `mod training_api;` in main.rs so its routes are reachable
- Replaces AppStateInner's `training_status: String` and
`training_config: Option<serde_json::Value>` with the proper
`training_state: training_api::TrainingState` and
`training_progress_tx: broadcast::Sender<String>` fields the
real handlers require
- Removes the three stub handlers (train_status/train_start/
train_stop) and their .route() registrations
- Adds `.merge(training_api::routes())` to expose the real pipeline
plus /ws/train/progress, /train/pretrain, /train/lora
- Inlines RECORDINGS_DIR + RecordedFrame into training_api.rs so we
do not also have to declare `mod recording;` in main.rs
(recording.rs references AppStateInner::recording_state which is
not present on the production AppStateInner; pulling it in
cascades into further refactoring)
- Switches training_api.rs's `crate::path_safety::safe_id` to
`wifi_densepose_sensing_server::path_safety::safe_id` because
path_safety lives in the library tree while training_api lives
in the binary tree
After this change, POST /api/v1/train/start actually trains, and
GET /api/v1/train/status returns the documented TrainingStatus
struct (active, epoch, total_epochs, train_loss, val_pck, val_oks,
lr, best_pck, best_epoch, patience_remaining, eta_secs, phase).
Verified end-to-end locally on an 80k-frame recording — early-stop
fired at epoch 40, .rvf written to data/models/.
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
be4efecbcd
commit
4f97bd33ab
|
|
@ -18,6 +18,7 @@ pub mod pose;
|
|||
mod rvf_container;
|
||||
mod rvf_pipeline;
|
||||
mod tracker_bridge;
|
||||
mod training_api;
|
||||
pub mod types;
|
||||
mod vital_signs;
|
||||
|
||||
|
|
@ -963,10 +964,11 @@ struct AppStateInner {
|
|||
/// Shutdown signal for the recording writer task.
|
||||
recording_stop_tx: Option<tokio::sync::watch::Sender<bool>>,
|
||||
// ── Training fields ─────────────────────────────────────────────────────
|
||||
/// Training status: "idle", "running", "completed", "failed".
|
||||
training_status: String,
|
||||
/// Training configuration, if any.
|
||||
training_config: Option<serde_json::Value>,
|
||||
/// Full training-pipeline state used by training_api::routes().
|
||||
training_state: training_api::TrainingState,
|
||||
/// Broadcast channel for per-epoch training progress (consumed by
|
||||
/// /ws/train/progress).
|
||||
training_progress_tx: tokio::sync::broadcast::Sender<String>,
|
||||
// ── Adaptive classifier (environment-tuned) ──────────────────────────
|
||||
/// Trained adaptive model (loaded from data/adaptive_model.json or trained at runtime).
|
||||
adaptive_model: Option<adaptive_classifier::AdaptiveModel>,
|
||||
|
|
@ -3911,54 +3913,10 @@ fn scan_recording_files() -> Vec<serde_json::Value> {
|
|||
}
|
||||
|
||||
// ── Training Endpoints ──────────────────────────────────────────────────────
|
||||
|
||||
/// GET /api/v1/train/status — get training status.
|
||||
async fn train_status(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
Json(serde_json::json!({
|
||||
"status": s.training_status,
|
||||
"config": s.training_config,
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/v1/train/start — start a training run.
|
||||
async fn train_start(
|
||||
State(state): State<SharedState>,
|
||||
Json(body): Json<serde_json::Value>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let mut s = state.write().await;
|
||||
if s.training_status == "running" {
|
||||
return Json(serde_json::json!({
|
||||
"error": "training already running",
|
||||
"success": false,
|
||||
}));
|
||||
}
|
||||
s.training_status = "running".to_string();
|
||||
s.training_config = Some(body.clone());
|
||||
info!("Training started with config: {}", body);
|
||||
Json(serde_json::json!({
|
||||
"success": true,
|
||||
"status": "running",
|
||||
"message": "Training pipeline started. Use GET /api/v1/train/status to monitor.",
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/v1/train/stop — stop the current training run.
|
||||
async fn train_stop(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let mut s = state.write().await;
|
||||
if s.training_status != "running" {
|
||||
return Json(serde_json::json!({
|
||||
"error": "no training in progress",
|
||||
"success": false,
|
||||
}));
|
||||
}
|
||||
s.training_status = "idle".to_string();
|
||||
info!("Training stopped");
|
||||
Json(serde_json::json!({
|
||||
"success": true,
|
||||
"status": "idle",
|
||||
}))
|
||||
}
|
||||
// Training routes are provided by training_api::routes() merged into the http_app
|
||||
// router below. The stub train_status/train_start/train_stop handlers that used
|
||||
// to live here were no-op string flippers; replaced by the real pipeline
|
||||
// (real_training_loop) in training_api.rs.
|
||||
|
||||
// ── Adaptive classifier endpoints ────────────────────────────────────────────
|
||||
|
||||
|
|
@ -6030,8 +5988,11 @@ async fn main() {
|
|||
recording_current_id: None,
|
||||
recording_stop_tx: None,
|
||||
// Training
|
||||
training_status: "idle".to_string(),
|
||||
training_config: None,
|
||||
training_state: training_api::TrainingState::default(),
|
||||
training_progress_tx: {
|
||||
let (t, _) = broadcast::channel::<String>(256);
|
||||
t
|
||||
},
|
||||
adaptive_model:
|
||||
adaptive_classifier::AdaptiveModel::load(&adaptive_classifier::model_path())
|
||||
.ok()
|
||||
|
|
@ -6228,10 +6189,8 @@ async fn main() {
|
|||
.route("/api/v1/recording/start", post(start_recording))
|
||||
.route("/api/v1/recording/stop", post(stop_recording))
|
||||
.route("/api/v1/recording/{id}", delete(delete_recording))
|
||||
// Training endpoints
|
||||
.route("/api/v1/train/status", get(train_status))
|
||||
.route("/api/v1/train/start", post(train_start))
|
||||
.route("/api/v1/train/stop", post(train_stop))
|
||||
// Training endpoints (real pipeline from training_api.rs)
|
||||
.merge(training_api::routes())
|
||||
// Adaptive classifier endpoints
|
||||
.route("/api/v1/adaptive/train", post(adaptive_train))
|
||||
.route("/api/v1/adaptive/status", get(adaptive_status))
|
||||
|
|
|
|||
|
|
@ -41,9 +41,22 @@ use serde::{Deserialize, Serialize};
|
|||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::recording::{RecordedFrame, RECORDINGS_DIR};
|
||||
use crate::rvf_container::RvfBuilder;
|
||||
|
||||
// Inlined from recording.rs to avoid pulling that module into main.rs's tree
|
||||
// (recording.rs references `AppStateInner::recording_state`, which doesn't
|
||||
// exist on the production AppStateInner — main.rs uses flat recording_* fields).
|
||||
pub const RECORDINGS_DIR: &str = "data/recordings";
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct RecordedFrame {
|
||||
pub timestamp: f64,
|
||||
pub subcarriers: Vec<f64>,
|
||||
pub rssi: f64,
|
||||
pub noise_floor: f64,
|
||||
pub features: serde_json::Value,
|
||||
}
|
||||
|
||||
// ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Directory for trained model output.
|
||||
|
|
@ -243,7 +256,7 @@ async fn load_recording_frames(dataset_ids: &[String]) -> Vec<RecordedFrame> {
|
|||
// '/', '..', null bytes, or anything outside [A-Za-z0-9._-] BEFORE
|
||||
// building the format!() path. Otherwise an attacker could read any
|
||||
// file the server process can access via `dataset_ids: ["../../etc/passwd"]`.
|
||||
let safe = match crate::path_safety::safe_id(id) {
|
||||
let safe = match wifi_densepose_sensing_server::path_safety::safe_id(id) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!("Skipping invalid dataset_id {id:?}: {e}");
|
||||
|
|
|
|||
Loading…
Reference in New Issue