Merge 3a4576a5f8 into 8d64434d21
This commit is contained in:
commit
72493d8b6d
|
|
@ -58,7 +58,7 @@ Loss curve: 0.181 (epoch 0) → 0.014 (epoch 399), eval loss 0.010. **400 epochs
|
|||
- Re-train with the same Candle pipeline (already validated to converge in seconds on RTX 5080).
|
||||
- Hailo HEF export via the Dataflow Compiler on a self-hosted runner.
|
||||
|
||||
The cog's runtime inference path is currently a centred-skeleton stub returning `confidence=0`. Wiring the `pose_v1.safetensors` weights into `src/inference.rs` is the next code change — separate PR.
|
||||
The cog's runtime inference path now loads `pose_v1.safetensors` directly through Candle in `src/inference.rs` — see `InferenceEngine::with_weights` and the `weights_load_and_forward_produces_seventeen_keypoint_pairs` test in the same file. The forward pass mirrors the training script (`Conv1d 56→64→128→128` encoder with dilations `[1, 2, 4]`, `GlobalMeanPool`, `Linear 128→256→34`, sigmoid) and emits `[17, 2]` keypoints with the published `confidence = 0.185` (PCK@50). If the safetensors file is missing on disk, the engine logs a `tracing::warn!` and falls back to the centred-skeleton stub (`confidence = 0`) so the runtime contract is preserved and the dashboard surfaces "no model yet" instead of crashing. The 3% PCK@20 / 18.5% PCK@50 numbers above remain the right way to read this model — wiring the weights does not improve accuracy, only replaces the placeholder output with the trained values.
|
||||
|
||||
## See also
|
||||
|
||||
|
|
|
|||
|
|
@ -145,8 +145,10 @@ impl InferenceEngine {
|
|||
}
|
||||
|
||||
/// Create an engine with a specific weights path (used by `--config`
|
||||
/// in `cog-pose-estimation run`). If `weights_path` is `None`, the
|
||||
/// stub fallback is used.
|
||||
/// in `cog-pose-estimation run`). If `weights_path` is `None` or the
|
||||
/// file does not exist on disk, the engine falls back to the
|
||||
/// centred-skeleton stub and emits a `tracing::warn!` so the
|
||||
/// appliance log shows why no real keypoints are coming through.
|
||||
pub fn with_weights(weights_path: Option<&Path>) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let device = pick_device();
|
||||
let inner = match weights_path {
|
||||
|
|
@ -159,9 +161,26 @@ impl InferenceEngine {
|
|||
VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)?
|
||||
};
|
||||
let net = PoseNet::new(vb)?;
|
||||
tracing::info!(
|
||||
weights = %p.display(),
|
||||
"loaded pose_v1.safetensors into candle backend"
|
||||
);
|
||||
Some(Arc::new(LoadedModel { net }))
|
||||
}
|
||||
_ => None,
|
||||
Some(p) => {
|
||||
tracing::warn!(
|
||||
weights = %p.display(),
|
||||
"pose weights file not found; falling back to centred-skeleton stub (confidence=0)"
|
||||
);
|
||||
None
|
||||
}
|
||||
None => {
|
||||
tracing::warn!(
|
||||
"no pose weights path configured and no default weights found on disk; \
|
||||
falling back to centred-skeleton stub (confidence=0)"
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
Ok(Self { inner, device })
|
||||
}
|
||||
|
|
@ -255,3 +274,98 @@ fn default_weights_path() -> Option<std::path::PathBuf> {
|
|||
];
|
||||
candidates.into_iter().find(|p| p.exists())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Unit tests — exercise the safetensors → forward-pass path. Integration-level
|
||||
// assertions (CLI surface, manifest round-trip, etc.) live in `tests/smoke.rs`.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Locate `pose_v1.safetensors` from any of the cwds a `cargo test`
|
||||
/// invocation might land in (workspace root, `v2/`, or the crate dir).
|
||||
fn locate_weights() -> Option<std::path::PathBuf> {
|
||||
let candidates = [
|
||||
std::path::PathBuf::from("cog/artifacts/pose_v1.safetensors"),
|
||||
std::path::PathBuf::from("crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"),
|
||||
std::path::PathBuf::from(
|
||||
"v2/crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors",
|
||||
),
|
||||
];
|
||||
candidates.into_iter().find(|p| p.exists())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stub_fallback_when_weights_missing() {
|
||||
// `with_weights(None)` must never panic and must produce a
|
||||
// finite, well-shaped output so the runtime loop keeps making
|
||||
// progress while the operator notices the warn log.
|
||||
let engine = InferenceEngine::with_weights(None).expect("engine init");
|
||||
assert_eq!(engine.backend(), "stub");
|
||||
let out = engine.infer(&SyntheticInput.as_window()).expect("infer");
|
||||
assert!(out.is_finite());
|
||||
assert_eq!(out.keypoints.len(), OUTPUT_KEYPOINTS * 2);
|
||||
assert_eq!(out.confidence, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn weights_load_and_forward_produces_seventeen_keypoint_pairs() {
|
||||
let Some(weights) = locate_weights() else {
|
||||
eprintln!(
|
||||
"(skipping — pose_v1.safetensors not on disk; run from the cog crate or repo root)"
|
||||
);
|
||||
return;
|
||||
};
|
||||
let engine = InferenceEngine::with_weights(Some(&weights)).expect("load real weights");
|
||||
assert!(
|
||||
engine.backend().starts_with("candle-"),
|
||||
"expected candle backend, got {}",
|
||||
engine.backend()
|
||||
);
|
||||
|
||||
// Synthetic [56, 20] zero-input window — the documented "no-op"
|
||||
// test signal. Anything finite and well-shaped proves the
|
||||
// safetensors weights flowed through the forward pass.
|
||||
let out = engine.infer(&SyntheticInput.as_window()).expect("infer");
|
||||
|
||||
// Shape: 17 (x, y) pairs = 34 scalars, no NaN, no Inf, all in
|
||||
// sigmoid's [0, 1] range.
|
||||
assert_eq!(
|
||||
out.keypoints.len(),
|
||||
OUTPUT_KEYPOINTS * 2,
|
||||
"expected {} scalars for 17 keypoint pairs",
|
||||
OUTPUT_KEYPOINTS * 2
|
||||
);
|
||||
let pairs: Vec<[f32; 2]> = out
|
||||
.keypoints
|
||||
.chunks_exact(2)
|
||||
.map(|c| [c[0], c[1]])
|
||||
.collect();
|
||||
assert_eq!(pairs.len(), OUTPUT_KEYPOINTS, "expected 17 (x, y) pairs");
|
||||
for (i, [x, y]) in pairs.iter().enumerate() {
|
||||
assert!(
|
||||
x.is_finite() && y.is_finite(),
|
||||
"keypoint {i} not finite: ({x}, {y})"
|
||||
);
|
||||
assert!(!x.is_nan() && !y.is_nan(), "keypoint {i} is NaN");
|
||||
assert!(
|
||||
(0.0..=1.0).contains(x) && (0.0..=1.0).contains(y),
|
||||
"keypoint {i} out of [0,1]: ({x}, {y})"
|
||||
);
|
||||
}
|
||||
|
||||
// Confidence is the published PCK@50 (constant for v0.0.1), so
|
||||
// anything > 0 proves we didn't silently fall through to the stub.
|
||||
assert!(out.confidence > 0.0);
|
||||
assert!(out.confidence.is_finite());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_wrong_shape_input_before_any_forward_pass() {
|
||||
let engine = InferenceEngine::with_weights(None).expect("engine init");
|
||||
let bad = CsiWindow { data: vec![0.0; 7] };
|
||||
assert!(engine.infer(&bad).is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,12 +40,22 @@ fn engine_rejects_wrong_shape_input() {
|
|||
#[test]
|
||||
fn real_weights_load_when_available() {
|
||||
use cog_pose_estimation::inference::InferenceEngine;
|
||||
let weights = std::path::Path::new("cog/artifacts/pose_v1.safetensors");
|
||||
if !weights.exists() {
|
||||
// Search the cwds a `cargo test` invocation can land in
|
||||
// (workspace root, `v2/`, or the crate dir).
|
||||
let candidates = [
|
||||
"cog/artifacts/pose_v1.safetensors",
|
||||
"crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors",
|
||||
"v2/crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors",
|
||||
];
|
||||
let Some(weights) = candidates
|
||||
.iter()
|
||||
.map(std::path::Path::new)
|
||||
.find(|p| p.exists())
|
||||
else {
|
||||
// Skip when running outside the repo (e.g. on a fresh appliance install).
|
||||
eprintln!("(skipping — cog/artifacts/pose_v1.safetensors not present in cwd)");
|
||||
eprintln!("(skipping — pose_v1.safetensors not present under any known cwd)");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let engine = InferenceEngine::with_weights(Some(weights)).expect("load real weights");
|
||||
assert!(
|
||||
engine.backend().starts_with("candle-"),
|
||||
|
|
@ -54,6 +64,27 @@ fn real_weights_load_when_available() {
|
|||
);
|
||||
let out = engine.infer(&SyntheticInput.as_window()).expect("infer");
|
||||
assert!(out.is_finite());
|
||||
|
||||
// Chunked into 17 (x, y) pairs — every coordinate must be finite and
|
||||
// sit inside the sigmoid output range. This is the "real wiring"
|
||||
// assertion the README points at.
|
||||
let pairs: Vec<[f32; 2]> = out
|
||||
.keypoints
|
||||
.chunks_exact(2)
|
||||
.map(|c| [c[0], c[1]])
|
||||
.collect();
|
||||
assert_eq!(pairs.len(), OUTPUT_KEYPOINTS);
|
||||
for (i, [x, y]) in pairs.iter().enumerate() {
|
||||
assert!(
|
||||
x.is_finite() && y.is_finite() && !x.is_nan() && !y.is_nan(),
|
||||
"keypoint {i} bad: ({x}, {y})"
|
||||
);
|
||||
assert!(
|
||||
(0.0..=1.0).contains(x) && (0.0..=1.0).contains(y),
|
||||
"keypoint {i} out of [0, 1]: ({x}, {y})"
|
||||
);
|
||||
}
|
||||
|
||||
// Real model emits the published validation PCK@50 as its self-reported
|
||||
// confidence — stub returns 0.0. This is the key assertion that proves
|
||||
// the cog isn't silently falling back to the stub.
|
||||
|
|
|
|||
Loading…
Reference in New Issue