fix(train,nn): Tier-2 correctness/security — metric scale, OOM bounds, panics (ADR-155 §Tier-2)
Each fix ships a test that would have caught the bug: - ruview_metrics OKS: derive scale from GT extent (no s=1.0 fake-Gold), reject s<=0, bound the loop to array extents (no panic on short/adversarial input). - config.validate(): UPPER bounds on window_frames/subcarriers/backbone_channels/ heatmap_size/keypoints/body_parts/batch_size + reject negative gpu_device_id (closes the config-OOM class); defaults+presets still validate. - subcarrier.rs: graceful fallback instead of panic on non-contiguous input. - ablation.rs latency_percentiles: total_cmp + NaN guard (no partial_cmp unwrap). - tensor.rs softmax(axis): normalize per-lane along the given axis (was whole- tensor), out-of-range axis -> NnError; fixes densepose per-pixel probs. - translator.rs apply_attention: real scaled-dot-product attention (was a uniform 1/seq_len stub that made any "with attention" ablation == without); mis-shaped checkpoint projections rejected. Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
84e2c920fd
commit
aa3a6725a6
|
|
@ -4,11 +4,39 @@
|
||||||
//! different backends (ONNX, tch, Candle).
|
//! different backends (ONNX, tch, Candle).
|
||||||
|
|
||||||
use crate::error::{NnError, NnResult};
|
use crate::error::{NnError, NnResult};
|
||||||
use ndarray::{Array1, Array2, Array3, Array4, ArrayD};
|
use ndarray::{Array1, Array2, Array3, Array4, ArrayD, ArrayViewMutD, Axis};
|
||||||
// num_traits is available if needed for advanced tensor operations
|
// num_traits is available if needed for advanced tensor operations
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
/// Apply a numerically-stable softmax in place to every 1-D lane of `view`
|
||||||
|
/// taken along `axis`. Each lane is shifted by its own max before
|
||||||
|
/// exponentiation, then divided by its own sum, so every lane sums to 1.0
|
||||||
|
/// independently — the per-pixel / per-class normalization densepose needs.
|
||||||
|
///
|
||||||
|
/// `axis` MUST be validated as in-range by the caller.
|
||||||
|
fn softmax_inplace_along_axis(mut view: ArrayViewMutD<'_, f32>, axis: usize) {
|
||||||
|
for mut lane in view.lanes_mut(Axis(axis)) {
|
||||||
|
let max = lane.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
// An all-`-inf` (or empty) lane has no finite max; leave it untouched
|
||||||
|
// to avoid producing NaNs from `exp(-inf - -inf)`.
|
||||||
|
if !max.is_finite() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let mut sum = 0.0f32;
|
||||||
|
for v in lane.iter_mut() {
|
||||||
|
let e = (*v - max).exp();
|
||||||
|
*v = e;
|
||||||
|
sum += e;
|
||||||
|
}
|
||||||
|
if sum > 0.0 {
|
||||||
|
for v in lane.iter_mut() {
|
||||||
|
*v /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Shape of a tensor
|
/// Shape of a tensor
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
pub struct TensorShape(Vec<usize>);
|
pub struct TensorShape(Vec<usize>);
|
||||||
|
|
@ -288,14 +316,39 @@ impl Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply softmax along axis
|
/// Apply softmax along the given `axis`.
|
||||||
pub fn softmax(&self, _axis: usize) -> NnResult<Tensor> {
|
///
|
||||||
|
/// Each 1-D lane along `axis` is normalized independently so it sums to
|
||||||
|
/// 1.0. This is the correct semantics for per-pixel / per-class probability
|
||||||
|
/// maps (e.g. DensePose body-part logits over the channel axis). A
|
||||||
|
/// numerically-stable max-shift is applied per lane.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// Returns [`NnError`] if `axis` is out of range for the tensor's rank, or
|
||||||
|
/// if the tensor type is unsupported.
|
||||||
|
pub fn softmax(&self, axis: usize) -> NnResult<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Tensor::Float4D(a) => {
|
Tensor::Float4D(a) => {
|
||||||
let max = a.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
|
if axis >= a.ndim() {
|
||||||
let exp = a.mapv(|x| (x - max).exp());
|
return Err(NnError::tensor_op(format!(
|
||||||
let sum = exp.sum();
|
"softmax axis {axis} out of range for {}-D tensor",
|
||||||
Ok(Tensor::Float4D(exp / sum))
|
a.ndim()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
let mut out = a.clone();
|
||||||
|
softmax_inplace_along_axis(out.view_mut().into_dyn(), axis);
|
||||||
|
Ok(Tensor::Float4D(out))
|
||||||
|
}
|
||||||
|
Tensor::FloatND(a) => {
|
||||||
|
if axis >= a.ndim() {
|
||||||
|
return Err(NnError::tensor_op(format!(
|
||||||
|
"softmax axis {axis} out of range for {}-D tensor",
|
||||||
|
a.ndim()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
let mut out = a.clone();
|
||||||
|
softmax_inplace_along_axis(out.view_mut(), axis);
|
||||||
|
Ok(Tensor::FloatND(out))
|
||||||
}
|
}
|
||||||
_ => Err(NnError::tensor_op(
|
_ => Err(NnError::tensor_op(
|
||||||
"Softmax not supported for this tensor type",
|
"Softmax not supported for this tensor type",
|
||||||
|
|
@ -517,6 +570,67 @@ mod tests {
|
||||||
assert!(sigmoid.max().unwrap() < 1.0);
|
assert!(sigmoid.max().unwrap() < 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: softmax(axis) must normalize along the GIVEN axis
|
||||||
|
// (per-lane sum == 1), not over the whole tensor.
|
||||||
|
#[test]
|
||||||
|
fn test_softmax_axis_sums_to_one_per_lane() {
|
||||||
|
// 2x3x1x1 tensor; softmax along axis 1 (the size-3 axis).
|
||||||
|
let arr =
|
||||||
|
Array4::from_shape_vec([2, 3, 1, 1], vec![1.0f32, 2.0, 3.0, -1.0, 0.0, 1.0]).unwrap();
|
||||||
|
let t = Tensor::Float4D(arr);
|
||||||
|
let sm = t.softmax(1).unwrap();
|
||||||
|
let out = sm.as_array4().unwrap();
|
||||||
|
// Each lane along axis 1 must sum to 1.0.
|
||||||
|
for b in 0..2 {
|
||||||
|
let lane_sum: f32 = (0..3).map(|c| out[[b, c, 0, 0]]).sum();
|
||||||
|
assert!((lane_sum - 1.0).abs() < 1e-6, "lane {b} sum = {lane_sum}");
|
||||||
|
}
|
||||||
|
// Probabilities must be ordered like the logits within a lane.
|
||||||
|
assert!(out[[0, 0, 0, 0]] < out[[0, 1, 0, 0]]);
|
||||||
|
assert!(out[[0, 1, 0, 0]] < out[[0, 2, 0, 0]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: softmax along different axes must give different
|
||||||
|
// results — the old global-softmax bug ignored the axis entirely.
|
||||||
|
#[test]
|
||||||
|
fn test_softmax_axis_choice_matters() {
|
||||||
|
let arr = Array4::from_shape_vec([1, 2, 2, 1], vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
|
||||||
|
let t = Tensor::Float4D(arr);
|
||||||
|
let along1 = t.softmax(1).unwrap();
|
||||||
|
let along2 = t.softmax(2).unwrap();
|
||||||
|
let a1 = along1.as_array4().unwrap();
|
||||||
|
let a2 = along2.as_array4().unwrap();
|
||||||
|
// The two normalizations partition the values differently, so at least
|
||||||
|
// one element must differ.
|
||||||
|
let mut differs = false;
|
||||||
|
for h in 0..2 {
|
||||||
|
if (a1[[0, 0, h, 0]] - a2[[0, 0, h, 0]]).abs() > 1e-6 {
|
||||||
|
differs = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(differs, "softmax along axis 1 must differ from axis 2");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: known-value check on a tiny tensor.
|
||||||
|
#[test]
|
||||||
|
fn test_softmax_known_values() {
|
||||||
|
// Lane [0, ln(3)] along axis 1 → softmax = [1/4, 3/4].
|
||||||
|
let arr = Array4::from_shape_vec([1, 2, 1, 1], vec![0.0f32, 3.0f32.ln()]).unwrap();
|
||||||
|
let t = Tensor::Float4D(arr);
|
||||||
|
let out = t.softmax(1).unwrap();
|
||||||
|
let a = out.as_array4().unwrap();
|
||||||
|
assert!((a[[0, 0, 0, 0]] - 0.25).abs() < 1e-6);
|
||||||
|
assert!((a[[0, 1, 0, 0]] - 0.75).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: out-of-range axis must return an error, never panic.
|
||||||
|
#[test]
|
||||||
|
fn test_softmax_axis_out_of_range_errors() {
|
||||||
|
let t = Tensor::zeros_4d([1, 2, 2, 2]);
|
||||||
|
assert!(t.softmax(4).is_err());
|
||||||
|
assert!(t.softmax(99).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_broadcast_compatible() {
|
fn test_broadcast_compatible() {
|
||||||
let a = TensorShape::new(vec![1, 3, 224, 224]);
|
let a = TensorShape::new(vec![1, 3, 224, 224]);
|
||||||
|
|
|
||||||
|
|
@ -556,34 +556,122 @@ impl ModalityTranslator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply multi-head attention
|
/// Apply single-head scaled-dot-product attention over the spatial
|
||||||
|
/// sequence: `softmax(Q·Kᵀ / √d) · V`, with `Q/K/V` linear projections of
|
||||||
|
/// each token's channel vector and a final output projection.
|
||||||
|
///
|
||||||
|
/// The spatial grid `[B, C, H, W]` is treated as a length-`H·W` token
|
||||||
|
/// sequence of `C`-dim feature vectors. Each `*_weight` projection is a
|
||||||
|
/// `[C × C]` matrix applied per token. This is a genuine attention
|
||||||
|
/// operation (not the previous uniform-weight identity stub), so the
|
||||||
|
/// returned per-pair attention weights actually depend on the input.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// Returns an error if any projection weight is not `[C × C]`, so a
|
||||||
|
/// mis-shaped checkpoint can never be silently treated as a no-op.
|
||||||
fn apply_attention(
|
fn apply_attention(
|
||||||
&self,
|
&self,
|
||||||
input: &Array4<f32>,
|
input: &Array4<f32>,
|
||||||
_weights: &AttentionWeights,
|
weights: &AttentionWeights,
|
||||||
) -> NnResult<(Array4<f32>, Array4<f32>)> {
|
) -> NnResult<(Array4<f32>, Array4<f32>)> {
|
||||||
let (batch, channels, height, width) = input.dim();
|
let (batch, channels, height, width) = input.dim();
|
||||||
let seq_len = height * width;
|
let seq_len = height * width;
|
||||||
|
|
||||||
// Flatten spatial dimensions
|
// Every projection must be a square [C × C] matrix to act per token.
|
||||||
let mut flat = ndarray::Array2::zeros((batch, seq_len * channels));
|
for (name, w) in [
|
||||||
|
("query_weight", &weights.query_weight),
|
||||||
|
("key_weight", &weights.key_weight),
|
||||||
|
("value_weight", &weights.value_weight),
|
||||||
|
("output_weight", &weights.output_weight),
|
||||||
|
] {
|
||||||
|
if w.dim() != (channels, channels) {
|
||||||
|
return Err(NnError::invalid_input(format!(
|
||||||
|
"attention {name} must be [{channels} x {channels}], got [{} x {}]",
|
||||||
|
w.dim().0,
|
||||||
|
w.dim().1
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if weights.output_bias.len() != channels {
|
||||||
|
return Err(NnError::shape_mismatch(
|
||||||
|
vec![channels],
|
||||||
|
vec![weights.output_bias.len()],
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten spatial grid into a [seq_len, channels] token matrix per batch.
|
||||||
|
// Project to Q, K, V; compute scaled-dot-product attention; project out.
|
||||||
|
let scale = 1.0 / (channels as f32).sqrt();
|
||||||
|
let mut out = Array4::zeros((batch, channels, height, width));
|
||||||
|
let mut attention_weights = Array4::zeros((batch, 1, seq_len, seq_len));
|
||||||
|
|
||||||
for b in 0..batch {
|
for b in 0..batch {
|
||||||
|
// Tokens: [seq_len, channels].
|
||||||
|
let mut tokens = ndarray::Array2::<f32>::zeros((seq_len, channels));
|
||||||
for h in 0..height {
|
for h in 0..height {
|
||||||
for w in 0..width {
|
for w in 0..width {
|
||||||
|
let s = h * width + w;
|
||||||
for c in 0..channels {
|
for c in 0..channels {
|
||||||
flat[[b, (h * width + w) * channels + c]] = input[[b, c, h, w]];
|
tokens[[s, c]] = input[[b, c, h, w]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Q = tokens·Wqᵀ, etc. (row vector × [C×C] projection).
|
||||||
|
let q = tokens.dot(&weights.query_weight.t());
|
||||||
|
let k = tokens.dot(&weights.key_weight.t());
|
||||||
|
let v = tokens.dot(&weights.value_weight.t());
|
||||||
|
|
||||||
|
// Scores = softmax_row(Q·Kᵀ · scale), then context = Scores·V.
|
||||||
|
let scores = q.dot(&k.t()).mapv(|x| x * scale);
|
||||||
|
for i in 0..seq_len {
|
||||||
|
// Numerically-stable row softmax.
|
||||||
|
let mut max = f32::NEG_INFINITY;
|
||||||
|
for j in 0..seq_len {
|
||||||
|
max = max.max(scores[[i, j]]);
|
||||||
|
}
|
||||||
|
let mut sum = 0.0f32;
|
||||||
|
let mut row = vec![0.0f32; seq_len];
|
||||||
|
for j in 0..seq_len {
|
||||||
|
let e = (scores[[i, j]] - max).exp();
|
||||||
|
row[j] = e;
|
||||||
|
sum += e;
|
||||||
|
}
|
||||||
|
if sum > 0.0 {
|
||||||
|
for j in 0..seq_len {
|
||||||
|
row[j] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in 0..seq_len {
|
||||||
|
attention_weights[[b, 0, i, j]] = row[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context = attention · V, then output projection + bias.
|
||||||
|
for h in 0..height {
|
||||||
|
for w in 0..width {
|
||||||
|
let i = h * width + w;
|
||||||
|
// ctx[c] = Σ_j attn[i,j] · v[j,c]
|
||||||
|
let mut ctx = vec![0.0f32; channels];
|
||||||
|
for j in 0..seq_len {
|
||||||
|
let a = attention_weights[[b, 0, i, j]];
|
||||||
|
for c in 0..channels {
|
||||||
|
ctx[c] += a * v[[j, c]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// out[c] = Σ_c' ctx[c'] · Wo[c, c'] + bias[c]
|
||||||
|
for c in 0..channels {
|
||||||
|
let mut acc = weights.output_bias[c];
|
||||||
|
for cp in 0..channels {
|
||||||
|
acc += ctx[cp] * weights.output_weight[[c, cp]];
|
||||||
|
}
|
||||||
|
out[[b, c, h, w]] = acc;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For simplicity, return input unchanged with identity attention
|
Ok((out, attention_weights))
|
||||||
let attention_weights = Array4::from_elem(
|
|
||||||
(batch, self.config.attention_heads, seq_len, seq_len),
|
|
||||||
1.0 / seq_len as f32,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok((input.clone(), attention_weights))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute translation loss between predicted and target features
|
/// Compute translation loss between predicted and target features
|
||||||
|
|
@ -760,6 +848,76 @@ mod tests {
|
||||||
assert_eq!(config.activation, ActivationType::GELU);
|
assert_eq!(config.activation, ActivationType::GELU);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: apply_attention must perform real scaled-dot-product
|
||||||
|
// attention, not return uniform 1/seq_len weights. With identity Q/K/V
|
||||||
|
// projections and a non-uniform input, the attention weights must NOT all
|
||||||
|
// equal 1/seq_len, and each row must still be a valid distribution.
|
||||||
|
#[test]
|
||||||
|
fn test_attention_is_not_uniform_stub() {
|
||||||
|
let channels = 4usize;
|
||||||
|
let height = 2usize;
|
||||||
|
let width = 2usize;
|
||||||
|
let seq_len = height * width;
|
||||||
|
|
||||||
|
// Identity projections so Q=K=V=tokens; output = identity, zero bias.
|
||||||
|
let identity = ndarray::Array2::<f32>::eye(channels);
|
||||||
|
let weights = AttentionWeights {
|
||||||
|
query_weight: identity.clone(),
|
||||||
|
key_weight: identity.clone(),
|
||||||
|
value_weight: identity.clone(),
|
||||||
|
output_weight: identity,
|
||||||
|
output_bias: ndarray::Array1::zeros(channels),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Non-uniform input: each spatial location has a distinct feature vector.
|
||||||
|
let mut input = Array4::<f32>::zeros((1, channels, height, width));
|
||||||
|
for c in 0..channels {
|
||||||
|
for h in 0..height {
|
||||||
|
for w in 0..width {
|
||||||
|
input[[0, c, h, w]] = (c + 2 * h + 4 * w) as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = TranslatorConfig::default().with_attention(1);
|
||||||
|
let translator = ModalityTranslator::new(config).unwrap();
|
||||||
|
let (out, attn) = translator.apply_attention(&input, &weights).unwrap();
|
||||||
|
|
||||||
|
// Each attention row must sum to 1 (valid softmax distribution).
|
||||||
|
for i in 0..seq_len {
|
||||||
|
let row_sum: f32 = (0..seq_len).map(|j| attn[[0, 0, i, j]]).sum();
|
||||||
|
assert!((row_sum - 1.0).abs() < 1e-5, "row {i} sum = {row_sum}");
|
||||||
|
}
|
||||||
|
// Weights must NOT all be the uniform 1/seq_len value of the old stub.
|
||||||
|
let uniform = 1.0 / seq_len as f32;
|
||||||
|
let any_non_uniform = (0..seq_len)
|
||||||
|
.flat_map(|i| (0..seq_len).map(move |j| (i, j)))
|
||||||
|
.any(|(i, j)| (attn[[0, 0, i, j]] - uniform).abs() > 1e-4);
|
||||||
|
assert!(any_non_uniform, "attention collapsed to uniform stub");
|
||||||
|
// Output is finite and shaped like the input.
|
||||||
|
assert_eq!(out.dim(), input.dim());
|
||||||
|
assert!(out.iter().all(|v| v.is_finite()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: a mis-shaped projection weight must be rejected, never
|
||||||
|
// silently treated as a no-op.
|
||||||
|
#[test]
|
||||||
|
fn test_attention_rejects_wrong_weight_shape() {
|
||||||
|
let channels = 4usize;
|
||||||
|
let bad = ndarray::Array2::<f32>::zeros((channels + 1, channels));
|
||||||
|
let weights = AttentionWeights {
|
||||||
|
query_weight: bad.clone(),
|
||||||
|
key_weight: bad.clone(),
|
||||||
|
value_weight: bad.clone(),
|
||||||
|
output_weight: bad,
|
||||||
|
output_bias: ndarray::Array1::zeros(channels),
|
||||||
|
};
|
||||||
|
let input = Array4::<f32>::zeros((1, channels, 2, 2));
|
||||||
|
let config = TranslatorConfig::default().with_attention(1);
|
||||||
|
let translator = ModalityTranslator::new(config).unwrap();
|
||||||
|
assert!(translator.apply_attention(&input, &weights).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_loss_computation() {
|
fn test_loss_computation() {
|
||||||
let config = TranslatorConfig::default();
|
let config = TranslatorConfig::default();
|
||||||
|
|
|
||||||
|
|
@ -53,13 +53,24 @@ impl FeatureSet {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `(p50, p95)` percentiles of a latency sample set (ms), nearest-rank.
|
/// `(p50, p95)` percentiles of a latency sample set (ms), nearest-rank.
|
||||||
|
///
|
||||||
|
/// Non-finite samples (NaN / ±inf) are discarded before ranking. Sorting uses
|
||||||
|
/// [`f64::total_cmp`] so a stray NaN can never trigger a `partial_cmp().unwrap()`
|
||||||
|
/// panic (ADR-155 §Tier-2). If every sample is non-finite (or the slice is
|
||||||
|
/// empty), returns `(0.0, 0.0)`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn latency_percentiles_ms(samples_ms: &[f64]) -> (f64, f64) {
|
pub fn latency_percentiles_ms(samples_ms: &[f64]) -> (f64, f64) {
|
||||||
if samples_ms.is_empty() {
|
// Drop non-finite values: a NaN latency is meaningless and must not poison
|
||||||
|
// the ranking or panic the sort.
|
||||||
|
let mut s: Vec<f64> = samples_ms
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.filter(|v| v.is_finite())
|
||||||
|
.collect();
|
||||||
|
if s.is_empty() {
|
||||||
return (0.0, 0.0);
|
return (0.0, 0.0);
|
||||||
}
|
}
|
||||||
let mut s = samples_ms.to_vec();
|
s.sort_by(f64::total_cmp);
|
||||||
s.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
||||||
let pick = |q: f64| {
|
let pick = |q: f64| {
|
||||||
// Nearest-rank: ceil(q * n) - 1, clamped.
|
// Nearest-rank: ceil(q * n) - 1, clamped.
|
||||||
let rank = ((q * s.len() as f64).ceil() as usize).clamp(1, s.len()) - 1;
|
let rank = ((q * s.len() as f64).ceil() as usize).clamp(1, s.len()) - 1;
|
||||||
|
|
@ -71,8 +82,16 @@ pub fn latency_percentiles_ms(samples_ms: &[f64]) -> (f64, f64) {
|
||||||
/// False-positive and false-negative rates from a confusion count.
|
/// False-positive and false-negative rates from a confusion count.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn confusion_rates(tp: u64, fp: u64, tn: u64, fn_: u64) -> (f64, f64) {
|
pub fn confusion_rates(tp: u64, fp: u64, tn: u64, fn_: u64) -> (f64, f64) {
|
||||||
let fp_rate = if fp + tn == 0 { 0.0 } else { fp as f64 / (fp + tn) as f64 };
|
let fp_rate = if fp + tn == 0 {
|
||||||
let fn_rate = if fn_ + tp == 0 { 0.0 } else { fn_ as f64 / (fn_ + tp) as f64 };
|
0.0
|
||||||
|
} else {
|
||||||
|
fp as f64 / (fp + tn) as f64
|
||||||
|
};
|
||||||
|
let fn_rate = if fn_ + tp == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
fn_ as f64 / (fn_ + tp) as f64
|
||||||
|
};
|
||||||
(fp_rate, fn_rate)
|
(fp_rate, fn_rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -164,7 +183,10 @@ impl AblationMetrics {
|
||||||
fn_rate,
|
fn_rate,
|
||||||
latency_p50_ms: p50,
|
latency_p50_ms: p50,
|
||||||
latency_p95_ms: p95,
|
latency_p95_ms: p95,
|
||||||
privacy_leakage: membership_inference_leakage(&run.member_scores, &run.nonmember_scores),
|
privacy_leakage: membership_inference_leakage(
|
||||||
|
&run.member_scores,
|
||||||
|
&run.nonmember_scores,
|
||||||
|
),
|
||||||
cross_room_degradation: (run.room_a_accuracy - run.room_b_accuracy).max(0.0),
|
cross_room_degradation: (run.room_a_accuracy - run.room_b_accuracy).max(0.0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -181,7 +203,9 @@ impl AblationReport {
|
||||||
/// Build from a set of variant runs.
|
/// Build from a set of variant runs.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn from_runs(runs: &[VariantRun]) -> Self {
|
pub fn from_runs(runs: &[VariantRun]) -> Self {
|
||||||
Self { rows: runs.iter().map(AblationMetrics::from_run).collect() }
|
Self {
|
||||||
|
rows: runs.iter().map(AblationMetrics::from_run).collect(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Look up a variant's metrics.
|
/// Look up a variant's metrics.
|
||||||
|
|
@ -194,7 +218,8 @@ impl AblationReport {
|
||||||
/// least `min_wins` of {presence accuracy ↑, localisation error ↓, p95 latency ↓}?
|
/// least `min_wins` of {presence accuracy ↑, localisation error ↓, p95 latency ↓}?
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn csi_cir_beats_csi_only(&self, min_wins: usize) -> bool {
|
pub fn csi_cir_beats_csi_only(&self, min_wins: usize) -> bool {
|
||||||
let (Some(a), Some(b)) = (self.get(FeatureSet::CsiOnly), self.get(FeatureSet::CsiCir)) else {
|
let (Some(a), Some(b)) = (self.get(FeatureSet::CsiOnly), self.get(FeatureSet::CsiCir))
|
||||||
|
else {
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
let wins = [
|
let wins = [
|
||||||
|
|
@ -249,6 +274,30 @@ mod tests {
|
||||||
assert_eq!(latency_percentiles_ms(&[]), (0.0, 0.0));
|
assert_eq!(latency_percentiles_ms(&[]), (0.0, 0.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: a NaN in the latency samples must NOT panic the sort
|
||||||
|
// (the old `partial_cmp().unwrap()` did) and must yield a sane percentile
|
||||||
|
// computed over the finite values only.
|
||||||
|
#[test]
|
||||||
|
fn latency_percentiles_with_nan_does_not_panic() {
|
||||||
|
let s = vec![
|
||||||
|
10.0,
|
||||||
|
f64::NAN,
|
||||||
|
20.0,
|
||||||
|
30.0,
|
||||||
|
f64::INFINITY,
|
||||||
|
40.0,
|
||||||
|
f64::NEG_INFINITY,
|
||||||
|
50.0,
|
||||||
|
];
|
||||||
|
let (p50, p95) = latency_percentiles_ms(&s);
|
||||||
|
// Finite set is [10,20,30,40,50]; nearest-rank p50=30, p95=50.
|
||||||
|
assert!(p50.is_finite() && p95.is_finite());
|
||||||
|
assert!((p50 - 30.0).abs() < 1e-9);
|
||||||
|
assert!((p95 - 50.0).abs() < 1e-9);
|
||||||
|
// All-NaN input degrades gracefully to (0, 0).
|
||||||
|
assert_eq!(latency_percentiles_ms(&[f64::NAN, f64::NAN]), (0.0, 0.0));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn confusion_rates_basic() {
|
fn confusion_rates_basic() {
|
||||||
let (fp_rate, fn_rate) = confusion_rates(80, 10, 90, 20);
|
let (fp_rate, fn_rate) = confusion_rates(80, 10, 90, 20);
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,43 @@ use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use crate::error::ConfigError;
|
use crate::error::ConfigError;
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Allocation-guard upper bounds (ADR-155 §Tier-2)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
//
|
||||||
|
// `validate()` historically only checked lower bounds, so a config with an
|
||||||
|
// absurd field (e.g. `window_frames = usize::MAX`) passed validation and only
|
||||||
|
// blew up later as an OOM / allocation-size overflow deep in the pipeline.
|
||||||
|
// These constants cap each dimensioning field at a value far above any real
|
||||||
|
// hardware configuration but well below the point where the product of
|
||||||
|
// dimensions overflows `usize` on a 64-bit allocation. They guard against
|
||||||
|
// allocation-overflow, not against "sensible" configs — every real preset
|
||||||
|
// stays orders of magnitude under these caps.
|
||||||
|
|
||||||
|
/// Maximum temporal window length, in frames. Caps the time dimension of every
|
||||||
|
/// CSI window allocation. Real captures use ≤ a few thousand frames.
|
||||||
|
pub const MAX_WINDOW_FRAMES: usize = 100_000;
|
||||||
|
|
||||||
|
/// Maximum subcarrier count (model or native). Real Wi-Fi captures top out in
|
||||||
|
/// the low hundreds; this leaves vast headroom while preventing overflow.
|
||||||
|
pub const MAX_SUBCARRIERS: usize = 100_000;
|
||||||
|
|
||||||
|
/// Maximum backbone feature-map channel count. Even large vision backbones use
|
||||||
|
/// a few thousand channels.
|
||||||
|
pub const MAX_BACKBONE_CHANNELS: usize = 1_000_000;
|
||||||
|
|
||||||
|
/// Maximum heatmap side length (H = W). Caps the square heatmap allocation.
|
||||||
|
pub const MAX_HEATMAP_SIZE: usize = 100_000;
|
||||||
|
|
||||||
|
/// Maximum number of keypoints. COCO uses 17; this is a wide safety margin.
|
||||||
|
pub const MAX_KEYPOINTS: usize = 10_000;
|
||||||
|
|
||||||
|
/// Maximum number of DensePose body-part classes. DensePose uses 24.
|
||||||
|
pub const MAX_BODY_PARTS: usize = 10_000;
|
||||||
|
|
||||||
|
/// Maximum mini-batch size. Guards the batch dimension of every allocation.
|
||||||
|
pub const MAX_BATCH_SIZE: usize = 1_000_000;
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// TrainingConfig
|
// TrainingConfig
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
@ -317,17 +354,36 @@ impl TrainingConfig {
|
||||||
/// increasing.
|
/// increasing.
|
||||||
/// - `save_top_k` must be at least 1.
|
/// - `save_top_k` must be at least 1.
|
||||||
/// - `val_every_epochs` must be at least 1.
|
/// - `val_every_epochs` must be at least 1.
|
||||||
|
/// - Dimensioning fields (`window_frames`, subcarrier counts,
|
||||||
|
/// `backbone_channels`, `heatmap_size`, `num_keypoints`,
|
||||||
|
/// `num_body_parts`, `batch_size`) must not exceed their
|
||||||
|
/// allocation-guard upper bounds (see `MAX_*` constants), so an absurd
|
||||||
|
/// value is rejected here rather than causing an OOM / allocation
|
||||||
|
/// overflow later in the pipeline.
|
||||||
|
/// - `gpu_device_id` must be non-negative.
|
||||||
pub fn validate(&self) -> Result<(), ConfigError> {
|
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||||
// Subcarrier counts
|
// Subcarrier counts
|
||||||
if self.num_subcarriers == 0 {
|
if self.num_subcarriers == 0 {
|
||||||
return Err(ConfigError::invalid_value("num_subcarriers", "must be > 0"));
|
return Err(ConfigError::invalid_value("num_subcarriers", "must be > 0"));
|
||||||
}
|
}
|
||||||
|
if self.num_subcarriers > MAX_SUBCARRIERS {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"num_subcarriers",
|
||||||
|
format!("must be <= {MAX_SUBCARRIERS} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
if self.native_subcarriers == 0 {
|
if self.native_subcarriers == 0 {
|
||||||
return Err(ConfigError::invalid_value(
|
return Err(ConfigError::invalid_value(
|
||||||
"native_subcarriers",
|
"native_subcarriers",
|
||||||
"must be > 0",
|
"must be > 0",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
if self.native_subcarriers > MAX_SUBCARRIERS {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"native_subcarriers",
|
||||||
|
format!("must be <= {MAX_SUBCARRIERS} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// Antenna counts
|
// Antenna counts
|
||||||
if self.num_antennas_tx == 0 {
|
if self.num_antennas_tx == 0 {
|
||||||
|
|
@ -341,30 +397,66 @@ impl TrainingConfig {
|
||||||
if self.window_frames == 0 {
|
if self.window_frames == 0 {
|
||||||
return Err(ConfigError::invalid_value("window_frames", "must be > 0"));
|
return Err(ConfigError::invalid_value("window_frames", "must be > 0"));
|
||||||
}
|
}
|
||||||
|
if self.window_frames > MAX_WINDOW_FRAMES {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"window_frames",
|
||||||
|
format!("must be <= {MAX_WINDOW_FRAMES} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// Heatmap
|
// Heatmap
|
||||||
if self.heatmap_size == 0 {
|
if self.heatmap_size == 0 {
|
||||||
return Err(ConfigError::invalid_value("heatmap_size", "must be > 0"));
|
return Err(ConfigError::invalid_value("heatmap_size", "must be > 0"));
|
||||||
}
|
}
|
||||||
|
if self.heatmap_size > MAX_HEATMAP_SIZE {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"heatmap_size",
|
||||||
|
format!("must be <= {MAX_HEATMAP_SIZE} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// Model dims
|
// Model dims
|
||||||
if self.num_keypoints == 0 {
|
if self.num_keypoints == 0 {
|
||||||
return Err(ConfigError::invalid_value("num_keypoints", "must be > 0"));
|
return Err(ConfigError::invalid_value("num_keypoints", "must be > 0"));
|
||||||
}
|
}
|
||||||
|
if self.num_keypoints > MAX_KEYPOINTS {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"num_keypoints",
|
||||||
|
format!("must be <= {MAX_KEYPOINTS} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
if self.num_body_parts == 0 {
|
if self.num_body_parts == 0 {
|
||||||
return Err(ConfigError::invalid_value("num_body_parts", "must be > 0"));
|
return Err(ConfigError::invalid_value("num_body_parts", "must be > 0"));
|
||||||
}
|
}
|
||||||
|
if self.num_body_parts > MAX_BODY_PARTS {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"num_body_parts",
|
||||||
|
format!("must be <= {MAX_BODY_PARTS} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
if self.backbone_channels == 0 {
|
if self.backbone_channels == 0 {
|
||||||
return Err(ConfigError::invalid_value(
|
return Err(ConfigError::invalid_value(
|
||||||
"backbone_channels",
|
"backbone_channels",
|
||||||
"must be > 0",
|
"must be > 0",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
if self.backbone_channels > MAX_BACKBONE_CHANNELS {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"backbone_channels",
|
||||||
|
format!("must be <= {MAX_BACKBONE_CHANNELS} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// Optimisation
|
// Optimisation
|
||||||
if self.batch_size == 0 {
|
if self.batch_size == 0 {
|
||||||
return Err(ConfigError::invalid_value("batch_size", "must be > 0"));
|
return Err(ConfigError::invalid_value("batch_size", "must be > 0"));
|
||||||
}
|
}
|
||||||
|
if self.batch_size > MAX_BATCH_SIZE {
|
||||||
|
return Err(ConfigError::invalid_value(
|
||||||
|
"batch_size",
|
||||||
|
format!("must be <= {MAX_BATCH_SIZE} (allocation guard)"),
|
||||||
|
));
|
||||||
|
}
|
||||||
if self.learning_rate <= 0.0 {
|
if self.learning_rate <= 0.0 {
|
||||||
return Err(ConfigError::invalid_value("learning_rate", "must be > 0.0"));
|
return Err(ConfigError::invalid_value("learning_rate", "must be > 0.0"));
|
||||||
}
|
}
|
||||||
|
|
@ -443,6 +535,11 @@ impl TrainingConfig {
|
||||||
return Err(ConfigError::invalid_value("save_top_k", "must be > 0"));
|
return Err(ConfigError::invalid_value("save_top_k", "must be > 0"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device: a CUDA device index can never be negative.
|
||||||
|
if self.gpu_device_id < 0 {
|
||||||
|
return Err(ConfigError::invalid_value("gpu_device_id", "must be >= 0"));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -555,6 +652,96 @@ mod tests {
|
||||||
assert!(!cfg2.needs_subcarrier_interp());
|
assert!(!cfg2.needs_subcarrier_interp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: every preset constructor must still validate after the
|
||||||
|
// upper-bound (allocation-guard) checks were added.
|
||||||
|
#[test]
|
||||||
|
fn presets_still_validate() {
|
||||||
|
TrainingConfig::default().validate().expect("default");
|
||||||
|
TrainingConfig::mmfi().validate().expect("mmfi");
|
||||||
|
TrainingConfig::ht40_192().validate().expect("ht40_192");
|
||||||
|
TrainingConfig::multiband_168()
|
||||||
|
.validate()
|
||||||
|
.expect("multiband_168");
|
||||||
|
TrainingConfig::for_subcarriers(168, 56)
|
||||||
|
.validate()
|
||||||
|
.expect("for_subcarriers");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: oversized dimensioning fields (config-OOM class) must be
|
||||||
|
// rejected, not passed through to an allocation that overflows / OOMs.
|
||||||
|
#[test]
|
||||||
|
fn oversized_window_frames_is_invalid() {
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
window_frames: MAX_WINDOW_FRAMES + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oversized_subcarriers_are_invalid() {
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
num_subcarriers: MAX_SUBCARRIERS + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
native_subcarriers: MAX_SUBCARRIERS + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oversized_backbone_channels_is_invalid() {
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
backbone_channels: MAX_BACKBONE_CHANNELS + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oversized_heatmap_size_is_invalid() {
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
heatmap_size: MAX_HEATMAP_SIZE + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oversized_keypoints_and_body_parts_are_invalid() {
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
num_keypoints: MAX_KEYPOINTS + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
num_body_parts: MAX_BODY_PARTS + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oversized_batch_size_is_invalid() {
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
batch_size: MAX_BATCH_SIZE + 1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn negative_gpu_device_id_is_invalid() {
|
||||||
|
let cfg = TrainingConfig {
|
||||||
|
gpu_device_id: -1,
|
||||||
|
..TrainingConfig::default()
|
||||||
|
};
|
||||||
|
assert!(cfg.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn config_fields_have_expected_defaults() {
|
fn config_fields_have_expected_defaults() {
|
||||||
let cfg = TrainingConfig::default();
|
let cfg = TrainingConfig::default();
|
||||||
|
|
|
||||||
|
|
@ -177,8 +177,13 @@ pub fn evaluate_joint_error(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OKS for this frame.
|
// OKS for this frame. ADR-155 §Tier-1.1/§Tier-2: never fall back to
|
||||||
let s = scale.get(i).copied().unwrap_or(1.0);
|
// s=1.0 on normalized [0,1] coordinates — that makes every distance ≈0
|
||||||
|
// and OKS ≈1.0 for any pose (the "fake Gold tier" bug). When no valid
|
||||||
|
// per-frame scale is supplied we derive it from the GT pose extent
|
||||||
|
// (`safe_diag`), exactly as the canonical OKS does.
|
||||||
|
let supplied = scale.get(i).copied().unwrap_or(0.0);
|
||||||
|
let s = if supplied > 0.0 { supplied } else { safe_diag };
|
||||||
let oks_frame = compute_single_oks(&pred_kpts[i], >_kpts[i], &visibility[i], s);
|
let oks_frame = compute_single_oks(&pred_kpts[i], >_kpts[i], &visibility[i], s);
|
||||||
oks_sum += oks_frame as f64;
|
oks_sum += oks_frame as f64;
|
||||||
}
|
}
|
||||||
|
|
@ -627,10 +632,18 @@ fn compute_bbox_diag(kp: &Array2<f32>, vis: &Array1<f32>) -> f32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_single_oks(pred: &Array2<f32>, gt: &Array2<f32>, vis: &Array1<f32>, s: f32) -> f32 {
|
fn compute_single_oks(pred: &Array2<f32>, gt: &Array2<f32>, vis: &Array1<f32>, s: f32) -> f32 {
|
||||||
|
// ADR-155 §Tier-2: a non-positive scale would divide by ≈0 (Inf/NaN OKS) —
|
||||||
|
// and on normalized coords s=1.0 was the fake-perfect bug. Reject it.
|
||||||
|
if !(s > 0.0) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
let s_sq = s * s;
|
let s_sq = s * s;
|
||||||
|
// ADR-155 §Tier-2: bound the loop to the actual array extents so adversarial
|
||||||
|
// / short inputs (< 17 rows, mismatched vis length) cannot panic on `[j]`.
|
||||||
|
let n = pred.shape()[0].min(gt.shape()[0]).min(vis.len()).min(17);
|
||||||
let mut num = 0.0_f32;
|
let mut num = 0.0_f32;
|
||||||
let mut den = 0.0_f32;
|
let mut den = 0.0_f32;
|
||||||
for j in 0..17 {
|
for j in 0..n {
|
||||||
if vis[j] < 0.5 {
|
if vis[j] < 0.5 {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -746,6 +759,59 @@ mod tests {
|
||||||
(pred, gt, vis)
|
(pred, gt, vis)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oks_rejects_nonpositive_scale() {
|
||||||
|
// ADR-155 §Tier-2: s<=0 must return 0.0, never Inf/NaN.
|
||||||
|
let (pred, gt, vis) = make_perfect_kpts();
|
||||||
|
assert_eq!(compute_single_oks(&pred, >, &vis, 0.0), 0.0);
|
||||||
|
assert_eq!(compute_single_oks(&pred, >, &vis, -1.0), 0.0);
|
||||||
|
assert!(compute_single_oks(&pred, >, &vis, 0.5).is_finite());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oks_does_not_panic_on_short_arrays() {
|
||||||
|
// ADR-155 §Tier-2: fewer than 17 rows / mismatched vis must not panic.
|
||||||
|
let pred = Array2::<f32>::zeros((5, 2));
|
||||||
|
let gt = Array2::<f32>::zeros((5, 2));
|
||||||
|
let vis = Array1::<f32>::ones(5);
|
||||||
|
let oks = compute_single_oks(&pred, >, &vis, 0.5);
|
||||||
|
assert!(oks.is_finite());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oks_not_perfect_for_wrong_pose_with_derived_scale() {
|
||||||
|
// ADR-155 §Tier-1.1/§Tier-2: a clearly wrong pose on normalized coords,
|
||||||
|
// evaluated with no supplied scale (derived from GT extent), must NOT
|
||||||
|
// look near-perfect — the old s=1.0 fallback would have returned ≈1.0.
|
||||||
|
let gt = Array2::from_shape_fn(
|
||||||
|
(17, 2),
|
||||||
|
|(j, d)| {
|
||||||
|
if d == 0 {
|
||||||
|
0.4 + j as f32 * 0.01
|
||||||
|
} else {
|
||||||
|
0.5
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let mut pred = gt.clone();
|
||||||
|
for j in 0..17 {
|
||||||
|
pred[[j, 1]] += 0.3; // shift every joint far in y
|
||||||
|
}
|
||||||
|
let vis = Array1::<f32>::ones(17);
|
||||||
|
let result = evaluate_joint_error(
|
||||||
|
&[pred],
|
||||||
|
&[gt],
|
||||||
|
&[vis],
|
||||||
|
&[], // no supplied scale ⇒ derive from GT extent
|
||||||
|
&JointErrorThresholds::default(),
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
result.oks < 0.5,
|
||||||
|
"wrong pose must not yield near-perfect OKS, got {}",
|
||||||
|
result.oks
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn joint_error_perfect_predictions_pass() {
|
fn joint_error_perfect_predictions_pass() {
|
||||||
let (pred, gt, vis) = make_perfect_kpts();
|
let (pred, gt, vis) = make_perfect_kpts();
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,11 @@ use ruvector_solver::types::CsrMatrix;
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// Panics if `target_sc == 0` or the input has no subcarrier dimension.
|
/// Panics if `target_sc == 0` or the input has no subcarrier dimension.
|
||||||
|
///
|
||||||
|
/// Non-contiguous inputs (e.g. a transposed or strided view) are handled
|
||||||
|
/// gracefully: the subcarrier lane is copied into a contiguous scratch buffer
|
||||||
|
/// when the underlying storage is not contiguous, so this function never
|
||||||
|
/// panics on layout (ADR-155 §Tier-2).
|
||||||
pub fn interpolate_subcarriers(arr: &Array4<f32>, target_sc: usize) -> Array4<f32> {
|
pub fn interpolate_subcarriers(arr: &Array4<f32>, target_sc: usize) -> Array4<f32> {
|
||||||
assert!(target_sc > 0, "target_sc must be > 0");
|
assert!(target_sc > 0, "target_sc must be > 0");
|
||||||
|
|
||||||
|
|
@ -54,16 +59,23 @@ pub fn interpolate_subcarriers(arr: &Array4<f32>, target_sc: usize) -> Array4<f3
|
||||||
// Precompute interpolation weights once.
|
// Precompute interpolation weights once.
|
||||||
let weights = compute_interp_weights(n_sc, target_sc);
|
let weights = compute_interp_weights(n_sc, target_sc);
|
||||||
|
|
||||||
|
// Reusable scratch buffer for the non-contiguous fallback path.
|
||||||
|
let mut scratch: Vec<f32> = Vec::new();
|
||||||
|
|
||||||
for t in 0..n_t {
|
for t in 0..n_t {
|
||||||
for tx in 0..n_tx {
|
for tx in 0..n_tx {
|
||||||
for rx in 0..n_rx {
|
for rx in 0..n_rx {
|
||||||
let src = arr.slice(s![t, tx, rx, ..]);
|
let src = arr.slice(s![t, tx, rx, ..]);
|
||||||
let src_slice = src.as_slice().unwrap_or_else(|| {
|
// Prefer the contiguous fast path; fall back to an owned copy
|
||||||
// Fallback: copy to a contiguous slice
|
// for non-contiguous layouts instead of panicking.
|
||||||
// (this path is hit when the array has a non-contiguous layout)
|
let src_slice: &[f32] = match src.as_slice() {
|
||||||
// In practice ndarray arrays sliced along last dim are contiguous.
|
Some(s) => s,
|
||||||
panic!("Subcarrier slice is not contiguous");
|
None => {
|
||||||
});
|
scratch.clear();
|
||||||
|
scratch.extend(src.iter().copied());
|
||||||
|
&scratch
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
for (k, &(i0, i1, w)) in weights.iter().enumerate() {
|
for (k, &(i0, i1, w)) in weights.iter().enumerate() {
|
||||||
let v = src_slice[i0] * (1.0 - w) + src_slice[i1] * w;
|
let v = src_slice[i0] * (1.0 - w) + src_slice[i1] * w;
|
||||||
|
|
@ -420,6 +432,35 @@ mod tests {
|
||||||
assert_eq!(out.shape(), &[4, 1, 3, 56]);
|
assert_eq!(out.shape(), &[4, 1, 3, 56]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ADR-155 §Tier-2: a non-contiguous input (subcarrier axis strided after an
|
||||||
|
// axis permutation) must NOT panic — the old `.as_slice().unwrap_or_else(||
|
||||||
|
// panic!(...))` path crashed on any non-contiguous layout.
|
||||||
|
#[test]
|
||||||
|
fn non_contiguous_input_does_not_panic() {
|
||||||
|
// Build a [t, sc, tx, rx] array, then permute so subcarriers land in the
|
||||||
|
// last axis. The resulting owned Array4 has non-standard strides, so its
|
||||||
|
// last-axis lanes are non-contiguous in memory.
|
||||||
|
let base =
|
||||||
|
Array4::<f32>::from_shape_fn((4, 8, 3, 3), |(t, sc, tx, rx)| (t + sc + tx + rx) as f32);
|
||||||
|
// permuted_axes consumes the owned array and returns an owned Array4
|
||||||
|
// with swapped strides: logical shape [t, tx, rx, sc], sc axis strided.
|
||||||
|
let strided: Array4<f32> = base.permuted_axes([0, 2, 3, 1]);
|
||||||
|
// Sanity: a last-axis lane really is non-contiguous.
|
||||||
|
assert!(strided.slice(s![0, 0, 0, ..]).as_slice().is_none());
|
||||||
|
|
||||||
|
let out = interpolate_subcarriers(&strided, 4);
|
||||||
|
assert_eq!(out.shape(), &[4, 3, 3, 4]);
|
||||||
|
// Endpoints preserved exactly even via the fallback copy path.
|
||||||
|
for tx in 0..3 {
|
||||||
|
for rx in 0..3 {
|
||||||
|
let first = strided[[0, tx, rx, 0]];
|
||||||
|
let last = strided[[0, tx, rx, 7]];
|
||||||
|
assert_abs_diff_eq!(out[[0, tx, rx, 0]], first, epsilon = 1e-5);
|
||||||
|
assert_abs_diff_eq!(out[[0, tx, rx, 3]], last, epsilon = 1e-5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sparse_interpolation_identity() {
|
fn sparse_interpolation_identity() {
|
||||||
// For same source and target count, should return same array
|
// For same source and target count, should return same array
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue