security(occworld-candle): int32-checkpoint crash + degenerate-input guards + ADR-179 (closes Milestone #9) (#1101)

* fix(occworld-candle): security review fixes — int32 checkpoint crash + predict input validation

Beyond-SOTA security + correctness review of wifi-densepose-occworld-candle
(Milestone #9, crate 4/4 — the last ungated crate).

Findings fixed:

1. HIGH (MEASURED) — checkpoint-load crash on any int32 tensor.
   model.rs mapped safetensors I32 -> candle DType::I64 and passed the raw
   int32 byte buffer (4 bytes/elem) to Tensor::from_raw_buffer(.., I64, ..).
   Candle derives elem_count = data.len() / dtype.size(), so the I64 path
   halved the count while keeping the original shape -> a tensor whose shape
   claims 2x its storage. Reading it PANICS (slice OOB: "range end index 6
   out of range for slice of length 3") on any checkpoint containing an int32
   tensor. Fixed: I32 -> DType::I32, I16 -> DType::I16 (both first-class
   candle dtypes). Reproduced on old code; pinned in tests/checkpoint_loading.rs.

2. LOW (MEASURED) — predict() lacked frame/batch validation at the input
   boundary. f_in > num_frames*2 over-indexed the temporal embedding (cryptic
   candle "gather" error); zero frame/batch fed a zero-element tensor in. Now
   rejected with a clear ShapeMismatch. Pinned in tests/input_validation.rs.

3. LOW (MEASURED) — divide-by-zero panic in the public VQCodebook::encode on a
   rank-0 / empty-last-dim tensor (last == 0). Now fails closed with a clear
   error. Pinned in vqvae.rs unit tests.

Dimensions confirmed clean with evidence: panic surface (no unwrap/expect/
panic in prod paths), NaN-state-poisoning (N/A — stateless engine, u8 input),
unbounded-alloc/shape-data mismatch (defended upstream by safetensors::
validate), secrets (none). unsafe_code = forbid.

Validation (MEASURED, Windows): crate 31/31 pass; workspace 0 failed (lone
desktop api_integration "Access is denied" file-lock flake passes 21/21 in
isolation); Python proof VERDICT PASS, hash f8e76f21…446f7a unchanged.

Warrants ADR slot 179 (parent to author).

Co-Authored-By: claude-flow <ruv@ruv.net>

* docs(adr): ADR-179 — occworld-candle checkpoint-load hardening (closes Milestone #9)

Records the HIGH int32-checkpoint crash fix (I32→I64 dtype-widening → slice-OOB
panic on load = DoS) + 2 LOW degenerate-input fixes from 5e77f47e5. Stateless
engine (NaN-poisoning N/A), unsafe forbidden, safetensors validate() defends
malloc upstream. occworld 31/31. Final ungated crate — Milestone #9 complete.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
rUv 2026-06-15 12:35:29 -04:00 committed by GitHub
parent 10c813fde3
commit c859f6f743
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 471 additions and 1 deletions

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,81 @@
# ADR-179: `wifi-densepose-occworld-candle` Checkpoint-Load Hardening
| Field | Value |
|-------|-------|
| **Status** | Accepted — 1 HIGH + 2 LOW bugs fixed + pinned (MEASURED on Windows) |
| **Date** | 2026-06-15 |
| **Deciders** | ruv |
| **Codename** | **OCCWORLD-DTYPE** |
| **Reviews** | `wifi-densepose-occworld-candle` (Candle occupancy-world model) |
| **Milestone** | #9 (ungated-crate security sweep) — crate 4 of 4 — **CLOSES the milestone** |
## Context
`wifi-densepose-occworld-candle` is a Candle-based occupancy-world model
(VQ-VAE + transformer over occupancy tokens). The real risk surface for an ML
crate is degenerate-input / malformed-weights handling: a `#[forbid(unsafe_code)]`
crate can still **panic** (a DoS, and under WASM an abort) when a tensor op hits an
inconsistent shape. The crate **builds and tests on Windows**, so all findings are
MEASURED.
## Decision
Fix the three reachable bugs, each pinned by a fails-on-old test; attest the rest
clean with evidence.
### Findings fixed (all MEASURED)
| # | Severity | Location | Issue | Fix |
|---|----------|----------|-------|-----|
| 1 | **HIGH** | `model.rs:95` (`Dtype::I32 => Some(DType::I64)`) | **Crash on any int32-tensor checkpoint.** An I32 byte buffer (4 B/elem) is handed to `from_raw_buffer(.., I64, shape, ..)`; candle derives `elem_count = data.len()/8`, **halving** the count while keeping the original shape → a tensor that claims 2× its storage. Reading it **panics** with a slice-OOB (`range end index 6 out of range for slice of length 3`) inside candle-core. A checkpoint with any int32 tensor (index/buffer tensors are common in PyTorch exports) → **DoS on load**. | Map `I32 → DType::I32`, `I16 → DType::I16` (both first-class candle dtypes). Pinned by `int32_tensor_loads_with_consistent_shape_and_values` (panics on old, passes on new). |
| 2 | LOW | `inference.rs::predict` | Frame/batch dims weren't validated (only H/W/D were): `f_in > num_frames*2` over-indexes the temporal embedding → a cryptic candle `InvalidIndex` *error* (not a panic — candle bounds-checks); zero frame/batch feeds a zero-element tensor. | Boundary guard rejects zero / over-capacity frame+batch with a clear `ShapeMismatch`. 5 pins. |
| 3 | LOW | `vqvae.rs:141` (`z.elem_count() / last`) | **Divide-by-zero panic** in public `VQCodebook::encode` on a rank-0 / empty-last-dim tensor (`last == 0`). | Fail-closed guard returns a clear error. Pinned by `encode_rejects_scalar_without_panicking`. |
The HIGH finding is the notable one: the crate's own dtype mapping **defeated**
the upstream `safetensors::validate()` byte-length guarantee by misdeclaring the
dtype — the one place malformed/widened weights could reach a panicking candle op.
### Dimensions confirmed clean (with evidence)
- **Panic surface** — grep for `unwrap()/expect()/panic!/unreachable!` across `src/`
**zero in production paths**; all ops use `?`/`map_err`; the `last().unwrap_or(&0)`
is now guarded. `as` casts operate only on config-bounded/internal values.
- **NaN-state-poisoning (the named class) — N/A.** The engine is **stateless between
`predict` calls** (no persistent world-model buffer to latch into), and input is
`u8` class indices (non-finite input structurally impossible). NaN weights flow to
`argmax` (deterministic, bounded to a valid class index) — no panic, no persistence.
- **Unbounded alloc / shape-data mismatch from malformed weights** — defended upstream
by `safetensors::validate()` (overflow-checked `nelements*dtype.size()` vs declared
byte range + contiguous-offset + buffer-length checks), rejected before reaching
candle. Finding #1 was the one place the crate defeated that guarantee.
- **Model/path loading**`load`/`load_safetensors` check `path.exists()` → typed
`CheckpointNotFound`; corrupt bytes → `CheckpointParse` (pinned). No path-traversal
surface (caller-supplied path, opened read-only, never joined with untrusted segments).
- **Secrets** — grep clean (only `token_h`/`token_w` config fields match `token`).
- **Determinism** — the crate's central honesty claim, verified by the pre-existing
`tests/predict_honesty.rs` (3 tests, still pass).
- `unsafe_code = "forbid"` in the manifest.
## Validation
- `cargo test -p wifi-densepose-occworld-candle --no-default-features` → **31/31**
(lib 17, checkpoint_loading 4, input_validation 5, predict_honesty 3, doctests 2),
0 failed.
- `cargo test --workspace --no-default-features` → 0 failed across every crate (a lone
`wifi-densepose-desktop --test api_integration` "Access is denied (os error 5)" was a
Windows file-lock/AV flake — re-ran isolated 21/21, unrelated).
- `python archive/v1/data/proof/verify.py`**VERDICT: PASS**, hash `f8e76f21…46f7a`
unchanged (occworld off the signal proof path).
## Consequences
### Positive
- A checkpoint-load DoS (the int32 dtype-widening panic) and two degenerate-input
panics are closed in the world-model crate, each pinned. **Milestone #9 (all 4
ungated crates) is complete.**
### Negative / Neutral
- None. Guards reject only malformed/degenerate inputs.
## Links
- ADR-176 / ADR-177 / ADR-178 — sibling Milestone-#9 reviews (ruview-swarm, nvsim, desktop)

View File

@ -206,6 +206,27 @@ impl OccWorldCandle {
)));
}
// Validate the externally-supplied frame and batch counts at this
// system boundary. The temporal positional embedding has only
// `num_frames * 2` rows, so a larger `f_in` would over-index the
// embedding table deep inside the transformer and surface as a cryptic
// "gather" index error; a zero frame/batch count would feed a
// zero-element tensor into the reshape/conv pipeline. Reject both here
// with a clear, domain-level error instead.
if f_in == 0 || b == 0 {
return Err(OccWorldError::ShapeMismatch(format!(
"past_occupancy must have non-zero batch and frame dims, got \
batch={b}, frames={f_in}"
)));
}
if f_in > cfg.num_frames * 2 {
return Err(OccWorldError::ShapeMismatch(format!(
"past_occupancy frame count {f_in} exceeds the temporal embedding \
capacity ({} = num_frames*2)",
cfg.num_frames * 2
)));
}
// ── Step 1: VQVAE encode each past frame ──────────────────────────
// Flatten batch*frames: (B, F, H, W, D) → (B*F, H, W, D)
let occ_flat = past_occupancy
@ -455,4 +476,8 @@ mod tests {
"expected CheckpointNotFound, got {result:?}"
);
}
// The `predict` input-validation boundary guards (zero/over-capacity frame
// counts) live in `tests/input_validation.rs` so they exercise only the
// public API and keep this file under the 500-line limit.
}

View File

@ -92,8 +92,21 @@ fn safetensor_dtype_to_candle(dt: safetensors::Dtype) -> Option<candle_core::DTy
Dtype::F64 => Some(DType::F64),
Dtype::F16 => Some(DType::F16),
Dtype::BF16 => Some(DType::BF16),
Dtype::I32 => Some(DType::I64), // widen for Candle compatibility
// I32 MUST map to DType::I32, not I64. `Tensor::from_raw_buffer`
// derives its element count from `data.len() / dtype.size_in_bytes()`;
// handing an int32 byte buffer (4 bytes/elem) to the I64 path
// (8 bytes/elem) halves the element count while keeping the original
// shape, producing a tensor whose declared shape claims twice as many
// elements as its storage holds. That silent shape/storage mismatch
// panics (slice OOB) the moment the tensor is read — a crash on any
// checkpoint containing an int32 tensor. See
// `tests/checkpoint_loading.rs::int32_tensor_loads_with_consistent_shape_and_values`.
Dtype::I32 => Some(DType::I32),
Dtype::I64 => Some(DType::I64),
// I16 is also a first-class Candle dtype (2 bytes/elem); map it
// directly rather than rejecting it, for the same byte-size-correctness
// reason as I32 above.
Dtype::I16 => Some(DType::I16),
Dtype::U8 => Some(DType::U8),
Dtype::U32 => Some(DType::U32),
_ => None,

