fix(cog-person-count): flag untrained-class counts low_confidence (ADR-159 A2)
The count head has 8 classes but count_train_results.json only has support for classes 0/1 (presence, not multi-occupant counting). An argmax on classes 2..=7 is out-of-distribution, yet the cog emitted it as a confident headcount and the crate billed itself a "multi-person counter". - Add MAX_TRAINED_CLASS=1, CountPrediction::is_low_confidence() and clamped_count(). - person.count events now carry low_confidence + raw_count, downgrade to level "warn" when OOD, and clamp the reported count to the trained range (no fabricated headcount). - run.started discloses count_max_trained_class / count_classes. - Cargo.toml description: "multi-person counter" -> "presence detector + (data-gated) person count". Multi-occupant accuracy stays DATA-GATED (not fabricated). Failing-on-old test: untrained_class_argmax_is_flagged_low_confidence. Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
98bf8c4726
commit
2400216920
|
|
@ -5,7 +5,7 @@ edition.workspace = true
|
|||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
description = "Cognitum Cog: learned multi-person counter from WiFi CSI (ADR-103). Replaces the PR #491 slot heuristic with a Candle-based count head + Stoer-Wagner multi-node fusion."
|
||||
description = "Cognitum Cog: WiFi-CSI presence detector + (data-gated) person count (ADR-103). Candle-based head trained on classes 0/1 (presence); the 8-class count head ships but counts above the trained range are flagged low_confidence. Stoer-Wagner multi-node fusion."
|
||||
|
||||
[[bin]]
|
||||
name = "cog-person-count"
|
||||
|
|
|
|||
|
|
@ -24,6 +24,17 @@ pub const INPUT_TIMESTEPS: usize = 20;
|
|||
/// Count classification over {0, 1, ..., 7} persons.
|
||||
pub const COUNT_CLASSES: usize = 8;
|
||||
|
||||
/// Highest class the shipped `count_v1` weights were actually **trained** on.
|
||||
///
|
||||
/// The count head has 8 logits, but `count_train_results.json` only has support
|
||||
/// for classes 0 and 1 (`per_class_accuracy` keys are `"0"` and `"1"`). The model
|
||||
/// is a presence detector (0 vs ≥1 person), **not** a calibrated multi-occupant
|
||||
/// counter. An argmax landing on classes 2..=7 is out-of-distribution: the logits
|
||||
/// there were never supervised against labelled data. We flag such outputs
|
||||
/// `low_confidence` so downstream consumers don't trust a fabricated headcount.
|
||||
/// (Multi-occupant *accuracy* is DATA-GATED — not fabricated here.)
|
||||
pub const MAX_TRAINED_CLASS: usize = 1;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CsiWindow {
|
||||
pub data: Vec<f32>,
|
||||
|
|
@ -45,6 +56,23 @@ impl CountPrediction {
|
|||
self.probs.iter().all(|v| v.is_finite()) && self.confidence.is_finite()
|
||||
}
|
||||
|
||||
/// True when the maximum-likelihood class is beyond what the shipped weights
|
||||
/// were trained on ([`MAX_TRAINED_CLASS`]). Such a prediction is out-of-
|
||||
/// distribution — the count head's logits for classes 2..=7 were never
|
||||
/// supervised, so the headcount is not trustworthy. Surfaced as the
|
||||
/// `low_confidence` field on the `person.count` event (honest-clip pattern).
|
||||
pub fn is_low_confidence(&self) -> bool {
|
||||
self.argmax() > MAX_TRAINED_CLASS
|
||||
}
|
||||
|
||||
/// Argmax clamped to [`MAX_TRAINED_CLASS`]. When the raw argmax is an
|
||||
/// untrained class we clamp the *reported* count to the highest trained
|
||||
/// class rather than emit a fabricated multi-occupant headcount. The raw
|
||||
/// distribution is still available in `probs` for diagnostics.
|
||||
pub fn clamped_count(&self) -> usize {
|
||||
self.argmax().min(MAX_TRAINED_CLASS)
|
||||
}
|
||||
|
||||
/// Maximum-likelihood class.
|
||||
pub fn argmax(&self) -> usize {
|
||||
let mut best_i = 0;
|
||||
|
|
|
|||
|
|
@ -45,20 +45,35 @@ pub fn run_started(cog_id: &str, sensing_url: &str, poll_ms: u64, model_path: &s
|
|||
"sensing_url": sensing_url,
|
||||
"poll_ms": poll_ms,
|
||||
"model_path": model_path,
|
||||
// Honest disclosure: the count head has 8 classes but the shipped
|
||||
// weights were only trained on classes 0..=MAX_TRAINED_CLASS
|
||||
// (presence, not multi-occupant counting). Counts above this are
|
||||
// flagged `low_confidence` on each person.count event.
|
||||
"count_max_trained_class": crate::inference::MAX_TRAINED_CLASS,
|
||||
"count_classes": crate::inference::COUNT_CLASSES,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
pub fn person_count(tick: u64, fused: &CountPrediction, n_nodes: usize) {
|
||||
let (lo, hi) = fused.p95_range();
|
||||
let low_confidence = fused.is_low_confidence();
|
||||
emit_event(&Event {
|
||||
ts: now_secs(),
|
||||
level: "info",
|
||||
// An out-of-distribution count (argmax beyond the trained classes) is
|
||||
// a warning, not a clean info reading.
|
||||
level: if low_confidence { "warn" } else { "info" },
|
||||
event: "person.count",
|
||||
fields: json!({
|
||||
"tick": tick,
|
||||
"count": fused.argmax(),
|
||||
// Reported count is clamped to the trained range — we never emit a
|
||||
// fabricated multi-occupant headcount the weights can't back.
|
||||
"count": fused.clamped_count(),
|
||||
// Raw argmax kept for diagnostics/audit.
|
||||
"raw_count": fused.argmax(),
|
||||
"confidence": fused.confidence,
|
||||
// True when argmax > MAX_TRAINED_CLASS (untrained class).
|
||||
"low_confidence": low_confidence,
|
||||
"count_p95_low": lo,
|
||||
"count_p95_high": hi,
|
||||
"n_nodes": n_nodes,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use cog_person_count::{
|
|||
fusion::{fuse_confidence_weighted, fuse_with_mincut_clip},
|
||||
inference::{
|
||||
CountPrediction, CsiWindow, InferenceEngine, SyntheticInput, COUNT_CLASSES,
|
||||
INPUT_SUBCARRIERS, INPUT_TIMESTEPS,
|
||||
INPUT_SUBCARRIERS, INPUT_TIMESTEPS, MAX_TRAINED_CLASS,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -83,6 +83,51 @@ fn fusion_passes_through_single_node() {
|
|||
assert!((out.confidence - 0.6).abs() < 1e-6);
|
||||
}
|
||||
|
||||
/// ADR-159 §A2 — the 8-class count head ships, but the weights were only
|
||||
/// trained on classes 0/1 (presence). A prediction whose argmax lands on an
|
||||
/// UNTRAINED class (2..=7) must be flagged `low_confidence` and the reported
|
||||
/// count clamped to the trained range, so we never emit a fabricated
|
||||
/// multi-occupant headcount. Fails on old code (no such flag/clamp existed).
|
||||
#[test]
|
||||
fn untrained_class_argmax_is_flagged_low_confidence() {
|
||||
// Sanity: the trained ceiling is below the head width.
|
||||
assert!(MAX_TRAINED_CLASS < COUNT_CLASSES - 1);
|
||||
|
||||
// Mass on an untrained class (5 persons) — out-of-distribution.
|
||||
let mut probs = [0.0_f32; COUNT_CLASSES];
|
||||
probs[5] = 0.9;
|
||||
probs[1] = 0.1;
|
||||
let oodp = CountPrediction {
|
||||
probs,
|
||||
confidence: 0.95, // even a "confident" softmax must be flagged
|
||||
};
|
||||
assert_eq!(oodp.argmax(), 5);
|
||||
assert!(
|
||||
oodp.is_low_confidence(),
|
||||
"argmax beyond MAX_TRAINED_CLASS must be flagged low_confidence"
|
||||
);
|
||||
assert_eq!(
|
||||
oodp.clamped_count(),
|
||||
MAX_TRAINED_CLASS,
|
||||
"reported count must clamp to the trained ceiling, not fabricate a headcount"
|
||||
);
|
||||
|
||||
// A trained-range prediction (1 person) is NOT flagged.
|
||||
let mut probs2 = [0.0_f32; COUNT_CLASSES];
|
||||
probs2[1] = 0.8;
|
||||
probs2[0] = 0.2;
|
||||
let inp = CountPrediction {
|
||||
probs: probs2,
|
||||
confidence: 0.8,
|
||||
};
|
||||
assert_eq!(inp.argmax(), 1);
|
||||
assert!(
|
||||
!inp.is_low_confidence(),
|
||||
"a trained-range count must not be flagged"
|
||||
);
|
||||
assert_eq!(inp.clamped_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mincut_clip_with_high_cap_is_noop() {
|
||||
let mut probs = [0.0_f32; COUNT_CLASSES];
|
||||
|
|
|
|||
Loading…
Reference in New Issue