140 lines
4.7 KiB
Rust
140 lines
4.7 KiB
Rust
//! Input-validation boundary tests for `OccWorldCandle::predict`.
|
|
//!
|
|
//! Security review (Milestone #9, crate 4/4). `predict` takes an
|
|
//! externally-supplied occupancy tensor; per the project's "validate input at
|
|
//! system boundaries" rule it must reject degenerate / out-of-capacity shapes
|
|
//! with a clear domain error rather than surfacing a cryptic deep-pipeline
|
|
//! Candle error (over-capacity frame counts over-index the temporal positional
|
|
//! embedding) or processing a zero-element tensor.
|
|
//!
|
|
//! These exercise only the public API and live here (not inline in
|
|
//! `inference.rs`) to keep that module under the 500-line cap.
|
|
|
|
use candle_core::{DType, Device, Tensor};
|
|
use wifi_densepose_occworld_candle::config::OccWorldConfig;
|
|
use wifi_densepose_occworld_candle::inference::OccWorldCandle;
|
|
use wifi_densepose_occworld_candle::error::OccWorldError;
|
|
|
|
fn small_cfg() -> OccWorldConfig {
|
|
OccWorldConfig {
|
|
grid_h: 8,
|
|
grid_w: 8,
|
|
grid_d: 4,
|
|
num_classes: 4,
|
|
free_class: 3,
|
|
base_channels: 8,
|
|
z_channels: 8,
|
|
codebook_size: 4,
|
|
embed_dim: 8,
|
|
num_frames: 2,
|
|
token_h: 4,
|
|
token_w: 4,
|
|
num_heads: 2,
|
|
num_layers: 1,
|
|
ffn_hidden: 16,
|
|
}
|
|
}
|
|
|
|
/// Zero frames is a degenerate input that would otherwise feed a zero-element
|
|
/// tensor into the reshape/conv pipeline. Must be rejected at the boundary.
|
|
#[test]
|
|
fn predict_rejects_zero_frames() {
|
|
let device = Device::Cpu;
|
|
let cfg = small_cfg();
|
|
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
|
let past = Tensor::zeros(
|
|
(1usize, 0usize, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
|
DType::U8,
|
|
&device,
|
|
)
|
|
.unwrap();
|
|
let result = engine.predict(&past);
|
|
assert!(
|
|
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
|
"zero-frame input must be rejected with ShapeMismatch"
|
|
);
|
|
}
|
|
|
|
/// Zero batch must also be rejected (same zero-element-tensor hazard).
|
|
#[test]
|
|
fn predict_rejects_zero_batch() {
|
|
let device = Device::Cpu;
|
|
let cfg = small_cfg();
|
|
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
|
let past = Tensor::zeros(
|
|
(0usize, cfg.num_frames, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
|
DType::U8,
|
|
&device,
|
|
)
|
|
.unwrap();
|
|
let result = engine.predict(&past);
|
|
assert!(
|
|
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
|
"zero-batch input must be rejected with ShapeMismatch"
|
|
);
|
|
}
|
|
|
|
/// More frames than the temporal embedding can index (`> num_frames*2`).
|
|
///
|
|
/// On the old code this over-indexed the temporal positional embedding deep in
|
|
/// the transformer and surfaced as a cryptic Candle "gather" `InvalidIndex`
|
|
/// error. The boundary guard now rejects it cleanly with `ShapeMismatch`.
|
|
#[test]
|
|
fn predict_rejects_too_many_frames() {
|
|
let device = Device::Cpu;
|
|
let cfg = small_cfg(); // num_frames = 2 → temporal capacity = 4
|
|
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
|
let too_many = cfg.num_frames * 2 + 1;
|
|
let past = Tensor::zeros(
|
|
(1usize, too_many, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
|
DType::U8,
|
|
&device,
|
|
)
|
|
.unwrap();
|
|
let result = engine.predict(&past);
|
|
assert!(
|
|
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
|
"over-capacity frame count must be rejected with ShapeMismatch"
|
|
);
|
|
}
|
|
|
|
/// A frame count exactly at capacity (`num_frames*2`) must still succeed —
|
|
/// the guard rejects only *over*-capacity, not the boundary value.
|
|
#[test]
|
|
fn predict_accepts_frame_count_at_capacity() {
|
|
let device = Device::Cpu;
|
|
let cfg = small_cfg();
|
|
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
|
let at_cap = cfg.num_frames * 2;
|
|
let past = Tensor::zeros(
|
|
(1usize, at_cap, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
|
DType::U8,
|
|
&device,
|
|
)
|
|
.unwrap();
|
|
let out = engine
|
|
.predict(&past)
|
|
.expect("at-capacity frame count must predict");
|
|
assert_eq!(out.sem_pred.dims()[1], at_cap, "frame dim preserved");
|
|
}
|
|
|
|
/// Wrong spatial geometry (H/W/D) is still rejected — pins the pre-existing
|
|
/// guard alongside the new frame/batch ones.
|
|
#[test]
|
|
fn predict_rejects_wrong_grid_dims() {
|
|
let device = Device::Cpu;
|
|
let cfg = small_cfg();
|
|
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
|
let past = Tensor::zeros(
|
|
(1usize, cfg.num_frames, cfg.grid_h + 1, cfg.grid_w, cfg.grid_d),
|
|
DType::U8,
|
|
&device,
|
|
)
|
|
.unwrap();
|
|
let result = engine.predict(&past);
|
|
assert!(
|
|
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
|
"wrong grid dims must be rejected with ShapeMismatch"
|
|
);
|
|
}
|