feat(cog-pose): wire pose_v1.safetensors into inference path

Replace silent stub fallback in InferenceEngine with a tracing-aware
loader:

- emit `tracing::info!` when pose_v1.safetensors loads into the candle
  backend, including the resolved on-disk path
- emit `tracing::warn!` (instead of silently falling through) when the
  configured weights file is missing or no path is set at all, so the
  appliance log surfaces "no model — running stub" instead of just
  emitting confidence=0 frames forever
- inline unit tests covering both branches: stub fallback shape +
  confidence, and a real-weights forward pass that asserts the output
  decomposes into 17 (x, y) pairs all finite and in [0, 1]

The forward pass itself already matched the training script
(Conv1d 56->64->128->128 with dilations [1, 2, 4], GlobalMeanPool,
Linear 128->256->34, sigmoid) and the tensor names in the
safetensors file (`enc.c1/c2/c3.weight|bias`, `head.fc1/fc2.weight|bias`)
line up with the VarBuilder prefixes — no architecture changes
required, only loader hygiene.

Confidence is the published validation PCK@50 (0.185); v0.0.1 doesn't
emit per-frame confidence and we're not fudging that here.
This commit is contained in:
lockewerks 2026-05-25 16:50:56 -06:00
parent baba851a89
commit 339d9d70dc
1 changed files with 117 additions and 3 deletions

View File

@ -145,8 +145,10 @@ impl InferenceEngine {
}
/// Create an engine with a specific weights path (used by `--config`
/// in `cog-pose-estimation run`). If `weights_path` is `None`, the
/// stub fallback is used.
/// in `cog-pose-estimation run`). If `weights_path` is `None` or the
/// file does not exist on disk, the engine falls back to the
/// centred-skeleton stub and emits a `tracing::warn!` so the
/// appliance log shows why no real keypoints are coming through.
pub fn with_weights(weights_path: Option<&Path>) -> Result<Self, Box<dyn std::error::Error>> {
let device = pick_device();
let inner = match weights_path {
@ -159,9 +161,26 @@ impl InferenceEngine {
VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)?
};
let net = PoseNet::new(vb)?;
tracing::info!(
weights = %p.display(),
"loaded pose_v1.safetensors into candle backend"
);
Some(Arc::new(LoadedModel { net }))
}
_ => None,
Some(p) => {
tracing::warn!(
weights = %p.display(),
"pose weights file not found; falling back to centred-skeleton stub (confidence=0)"
);
None
}
None => {
tracing::warn!(
"no pose weights path configured and no default weights found on disk; \
falling back to centred-skeleton stub (confidence=0)"
);
None
}
};
Ok(Self { inner, device })
}
@ -255,3 +274,98 @@ fn default_weights_path() -> Option<std::path::PathBuf> {
];
candidates.into_iter().find(|p| p.exists())
}
// ---------------------------------------------------------------------------
// Unit tests — exercise the safetensors → forward-pass path. Integration-level
// assertions (CLI surface, manifest round-trip, etc.) live in `tests/smoke.rs`.
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
/// Locate `pose_v1.safetensors` from any of the cwds a `cargo test`
/// invocation might land in (workspace root, `v2/`, or the crate dir).
fn locate_weights() -> Option<std::path::PathBuf> {
let candidates = [
std::path::PathBuf::from("cog/artifacts/pose_v1.safetensors"),
std::path::PathBuf::from("crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"),
std::path::PathBuf::from(
"v2/crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors",
),
];
candidates.into_iter().find(|p| p.exists())
}
#[test]
fn stub_fallback_when_weights_missing() {
// `with_weights(None)` must never panic and must produce a
// finite, well-shaped output so the runtime loop keeps making
// progress while the operator notices the warn log.
let engine = InferenceEngine::with_weights(None).expect("engine init");
assert_eq!(engine.backend(), "stub");
let out = engine.infer(&SyntheticInput.as_window()).expect("infer");
assert!(out.is_finite());
assert_eq!(out.keypoints.len(), OUTPUT_KEYPOINTS * 2);
assert_eq!(out.confidence, 0.0);
}
#[test]
fn weights_load_and_forward_produces_seventeen_keypoint_pairs() {
let Some(weights) = locate_weights() else {
eprintln!(
"(skipping — pose_v1.safetensors not on disk; run from the cog crate or repo root)"
);
return;
};
let engine = InferenceEngine::with_weights(Some(&weights)).expect("load real weights");
assert!(
engine.backend().starts_with("candle-"),
"expected candle backend, got {}",
engine.backend()
);
// Synthetic [56, 20] zero-input window — the documented "no-op"
// test signal. Anything finite and well-shaped proves the
// safetensors weights flowed through the forward pass.
let out = engine.infer(&SyntheticInput.as_window()).expect("infer");
// Shape: 17 (x, y) pairs = 34 scalars, no NaN, no Inf, all in
// sigmoid's [0, 1] range.
assert_eq!(
out.keypoints.len(),
OUTPUT_KEYPOINTS * 2,
"expected {} scalars for 17 keypoint pairs",
OUTPUT_KEYPOINTS * 2
);
let pairs: Vec<[f32; 2]> = out
.keypoints
.chunks_exact(2)
.map(|c| [c[0], c[1]])
.collect();
assert_eq!(pairs.len(), OUTPUT_KEYPOINTS, "expected 17 (x, y) pairs");
for (i, [x, y]) in pairs.iter().enumerate() {
assert!(
x.is_finite() && y.is_finite(),
"keypoint {i} not finite: ({x}, {y})"
);
assert!(!x.is_nan() && !y.is_nan(), "keypoint {i} is NaN");
assert!(
(0.0..=1.0).contains(x) && (0.0..=1.0).contains(y),
"keypoint {i} out of [0,1]: ({x}, {y})"
);
}
// Confidence is the published PCK@50 (constant for v0.0.1), so
// anything > 0 proves we didn't silently fall through to the stub.
assert!(out.confidence > 0.0);
assert!(out.confidence.is_finite());
}
#[test]
fn rejects_wrong_shape_input_before_any_forward_pass() {
let engine = InferenceEngine::with_weights(None).expect("engine init");
let bad = CsiWindow { data: vec![0.0; 7] };
assert!(engine.infer(&bad).is_err());
}
}