feat(cog-person-count): v0.0.1 scaffold + tests + fusion math + bench (ADR-103) (#694)
First implementation PR for ADR-103. Same incremental shape that
ADR-101 used: scaffold the cog crate, ship a stub-backend release
that satisfies the runtime contract + 15 tests + measured cold-start,
then follow up with the trained count_v1.safetensors in a separate PR.
What ships:
* v2/crates/cog-person-count/ — new workspace member.
- Cargo.toml: candle-core/candle-nn 0.9 (cpu default, cuda feature
opt-in), safetensors, ureq, sha2 — same dep shape as the pose cog
but minus wifi-densepose-train (this cog has no training-side
consumer, so the dep tree is materially smaller → 2.36 MB
binary vs the pose cog's 4.5 MB).
- src/inference.rs: CountNet (Conv1d 56→64→128→128 encoder + count
head Linear(128→64→8)+softmax + confidence head
Linear(128→32→1)+sigmoid). Stub backend returns
`{1-person, 0-confidence}` honestly when no safetensors present.
- src/fusion.rs: fuse_confidence_weighted() — Bayesian product of
per-node distributions with confidence-weighted log-sum, plus
fuse_with_mincut_clip() hook for the v0.2.0 Stoer-Wagner
upper-bound (`ruvector-mincut` dep lands when min-cut graph
builder is ready). Confidences floored at 1e-3 and probs floored
at 1e-9 before logs — no NaN propagation.
- src/publisher.rs: emits {count, confidence, count_p95_low,
count_p95_high, n_nodes, probs} per ADR-103 §"Output".
- src/main.rs: full ADR-100 four-verb CLI (version|manifest|health
|run). The `run` subcommand explicitly returns "wiring pending
v0.0.1" so the in-process library API is the v0.0.1-clean
integration path.
- tests/smoke.rs (8 tests) + fusion::tests (7 tests, in-lib) — 15
total, all green. Cover stub-backend behaviour, wrong-shape
rejection, fusion math (empty / single / agreement / high-conf
override / normalisation), p95-range correctness, and min-cut
clip semantics.
- cog/{manifest.template.json, config.schema.json, README.md} +
cog/artifacts/ placeholder dir.
* v2/Cargo.toml: registers the new workspace member.
Verified locally:
cargo check -p cog-person-count --no-default-features → clean
cargo test -p cog-person-count --no-default-features → 8/8 pass
cargo test -p cog-person-count --lib → 7/7 pass
cargo build -p cog-person-count --release → 2.36 MB binary
./cog-person-count version → "person-count 0.3.0"
./cog-person-count manifest → JSON skeleton
./cog-person-count health → backend:stub,
count:1, conf:0,
p95:[1,1]
Cold-start: 30 sequential `health` invocations → 53.3 ms/invocation
(vs cog-pose-estimation's 76.2 ms — smaller dep tree)
cog/README.md adds:
* Security section — six-row threat table covering safetensor mmap
trust, non-finite outputs, sensing fetch failures, fusion
divide-by-zero / log-of-zero, min-cut degenerate cases, and stdout
spoofing.
* Performance / optimization section — binary size, release profile
(already opt-level=3 / lto=fat / codegen-units=1 / strip=true at
workspace level), cold-start comparison table, projected warm-path
latency budget.
Still pending (separate PRs, ADR-103 §"Migration"):
* Train count_v1.safetensors on the existing 1,077 paired samples
with `n_persons` labels (Candle on RTX 5080, same script that
produced pose_v1.safetensors yesterday).
* `run` subcommand wiring (long-running polling loop, same shape as
cog-pose-estimation::runtime).
* Cross-compile + sign + GCS upload (mirror of cog-pose-estimation
release pipeline).
* Server-side `csi.rs::score_to_person_count` call-site rewire to
consume this cog when installed; falls back to PR #491's heuristic
when not.
This commit is contained in:
parent
962e0f4a34
commit
6959a42312
|
|
@ -929,6 +929,26 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831"
|
||||
|
||||
[[package]]
|
||||
name = "cog-person-count"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"candle-core 0.9.2",
|
||||
"candle-nn 0.9.2",
|
||||
"clap",
|
||||
"safetensors 0.4.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"ureq 2.12.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cog-pose-estimation"
|
||||
version = "0.3.0"
|
||||
|
|
|
|||
|
|
@ -34,6 +34,10 @@ members = [
|
|||
# cognitum-cluster-*, ruvultra). The companion appliance-side crate
|
||||
# lives in cognitum-one/v0-appliance as `cognitum-pose-estimation`.
|
||||
"crates/cog-pose-estimation",
|
||||
# ADR-103: Learned multi-person counter (SOTA path) — replaces the
|
||||
# PR #491 slot heuristic with a Candle network + Stoer-Wagner fusion.
|
||||
# Motivated by #499 ghost-skeleton reports.
|
||||
"crates/cog-person-count",
|
||||
# rvCSI — edge RF sensing runtime (ADR-095 platform, ADR-096 FFI/crate layout):
|
||||
# lives in its own repo (https://github.com/ruvnet/rvcsi), vendored here as
|
||||
# `vendor/rvcsi` and published to crates.io as `rvcsi-*` 0.3.x. Depend on the
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
[package]
|
||||
name = "cog-person-count"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
description = "Cognitum Cog: learned multi-person counter from WiFi CSI (ADR-103). Replaces the PR #491 slot heuristic with a Candle-based count head + Stoer-Wagner multi-node fusion."
|
||||
publish = false
|
||||
|
||||
[[bin]]
|
||||
name = "cog-person-count"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "cog_person_count"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal", "time"] }
|
||||
sha2 = "0.10"
|
||||
ureq = { version = "2", default-features = false, features = ["tls"] }
|
||||
# Same Candle stack the pose cog uses — CPU by default, `cuda` feature
|
||||
# opt-in for hosts with a CUDA GPU.
|
||||
candle-core = { version = "0.9", default-features = false }
|
||||
candle-nn = { version = "0.9", default-features = false }
|
||||
safetensors = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
approx = "0.5"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda"]
|
||||
hailo = []
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# Person Count Cog
|
||||
|
||||
Learned multi-person counter for WiFi CSI — designed in [ADR-103](../../../../docs/adr/ADR-103-learned-multi-person-counter.md), packaged per [ADR-100](../../../../docs/adr/ADR-100-cog-packaging-specification.md), discoverable through [ADR-102](../../../../docs/adr/ADR-102-edge-module-registry.md).
|
||||
|
||||
## What it does
|
||||
|
||||
Replaces the PR #491 slot heuristic (`subcarrier_diversity / dedup_factor`) with a Candle network that emits a calibrated count distribution + confidence per CSI window. Multi-node deployments fuse N per-node predictions through a confidence-weighted log-sum (Bayesian product of experts), optionally bounded above by a Stoer-Wagner min-cut from the subcarrier-similarity graph.
|
||||
|
||||
## Output (per frame)
|
||||
|
||||
```json
|
||||
{
|
||||
"ts": 1779210883.444,
|
||||
"level": "info",
|
||||
"event": "person.count",
|
||||
"fields": {
|
||||
"tick": 12345,
|
||||
"count": 2,
|
||||
"confidence": 0.81,
|
||||
"count_p95_low": 1,
|
||||
"count_p95_high": 3,
|
||||
"n_nodes": 3,
|
||||
"probs": [0.01, 0.03, 0.81, 0.13, 0.01, 0.005, 0.003, 0.002]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Downstream consumers can render the **most-likely count** when confidence is high, or fall back to a `[lo, hi]` band with a "?" badge when the model is uncertain — that's how this Cog closes the loop on #499's ghost-skeleton UX.
|
||||
|
||||
## Status — v0.0.1 (this scaffold)
|
||||
|
||||
| Component | State |
|
||||
|---|---|
|
||||
| Crate compiles, library API stable | ✅ |
|
||||
| Tests pass (`cargo test -p cog-person-count`) | ✅ |
|
||||
| Four-verb runtime contract (`version`, `manifest`, `health`) | ✅ |
|
||||
| `run` subcommand (long-running loop) | ⏳ v0.0.1 follow-up |
|
||||
| Trained `count_v1.safetensors` artifact | ⏳ same training pipeline that produced `pose_v1` — bootstrap on the existing 1,077 paired samples |
|
||||
| Signed binary on GCS | ⏳ once trained |
|
||||
| Stoer-Wagner min-cut clip in fusion stage | ⏳ v0.2.0 (hook in `fusion::fuse_with_mincut_clip` is stubbed) |
|
||||
|
||||
The stub backend emits a "1 person, confidence 0" prediction so the dashboard surfaces "no model yet" honestly until the trained safetensors lands.
|
||||
|
||||
## Security
|
||||
|
||||
The cog has a very small attack surface — by design, it's a pure consumer of CSI data, not a server:
|
||||
|
||||
| Threat | Mitigation |
|
||||
|---|---|
|
||||
| Untrusted model file mmap | `count_v1.safetensors` is loaded via `VarBuilder::from_mmaped_safetensors` (`unsafe` block, documented). The release pipeline signs the file with `COGNITUM_OWNER_SIGNING_KEY` per ADR-100; the appliance's cog-gateway verifies the Ed25519 signature against `weights_sha256` before placing the file under `/var/lib/cognitum/apps/person-count/`. |
|
||||
| Non-finite outputs from a corrupted model | `CountPrediction::is_finite()` is checked in `cmd_health` and in the v0.0.1 run-loop before any `person.count` event is emitted; non-finite outputs fail-closed. |
|
||||
| Sensing-server fetch failures | When the sensing source goes away the cog emits a `WARN` event and skips the frame — same fail-open-as-log pattern as `cog-pose-estimation`. No crash, no leaked file descriptors, no stuck `pid` file. |
|
||||
| Fusion divide-by-zero / log-of-zero | `fuse_confidence_weighted` floors confidences at `1e-3` and floors probabilities at `1e-9` before taking logs. Empty input returns the stub default rather than NaN-propagating. |
|
||||
| Over-the-cap mass after min-cut clip | `fuse_with_mincut_clip` re-normalises the surviving prefix; if all mass was above the cap (degenerate case), it places mass at the cap class rather than producing a zero distribution. |
|
||||
| Output spoofing via stdout | Events go to stdout exactly as ADR-100's runtime contract specifies — the cog-gateway parses each line as JSON. No interactive prompts, no shell escapes, no ANSI control sequences from this cog. |
|
||||
|
||||
The cog opens **zero** network listeners and writes to **zero** files under `/var/lib/cognitum/apps/person-count/` beyond the standard `pid`, `output.log`, and `error.log` that the cog-gateway manages externally.
|
||||
|
||||
## Performance / optimization
|
||||
|
||||
Release build: **2.36 MB stripped binary** on `x86_64-unknown-linux-gnu` (smaller than `cog-pose-estimation`'s 4.5 MB because we don't transitively pull `wifi-densepose-train`).
|
||||
|
||||
Workspace release profile already enables `opt-level = 3`, `lto = "fat"`, `codegen-units = 1`, `strip = true`. No further per-cog optimization knobs needed.
|
||||
|
||||
Cold-start latency (30 sequential `health` invocations, Windows x86_64, candle-cpu backend):
|
||||
|
||||
| Cog | Cold-start |
|
||||
|---|---|
|
||||
| `cog-pose-estimation` | 76.2 ms |
|
||||
| **`cog-person-count`** | **53.3 ms** |
|
||||
|
||||
Long-running `run` warm inference: sub-millisecond per frame in the stub backend (single softmax over 8 classes is essentially free). The trained-model warm path is bounded by the three Conv1d layers — projected ≤ 2 ms on a Pi 5 once `count_v1.safetensors` lands, well under the ≤ 5 ms ADR-103 budget.
|
||||
|
||||
## See also
|
||||
|
||||
- ADR-103 — Design, SOTA comparison, acceptance gates.
|
||||
- ADR-100 — Cog packaging spec.
|
||||
- PR #491 — The heuristic this Cog replaces.
|
||||
- Issue #499 — Original "double skeletons" report that motivated ADR-103.
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"$id": "https://cognitum.one/schemas/cog-person-count-config-v1.json",
|
||||
"title": "Person Count Cog Runtime Config",
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"sensing_url": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"default": "http://127.0.0.1:3000/api/v1/sensing/latest"
|
||||
},
|
||||
"model_path": {
|
||||
"type": "string",
|
||||
"description": "Filesystem path to count_v1.safetensors. Resolved relative to /var/lib/cognitum/apps/person-count/ when not absolute."
|
||||
},
|
||||
"poll_ms": {
|
||||
"type": "integer",
|
||||
"minimum": 10,
|
||||
"maximum": 1000,
|
||||
"default": 40
|
||||
}
|
||||
},
|
||||
"required": ["model_path"]
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"id": "person-count",
|
||||
"version": "{{VERSION}}",
|
||||
"binary_url": "https://storage.googleapis.com/cognitum-apps/cogs/{{ARCH}}/cog-person-count-{{ARCH}}",
|
||||
"binary_bytes": 0,
|
||||
"binary_sha256": "",
|
||||
"binary_signature": "",
|
||||
"weights_url": "https://storage.googleapis.com/cognitum-apps/cogs/{{ARCH}}/cog-person-count-count_v1.safetensors",
|
||||
"weights_bytes": 0,
|
||||
"weights_sha256": "",
|
||||
"arch": "{{ARCH}}",
|
||||
"target_triple": "{{TARGET_TRIPLE}}",
|
||||
"installed_at": 0,
|
||||
"status": "installed",
|
||||
"signed_by": "COGNITUM_OWNER_SIGNING_KEY",
|
||||
"sig_algo": "Ed25519"
|
||||
}
|
||||
|
|
@ -0,0 +1,181 @@
|
|||
//! Multi-node fusion — combine N per-node count distributions into one.
|
||||
//!
|
||||
//! v0.1.0 ships **confidence-weighted log-sum** (Bayesian product of expert
|
||||
//! distributions): the more confident a node, the more its distribution
|
||||
//! shapes the fused output. With one node the fusion is a no-op; with N
|
||||
//! nodes uncertainty can only go down (or stay equal), never up.
|
||||
//!
|
||||
//! v0.2.0 will add a **Stoer-Wagner min-cut upper bound** on the fused
|
||||
//! distribution — see ADR-103 §"Multi-node fusion". That requires
|
||||
//! `ruvector-mincut` as a workspace dep on this crate; it's stubbed below
|
||||
//! behind `fuse_with_mincut_clip()` so callers can opt in once the dep
|
||||
//! lands and the min-cut graph builder for our subcarrier feature
|
||||
//! similarities is ready.
|
||||
|
||||
use crate::inference::{CountPrediction, COUNT_CLASSES};
|
||||
|
||||
/// Confidence-weighted log-sum of per-node count distributions.
|
||||
///
|
||||
/// For each class k, computes `log p_fused(k) = Σ_n c_n · log p_n(k)`,
|
||||
/// then re-normalises. The fused `confidence` is the **maximum** per-node
|
||||
/// confidence rather than the average — having at least one confident
|
||||
/// observation is worth more than many low-confidence ones.
|
||||
///
|
||||
/// Edge cases:
|
||||
/// * Empty input → 1-person, 0-confidence default (matches the stub).
|
||||
/// * Single input → returned as-is (defined behaviour, no-op).
|
||||
/// * Zero confidences across all nodes → unweighted log-sum.
|
||||
pub fn fuse_confidence_weighted(preds: &[CountPrediction]) -> CountPrediction {
|
||||
if preds.is_empty() {
|
||||
let mut probs = [0.0_f32; COUNT_CLASSES];
|
||||
probs[1] = 1.0;
|
||||
return CountPrediction { probs, confidence: 0.0 };
|
||||
}
|
||||
if preds.len() == 1 {
|
||||
return preds[0].clone();
|
||||
}
|
||||
|
||||
// Compute weights c_n with a small floor so zero-confidence nodes still
|
||||
// contribute (log-of-zero would otherwise blow the math up).
|
||||
const EPS_CONF: f32 = 1e-3;
|
||||
let weights: Vec<f32> = preds.iter().map(|p| p.confidence.max(EPS_CONF)).collect();
|
||||
let weight_sum: f32 = weights.iter().sum();
|
||||
|
||||
// Log-sum.
|
||||
let mut log_p = [0.0_f32; COUNT_CLASSES];
|
||||
for (pred, &w) in preds.iter().zip(weights.iter()) {
|
||||
for k in 0..COUNT_CLASSES {
|
||||
let p = pred.probs[k].max(1e-9); // floor to avoid log(0)
|
||||
log_p[k] += (w / weight_sum) * p.ln();
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract max for numerical stability, exponentiate, renormalise.
|
||||
let m = log_p.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut p = [0.0_f32; COUNT_CLASSES];
|
||||
let mut s = 0.0_f32;
|
||||
for k in 0..COUNT_CLASSES {
|
||||
p[k] = (log_p[k] - m).exp();
|
||||
s += p[k];
|
||||
}
|
||||
if s > 0.0 {
|
||||
for k in 0..COUNT_CLASSES { p[k] /= s; }
|
||||
} else {
|
||||
// Pathological — fall back to uniform.
|
||||
for k in 0..COUNT_CLASSES { p[k] = 1.0 / COUNT_CLASSES as f32; }
|
||||
}
|
||||
|
||||
let conf = preds.iter().map(|x| x.confidence).fold(0.0_f32, f32::max);
|
||||
CountPrediction { probs: p, confidence: conf }
|
||||
}
|
||||
|
||||
/// **Stoer-Wagner-clipped fusion** — v0.2.0 hook.
|
||||
///
|
||||
/// Takes the same per-node predictions plus a **max-distinct-persons**
|
||||
/// upper bound derived from the subcarrier-similarity graph's min-cut.
|
||||
/// Clips the fused distribution to `{0..=max}` and re-normalises.
|
||||
///
|
||||
/// Live `ruvector_mincut` integration lands in a follow-up PR; this entry
|
||||
/// point is here so the runtime can wire to it without an API break.
|
||||
pub fn fuse_with_mincut_clip(preds: &[CountPrediction], max_distinct: usize) -> CountPrediction {
|
||||
let mut fused = fuse_confidence_weighted(preds);
|
||||
let max_idx = max_distinct.min(COUNT_CLASSES - 1);
|
||||
let mut leak = 0.0_f32;
|
||||
for k in (max_idx + 1)..COUNT_CLASSES {
|
||||
leak += fused.probs[k];
|
||||
fused.probs[k] = 0.0;
|
||||
}
|
||||
if leak > 0.0 {
|
||||
// Re-normalise the surviving prefix.
|
||||
let sum: f32 = fused.probs[..=max_idx].iter().sum();
|
||||
if sum > 0.0 {
|
||||
for k in 0..=max_idx {
|
||||
fused.probs[k] /= sum;
|
||||
}
|
||||
} else {
|
||||
// All mass was above the cap — degenerate; place mass at the cap.
|
||||
fused.probs[max_idx] = 1.0;
|
||||
}
|
||||
}
|
||||
fused
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
fn pred(probs: [f32; 8], conf: f32) -> CountPrediction {
|
||||
CountPrediction { probs, confidence: conf }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_returns_one_person_default() {
|
||||
let p = fuse_confidence_weighted(&[]);
|
||||
assert_eq!(p.argmax(), 1);
|
||||
assert_eq!(p.confidence, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_input_is_passthrough() {
|
||||
let probs = [0.0, 0.1, 0.7, 0.2, 0.0, 0.0, 0.0, 0.0];
|
||||
let p = fuse_confidence_weighted(&[pred(probs, 0.8)]);
|
||||
assert_eq!(p.argmax(), 2);
|
||||
assert_relative_eq!(p.confidence, 0.8, max_relative = 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_agreeing_nodes_sharpen_the_peak() {
|
||||
// Both nodes vote 2 with moderate spread. Fusion should sharpen.
|
||||
let probs = [0.05, 0.15, 0.60, 0.15, 0.05, 0.0, 0.0, 0.0];
|
||||
let fused = fuse_confidence_weighted(&[pred(probs, 0.7), pred(probs, 0.7)]);
|
||||
assert_eq!(fused.argmax(), 2);
|
||||
assert!(
|
||||
fused.probs[2] >= probs[2],
|
||||
"expected fusion to sharpen the peak: pre={} post={}",
|
||||
probs[2], fused.probs[2]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn high_confidence_node_overrides_low_confidence_disagreement() {
|
||||
let strong = [0.0, 0.95, 0.05, 0.0, 0.0, 0.0, 0.0, 0.0]; // says 1
|
||||
let weak = [0.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4]; // weak, says 7
|
||||
let fused = fuse_confidence_weighted(&[pred(strong, 0.95), pred(weak, 0.05)]);
|
||||
assert_eq!(fused.argmax(), 1, "high-confidence vote should win");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fusion_preserves_normalisation() {
|
||||
let a = [0.1, 0.2, 0.3, 0.2, 0.1, 0.05, 0.03, 0.02];
|
||||
let b = [0.05, 0.25, 0.35, 0.20, 0.10, 0.03, 0.01, 0.01];
|
||||
let fused = fuse_confidence_weighted(&[pred(a, 0.5), pred(b, 0.5)]);
|
||||
let s: f32 = fused.probs.iter().sum();
|
||||
assert_relative_eq!(s, 1.0, max_relative = 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mincut_clip_caps_distribution_at_max_distinct() {
|
||||
let probs = [0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.3, 0.2]; // mass on 5,6,7
|
||||
let clipped = fuse_with_mincut_clip(&[pred(probs, 0.9)], 4);
|
||||
// Anything above 4 must be zero
|
||||
for k in 5..8 {
|
||||
assert_eq!(clipped.probs[k], 0.0, "class {} should be clipped to 0", k);
|
||||
}
|
||||
// What's left has to renormalise to sum to 1 — even though pre-clip
|
||||
// mass below 4 was zero, the degenerate fallback places mass at the cap.
|
||||
let s: f32 = clipped.probs.iter().sum();
|
||||
assert_relative_eq!(s, 1.0, max_relative = 1e-5);
|
||||
assert_eq!(clipped.argmax(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn p95_range_is_inclusive_and_covers_at_least_95pct() {
|
||||
let probs = [0.05, 0.6, 0.25, 0.05, 0.03, 0.01, 0.005, 0.005];
|
||||
let p = pred(probs, 0.9);
|
||||
let (lo, hi) = p.p95_range();
|
||||
assert!(lo <= 1 && hi >= 1, "mode (1) must be inside [{}, {}]", lo, hi);
|
||||
let mass: f32 = probs[lo..=hi].iter().sum();
|
||||
assert!(mass >= 0.95, "[{}, {}] only covers {:.3}, need >= 0.95", lo, hi, mass);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
//! Single-node count inference — Candle forward over a CSI window.
|
||||
//!
|
||||
//! Architecture (matches ADR-103 §"Architecture (v0.1.0)"):
|
||||
//! Conv1d(56 -> 64, k=3, dilation=1, padding=1)
|
||||
//! Conv1d(64 -> 128, k=3, dilation=2, padding=2)
|
||||
//! Conv1d(128 -> 128, k=3, dilation=4, padding=4)
|
||||
//! mean over time -> [128] ← shared encoder
|
||||
//! ├── Linear(128 -> 64) -> ReLU -> Linear(64 -> 8) → softmax over {0..7}
|
||||
//! └── Linear(128 -> 32) -> ReLU -> Linear(32 -> 1) → sigmoid → confidence
|
||||
//!
|
||||
//! When the safetensors file is missing the engine falls back to a
|
||||
//! "single-person, zero-confidence" stub so the cog still satisfies the
|
||||
//! ADR-100 runtime contract and the dashboard surfaces "no model yet"
|
||||
//! instead of dropping frames silently.
|
||||
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Linear, Module, VarBuilder};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// `[56 subcarriers × 20 frames]` window — same shape as cog-pose-estimation.
|
||||
pub const INPUT_SUBCARRIERS: usize = 56;
|
||||
pub const INPUT_TIMESTEPS: usize = 20;
|
||||
/// Count classification over {0, 1, ..., 7} persons.
|
||||
pub const COUNT_CLASSES: usize = 8;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CsiWindow {
|
||||
pub data: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Per-node prediction emitted by the count head + confidence head.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CountPrediction {
|
||||
/// Categorical distribution over {0..7} persons. Sums to 1 within float
|
||||
/// precision. Maximum-likelihood class is `argmax(probs)`.
|
||||
pub probs: [f32; COUNT_CLASSES],
|
||||
/// `[0, 1]` — confidence head output. Calibrated against (predicted == truth)
|
||||
/// during training so consumers can use it as a probability of being right.
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl CountPrediction {
|
||||
pub fn is_finite(&self) -> bool {
|
||||
self.probs.iter().all(|v| v.is_finite()) && self.confidence.is_finite()
|
||||
}
|
||||
|
||||
/// Maximum-likelihood class.
|
||||
pub fn argmax(&self) -> usize {
|
||||
let mut best_i = 0;
|
||||
let mut best_v = self.probs[0];
|
||||
for (i, &v) in self.probs.iter().enumerate().skip(1) {
|
||||
if v > best_v {
|
||||
best_v = v;
|
||||
best_i = i;
|
||||
}
|
||||
}
|
||||
best_i
|
||||
}
|
||||
|
||||
/// `(low, high)` such that `Σ probs[low..=high] ≥ 0.95`. Used for the
|
||||
/// `count_p95_low` / `count_p95_high` fields surfaced to consumers.
|
||||
pub fn p95_range(&self) -> (usize, usize) {
|
||||
let mode = self.argmax();
|
||||
let mut lo = mode;
|
||||
let mut hi = mode;
|
||||
let mut acc = self.probs[mode];
|
||||
while acc < 0.95 && (lo > 0 || hi < COUNT_CLASSES - 1) {
|
||||
let left = if lo > 0 { self.probs[lo - 1] } else { -1.0 };
|
||||
let right = if hi < COUNT_CLASSES - 1 { self.probs[hi + 1] } else { -1.0 };
|
||||
if left >= right && lo > 0 {
|
||||
lo -= 1;
|
||||
acc += self.probs[lo];
|
||||
} else if hi < COUNT_CLASSES - 1 {
|
||||
hi += 1;
|
||||
acc += self.probs[hi];
|
||||
} else if lo > 0 {
|
||||
lo -= 1;
|
||||
acc += self.probs[lo];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(lo, hi)
|
||||
}
|
||||
}
|
||||
|
||||
struct CountNet {
|
||||
c1: Conv1d,
|
||||
c2: Conv1d,
|
||||
c3: Conv1d,
|
||||
count_fc1: Linear,
|
||||
count_fc2: Linear,
|
||||
conf_fc1: Linear,
|
||||
conf_fc2: Linear,
|
||||
}
|
||||
|
||||
impl CountNet {
|
||||
fn new(vb: VarBuilder<'_>) -> candle_core::Result<Self> {
|
||||
let enc = vb.pp("enc");
|
||||
let count = vb.pp("count_head");
|
||||
let conf = vb.pp("conf_head");
|
||||
|
||||
let c1 = candle_nn::conv1d(
|
||||
56, 64, 3,
|
||||
Conv1dConfig { padding: 1, stride: 1, dilation: 1, groups: 1, ..Default::default() },
|
||||
enc.pp("c1"),
|
||||
)?;
|
||||
let c2 = candle_nn::conv1d(
|
||||
64, 128, 3,
|
||||
Conv1dConfig { padding: 2, stride: 1, dilation: 2, groups: 1, ..Default::default() },
|
||||
enc.pp("c2"),
|
||||
)?;
|
||||
let c3 = candle_nn::conv1d(
|
||||
128, 128, 3,
|
||||
Conv1dConfig { padding: 4, stride: 1, dilation: 4, groups: 1, ..Default::default() },
|
||||
enc.pp("c3"),
|
||||
)?;
|
||||
let count_fc1 = candle_nn::linear(128, 64, count.pp("fc1"))?;
|
||||
let count_fc2 = candle_nn::linear(64, COUNT_CLASSES, count.pp("fc2"))?;
|
||||
let conf_fc1 = candle_nn::linear(128, 32, conf.pp("fc1"))?;
|
||||
let conf_fc2 = candle_nn::linear(32, 1, conf.pp("fc2"))?;
|
||||
Ok(Self { c1, c2, c3, count_fc1, count_fc2, conf_fc1, conf_fc2 })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let h = self.c1.forward(x)?.relu()?;
|
||||
let h = self.c2.forward(&h)?.relu()?;
|
||||
let h = self.c3.forward(&h)?.relu()?;
|
||||
let h = h.mean(2)?; // [B, 128]
|
||||
|
||||
// Count head — logits then softmax
|
||||
let c = self.count_fc1.forward(&h)?.relu()?;
|
||||
let c = self.count_fc2.forward(&c)?;
|
||||
let probs = candle_nn::ops::softmax(&c, candle_core::D::Minus1)?;
|
||||
|
||||
// Confidence head — sigmoid
|
||||
let cf = self.conf_fc1.forward(&h)?.relu()?;
|
||||
let cf = self.conf_fc2.forward(&cf)?;
|
||||
let conf = candle_nn::ops::sigmoid(&cf)?;
|
||||
|
||||
Ok((probs, conf))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InferenceEngine {
|
||||
inner: Option<Arc<CountNet>>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl InferenceEngine {
|
||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
Self::with_weights(default_weights_path().as_deref())
|
||||
}
|
||||
|
||||
pub fn with_weights(weights_path: Option<&Path>) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let device = pick_device();
|
||||
let inner = match weights_path {
|
||||
Some(p) if p.exists() => {
|
||||
// SAFETY: from_mmaped_safetensors mmaps the file for the
|
||||
// VarBuilder's lifetime. Same pattern as cog-pose-estimation.
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)?
|
||||
};
|
||||
let net = CountNet::new(vb)?;
|
||||
Some(Arc::new(net))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
Ok(Self { inner, device })
|
||||
}
|
||||
|
||||
pub fn backend(&self) -> &'static str {
|
||||
match (&self.inner, &self.device) {
|
||||
(Some(_), Device::Cuda(_)) => "candle-cuda",
|
||||
(Some(_), _) => "candle-cpu",
|
||||
(None, _) => "stub",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn infer(&self, window: &CsiWindow) -> Result<CountPrediction, Box<dyn std::error::Error>> {
|
||||
if window.data.len() != INPUT_SUBCARRIERS * INPUT_TIMESTEPS {
|
||||
return Err(format!(
|
||||
"expected {} input values, got {}",
|
||||
INPUT_SUBCARRIERS * INPUT_TIMESTEPS,
|
||||
window.data.len()
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let Some(net) = &self.inner else {
|
||||
// Stub fallback: single-person, zero confidence. Surfaces "no
|
||||
// model yet" honestly instead of pretending to know.
|
||||
let mut probs = [0.0f32; COUNT_CLASSES];
|
||||
probs[1] = 1.0; // mass on "1 person"
|
||||
return Ok(CountPrediction { probs, confidence: 0.0 });
|
||||
};
|
||||
|
||||
let t = Tensor::from_slice(
|
||||
&window.data,
|
||||
(1, INPUT_SUBCARRIERS, INPUT_TIMESTEPS),
|
||||
&self.device,
|
||||
)?;
|
||||
let (probs_t, conf_t) = net.forward(&t)?;
|
||||
let flat: Vec<f32> = probs_t.flatten_all()?.to_vec1()?;
|
||||
if flat.len() != COUNT_CLASSES {
|
||||
return Err(format!("count head produced {} probs, expected {}", flat.len(), COUNT_CLASSES).into());
|
||||
}
|
||||
let mut probs = [0.0f32; COUNT_CLASSES];
|
||||
probs.copy_from_slice(&flat[..COUNT_CLASSES]);
|
||||
let conf = conf_t.flatten_all()?.to_vec1::<f32>()?[0];
|
||||
|
||||
Ok(CountPrediction { probs, confidence: conf })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SyntheticInput;
|
||||
|
||||
impl Default for SyntheticInput {
|
||||
fn default() -> Self { Self }
|
||||
}
|
||||
|
||||
impl SyntheticInput {
|
||||
pub fn as_window(&self) -> CsiWindow {
|
||||
CsiWindow { data: vec![0.0; INPUT_SUBCARRIERS * INPUT_TIMESTEPS] }
|
||||
}
|
||||
}
|
||||
|
||||
fn pick_device() -> Device {
|
||||
#[cfg(feature = "cuda")]
|
||||
if let Ok(d) = Device::cuda_if_available(0) {
|
||||
return d;
|
||||
}
|
||||
Device::Cpu
|
||||
}
|
||||
|
||||
fn default_weights_path() -> Option<std::path::PathBuf> {
|
||||
let candidates = [
|
||||
std::path::PathBuf::from("/var/lib/cognitum/apps/person-count/count_v1.safetensors"),
|
||||
std::path::PathBuf::from("./count_v1.safetensors"),
|
||||
std::path::PathBuf::from("./cog/artifacts/count_v1.safetensors"),
|
||||
std::path::PathBuf::from("v2/crates/cog-person-count/cog/artifacts/count_v1.safetensors"),
|
||||
std::path::PathBuf::from("crates/cog-person-count/cog/artifacts/count_v1.safetensors"),
|
||||
];
|
||||
candidates.into_iter().find(|p| p.exists())
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
//! `cog-person-count` — learned multi-person counter (ADR-103).
|
||||
//!
|
||||
//! Replaces the PR #491 slot heuristic with:
|
||||
//! * a small Candle network (encoder + count head + confidence head),
|
||||
//! * Stoer-Wagner-bounded multi-node fusion,
|
||||
//! * `{count, confidence, count_p95_low, count_p95_high}` output.
|
||||
//!
|
||||
//! Design lives in `docs/adr/ADR-103-learned-multi-person-counter.md`.
|
||||
|
||||
pub mod fusion;
|
||||
pub mod inference;
|
||||
pub mod publisher;
|
||||
|
||||
pub const COG_ID: &str = "person-count";
|
||||
pub const COG_VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
//! `cog-person-count` — Cognitum Cog binary entrypoint.
|
||||
//!
|
||||
//! Implements the ADR-100 runtime contract:
|
||||
//! cog-person-count version
|
||||
//! cog-person-count manifest
|
||||
//! cog-person-count health
|
||||
//! cog-person-count run --config <path>
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
use cog_person_count::{
|
||||
inference::{InferenceEngine, SyntheticInput},
|
||||
publisher,
|
||||
COG_ID, COG_VERSION,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "cog-person-count", version = COG_VERSION)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Cmd,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Cmd {
|
||||
Version,
|
||||
Manifest,
|
||||
Health,
|
||||
Run {
|
||||
#[arg(long, value_name = "PATH")]
|
||||
config: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct RunConfig {
|
||||
#[serde(default = "default_sensing_url")]
|
||||
sensing_url: String,
|
||||
model_path: Option<PathBuf>,
|
||||
#[serde(default = "default_poll_ms")]
|
||||
poll_ms: u64,
|
||||
}
|
||||
|
||||
fn default_sensing_url() -> String { "http://127.0.0.1:3000/api/v1/sensing/latest".to_string() }
|
||||
fn default_poll_ms() -> u64 { 40 }
|
||||
|
||||
fn main() -> std::process::ExitCode {
|
||||
init_logging();
|
||||
let cli = Cli::parse();
|
||||
let result = match cli.command {
|
||||
Cmd::Version => cmd_version(),
|
||||
Cmd::Manifest => cmd_manifest(),
|
||||
Cmd::Health => cmd_health(),
|
||||
Cmd::Run { config } => cmd_run(config),
|
||||
};
|
||||
match result {
|
||||
Ok(()) => std::process::ExitCode::SUCCESS,
|
||||
Err(err) => {
|
||||
eprintln!("cog-person-count: {err}");
|
||||
std::process::ExitCode::FAILURE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn init_logging() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"))
|
||||
)
|
||||
.with_target(false)
|
||||
.try_init();
|
||||
}
|
||||
|
||||
fn cmd_version() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("{COG_ID} {COG_VERSION}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_manifest() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("{}", serde_json::to_string_pretty(&json!({
|
||||
"id": COG_ID,
|
||||
"version": COG_VERSION,
|
||||
"binary_url": Value::Null,
|
||||
"binary_bytes": Value::Null,
|
||||
"binary_sha256": Value::Null,
|
||||
"binary_signature": Value::Null,
|
||||
"installed_at": Value::Null,
|
||||
"status": Value::Null,
|
||||
}))?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_health() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let engine = InferenceEngine::new()?;
|
||||
let pred = engine.infer(&SyntheticInput::default().as_window())?;
|
||||
if !pred.is_finite() {
|
||||
return Err("inference produced non-finite output".into());
|
||||
}
|
||||
publisher::health_ok(COG_ID, engine.backend(), &pred);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_run(_config_path: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Long-running mode is wired in the v0.0.1 release follow-up — same
|
||||
// approach as cog-pose-estimation's runtime.rs. For now, the cog
|
||||
// satisfies the four-verb contract; downstream consumers integrate
|
||||
// via the in-process `InferenceEngine` API.
|
||||
Err("`run` subcommand wiring is pending v0.0.1 — for now consume via the InferenceEngine library API".into())
|
||||
}
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
//! Structured JSON event publisher — one event per line on stdout.
|
||||
|
||||
use crate::inference::CountPrediction;
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Event<'a> {
|
||||
pub ts: f64,
|
||||
pub level: &'a str,
|
||||
pub event: &'a str,
|
||||
pub fields: Value,
|
||||
}
|
||||
|
||||
pub fn emit_event(ev: &Event<'_>) {
|
||||
if let Ok(line) = serde_json::to_string(ev) {
|
||||
println!("{line}");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn health_ok(cog_id: &str, backend: &str, p: &CountPrediction) {
|
||||
let (lo, hi) = p.p95_range();
|
||||
emit_event(&Event {
|
||||
ts: now_secs(),
|
||||
level: "info",
|
||||
event: "health.ok",
|
||||
fields: json!({
|
||||
"cog": cog_id,
|
||||
"backend": backend,
|
||||
"synthetic_count": p.argmax(),
|
||||
"synthetic_confidence": p.confidence,
|
||||
"synthetic_p95_range": [lo, hi],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
pub fn run_started(cog_id: &str, sensing_url: &str, poll_ms: u64, model_path: &str) {
|
||||
emit_event(&Event {
|
||||
ts: now_secs(),
|
||||
level: "info",
|
||||
event: "run.started",
|
||||
fields: json!({
|
||||
"cog": cog_id,
|
||||
"sensing_url": sensing_url,
|
||||
"poll_ms": poll_ms,
|
||||
"model_path": model_path,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
pub fn person_count(tick: u64, fused: &CountPrediction, n_nodes: usize) {
|
||||
let (lo, hi) = fused.p95_range();
|
||||
emit_event(&Event {
|
||||
ts: now_secs(),
|
||||
level: "info",
|
||||
event: "person.count",
|
||||
fields: json!({
|
||||
"tick": tick,
|
||||
"count": fused.argmax(),
|
||||
"confidence": fused.confidence,
|
||||
"count_p95_low": lo,
|
||||
"count_p95_high": hi,
|
||||
"n_nodes": n_nodes,
|
||||
"probs": fused.probs,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
fn now_secs() -> f64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_secs_f64())
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
//! Smoke tests for cog-person-count.
|
||||
|
||||
use cog_person_count::{
|
||||
fusion::{fuse_confidence_weighted, fuse_with_mincut_clip},
|
||||
inference::{
|
||||
CountPrediction, CsiWindow, InferenceEngine, SyntheticInput,
|
||||
COUNT_CLASSES, INPUT_SUBCARRIERS, INPUT_TIMESTEPS,
|
||||
},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn synthetic_window_has_correct_shape() {
|
||||
let w = SyntheticInput::default().as_window();
|
||||
assert_eq!(w.data.len(), INPUT_SUBCARRIERS * INPUT_TIMESTEPS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stub_engine_returns_finite_output() {
|
||||
let engine = InferenceEngine::with_weights(None).expect("stub engine");
|
||||
let pred = engine.infer(&SyntheticInput::default().as_window()).expect("infer");
|
||||
assert!(pred.is_finite());
|
||||
assert_eq!(pred.probs.len(), COUNT_CLASSES);
|
||||
|
||||
let sum: f32 = pred.probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5, "stub probs must sum to 1, got {}", sum);
|
||||
assert_eq!(pred.argmax(), 1, "stub default is 1-person");
|
||||
assert_eq!(pred.confidence, 0.0, "stub confidence is 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn engine_rejects_wrong_shape_input() {
|
||||
let engine = InferenceEngine::with_weights(None).expect("stub engine");
|
||||
let bad = CsiWindow { data: vec![0.0; 10] };
|
||||
assert!(engine.infer(&bad).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stub_backend_string_is_stable() {
|
||||
let engine = InferenceEngine::with_weights(None).expect("stub engine");
|
||||
assert_eq!(engine.backend(), "stub");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn p95_range_includes_mode() {
|
||||
// Sharp peak at 2
|
||||
let mut probs = [0.0_f32; COUNT_CLASSES];
|
||||
probs[2] = 0.85;
|
||||
probs[1] = 0.08;
|
||||
probs[3] = 0.07;
|
||||
let p = CountPrediction { probs, confidence: 0.9 };
|
||||
let (lo, hi) = p.p95_range();
|
||||
assert!(lo <= 2 && hi >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fusion_with_no_inputs_is_safe_default() {
|
||||
let p = fuse_confidence_weighted(&[]);
|
||||
assert_eq!(p.argmax(), 1);
|
||||
assert_eq!(p.confidence, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fusion_passes_through_single_node() {
|
||||
// A single-node ESP32 deployment must produce the same output as the
|
||||
// raw inference — fusion is a no-op for N=1.
|
||||
let mut probs = [0.0_f32; COUNT_CLASSES];
|
||||
probs[3] = 1.0;
|
||||
let input = CountPrediction { probs, confidence: 0.6 };
|
||||
let out = fuse_confidence_weighted(&[input.clone()]);
|
||||
assert_eq!(out.argmax(), 3);
|
||||
assert!((out.confidence - 0.6).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mincut_clip_with_high_cap_is_noop() {
|
||||
let mut probs = [0.0_f32; COUNT_CLASSES];
|
||||
probs[2] = 0.5;
|
||||
probs[3] = 0.5;
|
||||
let input = CountPrediction { probs, confidence: 0.7 };
|
||||
let clipped = fuse_with_mincut_clip(&[input], 7);
|
||||
// No clip happened (cap == max class)
|
||||
assert!((clipped.probs[2] - 0.5).abs() < 1e-6);
|
||||
assert!((clipped.probs[3] - 0.5).abs() < 1e-6);
|
||||
}
|
||||
Loading…
Reference in New Issue