From 51fd326aaefe3791198071ad07d1ce0b4c1fa627 Mon Sep 17 00:00:00 2001 From: lockewerks <59770696+lockewerks@users.noreply.github.com> Date: Mon, 25 May 2026 16:51:05 -0600 Subject: [PATCH] test(cog-pose): verify weights load and produce valid 17-keypoint output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Strengthen `real_weights_load_when_available` in the smoke suite: - search for `pose_v1.safetensors` under the three plausible test cwds (crate dir, `v2/`, repo root) so the test exercises the real loader regardless of where `cargo test` is invoked from - after inference, chunk the flat 34-scalar output into 17 (x, y) pairs and assert each pair is finite, non-NaN, and inside the sigmoid range [0, 1] - keep the skip path with a clear message when the safetensors blob isn't on disk (so the suite still passes on a fresh appliance install where weights ship separately) This complements the inline unit tests in `src/inference.rs` — those exercise the loader / forward-pass plumbing, the smoke test exercises the public crate surface. --- v2/crates/cog-pose-estimation/tests/smoke.rs | 39 ++++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/v2/crates/cog-pose-estimation/tests/smoke.rs b/v2/crates/cog-pose-estimation/tests/smoke.rs index f44cf9d3..e27ebce1 100644 --- a/v2/crates/cog-pose-estimation/tests/smoke.rs +++ b/v2/crates/cog-pose-estimation/tests/smoke.rs @@ -40,12 +40,22 @@ fn engine_rejects_wrong_shape_input() { #[test] fn real_weights_load_when_available() { use cog_pose_estimation::inference::InferenceEngine; - let weights = std::path::Path::new("cog/artifacts/pose_v1.safetensors"); - if !weights.exists() { + // Search the cwds a `cargo test` invocation can land in + // (workspace root, `v2/`, or the crate dir). + let candidates = [ + "cog/artifacts/pose_v1.safetensors", + "crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors", + "v2/crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors", + ]; + let Some(weights) = candidates + .iter() + .map(std::path::Path::new) + .find(|p| p.exists()) + else { // Skip when running outside the repo (e.g. on a fresh appliance install). - eprintln!("(skipping — cog/artifacts/pose_v1.safetensors not present in cwd)"); + eprintln!("(skipping — pose_v1.safetensors not present under any known cwd)"); return; - } + }; let engine = InferenceEngine::with_weights(Some(weights)).expect("load real weights"); assert!( engine.backend().starts_with("candle-"), @@ -54,6 +64,27 @@ fn real_weights_load_when_available() { ); let out = engine.infer(&SyntheticInput.as_window()).expect("infer"); assert!(out.is_finite()); + + // Chunked into 17 (x, y) pairs — every coordinate must be finite and + // sit inside the sigmoid output range. This is the "real wiring" + // assertion the README points at. + let pairs: Vec<[f32; 2]> = out + .keypoints + .chunks_exact(2) + .map(|c| [c[0], c[1]]) + .collect(); + assert_eq!(pairs.len(), OUTPUT_KEYPOINTS); + for (i, [x, y]) in pairs.iter().enumerate() { + assert!( + x.is_finite() && y.is_finite() && !x.is_nan() && !y.is_nan(), + "keypoint {i} bad: ({x}, {y})" + ); + assert!( + (0.0..=1.0).contains(x) && (0.0..=1.0).contains(y), + "keypoint {i} out of [0, 1]: ({x}, {y})" + ); + } + // Real model emits the published validation PCK@50 as its self-reported // confidence — stub returns 0.0. This is the key assertion that proves // the cog isn't silently falling back to the stub.