security(occworld-candle): int32-checkpoint crash + degenerate-input guards + ADR-179 (closes Milestone #9) (#1101)
* fix(occworld-candle): security review fixes — int32 checkpoint crash + predict input validation Beyond-SOTA security + correctness review of wifi-densepose-occworld-candle (Milestone #9, crate 4/4 — the last ungated crate). Findings fixed: 1. HIGH (MEASURED) — checkpoint-load crash on any int32 tensor. model.rs mapped safetensors I32 -> candle DType::I64 and passed the raw int32 byte buffer (4 bytes/elem) to Tensor::from_raw_buffer(.., I64, ..). Candle derives elem_count = data.len() / dtype.size(), so the I64 path halved the count while keeping the original shape -> a tensor whose shape claims 2x its storage. Reading it PANICS (slice OOB: "range end index 6 out of range for slice of length 3") on any checkpoint containing an int32 tensor. Fixed: I32 -> DType::I32, I16 -> DType::I16 (both first-class candle dtypes). Reproduced on old code; pinned in tests/checkpoint_loading.rs. 2. LOW (MEASURED) — predict() lacked frame/batch validation at the input boundary. f_in > num_frames*2 over-indexed the temporal embedding (cryptic candle "gather" error); zero frame/batch fed a zero-element tensor in. Now rejected with a clear ShapeMismatch. Pinned in tests/input_validation.rs. 3. LOW (MEASURED) — divide-by-zero panic in the public VQCodebook::encode on a rank-0 / empty-last-dim tensor (last == 0). Now fails closed with a clear error. Pinned in vqvae.rs unit tests. Dimensions confirmed clean with evidence: panic surface (no unwrap/expect/ panic in prod paths), NaN-state-poisoning (N/A — stateless engine, u8 input), unbounded-alloc/shape-data mismatch (defended upstream by safetensors:: validate), secrets (none). unsafe_code = forbid. Validation (MEASURED, Windows): crate 31/31 pass; workspace 0 failed (lone desktop api_integration "Access is denied" file-lock flake passes 21/21 in isolation); Python proof VERDICT PASS, hash f8e76f21…446f7a unchanged. Warrants ADR slot 179 (parent to author). Co-Authored-By: claude-flow <ruv@ruv.net> * docs(adr): ADR-179 — occworld-candle checkpoint-load hardening (closes Milestone #9) Records the HIGH int32-checkpoint crash fix (I32→I64 dtype-widening → slice-OOB panic on load = DoS) + 2 LOW degenerate-input fixes from 5e77f47e5. Stateless engine (NaN-poisoning N/A), unsafe forbidden, safetensors validate() defends malloc upstream. occworld 31/31. Final ungated crate — Milestone #9 complete. Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
10c813fde3
commit
c859f6f743
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,81 @@
|
|||
# ADR-179: `wifi-densepose-occworld-candle` Checkpoint-Load Hardening
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Status** | Accepted — 1 HIGH + 2 LOW bugs fixed + pinned (MEASURED on Windows) |
|
||||
| **Date** | 2026-06-15 |
|
||||
| **Deciders** | ruv |
|
||||
| **Codename** | **OCCWORLD-DTYPE** |
|
||||
| **Reviews** | `wifi-densepose-occworld-candle` (Candle occupancy-world model) |
|
||||
| **Milestone** | #9 (ungated-crate security sweep) — crate 4 of 4 — **CLOSES the milestone** |
|
||||
|
||||
## Context
|
||||
|
||||
`wifi-densepose-occworld-candle` is a Candle-based occupancy-world model
|
||||
(VQ-VAE + transformer over occupancy tokens). The real risk surface for an ML
|
||||
crate is degenerate-input / malformed-weights handling: a `#[forbid(unsafe_code)]`
|
||||
crate can still **panic** (a DoS, and under WASM an abort) when a tensor op hits an
|
||||
inconsistent shape. The crate **builds and tests on Windows**, so all findings are
|
||||
MEASURED.
|
||||
|
||||
## Decision
|
||||
|
||||
Fix the three reachable bugs, each pinned by a fails-on-old test; attest the rest
|
||||
clean with evidence.
|
||||
|
||||
### Findings fixed (all MEASURED)
|
||||
|
||||
| # | Severity | Location | Issue | Fix |
|
||||
|---|----------|----------|-------|-----|
|
||||
| 1 | **HIGH** | `model.rs:95` (`Dtype::I32 => Some(DType::I64)`) | **Crash on any int32-tensor checkpoint.** An I32 byte buffer (4 B/elem) is handed to `from_raw_buffer(.., I64, shape, ..)`; candle derives `elem_count = data.len()/8`, **halving** the count while keeping the original shape → a tensor that claims 2× its storage. Reading it **panics** with a slice-OOB (`range end index 6 out of range for slice of length 3`) inside candle-core. A checkpoint with any int32 tensor (index/buffer tensors are common in PyTorch exports) → **DoS on load**. | Map `I32 → DType::I32`, `I16 → DType::I16` (both first-class candle dtypes). Pinned by `int32_tensor_loads_with_consistent_shape_and_values` (panics on old, passes on new). |
|
||||
| 2 | LOW | `inference.rs::predict` | Frame/batch dims weren't validated (only H/W/D were): `f_in > num_frames*2` over-indexes the temporal embedding → a cryptic candle `InvalidIndex` *error* (not a panic — candle bounds-checks); zero frame/batch feeds a zero-element tensor. | Boundary guard rejects zero / over-capacity frame+batch with a clear `ShapeMismatch`. 5 pins. |
|
||||
| 3 | LOW | `vqvae.rs:141` (`z.elem_count() / last`) | **Divide-by-zero panic** in public `VQCodebook::encode` on a rank-0 / empty-last-dim tensor (`last == 0`). | Fail-closed guard returns a clear error. Pinned by `encode_rejects_scalar_without_panicking`. |
|
||||
|
||||
The HIGH finding is the notable one: the crate's own dtype mapping **defeated**
|
||||
the upstream `safetensors::validate()` byte-length guarantee by misdeclaring the
|
||||
dtype — the one place malformed/widened weights could reach a panicking candle op.
|
||||
|
||||
### Dimensions confirmed clean (with evidence)
|
||||
|
||||
- **Panic surface** — grep for `unwrap()/expect()/panic!/unreachable!` across `src/`
|
||||
→ **zero in production paths**; all ops use `?`/`map_err`; the `last().unwrap_or(&0)`
|
||||
is now guarded. `as` casts operate only on config-bounded/internal values.
|
||||
- **NaN-state-poisoning (the named class) — N/A.** The engine is **stateless between
|
||||
`predict` calls** (no persistent world-model buffer to latch into), and input is
|
||||
`u8` class indices (non-finite input structurally impossible). NaN weights flow to
|
||||
`argmax` (deterministic, bounded to a valid class index) — no panic, no persistence.
|
||||
- **Unbounded alloc / shape-data mismatch from malformed weights** — defended upstream
|
||||
by `safetensors::validate()` (overflow-checked `nelements*dtype.size()` vs declared
|
||||
byte range + contiguous-offset + buffer-length checks), rejected before reaching
|
||||
candle. Finding #1 was the one place the crate defeated that guarantee.
|
||||
- **Model/path loading** — `load`/`load_safetensors` check `path.exists()` → typed
|
||||
`CheckpointNotFound`; corrupt bytes → `CheckpointParse` (pinned). No path-traversal
|
||||
surface (caller-supplied path, opened read-only, never joined with untrusted segments).
|
||||
- **Secrets** — grep clean (only `token_h`/`token_w` config fields match `token`).
|
||||
- **Determinism** — the crate's central honesty claim, verified by the pre-existing
|
||||
`tests/predict_honesty.rs` (3 tests, still pass).
|
||||
- `unsafe_code = "forbid"` in the manifest.
|
||||
|
||||
## Validation
|
||||
|
||||
- `cargo test -p wifi-densepose-occworld-candle --no-default-features` → **31/31**
|
||||
(lib 17, checkpoint_loading 4, input_validation 5, predict_honesty 3, doctests 2),
|
||||
0 failed.
|
||||
- `cargo test --workspace --no-default-features` → 0 failed across every crate (a lone
|
||||
`wifi-densepose-desktop --test api_integration` "Access is denied (os error 5)" was a
|
||||
Windows file-lock/AV flake — re-ran isolated 21/21, unrelated).
|
||||
- `python archive/v1/data/proof/verify.py` → **VERDICT: PASS**, hash `f8e76f21…46f7a`
|
||||
unchanged (occworld off the signal proof path).
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
- A checkpoint-load DoS (the int32 dtype-widening panic) and two degenerate-input
|
||||
panics are closed in the world-model crate, each pinned. **Milestone #9 (all 4
|
||||
ungated crates) is complete.**
|
||||
|
||||
### Negative / Neutral
|
||||
- None. Guards reject only malformed/degenerate inputs.
|
||||
|
||||
## Links
|
||||
- ADR-176 / ADR-177 / ADR-178 — sibling Milestone-#9 reviews (ruview-swarm, nvsim, desktop)
|
||||
|
|
@ -206,6 +206,27 @@ impl OccWorldCandle {
|
|||
)));
|
||||
}
|
||||
|
||||
// Validate the externally-supplied frame and batch counts at this
|
||||
// system boundary. The temporal positional embedding has only
|
||||
// `num_frames * 2` rows, so a larger `f_in` would over-index the
|
||||
// embedding table deep inside the transformer and surface as a cryptic
|
||||
// "gather" index error; a zero frame/batch count would feed a
|
||||
// zero-element tensor into the reshape/conv pipeline. Reject both here
|
||||
// with a clear, domain-level error instead.
|
||||
if f_in == 0 || b == 0 {
|
||||
return Err(OccWorldError::ShapeMismatch(format!(
|
||||
"past_occupancy must have non-zero batch and frame dims, got \
|
||||
batch={b}, frames={f_in}"
|
||||
)));
|
||||
}
|
||||
if f_in > cfg.num_frames * 2 {
|
||||
return Err(OccWorldError::ShapeMismatch(format!(
|
||||
"past_occupancy frame count {f_in} exceeds the temporal embedding \
|
||||
capacity ({} = num_frames*2)",
|
||||
cfg.num_frames * 2
|
||||
)));
|
||||
}
|
||||
|
||||
// ── Step 1: VQVAE encode each past frame ──────────────────────────
|
||||
// Flatten batch*frames: (B, F, H, W, D) → (B*F, H, W, D)
|
||||
let occ_flat = past_occupancy
|
||||
|
|
@ -455,4 +476,8 @@ mod tests {
|
|||
"expected CheckpointNotFound, got {result:?}"
|
||||
);
|
||||
}
|
||||
|
||||
// The `predict` input-validation boundary guards (zero/over-capacity frame
|
||||
// counts) live in `tests/input_validation.rs` so they exercise only the
|
||||
// public API and keep this file under the 500-line limit.
|
||||
}
|
||||
|
|
|
|||
|
|
@ -92,8 +92,21 @@ fn safetensor_dtype_to_candle(dt: safetensors::Dtype) -> Option<candle_core::DTy
|
|||
Dtype::F64 => Some(DType::F64),
|
||||
Dtype::F16 => Some(DType::F16),
|
||||
Dtype::BF16 => Some(DType::BF16),
|
||||
Dtype::I32 => Some(DType::I64), // widen for Candle compatibility
|
||||
// I32 MUST map to DType::I32, not I64. `Tensor::from_raw_buffer`
|
||||
// derives its element count from `data.len() / dtype.size_in_bytes()`;
|
||||
// handing an int32 byte buffer (4 bytes/elem) to the I64 path
|
||||
// (8 bytes/elem) halves the element count while keeping the original
|
||||
// shape, producing a tensor whose declared shape claims twice as many
|
||||
// elements as its storage holds. That silent shape/storage mismatch
|
||||
// panics (slice OOB) the moment the tensor is read — a crash on any
|
||||
// checkpoint containing an int32 tensor. See
|
||||
// `tests/checkpoint_loading.rs::int32_tensor_loads_with_consistent_shape_and_values`.
|
||||
Dtype::I32 => Some(DType::I32),
|
||||
Dtype::I64 => Some(DType::I64),
|
||||
// I16 is also a first-class Candle dtype (2 bytes/elem); map it
|
||||
// directly rather than rejecting it, for the same byte-size-correctness
|
||||
// reason as I32 above.
|
||||
Dtype::I16 => Some(DType::I16),
|
||||
Dtype::U8 => Some(DType::U8),
|
||||
Dtype::U32 => Some(DType::U32),
|
||||
_ => None,
|
||||
|
|
|
|||
|
|
@ -137,6 +137,17 @@ impl VQCodebook {
|
|||
let orig_shape = z.shape().clone();
|
||||
let orig_dims = orig_shape.dims().to_vec();
|
||||
let last = *orig_shape.dims().last().unwrap_or(&0);
|
||||
// Guard the divide below: a scalar (rank-0) or empty-last-dim tensor
|
||||
// would make `last == 0` and panic on the `elem_count() / last`
|
||||
// division. `encode` is a `pub fn` on a `pub struct`, so this is a
|
||||
// reachable public boundary — fail closed with a clear error instead.
|
||||
if last == 0 {
|
||||
return Err(candle_core::Error::Msg(format!(
|
||||
"VQCodebook::encode expects a tensor with a non-zero last dim of \
|
||||
size embed_dim={}, got shape {orig_dims:?}",
|
||||
self.embed_dim
|
||||
)));
|
||||
}
|
||||
// Flatten to (N, embed_dim)
|
||||
let n = z.elem_count() / last;
|
||||
let z_flat = z.reshape((n, last))?; // (N, D)
|
||||
|
|
@ -339,6 +350,21 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_rejects_scalar_without_panicking() {
|
||||
// A rank-0 (scalar) tensor has an empty dims list → `last == 0`.
|
||||
// Before the guard this divided by zero and panicked; now it returns
|
||||
// a clean error. `encode` is public, so this is a reachable boundary.
|
||||
let device = Device::Cpu;
|
||||
let codebook = VQCodebook::dummy(4, 8, &device).unwrap();
|
||||
let scalar = Tensor::from_vec(vec![1.0f32], (), &device).unwrap();
|
||||
let result = codebook.encode(&scalar);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"scalar input must error, not panic; got {result:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fold_unfold_roundtrip() -> candle_core::Result<()> {
|
||||
let device = Device::Cpu;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,185 @@
|
|||
//! Checkpoint-loading robustness tests for `crate::model::load_safetensors`.
|
||||
//!
|
||||
//! Security review (Milestone #9, crate 4/4). These tests pin the behaviour of
|
||||
//! the SafeTensors weight-loading path against malformed / degenerate
|
||||
//! checkpoints — the only externally-controlled file-input surface in the crate.
|
||||
//!
|
||||
//! The headline regression is the **int32 dtype-widening byte-size bug**
|
||||
//! (`security/occworld-candle` finding #1): `model.rs` mapped
|
||||
//! `safetensors::Dtype::I32` → `candle_core::DType::I64` and then handed the
|
||||
//! raw *int32* byte buffer (4 bytes/elem) to `Tensor::from_raw_buffer(.., I64,
|
||||
//! shape, ..)`. Candle's `from_raw_buffer` computes `elem_count =
|
||||
//! data.len() / 8`, producing a tensor whose declared shape claims twice as
|
||||
//! many elements as the backing storage actually holds — a silent
|
||||
//! shape/storage inconsistency on attacker-supplied checkpoints.
|
||||
//!
|
||||
//! `build_safetensors` hand-assembles the binary container
|
||||
//! (`<u64 LE header_len><JSON header><raw data>`) so the test states exactly
|
||||
//! what bytes reach the loader, independent of the `safetensors` writer API.
|
||||
|
||||
use candle_core::Device;
|
||||
use wifi_densepose_occworld_candle::model::load_safetensors;
|
||||
|
||||
/// Hand-build a single-tensor SafeTensors buffer.
|
||||
///
|
||||
/// `dtype` is the safetensors dtype string (e.g. `"I32"`, `"F32"`).
|
||||
/// `shape` is the declared shape. `data` is the raw little-endian tensor bytes
|
||||
/// — the caller is responsible for making `data.len()` consistent with
|
||||
/// `shape × dtype_size` (safetensors itself validates this, so an inconsistent
|
||||
/// pair is rejected before reaching the candle conversion).
|
||||
fn build_safetensors(name: &str, dtype: &str, shape: &[usize], data: &[u8]) -> Vec<u8> {
|
||||
let shape_json: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
|
||||
let header = format!(
|
||||
"{{\"{name}\":{{\"dtype\":\"{dtype}\",\"shape\":[{}],\"data_offsets\":[0,{}]}}}}",
|
||||
shape_json.join(","),
|
||||
data.len()
|
||||
);
|
||||
let header_bytes = header.into_bytes();
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
|
||||
buf.extend_from_slice(&header_bytes);
|
||||
buf.extend_from_slice(data);
|
||||
buf
|
||||
}
|
||||
|
||||
fn write_temp(bytes: &[u8], stem: &str) -> std::path::PathBuf {
|
||||
let mut p = std::env::temp_dir();
|
||||
p.push(format!(
|
||||
"occworld_ckpt_{stem}_{}_{}.safetensors",
|
||||
std::process::id(),
|
||||
// nanosecond-ish disambiguator so parallel tests never collide
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos())
|
||||
.unwrap_or(0)
|
||||
));
|
||||
std::fs::write(&p, bytes).expect("write temp checkpoint");
|
||||
p
|
||||
}
|
||||
|
||||
/// REGRESSION (finding #1): an int32 tensor in a checkpoint must load into a
|
||||
/// tensor whose element count matches its declared shape.
|
||||
///
|
||||
/// On the OLD code (`I32 -> DType::I64`) the 6-element int32 tensor below was
|
||||
/// handed to `from_raw_buffer(.., I64, [2,3], ..)`, which derived
|
||||
/// `elem_count = 24 bytes / 8 = 3` and built a 3-element storage carrying a
|
||||
/// shape claiming 6 elements — reading it panicked with a slice-OOB
|
||||
/// (`range end index 6 out of range for slice of length 3`). On the FIXED code
|
||||
/// (`I32 -> DType::I32`) the tensor round-trips: dtype I32, 6 elements,
|
||||
/// values `[1,2,3,4,5,6]`.
|
||||
#[test]
|
||||
fn int32_tensor_loads_with_consistent_shape_and_values() {
|
||||
let device = Device::Cpu;
|
||||
let shape = [2usize, 3];
|
||||
let vals: [i32; 6] = [1, 2, 3, 4, 5, 6];
|
||||
let mut data = Vec::with_capacity(24);
|
||||
for v in vals {
|
||||
data.extend_from_slice(&v.to_le_bytes());
|
||||
}
|
||||
let bytes = build_safetensors("quantize.embedding.weight", "I32", &shape, &data);
|
||||
let path = write_temp(&bytes, "i32");
|
||||
|
||||
let map = load_safetensors(&path, &device).expect("int32 checkpoint must load");
|
||||
let t = map
|
||||
.get("quantize.embedding.weight")
|
||||
.expect("mapped key present");
|
||||
|
||||
// The declared shape's element count MUST equal the storage's element
|
||||
// count. On the old code these disagreed (6 vs 3).
|
||||
assert_eq!(
|
||||
t.dims(),
|
||||
&[2, 3],
|
||||
"int32 tensor must preserve its declared shape"
|
||||
);
|
||||
assert_eq!(
|
||||
t.elem_count(),
|
||||
6,
|
||||
"element count must match shape — storage/shape consistency"
|
||||
);
|
||||
|
||||
// The dtype must be I32 — the int32 byte buffer is interpreted as int32,
|
||||
// not reinterpreted as half as many int64 lanes.
|
||||
assert_eq!(
|
||||
t.dtype(),
|
||||
candle_core::DType::I32,
|
||||
"int32 checkpoint tensor must load as DType::I32"
|
||||
);
|
||||
|
||||
// And the values must be exactly recovered (no reinterpretation of two
|
||||
// int32 lanes as one int64). This is the strongest proof the dtype is
|
||||
// handled correctly end-to-end.
|
||||
let flat = t.flatten_all().expect("flatten");
|
||||
let got: Vec<i32> = flat.to_vec1::<i32>().expect("to_vec i32");
|
||||
assert_eq!(
|
||||
got,
|
||||
vec![1i32, 2, 3, 4, 5, 6],
|
||||
"int32 values must be recovered exactly"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
/// A well-formed F32 tensor must round-trip unchanged (control case — proves
|
||||
/// the fix does not regress the common float path).
|
||||
#[test]
|
||||
fn f32_tensor_round_trips() {
|
||||
let device = Device::Cpu;
|
||||
let shape = [4usize];
|
||||
let vals: [f32; 4] = [0.5, -1.0, 2.25, 3.0];
|
||||
let mut data = Vec::with_capacity(16);
|
||||
for v in vals {
|
||||
data.extend_from_slice(&v.to_le_bytes());
|
||||
}
|
||||
let bytes = build_safetensors("post_quant_conv.bias", "F32", &shape, &data);
|
||||
let path = write_temp(&bytes, "f32");
|
||||
|
||||
let map = load_safetensors(&path, &device).expect("f32 checkpoint must load");
|
||||
let t = map.get("post_quant_conv.bias").expect("key present");
|
||||
assert_eq!(t.dims(), &[4]);
|
||||
let got: Vec<f32> = t.to_vec1::<f32>().expect("to_vec f32");
|
||||
assert_eq!(got, vec![0.5, -1.0, 2.25, 3.0]);
|
||||
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
/// A truncated / corrupt header must produce a parse error, never a panic.
|
||||
/// (Defense-in-depth: the loader is fed an untrusted file.)
|
||||
#[test]
|
||||
fn corrupt_checkpoint_errors_cleanly() {
|
||||
let device = Device::Cpu;
|
||||
// Garbage that is not a valid SafeTensors container.
|
||||
let bytes = vec![0xFFu8; 32];
|
||||
let path = write_temp(&bytes, "corrupt");
|
||||
|
||||
let result = load_safetensors(&path, &device);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"corrupt checkpoint must error, got Ok: {result:?}"
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
/// An int64 tensor must still load correctly (proves the fix narrows only the
|
||||
/// I32 mapping and leaves the genuine I64 path intact).
|
||||
#[test]
|
||||
fn int64_tensor_round_trips() {
|
||||
let device = Device::Cpu;
|
||||
let shape = [3usize];
|
||||
let vals: [i64; 3] = [10, -20, 30];
|
||||
let mut data = Vec::with_capacity(24);
|
||||
for v in vals {
|
||||
data.extend_from_slice(&v.to_le_bytes());
|
||||
}
|
||||
let bytes = build_safetensors("transformer.output_head.bias", "I64", &shape, &data);
|
||||
let path = write_temp(&bytes, "i64");
|
||||
|
||||
let map = load_safetensors(&path, &device).expect("i64 checkpoint must load");
|
||||
let t = map.get("transformer.output_head.bias").expect("key present");
|
||||
assert_eq!(t.dims(), &[3]);
|
||||
assert_eq!(t.elem_count(), 3);
|
||||
let got: Vec<i64> = t.to_vec1::<i64>().expect("to_vec i64");
|
||||
assert_eq!(got, vec![10, -20, 30]);
|
||||
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
//! Input-validation boundary tests for `OccWorldCandle::predict`.
|
||||
//!
|
||||
//! Security review (Milestone #9, crate 4/4). `predict` takes an
|
||||
//! externally-supplied occupancy tensor; per the project's "validate input at
|
||||
//! system boundaries" rule it must reject degenerate / out-of-capacity shapes
|
||||
//! with a clear domain error rather than surfacing a cryptic deep-pipeline
|
||||
//! Candle error (over-capacity frame counts over-index the temporal positional
|
||||
//! embedding) or processing a zero-element tensor.
|
||||
//!
|
||||
//! These exercise only the public API and live here (not inline in
|
||||
//! `inference.rs`) to keep that module under the 500-line cap.
|
||||
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use wifi_densepose_occworld_candle::config::OccWorldConfig;
|
||||
use wifi_densepose_occworld_candle::inference::OccWorldCandle;
|
||||
use wifi_densepose_occworld_candle::error::OccWorldError;
|
||||
|
||||
fn small_cfg() -> OccWorldConfig {
|
||||
OccWorldConfig {
|
||||
grid_h: 8,
|
||||
grid_w: 8,
|
||||
grid_d: 4,
|
||||
num_classes: 4,
|
||||
free_class: 3,
|
||||
base_channels: 8,
|
||||
z_channels: 8,
|
||||
codebook_size: 4,
|
||||
embed_dim: 8,
|
||||
num_frames: 2,
|
||||
token_h: 4,
|
||||
token_w: 4,
|
||||
num_heads: 2,
|
||||
num_layers: 1,
|
||||
ffn_hidden: 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Zero frames is a degenerate input that would otherwise feed a zero-element
|
||||
/// tensor into the reshape/conv pipeline. Must be rejected at the boundary.
|
||||
#[test]
|
||||
fn predict_rejects_zero_frames() {
|
||||
let device = Device::Cpu;
|
||||
let cfg = small_cfg();
|
||||
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
||||
let past = Tensor::zeros(
|
||||
(1usize, 0usize, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
||||
DType::U8,
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let result = engine.predict(&past);
|
||||
assert!(
|
||||
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
||||
"zero-frame input must be rejected with ShapeMismatch"
|
||||
);
|
||||
}
|
||||
|
||||
/// Zero batch must also be rejected (same zero-element-tensor hazard).
|
||||
#[test]
|
||||
fn predict_rejects_zero_batch() {
|
||||
let device = Device::Cpu;
|
||||
let cfg = small_cfg();
|
||||
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
||||
let past = Tensor::zeros(
|
||||
(0usize, cfg.num_frames, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
||||
DType::U8,
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let result = engine.predict(&past);
|
||||
assert!(
|
||||
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
||||
"zero-batch input must be rejected with ShapeMismatch"
|
||||
);
|
||||
}
|
||||
|
||||
/// More frames than the temporal embedding can index (`> num_frames*2`).
|
||||
///
|
||||
/// On the old code this over-indexed the temporal positional embedding deep in
|
||||
/// the transformer and surfaced as a cryptic Candle "gather" `InvalidIndex`
|
||||
/// error. The boundary guard now rejects it cleanly with `ShapeMismatch`.
|
||||
#[test]
|
||||
fn predict_rejects_too_many_frames() {
|
||||
let device = Device::Cpu;
|
||||
let cfg = small_cfg(); // num_frames = 2 → temporal capacity = 4
|
||||
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
||||
let too_many = cfg.num_frames * 2 + 1;
|
||||
let past = Tensor::zeros(
|
||||
(1usize, too_many, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
||||
DType::U8,
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let result = engine.predict(&past);
|
||||
assert!(
|
||||
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
||||
"over-capacity frame count must be rejected with ShapeMismatch"
|
||||
);
|
||||
}
|
||||
|
||||
/// A frame count exactly at capacity (`num_frames*2`) must still succeed —
|
||||
/// the guard rejects only *over*-capacity, not the boundary value.
|
||||
#[test]
|
||||
fn predict_accepts_frame_count_at_capacity() {
|
||||
let device = Device::Cpu;
|
||||
let cfg = small_cfg();
|
||||
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
||||
let at_cap = cfg.num_frames * 2;
|
||||
let past = Tensor::zeros(
|
||||
(1usize, at_cap, cfg.grid_h, cfg.grid_w, cfg.grid_d),
|
||||
DType::U8,
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let out = engine
|
||||
.predict(&past)
|
||||
.expect("at-capacity frame count must predict");
|
||||
assert_eq!(out.sem_pred.dims()[1], at_cap, "frame dim preserved");
|
||||
}
|
||||
|
||||
/// Wrong spatial geometry (H/W/D) is still rejected — pins the pre-existing
|
||||
/// guard alongside the new frame/batch ones.
|
||||
#[test]
|
||||
fn predict_rejects_wrong_grid_dims() {
|
||||
let device = Device::Cpu;
|
||||
let cfg = small_cfg();
|
||||
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
|
||||
let past = Tensor::zeros(
|
||||
(1usize, cfg.num_frames, cfg.grid_h + 1, cfg.grid_w, cfg.grid_d),
|
||||
DType::U8,
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let result = engine.predict(&past);
|
||||
assert!(
|
||||
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
|
||||
"wrong grid dims must be rejected with ShapeMismatch"
|
||||
);
|
||||
}
|
||||
Loading…
Reference in New Issue