View File

@ -137,6 +137,17 @@ impl VQCodebook {
let orig_shape = z.shape().clone();
let orig_dims = orig_shape.dims().to_vec();
let last = *orig_shape.dims().last().unwrap_or(&0);
// Guard the divide below: a scalar (rank-0) or empty-last-dim tensor
// would make `last == 0` and panic on the `elem_count() / last`
// division. `encode` is a `pub fn` on a `pub struct`, so this is a
// reachable public boundary — fail closed with a clear error instead.
if last == 0 {
return Err(candle_core::Error::Msg(format!(
"VQCodebook::encode expects a tensor with a non-zero last dim of \
size embed_dim={}, got shape {orig_dims:?}",
self.embed_dim
)));
}
// Flatten to (N, embed_dim)
let n = z.elem_count() / last;
let z_flat = z.reshape((n, last))?; // (N, D)
@ -339,6 +350,21 @@ mod tests {
Ok(())
}
#[test]
fn encode_rejects_scalar_without_panicking() {
// A rank-0 (scalar) tensor has an empty dims list → `last == 0`.
// Before the guard this divided by zero and panicked; now it returns
// a clean error. `encode` is public, so this is a reachable boundary.
let device = Device::Cpu;
let codebook = VQCodebook::dummy(4, 8, &device).unwrap();
let scalar = Tensor::from_vec(vec![1.0f32], (), &device).unwrap();
let result = codebook.encode(&scalar);
assert!(
result.is_err(),
"scalar input must error, not panic; got {result:?}"
);
}
#[test]
fn test_fold_unfold_roundtrip() -> candle_core::Result<()> {
let device = Device::Cpu;

View File

@ -0,0 +1,185 @@
//! Checkpoint-loading robustness tests for `crate::model::load_safetensors`.
//!
//! Security review (Milestone #9, crate 4/4). These tests pin the behaviour of
//! the SafeTensors weight-loading path against malformed / degenerate
//! checkpoints — the only externally-controlled file-input surface in the crate.
//!
//! The headline regression is the **int32 dtype-widening byte-size bug**
//! (`security/occworld-candle` finding #1): `model.rs` mapped
//! `safetensors::Dtype::I32` → `candle_core::DType::I64` and then handed the
//! raw *int32* byte buffer (4 bytes/elem) to `Tensor::from_raw_buffer(.., I64,
//! shape, ..)`. Candle's `from_raw_buffer` computes `elem_count =
//! data.len() / 8`, producing a tensor whose declared shape claims twice as
//! many elements as the backing storage actually holds — a silent
//! shape/storage inconsistency on attacker-supplied checkpoints.
//!
//! `build_safetensors` hand-assembles the binary container
//! (`<u64 LE header_len><JSON header><raw data>`) so the test states exactly
//! what bytes reach the loader, independent of the `safetensors` writer API.
use candle_core::Device;
use wifi_densepose_occworld_candle::model::load_safetensors;
/// Hand-build a single-tensor SafeTensors buffer.
///
/// `dtype` is the safetensors dtype string (e.g. `"I32"`, `"F32"`).
/// `shape` is the declared shape. `data` is the raw little-endian tensor bytes
/// — the caller is responsible for making `data.len()` consistent with
/// `shape × dtype_size` (safetensors itself validates this, so an inconsistent
/// pair is rejected before reaching the candle conversion).
fn build_safetensors(name: &str, dtype: &str, shape: &[usize], data: &[u8]) -> Vec<u8> {
let shape_json: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
let header = format!(
"{{\"{name}\":{{\"dtype\":\"{dtype}\",\"shape\":[{}],\"data_offsets\":[0,{}]}}}}",
shape_json.join(","),
data.len()
);
let header_bytes = header.into_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(&header_bytes);
buf.extend_from_slice(data);
buf
}
fn write_temp(bytes: &[u8], stem: &str) -> std::path::PathBuf {
let mut p = std::env::temp_dir();
p.push(format!(
"occworld_ckpt_{stem}_{}_{}.safetensors",
std::process::id(),
// nanosecond-ish disambiguator so parallel tests never collide
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
std::fs::write(&p, bytes).expect("write temp checkpoint");
p
}
/// REGRESSION (finding #1): an int32 tensor in a checkpoint must load into a
/// tensor whose element count matches its declared shape.
///
/// On the OLD code (`I32 -> DType::I64`) the 6-element int32 tensor below was
/// handed to `from_raw_buffer(.., I64, [2,3], ..)`, which derived
/// `elem_count = 24 bytes / 8 = 3` and built a 3-element storage carrying a
/// shape claiming 6 elements — reading it panicked with a slice-OOB
/// (`range end index 6 out of range for slice of length 3`). On the FIXED code
/// (`I32 -> DType::I32`) the tensor round-trips: dtype I32, 6 elements,
/// values `[1,2,3,4,5,6]`.
#[test]
fn int32_tensor_loads_with_consistent_shape_and_values() {
let device = Device::Cpu;
let shape = [2usize, 3];
let vals: [i32; 6] = [1, 2, 3, 4, 5, 6];
let mut data = Vec::with_capacity(24);
for v in vals {
data.extend_from_slice(&v.to_le_bytes());
}
let bytes = build_safetensors("quantize.embedding.weight", "I32", &shape, &data);
let path = write_temp(&bytes, "i32");
let map = load_safetensors(&path, &device).expect("int32 checkpoint must load");
let t = map
.get("quantize.embedding.weight")
.expect("mapped key present");
// The declared shape's element count MUST equal the storage's element
// count. On the old code these disagreed (6 vs 3).
assert_eq!(
t.dims(),
&[2, 3],
"int32 tensor must preserve its declared shape"
);
assert_eq!(
t.elem_count(),
6,
"element count must match shape — storage/shape consistency"
);
// The dtype must be I32 — the int32 byte buffer is interpreted as int32,
// not reinterpreted as half as many int64 lanes.
assert_eq!(
t.dtype(),
candle_core::DType::I32,
"int32 checkpoint tensor must load as DType::I32"
);
// And the values must be exactly recovered (no reinterpretation of two
// int32 lanes as one int64). This is the strongest proof the dtype is
// handled correctly end-to-end.
let flat = t.flatten_all().expect("flatten");
let got: Vec<i32> = flat.to_vec1::<i32>().expect("to_vec i32");
assert_eq!(
got,
vec![1i32, 2, 3, 4, 5, 6],
"int32 values must be recovered exactly"
);
let _ = std::fs::remove_file(&path);
}
/// A well-formed F32 tensor must round-trip unchanged (control case — proves
/// the fix does not regress the common float path).
#[test]
fn f32_tensor_round_trips() {
let device = Device::Cpu;
let shape = [4usize];
let vals: [f32; 4] = [0.5, -1.0, 2.25, 3.0];
let mut data = Vec::with_capacity(16);
for v in vals {
data.extend_from_slice(&v.to_le_bytes());
}
let bytes = build_safetensors("post_quant_conv.bias", "F32", &shape, &data);
let path = write_temp(&bytes, "f32");
let map = load_safetensors(&path, &device).expect("f32 checkpoint must load");
let t = map.get("post_quant_conv.bias").expect("key present");
assert_eq!(t.dims(), &[4]);
let got: Vec<f32> = t.to_vec1::<f32>().expect("to_vec f32");
assert_eq!(got, vec![0.5, -1.0, 2.25, 3.0]);
let _ = std::fs::remove_file(&path);
}
/// A truncated / corrupt header must produce a parse error, never a panic.
/// (Defense-in-depth: the loader is fed an untrusted file.)
#[test]
fn corrupt_checkpoint_errors_cleanly() {
let device = Device::Cpu;
// Garbage that is not a valid SafeTensors container.
let bytes = vec![0xFFu8; 32];
let path = write_temp(&bytes, "corrupt");
let result = load_safetensors(&path, &device);
assert!(
result.is_err(),
"corrupt checkpoint must error, got Ok: {result:?}"
);
let _ = std::fs::remove_file(&path);
}
/// An int64 tensor must still load correctly (proves the fix narrows only the
/// I32 mapping and leaves the genuine I64 path intact).
#[test]
fn int64_tensor_round_trips() {
let device = Device::Cpu;
let shape = [3usize];
let vals: [i64; 3] = [10, -20, 30];
let mut data = Vec::with_capacity(24);
for v in vals {
data.extend_from_slice(&v.to_le_bytes());
}
let bytes = build_safetensors("transformer.output_head.bias", "I64", &shape, &data);
let path = write_temp(&bytes, "i64");
let map = load_safetensors(&path, &device).expect("i64 checkpoint must load");
let t = map.get("transformer.output_head.bias").expect("key present");
assert_eq!(t.dims(), &[3]);
assert_eq!(t.elem_count(), 3);
let got: Vec<i64> = t.to_vec1::<i64>().expect("to_vec i64");
assert_eq!(got, vec![10, -20, 30]);
let _ = std::fs::remove_file(&path);
}

View File

@ -0,0 +1,139 @@
//! Input-validation boundary tests for `OccWorldCandle::predict`.
//!
//! Security review (Milestone #9, crate 4/4). `predict` takes an
//! externally-supplied occupancy tensor; per the project's "validate input at
//! system boundaries" rule it must reject degenerate / out-of-capacity shapes
//! with a clear domain error rather than surfacing a cryptic deep-pipeline
//! Candle error (over-capacity frame counts over-index the temporal positional
//! embedding) or processing a zero-element tensor.
//!
//! These exercise only the public API and live here (not inline in
//! `inference.rs`) to keep that module under the 500-line cap.
use candle_core::{DType, Device, Tensor};
use wifi_densepose_occworld_candle::config::OccWorldConfig;
use wifi_densepose_occworld_candle::inference::OccWorldCandle;
use wifi_densepose_occworld_candle::error::OccWorldError;
fn small_cfg() -> OccWorldConfig {
OccWorldConfig {
grid_h: 8,
grid_w: 8,
grid_d: 4,
num_classes: 4,
free_class: 3,
base_channels: 8,
z_channels: 8,
codebook_size: 4,
embed_dim: 8,
num_frames: 2,
token_h: 4,
token_w: 4,
num_heads: 2,
num_layers: 1,
ffn_hidden: 16,
}
}
/// Zero frames is a degenerate input that would otherwise feed a zero-element
/// tensor into the reshape/conv pipeline. Must be rejected at the boundary.
#[test]
fn predict_rejects_zero_frames() {
let device = Device::Cpu;
let cfg = small_cfg();
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
let past = Tensor::zeros(
(1usize, 0usize, cfg.grid_h, cfg.grid_w, cfg.grid_d),
DType::U8,
&device,
)
.unwrap();
let result = engine.predict(&past);
assert!(
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
"zero-frame input must be rejected with ShapeMismatch"
);
}
/// Zero batch must also be rejected (same zero-element-tensor hazard).
#[test]
fn predict_rejects_zero_batch() {
let device = Device::Cpu;
let cfg = small_cfg();
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
let past = Tensor::zeros(
(0usize, cfg.num_frames, cfg.grid_h, cfg.grid_w, cfg.grid_d),
DType::U8,
&device,
)
.unwrap();
let result = engine.predict(&past);
assert!(
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
"zero-batch input must be rejected with ShapeMismatch"
);
}
/// More frames than the temporal embedding can index (`> num_frames*2`).
///
/// On the old code this over-indexed the temporal positional embedding deep in
/// the transformer and surfaced as a cryptic Candle "gather" `InvalidIndex`
/// error. The boundary guard now rejects it cleanly with `ShapeMismatch`.
#[test]
fn predict_rejects_too_many_frames() {
let device = Device::Cpu;
let cfg = small_cfg(); // num_frames = 2 → temporal capacity = 4
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
let too_many = cfg.num_frames * 2 + 1;
let past = Tensor::zeros(
(1usize, too_many, cfg.grid_h, cfg.grid_w, cfg.grid_d),
DType::U8,
&device,
)
.unwrap();
let result = engine.predict(&past);
assert!(
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
"over-capacity frame count must be rejected with ShapeMismatch"
);
}
/// A frame count exactly at capacity (`num_frames*2`) must still succeed —
/// the guard rejects only *over*-capacity, not the boundary value.
#[test]
fn predict_accepts_frame_count_at_capacity() {
let device = Device::Cpu;
let cfg = small_cfg();
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
let at_cap = cfg.num_frames * 2;
let past = Tensor::zeros(
(1usize, at_cap, cfg.grid_h, cfg.grid_w, cfg.grid_d),
DType::U8,
&device,
)
.unwrap();
let out = engine
.predict(&past)
.expect("at-capacity frame count must predict");
assert_eq!(out.sem_pred.dims()[1], at_cap, "frame dim preserved");
}
/// Wrong spatial geometry (H/W/D) is still rejected — pins the pre-existing
/// guard alongside the new frame/batch ones.
#[test]
fn predict_rejects_wrong_grid_dims() {
let device = Device::Cpu;
let cfg = small_cfg();
let engine = OccWorldCandle::dummy(cfg.clone(), device.clone()).unwrap();
let past = Tensor::zeros(
(1usize, cfg.num_frames, cfg.grid_h + 1, cfg.grid_w, cfg.grid_d),
DType::U8,
&device,
)
.unwrap();
let result = engine.predict(&past);
assert!(
matches!(result, Err(OccWorldError::ShapeMismatch(_))),
"wrong grid dims must be rejected with ShapeMismatch"
);
}