diff --git a/v2/crates/wifi-densepose-nn/src/lib.rs b/v2/crates/wifi-densepose-nn/src/lib.rs index 44264f47..87a97290 100644 --- a/v2/crates/wifi-densepose-nn/src/lib.rs +++ b/v2/crates/wifi-densepose-nn/src/lib.rs @@ -35,6 +35,8 @@ pub mod error; pub mod inference; #[cfg(feature = "onnx")] pub mod onnx; +/// ADR-146 — RF encoder multi-task heads + uncertainty + contrastive batcher. +pub mod rf_encoder; pub mod tensor; pub mod translator; diff --git a/v2/crates/wifi-densepose-nn/src/rf_encoder.rs b/v2/crates/wifi-densepose-nn/src/rf_encoder.rs new file mode 100644 index 00000000..1ff8340d --- /dev/null +++ b/v2/crates/wifi-densepose-nn/src/rf_encoder.rs @@ -0,0 +1,347 @@ +//! ADR-146 — RF encoder multi-task heads + uncertainty quantification. +//! +//! Extends ADR-024 (AETHER contrastive embedding) with seven task-specific head +//! branches over a shared RF embedding, per-head uncertainty, a +//! calibration-robustness loss tying invariance to the ADR-135 `calibration_id`, +//! and a `ContrastiveBatcher` sampling contract. The tensor ABI is **pure-Rust +//! `f32`** (no backend-specific tensor type at this boundary) so inference is +//! deterministic and witnessable (ADR-136 §2.5) and a head can be toggled by the +//! ADR-145 ablation matrix. + +/// Shared RF embedding dimension (ADR-146 / ADR-024 AETHER). +pub const EMBEDDING_DIM: usize = 256; + +/// A 256-d shared RF embedding (pure-Rust f32 ABI). +#[derive(Debug, Clone, PartialEq)] +pub struct RfEmbedding(pub Vec); + +impl RfEmbedding { + /// Wrap a vector, asserting it is [`EMBEDDING_DIM`] long. + #[must_use] + pub fn new(v: Vec) -> Self { + debug_assert_eq!(v.len(), EMBEDDING_DIM, "embedding must be {EMBEDDING_DIM}-d"); + Self(v) + } + + /// Squared L2 distance to another embedding. + #[must_use] + pub fn sq_dist(&self, other: &RfEmbedding) -> f32 { + self.0.iter().zip(&other.0).map(|(a, b)| (a - b).powi(2)).sum() + } +} + +/// The seven task heads over the shared encoder (ADR-146 §2.1). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TaskKind { + /// 17-keypoint pose. + Pose, + /// Binary presence. + Presence, + /// Person count. + Count, + /// Activity class. + Activity, + /// Vital signs (HR/BR). + Vitals, + /// Gait signature. + Gait, + /// Identity embedding (AETHER re-ID). + IdentityEmbedding, +} + +impl TaskKind { + /// All seven heads. + pub const ALL: [TaskKind; 7] = [ + TaskKind::Pose, + TaskKind::Presence, + TaskKind::Count, + TaskKind::Activity, + TaskKind::Vitals, + TaskKind::Gait, + TaskKind::IdentityEmbedding, + ]; +} + +/// One head's output: task values plus a scalar predictive uncertainty +/// (ADR-146 §2.2). `uncertainty` mirrors the spirit of the ADR-136 +/// `QualityScored` trait — lower is more confident. +#[derive(Debug, Clone, PartialEq)] +pub struct HeadOutput { + /// Which head produced this. + pub task: TaskKind, + /// Raw output activations. + pub values: Vec, + /// Predictive uncertainty in [0, ∞); softplus of a learned log-variance. + pub uncertainty: f32, +} + +impl HeadOutput { + /// Confidence in [0, 1] derived from uncertainty (`1 / (1 + uncertainty)`), + /// matching the ADR-136 `QualityScored::quality_score` contract shape. + #[must_use] + pub fn confidence(&self) -> f32 { + 1.0 / (1.0 + self.uncertainty) + } +} + +/// A linear task head: `out = W·emb + b`, plus a separate scalar log-variance +/// projection `lv = wᵥ·emb + bᵥ` whose softplus is the predictive uncertainty. +#[derive(Debug, Clone)] +pub struct LinearHead { + task: TaskKind, + /// Row-major `[out_dim × EMBEDDING_DIM]` weights. + w: Vec, + b: Vec, + out_dim: usize, + /// Uncertainty (log-variance) projection over the embedding. + var_w: Vec, + var_b: f32, +} + +impl LinearHead { + /// Build a head with given weights. `w.len()` must be `out_dim * EMBEDDING_DIM`. + #[must_use] + pub fn new(task: TaskKind, out_dim: usize, w: Vec, b: Vec, var_w: Vec, var_b: f32) -> Self { + assert_eq!(w.len(), out_dim * EMBEDDING_DIM, "weight shape mismatch"); + assert_eq!(b.len(), out_dim, "bias shape mismatch"); + assert_eq!(var_w.len(), EMBEDDING_DIM, "var weight shape mismatch"); + Self { task, w, b, out_dim, var_w, var_b } + } + + /// A zero-initialised head (uncertainty = softplus(0) ≈ 0.693). + #[must_use] + pub fn zeros(task: TaskKind, out_dim: usize) -> Self { + Self::new( + task, + out_dim, + vec![0.0; out_dim * EMBEDDING_DIM], + vec![0.0; out_dim], + vec![0.0; EMBEDDING_DIM], + 0.0, + ) + } + + /// Forward pass over a shared embedding. + #[must_use] + pub fn forward(&self, emb: &RfEmbedding) -> HeadOutput { + let mut values = vec![0.0f32; self.out_dim]; + for o in 0..self.out_dim { + let row = &self.w[o * EMBEDDING_DIM..(o + 1) * EMBEDDING_DIM]; + let dot: f32 = row.iter().zip(&emb.0).map(|(wi, xi)| wi * xi).sum(); + values[o] = dot + self.b[o]; + } + let log_var: f32 = self.var_w.iter().zip(&emb.0).map(|(wi, xi)| wi * xi).sum::() + self.var_b; + let uncertainty = softplus(log_var); + HeadOutput { task: self.task, values, uncertainty } + } +} + +fn softplus(x: f32) -> f32 { + // Numerically stable softplus. + if x > 20.0 { + x + } else { + (1.0 + x.exp()).ln() + } +} + +/// Multi-task encoder: a shared embedding feeding a set of [`LinearHead`]s +/// (ADR-146 §2.1). Heads can be subset for ADR-145 ablation. +#[derive(Debug, Clone, Default)] +pub struct MultiTaskHeads { + heads: Vec, +} + +impl MultiTaskHeads { + /// Empty head set. + #[must_use] + pub fn new() -> Self { + Self { heads: Vec::new() } + } + + /// Add a head. + pub fn push(&mut self, head: LinearHead) { + self.heads.push(head); + } + + /// Number of active heads. + #[must_use] + pub fn len(&self) -> usize { + self.heads.len() + } + + /// Whether no heads are configured. + #[must_use] + pub fn is_empty(&self) -> bool { + self.heads.is_empty() + } + + /// Run every head on the shared embedding. + #[must_use] + pub fn forward(&self, emb: &RfEmbedding) -> Vec { + self.heads.iter().map(|h| h.forward(emb)).collect() + } + + /// Run only the heads in `enabled` (ADR-145 ablation toggle). + #[must_use] + pub fn forward_subset(&self, emb: &RfEmbedding, enabled: &[TaskKind]) -> Vec { + self.heads + .iter() + .filter(|h| enabled.contains(&h.task)) + .map(|h| h.forward(emb)) + .collect() + } +} + +/// Calibration-robustness loss (ADR-146 §2.3): the encoder should produce the +/// same embedding for the same physical input under two different ADR-135 +/// calibration baselines. Returns the mean squared embedding difference — a +/// penalty that is 0 under perfect calibration invariance. +#[must_use] +pub fn calibration_robustness_loss(under_cal_a: &RfEmbedding, under_cal_b: &RfEmbedding) -> f32 { + under_cal_a.sq_dist(under_cal_b) / EMBEDDING_DIM as f32 +} + +/// Triplet contrastive loss (ADR-024 / ADR-146 §2.4): pull `anchor` toward +/// `positive` (same physical state), push from `negative` (different), with a +/// margin. `max(0, d(a,p) - d(a,n) + margin)`. +#[must_use] +pub fn triplet_loss(anchor: &RfEmbedding, positive: &RfEmbedding, negative: &RfEmbedding, margin: f32) -> f32 { + (anchor.sq_dist(positive) - anchor.sq_dist(negative) + margin).max(0.0) +} + +/// A contrastive training triplet over the shared embedding space. +#[derive(Debug, Clone)] +pub struct Triplet { + /// Anchor sample index. + pub anchor: usize, + /// Positive (same state, different environment) index. + pub positive: usize, + /// Negative (different state) index. + pub negative: usize, +} + +/// Formalised contrastive pair/triplet sampler (ADR-146 §2.4): positives are the +/// *same physical state across different environments* (cross-room invariance, +/// ADR-027 MERIDIAN); negatives are *different states*. +#[derive(Debug, Clone)] +pub struct ContrastiveBatcher { + /// `state_of[i]` = the physical-state label of sample `i`. + state_of: Vec, + /// `env_of[i]` = the environment/room label of sample `i`. + env_of: Vec, +} + +impl ContrastiveBatcher { + /// Build from per-sample (state, environment) labels. + #[must_use] + pub fn new(state_of: Vec, env_of: Vec) -> Self { + assert_eq!(state_of.len(), env_of.len(), "label vectors must align"); + Self { state_of, env_of } + } + + /// Deterministically enumerate triplets: for each anchor, the first sample + /// with the *same state but a different environment* is the positive, and + /// the first sample with a *different state* is the negative. Anchors with + /// no valid positive or negative are skipped. Determinism (lowest-index + /// choice) keeps the batch witnessable (ADR-136 §2.5). + #[must_use] + pub fn triplets(&self) -> Vec { + let n = self.state_of.len(); + let mut out = Vec::new(); + for a in 0..n { + let positive = (0..n).find(|&p| { + p != a && self.state_of[p] == self.state_of[a] && self.env_of[p] != self.env_of[a] + }); + let negative = (0..n).find(|&q| self.state_of[q] != self.state_of[a]); + if let (Some(positive), Some(negative)) = (positive, negative) { + out.push(Triplet { anchor: a, positive, negative }); + } + } + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn emb(fill: f32) -> RfEmbedding { + RfEmbedding::new(vec![fill; EMBEDDING_DIM]) + } + + #[test] + fn head_forward_produces_values_and_finite_uncertainty() { + let head = LinearHead::zeros(TaskKind::Presence, 2); + let out = head.forward(&emb(1.0)); + assert_eq!(out.values, vec![0.0, 0.0]); // zero weights + assert!(out.uncertainty.is_finite() && out.uncertainty > 0.0); + assert!((out.confidence() - 1.0 / (1.0 + out.uncertainty)).abs() < 1e-6); + } + + #[test] + fn uncertainty_responds_to_log_variance_weights() { + // var_w all 1 → log_var = sum(emb) = 256 → softplus ≈ 256 (clamped path). + let head = LinearHead::new( + TaskKind::Vitals, + 1, + vec![0.0; EMBEDDING_DIM], + vec![0.0], + vec![1.0; EMBEDDING_DIM], + 0.0, + ); + let out = head.forward(&emb(1.0)); + assert!(out.uncertainty > 100.0, "high log-var → high uncertainty"); + assert!(out.confidence() < 0.02); + } + + #[test] + fn calibration_robustness_loss_zero_for_identical() { + assert_eq!(calibration_robustness_loss(&emb(0.5), &emb(0.5)), 0.0); + assert!(calibration_robustness_loss(&emb(0.0), &emb(1.0)) > 0.0); + } + + #[test] + fn triplet_loss_properties() { + let a = emb(0.0); + let p = emb(0.1); // close + let n = emb(5.0); // far + // d(a,p) << d(a,n) → loss should be 0 with a modest margin. + assert_eq!(triplet_loss(&a, &p, &n, 0.5), 0.0); + // Swap: positive far, negative close → positive loss. + assert!(triplet_loss(&a, &n, &p, 0.5) > 0.0); + } + + #[test] + fn multitask_subset_ablation() { + let mut heads = MultiTaskHeads::new(); + heads.push(LinearHead::zeros(TaskKind::Presence, 1)); + heads.push(LinearHead::zeros(TaskKind::Pose, 51)); + heads.push(LinearHead::zeros(TaskKind::Vitals, 2)); + assert_eq!(heads.forward(&emb(1.0)).len(), 3); + // Ablate to just presence + vitals. + let sub = heads.forward_subset(&emb(1.0), &[TaskKind::Presence, TaskKind::Vitals]); + assert_eq!(sub.len(), 2); + assert!(sub.iter().all(|o| o.task != TaskKind::Pose)); + } + + #[test] + fn contrastive_batcher_samples_cross_env_positives() { + // samples: 0=(stateA,room0) 1=(stateA,room1) 2=(stateB,room0) + let b = ContrastiveBatcher::new(vec![0, 0, 1], vec![0, 1, 0]); + let trips = b.triplets(); + // Anchor 0: positive=1 (same state, diff room), negative=2 (diff state). + let t0 = trips.iter().find(|t| t.anchor == 0).unwrap(); + assert_eq!(t0.positive, 1); + assert_eq!(t0.negative, 2); + // Anchor 2 (stateB) has no same-state-diff-env positive → skipped. + assert!(trips.iter().all(|t| t.anchor != 2)); + // Deterministic. + assert_eq!(b.triplets().len(), trips.len()); + } + + #[test] + fn seven_task_heads() { + assert_eq!(TaskKind::ALL.len(), 7); + } +}