fix(cog-person-count): flag untrained-class counts low_confidence (ADR-159 A2)

The count head has 8 classes but count_train_results.json only has
support for classes 0/1 (presence, not multi-occupant counting). An
argmax on classes 2..=7 is out-of-distribution, yet the cog emitted it
as a confident headcount and the crate billed itself a "multi-person
counter".

- Add MAX_TRAINED_CLASS=1, CountPrediction::is_low_confidence() and
  clamped_count().
- person.count events now carry low_confidence + raw_count, downgrade to
  level "warn" when OOD, and clamp the reported count to the trained
  range (no fabricated headcount).
- run.started discloses count_max_trained_class / count_classes.
- Cargo.toml description: "multi-person counter" ->
  "presence detector + (data-gated) person count".

Multi-occupant accuracy stays DATA-GATED (not fabricated).

Failing-on-old test: untrained_class_argmax_is_flagged_low_confidence.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-06-11 23:10:01 -04:00
parent 98bf8c4726
commit 2400216920
4 changed files with 92 additions and 4 deletions

View File

@ -5,7 +5,7 @@ edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
description = "Cognitum Cog: learned multi-person counter from WiFi CSI (ADR-103). Replaces the PR #491 slot heuristic with a Candle-based count head + Stoer-Wagner multi-node fusion."
description = "Cognitum Cog: WiFi-CSI presence detector + (data-gated) person count (ADR-103). Candle-based head trained on classes 0/1 (presence); the 8-class count head ships but counts above the trained range are flagged low_confidence. Stoer-Wagner multi-node fusion."
[[bin]]
name = "cog-person-count"

View File

