diff --git a/v2/crates/cog-pose-estimation/src/inference.rs b/v2/crates/cog-pose-estimation/src/inference.rs index 2e1623ed..6f491248 100644 --- a/v2/crates/cog-pose-estimation/src/inference.rs +++ b/v2/crates/cog-pose-estimation/src/inference.rs @@ -46,6 +46,40 @@ impl PoseOutput { } } +/// Per-room LoRA calibration adapter (ADR-150 §3.5–3.6). Low-rank deltas on the pose +/// head: `delta = (x · A) · B`, with `A:[in,r]`, `B:[r,out]` (scale baked into `B` at +/// save time). A handful of labeled in-room samples fit this ~few-KB adapter and recover +/// SOTA-level pose for an unseen room/person, on top of the frozen shared base. +/// Adapter safetensors keys: `fc1.a`, `fc1.b`, `fc2.a`, `fc2.b` (any subset). +#[derive(Clone)] +struct PoseLora { + fc1: Option<(Tensor, Tensor)>, + fc2: Option<(Tensor, Tensor)>, +} + +impl PoseLora { + /// Load from an adapter safetensors. Missing layer keys are simply skipped. + fn load(path: &Path, device: &Device) -> candle_core::Result { + let t = candle_core::safetensors::load(path, device)?; + let pair = |a: &str, b: &str| match (t.get(a), t.get(b)) { + (Some(x), Some(y)) => Some((x.clone(), y.clone())), + _ => None, + }; + Ok(Self { + fc1: pair("fc1.a", "fc1.b"), + fc2: pair("fc2.a", "fc2.b"), + }) + } + + /// `y + (x · A) · B` when an adapter for this layer is present, else `y` unchanged. + fn apply(slot: &Option<(Tensor, Tensor)>, x: &Tensor, y: Tensor) -> candle_core::Result { + match slot { + Some((a, b)) => y + x.matmul(a)?.matmul(b)?, + None => Ok(y), + } + } +} + /// Internal model — mirrors the training script's `PoseModel` exactly. struct PoseNet { c1: Conv1d, @@ -53,6 +87,8 @@ struct PoseNet { c3: Conv1d, fc1: Linear, fc2: Linear, + /// Optional per-room calibration adapter (none = shared base behaviour). + adapter: Option, } impl PoseNet { @@ -108,20 +144,31 @@ impl PoseNet { c3, fc1, fc2, + adapter: None, }) } - /// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`. + /// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`. Applies the per-room + /// LoRA calibration adapter on the head layers when one is attached. fn forward(&self, x: &Tensor) -> candle_core::Result { let h = self.c1.forward(x)?.relu()?; let h = self.c2.forward(&h)?.relu()?; let h = self.c3.forward(&h)?.relu()?; // Global average pool over time dim (last dim) -> [B, 128] - let h = h.mean(2)?; - let h = self.fc1.forward(&h)?.relu()?; - let h = self.fc2.forward(&h)?; + let pooled = h.mean(2)?; + // fc1 (+ adapter delta) -> ReLU + let mut h1 = self.fc1.forward(&pooled)?; + if let Some(ad) = &self.adapter { + h1 = PoseLora::apply(&ad.fc1, &pooled, h1)?; + } + let h1 = h1.relu()?; + // fc2 (+ adapter delta) + let mut h2 = self.fc2.forward(&h1)?; + if let Some(ad) = &self.adapter { + h2 = PoseLora::apply(&ad.fc2, &h1, h2)?; + } // sigmoid -> keep in [0, 1] - candle_nn::ops::sigmoid(&h) + candle_nn::ops::sigmoid(&h2) } } @@ -148,6 +195,17 @@ impl InferenceEngine { /// in `cog-pose-estimation run`). If `weights_path` is `None`, the /// stub fallback is used. pub fn with_weights(weights_path: Option<&Path>) -> Result> { + Self::with_weights_and_adapter(weights_path, None) + } + + /// Create an engine with a shared base **and an optional per-room calibration + /// adapter** (ADR-150 §3.5). The adapter is a tiny LoRA safetensors fitted from a + /// short labeled in-room capture (`aether-arena/calibration/calibrate.py`); attaching + /// it recovers SOTA-level pose in an unseen room/person. `None` = uncalibrated base. + pub fn with_weights_and_adapter( + weights_path: Option<&Path>, + adapter_path: Option<&Path>, + ) -> Result> { let device = pick_device(); let inner = match weights_path { Some(p) if p.exists() => { @@ -158,7 +216,12 @@ impl InferenceEngine { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)? }; - let net = PoseNet::new(vb)?; + let mut net = PoseNet::new(vb)?; + if let Some(ap) = adapter_path { + if ap.exists() { + net.adapter = Some(PoseLora::load(ap, &device)?); + } + } Some(Arc::new(LoadedModel { net })) } _ => None, @@ -166,6 +229,14 @@ impl InferenceEngine { Ok(Self { inner, device }) } + /// Whether a per-room calibration adapter is currently attached. + pub fn is_calibrated(&self) -> bool { + self.inner + .as_ref() + .map(|m| m.net.adapter.is_some()) + .unwrap_or(false) + } + /// Where the weights actually came from. Useful for the run.started event. pub fn backend(&self) -> &'static str { match (&self.inner, &self.device) { diff --git a/v2/crates/cog-pose-estimation/tests/smoke.rs b/v2/crates/cog-pose-estimation/tests/smoke.rs index f44cf9d3..9d5cda7b 100644 --- a/v2/crates/cog-pose-estimation/tests/smoke.rs +++ b/v2/crates/cog-pose-estimation/tests/smoke.rs @@ -63,6 +63,76 @@ fn real_weights_load_when_available() { ); } +#[test] +fn per_room_adapter_changes_inference_output() { + // Build a minimal valid base + a non-trivial LoRA adapter in a tempdir, then verify + // the calibration adapter (ADR-150 §3.5) is detected and actually alters the output. + use candle_core::{DType, Device, Tensor}; + use std::collections::HashMap; + + let dev = Device::Cpu; + let dir = std::env::temp_dir().join(format!("cogpose_adapter_test_{}", std::process::id())); + std::fs::create_dir_all(&dir).unwrap(); + let base_p = dir.join("base.safetensors"); + let adapter_p = dir.join("room.adapter.safetensors"); + + // --- base weights (random but finite) matching PoseNet's VarBuilder keys --- + let mut w: HashMap = HashMap::new(); + let mut put = |k: &str, t: Tensor| { + w.insert(k.to_string(), t); + }; + put("enc.c1.weight", Tensor::randn(0f32, 0.1, (64, 56, 3), &dev).unwrap()); + put("enc.c1.bias", Tensor::zeros(64, DType::F32, &dev).unwrap()); + put("enc.c2.weight", Tensor::randn(0f32, 0.1, (128, 64, 3), &dev).unwrap()); + put("enc.c2.bias", Tensor::zeros(128, DType::F32, &dev).unwrap()); + put("enc.c3.weight", Tensor::randn(0f32, 0.1, (128, 128, 3), &dev).unwrap()); + put("enc.c3.bias", Tensor::zeros(128, DType::F32, &dev).unwrap()); + put("head.fc1.weight", Tensor::randn(0f32, 0.1, (256, 128), &dev).unwrap()); + put("head.fc1.bias", Tensor::zeros(256, DType::F32, &dev).unwrap()); + put("head.fc2.weight", Tensor::randn(0f32, 0.1, (34, 256), &dev).unwrap()); + put("head.fc2.bias", Tensor::zeros(34, DType::F32, &dev).unwrap()); + candle_core::safetensors::save(&w, &base_p).unwrap(); + + // --- adapter: non-zero low-rank deltas on both head layers (scale baked into B) --- + let r = 4usize; + let mut ad: HashMap = HashMap::new(); + ad.insert("fc1.a".into(), Tensor::randn(0f32, 0.5, (128, r), &dev).unwrap()); + ad.insert("fc1.b".into(), Tensor::randn(0f32, 0.5, (r, 256), &dev).unwrap()); + ad.insert("fc2.a".into(), Tensor::randn(0f32, 0.5, (256, r), &dev).unwrap()); + ad.insert("fc2.b".into(), Tensor::randn(0f32, 0.5, (r, 34), &dev).unwrap()); + candle_core::safetensors::save(&ad, &adapter_p).unwrap(); + + let base = InferenceEngine::with_weights(Some(&base_p)).expect("base load"); + let cal = InferenceEngine::with_weights_and_adapter(Some(&base_p), Some(&adapter_p)) + .expect("calibrated load"); + + assert!(!base.is_calibrated(), "base must report uncalibrated"); + assert!(cal.is_calibrated(), "adapter engine must report calibrated"); + + // Non-zero input — a zero window would zero the LoRA delta (x·A·B = 0). + let win = cog_pose_estimation::inference::CsiWindow { + data: (0..INPUT_SUBCARRIERS * INPUT_TIMESTEPS) + .map(|i| ((i % 7) as f32 - 3.0) * 0.2) + .collect(), + }; + let a = base.infer(&win).expect("base infer"); + let b = cal.infer(&win).expect("calibrated infer"); + assert!(a.is_finite() && b.is_finite()); + + let diff: f32 = a + .keypoints + .iter() + .zip(&b.keypoints) + .map(|(x, y)| (x - y).abs()) + .sum(); + assert!( + diff > 1e-4, + "per-room adapter must change the output (sum|Δ| = {diff})" + ); + + let _ = std::fs::remove_dir_all(&dir); +} + #[test] fn manifest_roundtrips() { let spec = ManifestSpec::embedded("pose-estimation", "0.0.1");