diff --git a/v2/crates/cog-person-count/Cargo.toml b/v2/crates/cog-person-count/Cargo.toml index 64526f60..2b3a65ea 100644 --- a/v2/crates/cog-person-count/Cargo.toml +++ b/v2/crates/cog-person-count/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true authors.workspace = true license.workspace = true repository.workspace = true -description = "Cognitum Cog: learned multi-person counter from WiFi CSI (ADR-103). Replaces the PR #491 slot heuristic with a Candle-based count head + Stoer-Wagner multi-node fusion." +description = "Cognitum Cog: WiFi-CSI presence detector + (data-gated) person count (ADR-103). Candle-based head trained on classes 0/1 (presence); the 8-class count head ships but counts above the trained range are flagged low_confidence. Stoer-Wagner multi-node fusion." [[bin]] name = "cog-person-count" diff --git a/v2/crates/cog-person-count/src/inference.rs b/v2/crates/cog-person-count/src/inference.rs index 96f82e89..fc810398 100644 --- a/v2/crates/cog-person-count/src/inference.rs +++ b/v2/crates/cog-person-count/src/inference.rs @@ -24,6 +24,17 @@ pub const INPUT_TIMESTEPS: usize = 20; /// Count classification over {0, 1, ..., 7} persons. pub const COUNT_CLASSES: usize = 8; +/// Highest class the shipped `count_v1` weights were actually **trained** on. +/// +/// The count head has 8 logits, but `count_train_results.json` only has support +/// for classes 0 and 1 (`per_class_accuracy` keys are `"0"` and `"1"`). The model +/// is a presence detector (0 vs ≥1 person), **not** a calibrated multi-occupant +/// counter. An argmax landing on classes 2..=7 is out-of-distribution: the logits +/// there were never supervised against labelled data. We flag such outputs +/// `low_confidence` so downstream consumers don't trust a fabricated headcount. +/// (Multi-occupant *accuracy* is DATA-GATED — not fabricated here.) +pub const MAX_TRAINED_CLASS: usize = 1; + #[derive(Debug, Clone)] pub struct CsiWindow { pub data: Vec, @@ -45,6 +56,23 @@ impl CountPrediction { self.probs.iter().all(|v| v.is_finite()) && self.confidence.is_finite() } + /// True when the maximum-likelihood class is beyond what the shipped weights + /// were trained on ([`MAX_TRAINED_CLASS`]). Such a prediction is out-of- + /// distribution — the count head's logits for classes 2..=7 were never + /// supervised, so the headcount is not trustworthy. Surfaced as the + /// `low_confidence` field on the `person.count` event (honest-clip pattern). + pub fn is_low_confidence(&self) -> bool { + self.argmax() > MAX_TRAINED_CLASS + } + + /// Argmax clamped to [`MAX_TRAINED_CLASS`]. When the raw argmax is an + /// untrained class we clamp the *reported* count to the highest trained + /// class rather than emit a fabricated multi-occupant headcount. The raw + /// distribution is still available in `probs` for diagnostics. + pub fn clamped_count(&self) -> usize { + self.argmax().min(MAX_TRAINED_CLASS) + } + /// Maximum-likelihood class. pub fn argmax(&self) -> usize { let mut best_i = 0; diff --git a/v2/crates/cog-person-count/src/publisher.rs b/v2/crates/cog-person-count/src/publisher.rs index 2287a775..677788af 100644 --- a/v2/crates/cog-person-count/src/publisher.rs +++ b/v2/crates/cog-person-count/src/publisher.rs @@ -45,20 +45,35 @@ pub fn run_started(cog_id: &str, sensing_url: &str, poll_ms: u64, model_path: &s "sensing_url": sensing_url, "poll_ms": poll_ms, "model_path": model_path, + // Honest disclosure: the count head has 8 classes but the shipped + // weights were only trained on classes 0..=MAX_TRAINED_CLASS + // (presence, not multi-occupant counting). Counts above this are + // flagged `low_confidence` on each person.count event. + "count_max_trained_class": crate::inference::MAX_TRAINED_CLASS, + "count_classes": crate::inference::COUNT_CLASSES, }), }); } pub fn person_count(tick: u64, fused: &CountPrediction, n_nodes: usize) { let (lo, hi) = fused.p95_range(); + let low_confidence = fused.is_low_confidence(); emit_event(&Event { ts: now_secs(), - level: "info", + // An out-of-distribution count (argmax beyond the trained classes) is + // a warning, not a clean info reading. + level: if low_confidence { "warn" } else { "info" }, event: "person.count", fields: json!({ "tick": tick, - "count": fused.argmax(), + // Reported count is clamped to the trained range — we never emit a + // fabricated multi-occupant headcount the weights can't back. + "count": fused.clamped_count(), + // Raw argmax kept for diagnostics/audit. + "raw_count": fused.argmax(), "confidence": fused.confidence, + // True when argmax > MAX_TRAINED_CLASS (untrained class). + "low_confidence": low_confidence, "count_p95_low": lo, "count_p95_high": hi, "n_nodes": n_nodes, diff --git a/v2/crates/cog-person-count/tests/smoke.rs b/v2/crates/cog-person-count/tests/smoke.rs index 433c7155..2447cca1 100644 --- a/v2/crates/cog-person-count/tests/smoke.rs +++ b/v2/crates/cog-person-count/tests/smoke.rs @@ -4,7 +4,7 @@ use cog_person_count::{ fusion::{fuse_confidence_weighted, fuse_with_mincut_clip}, inference::{ CountPrediction, CsiWindow, InferenceEngine, SyntheticInput, COUNT_CLASSES, - INPUT_SUBCARRIERS, INPUT_TIMESTEPS, + INPUT_SUBCARRIERS, INPUT_TIMESTEPS, MAX_TRAINED_CLASS, }, }; @@ -83,6 +83,51 @@ fn fusion_passes_through_single_node() { assert!((out.confidence - 0.6).abs() < 1e-6); } +/// ADR-159 §A2 — the 8-class count head ships, but the weights were only +/// trained on classes 0/1 (presence). A prediction whose argmax lands on an +/// UNTRAINED class (2..=7) must be flagged `low_confidence` and the reported +/// count clamped to the trained range, so we never emit a fabricated +/// multi-occupant headcount. Fails on old code (no such flag/clamp existed). +#[test] +fn untrained_class_argmax_is_flagged_low_confidence() { + // Sanity: the trained ceiling is below the head width. + assert!(MAX_TRAINED_CLASS < COUNT_CLASSES - 1); + + // Mass on an untrained class (5 persons) — out-of-distribution. + let mut probs = [0.0_f32; COUNT_CLASSES]; + probs[5] = 0.9; + probs[1] = 0.1; + let oodp = CountPrediction { + probs, + confidence: 0.95, // even a "confident" softmax must be flagged + }; + assert_eq!(oodp.argmax(), 5); + assert!( + oodp.is_low_confidence(), + "argmax beyond MAX_TRAINED_CLASS must be flagged low_confidence" + ); + assert_eq!( + oodp.clamped_count(), + MAX_TRAINED_CLASS, + "reported count must clamp to the trained ceiling, not fabricate a headcount" + ); + + // A trained-range prediction (1 person) is NOT flagged. + let mut probs2 = [0.0_f32; COUNT_CLASSES]; + probs2[1] = 0.8; + probs2[0] = 0.2; + let inp = CountPrediction { + probs: probs2, + confidence: 0.8, + }; + assert_eq!(inp.argmax(), 1); + assert!( + !inp.is_low_confidence(), + "a trained-range count must not be flagged" + ); + assert_eq!(inp.clamped_count(), 1); +} + #[test] fn mincut_clip_with_high_cap_is_noop() { let mut probs = [0.0_f32; COUNT_CLASSES];