@ -24,6 +24,17 @@ pub const INPUT_TIMESTEPS: usize = 20;
/// Count classification over {0, 1, ..., 7} persons.
pub const COUNT_CLASSES: usize = 8;
/// Highest class the shipped `count_v1` weights were actually **trained** on.
///
/// The count head has 8 logits, but `count_train_results.json` only has support
/// for classes 0 and 1 (`per_class_accuracy` keys are `"0"` and `"1"`). The model
/// is a presence detector (0 vs ≥1 person), **not** a calibrated multi-occupant
/// counter. An argmax landing on classes 2..=7 is out-of-distribution: the logits
/// there were never supervised against labelled data. We flag such outputs
/// `low_confidence` so downstream consumers don't trust a fabricated headcount.
/// (Multi-occupant *accuracy* is DATA-GATED — not fabricated here.)
pub const MAX_TRAINED_CLASS: usize = 1;
#[derive(Debug, Clone)]
pub struct CsiWindow {
pub data: Vec<f32>,
@ -45,6 +56,23 @@ impl CountPrediction {
self.probs.iter().all(|v| v.is_finite()) && self.confidence.is_finite()
}
/// True when the maximum-likelihood class is beyond what the shipped weights
/// were trained on ([`MAX_TRAINED_CLASS`]). Such a prediction is out-of-
/// distribution — the count head's logits for classes 2..=7 were never
/// supervised, so the headcount is not trustworthy. Surfaced as the
/// `low_confidence` field on the `person.count` event (honest-clip pattern).
pub fn is_low_confidence(&self) -> bool {
self.argmax() > MAX_TRAINED_CLASS
}
/// Argmax clamped to [`MAX_TRAINED_CLASS`]. When the raw argmax is an
/// untrained class we clamp the *reported* count to the highest trained
/// class rather than emit a fabricated multi-occupant headcount. The raw
/// distribution is still available in `probs` for diagnostics.
pub fn clamped_count(&self) -> usize {
self.argmax().min(MAX_TRAINED_CLASS)
}
/// Maximum-likelihood class.
pub fn argmax(&self) -> usize {
let mut best_i = 0;

View File

@ -45,20 +45,35 @@ pub fn run_started(cog_id: &str, sensing_url: &str, poll_ms: u64, model_path: &s
"sensing_url": sensing_url,
"poll_ms": poll_ms,
"model_path": model_path,
// Honest disclosure: the count head has 8 classes but the shipped
// weights were only trained on classes 0..=MAX_TRAINED_CLASS
// (presence, not multi-occupant counting). Counts above this are
// flagged `low_confidence` on each person.count event.
"count_max_trained_class": crate::inference::MAX_TRAINED_CLASS,
"count_classes": crate::inference::COUNT_CLASSES,
}),
});
}
pub fn person_count(tick: u64, fused: &CountPrediction, n_nodes: usize) {
let (lo, hi) = fused.p95_range();
let low_confidence = fused.is_low_confidence();
emit_event(&Event {
ts: now_secs(),
level: "info",
// An out-of-distribution count (argmax beyond the trained classes) is
// a warning, not a clean info reading.
level: if low_confidence { "warn" } else { "info" },
event: "person.count",
fields: json!({
"tick": tick,
"count": fused.argmax(),
// Reported count is clamped to the trained range — we never emit a
// fabricated multi-occupant headcount the weights can't back.
"count": fused.clamped_count(),
// Raw argmax kept for diagnostics/audit.
"raw_count": fused.argmax(),
"confidence": fused.confidence,
// True when argmax > MAX_TRAINED_CLASS (untrained class).
"low_confidence": low_confidence,
"count_p95_low": lo,
"count_p95_high": hi,
"n_nodes": n_nodes,

View File

@ -4,7 +4,7 @@ use cog_person_count::{
fusion::{fuse_confidence_weighted, fuse_with_mincut_clip},
inference::{
CountPrediction, CsiWindow, InferenceEngine, SyntheticInput, COUNT_CLASSES,
INPUT_SUBCARRIERS, INPUT_TIMESTEPS,
INPUT_SUBCARRIERS, INPUT_TIMESTEPS, MAX_TRAINED_CLASS,
},
};
@ -83,6 +83,51 @@ fn fusion_passes_through_single_node() {
assert!((out.confidence - 0.6).abs() < 1e-6);
}
/// ADR-159 §A2 — the 8-class count head ships, but the weights were only
/// trained on classes 0/1 (presence). A prediction whose argmax lands on an
/// UNTRAINED class (2..=7) must be flagged `low_confidence` and the reported
/// count clamped to the trained range, so we never emit a fabricated
/// multi-occupant headcount. Fails on old code (no such flag/clamp existed).
#[test]
fn untrained_class_argmax_is_flagged_low_confidence() {
// Sanity: the trained ceiling is below the head width.
assert!(MAX_TRAINED_CLASS < COUNT_CLASSES - 1);
// Mass on an untrained class (5 persons) — out-of-distribution.
let mut probs = [0.0_f32; COUNT_CLASSES];
probs[5] = 0.9;
probs[1] = 0.1;
let oodp = CountPrediction {
probs,
confidence: 0.95, // even a "confident" softmax must be flagged
};
assert_eq!(oodp.argmax(), 5);
assert!(
oodp.is_low_confidence(),
"argmax beyond MAX_TRAINED_CLASS must be flagged low_confidence"
);
assert_eq!(
oodp.clamped_count(),
MAX_TRAINED_CLASS,
"reported count must clamp to the trained ceiling, not fabricate a headcount"
);
// A trained-range prediction (1 person) is NOT flagged.
let mut probs2 = [0.0_f32; COUNT_CLASSES];
probs2[1] = 0.8;
probs2[0] = 0.2;
let inp = CountPrediction {
probs: probs2,
confidence: 0.8,
};
assert_eq!(inp.argmax(), 1);
assert!(
!inp.is_low_confidence(),
"a trained-range count must not be flagged"
);
assert_eq!(inp.clamped_count(), 1);
}
#[test]
fn mincut_clip_with_high_cap_is_noop() {
let mut probs = [0.0_f32; COUNT_CLASSES];