//! Inference engine — loads `pose_v1.safetensors` (produced by the //! Candle training run on `ruvultra`'s RTX 5080, see //! `cog/artifacts/pose_v1.safetensors` + `docs/benchmarks/pose-estimation-cog.md`) //! and runs the encoder + pose head on each CSI window. //! //! Architecture mirrors the training script exactly: //! Conv1d(56 -> 64, k=3, dilation=1, padding=1) //! Conv1d(64 -> 128, k=3, dilation=2, padding=2) //! Conv1d(128 -> 128, k=3, dilation=4, padding=4) //! mean over time -> [128] //! Linear(128 -> 256) -> ReLU //! Linear(256 -> 34) -> sigmoid -> reshape [17, 2] //! //! When the safetensors file is missing the engine falls back to a //! centred-skeleton baseline with `confidence=0` so the cog still //! satisfies the ADR-100 runtime contract and the dashboard surfaces //! "no model yet" instead of dropping frames silently. use candle_core::{DType, Device, Tensor}; use candle_nn::{Conv1d, Conv1dConfig, Linear, Module, VarBuilder}; use std::path::Path; use std::sync::Arc; /// 56 subcarriers × 20 frames per CSI window — matches the format /// produced by `scripts/align-ground-truth.js` after #641. pub const INPUT_SUBCARRIERS: usize = 56; pub const INPUT_TIMESTEPS: usize = 20; pub const OUTPUT_KEYPOINTS: usize = 17; #[derive(Debug, Clone)] pub struct CsiWindow { pub data: Vec, // length INPUT_SUBCARRIERS * INPUT_TIMESTEPS } #[derive(Debug, Clone)] pub struct PoseOutput { /// Flat `[OUTPUT_KEYPOINTS * 2]` keypoints in `[0, 1]` normalised /// image coords, ordered (x0, y0, x1, y1, …). pub keypoints: Vec, pub confidence: f32, } impl PoseOutput { pub fn is_finite(&self) -> bool { self.keypoints.iter().all(|v| v.is_finite()) && self.confidence.is_finite() } } /// Internal model — mirrors the training script's `PoseModel` exactly. struct PoseNet { c1: Conv1d, c2: Conv1d, c3: Conv1d, fc1: Linear, fc2: Linear, } impl PoseNet { fn new(vb: VarBuilder<'_>) -> candle_core::Result { let enc = vb.pp("enc"); let head = vb.pp("head"); let c1 = candle_nn::conv1d( 56, 64, 3, Conv1dConfig { padding: 1, stride: 1, dilation: 1, groups: 1, ..Default::default() }, enc.pp("c1"), )?; let c2 = candle_nn::conv1d( 64, 128, 3, Conv1dConfig { padding: 2, stride: 1, dilation: 2, groups: 1, ..Default::default() }, enc.pp("c2"), )?; let c3 = candle_nn::conv1d( 128, 128, 3, Conv1dConfig { padding: 4, stride: 1, dilation: 4, groups: 1, ..Default::default() }, enc.pp("c3"), )?; let fc1 = candle_nn::linear(128, 256, head.pp("fc1"))?; let fc2 = candle_nn::linear(256, 34, head.pp("fc2"))?; Ok(Self { c1, c2, c3, fc1, fc2, }) } /// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`. fn forward(&self, x: &Tensor) -> candle_core::Result { let h = self.c1.forward(x)?.relu()?; let h = self.c2.forward(&h)?.relu()?; let h = self.c3.forward(&h)?.relu()?; // Global average pool over time dim (last dim) -> [B, 128] let h = h.mean(2)?; let h = self.fc1.forward(&h)?.relu()?; let h = self.fc2.forward(&h)?; // sigmoid -> keep in [0, 1] candle_nn::ops::sigmoid(&h) } } pub struct InferenceEngine { inner: Option>, device: Device, } struct LoadedModel { net: PoseNet, } impl InferenceEngine { /// Create an engine. Tries to load weights from `cog/artifacts/pose_v1.safetensors` /// (relative to current dir or the cog install dir under /// `/var/lib/cognitum/apps/pose-estimation/`). Returns a usable /// engine either way — without weights, `infer` produces the /// stub output. pub fn new() -> Result> { Self::with_weights(default_weights_path().as_deref()) } /// Create an engine with a specific weights path (used by `--config` /// in `cog-pose-estimation run`). If `weights_path` is `None`, the /// stub fallback is used. pub fn with_weights(weights_path: Option<&Path>) -> Result> { let device = pick_device(); let inner = match weights_path { Some(p) if p.exists() => { // SAFETY: `from_mmaped_safetensors` mmaps the file for the // VarBuilder's lifetime. We don't modify the file while the // VarBuilder is alive, and the file is read-only on disk on // appliance installs. let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)? }; let net = PoseNet::new(vb)?; Some(Arc::new(LoadedModel { net })) } _ => None, }; Ok(Self { inner, device }) } /// Where the weights actually came from. Useful for the run.started event. pub fn backend(&self) -> &'static str { match (&self.inner, &self.device) { (Some(_), Device::Cuda(_)) => "candle-cuda", (Some(_), _) => "candle-cpu", (None, _) => "stub", } } pub fn infer(&self, window: &CsiWindow) -> Result> { if window.data.len() != INPUT_SUBCARRIERS * INPUT_TIMESTEPS { return Err(format!( "expected {} input values, got {}", INPUT_SUBCARRIERS * INPUT_TIMESTEPS, window.data.len() ) .into()); } let Some(model) = &self.inner else { // Stub fallback — model not loaded. return Ok(PoseOutput { keypoints: vec![0.5f32; OUTPUT_KEYPOINTS * 2], confidence: 0.0, }); }; // Build [1, 56, 20] tensor from the flat row-major buffer. let t = Tensor::from_slice( &window.data, (1, INPUT_SUBCARRIERS, INPUT_TIMESTEPS), &self.device, )?; let out = model.net.forward(&t)?; // [1, 34] let flat: Vec = out.flatten_all()?.to_vec1()?; // Confidence from pose_v1 is a published constant rather than per-frame — // the trained model didn't emit a confidence head. Use the validation-set // PCK@50 (18.5%) as the published self-reported confidence so downstream // consumers can gate display decisions on it. Ok(PoseOutput { keypoints: flat, confidence: 0.185, }) } } /// Synthetic CSI window for the `health` subcommand. Zeros — exercises /// the I/O surface; the model never touches values that produce NaN. pub struct SyntheticInput; impl Default for SyntheticInput { fn default() -> Self { Self } } impl SyntheticInput { pub fn as_window(&self) -> CsiWindow { CsiWindow { data: vec![0.0; INPUT_SUBCARRIERS * INPUT_TIMESTEPS], } } } // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- fn pick_device() -> Device { #[cfg(feature = "cuda")] if let Ok(d) = Device::cuda_if_available(0) { return d; } Device::Cpu } fn default_weights_path() -> Option { // Search in the order an installed Cog would see it. let candidates = [ std::path::PathBuf::from("/var/lib/cognitum/apps/pose-estimation/pose_v1.safetensors"), std::path::PathBuf::from("./pose_v1.safetensors"), std::path::PathBuf::from("./cog/artifacts/pose_v1.safetensors"), // From the repo root. std::path::PathBuf::from("v2/crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"), // From inside v2/. std::path::PathBuf::from("crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"), ]; candidates.into_iter().find(|p| p.exists()) }