201 lines
7.4 KiB
Python
201 lines
7.4 KiB
Python
"""Shared infrastructure for the LOCAL wiflow-std benchmark scripts (ADR-152).
|
|
|
|
This module is the single canonical implementation of the helpers that were
|
|
previously copy-pasted across eval_repro.py / quantize_bench.py /
|
|
onnx_bench.py / eval_ort_accuracy.py / export_to_safetensors.py:
|
|
|
|
- ``import_upstream()`` -- sys.path setup + the models-package stub that
|
|
works around the upstream import bug, plus the >1GB np.load mmap patch
|
|
- ``install_np_load_mmap_patch()`` -- the mmap patch on its own
|
|
- ``remap_legacy_keys()`` / ``load_remapped_state()`` -- checkpoint
|
|
key remap for the pre-rename released checkpoint
|
|
- ``load_wiflow_model()`` -- WiFlowPoseModel from a checkpoint, eval mode
|
|
- ``set_seed()`` -- mirrors upstream run.py seeding exactly
|
|
- ``evaluate()`` -- THE canonical batch-weighted PCK/MPJPE evaluation loop
|
|
(thresholds 0.1-0.5, upstream utils/metrics.py math); accepts either a
|
|
torch nn.Module or an onnxruntime InferenceSession
|
|
|
|
The scripts under remote/ deploy to ruvultra as standalone single files and
|
|
therefore intentionally inline private copies of these helpers; when editing
|
|
them, treat this module as the reference implementation and keep the copies
|
|
in sync.
|
|
"""
|
|
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
import types
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
HERE = os.path.dirname(os.path.abspath(__file__))
|
|
UPSTREAM = os.path.join(HERE, "upstream")
|
|
RESULTS = os.path.join(HERE, "results")
|
|
|
|
DEFAULT_THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# >1GB np.load mmap patch
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# csi_windows.npy is ~13 GB; mmap large arrays instead of loading into RAM
|
|
# (loading it eagerly needs ~15 GB).
|
|
_np_load = np.load
|
|
|
|
|
|
def _np_load_mmap(path, *a, **kw):
|
|
if (isinstance(path, str) and path.endswith(".npy")
|
|
and os.path.getsize(path) > 1 << 30 and "mmap_mode" not in kw):
|
|
kw["mmap_mode"] = "r"
|
|
return _np_load(path, *a, **kw)
|
|
|
|
|
|
def install_np_load_mmap_patch():
|
|
"""Globally patch np.load so .npy files >1GB are mmap'd read-only.
|
|
|
|
Idempotent. Patching the numpy module attribute is equivalent to the
|
|
historical ``upstream_dataset.np.load = _np_load_mmap`` (dataset.np IS
|
|
the numpy module), but works regardless of import order.
|
|
"""
|
|
np.load = _np_load_mmap
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# upstream import shim
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def import_upstream(mmap_patch=True):
|
|
"""Make the upstream WiFlow-STD clone importable; returns its path.
|
|
|
|
Upstream bug: models/__init__.py imports TemporalConvNet, which
|
|
models/tcn.py does not define -- the package fails to import as
|
|
published. Register a stub package so the broken __init__ never
|
|
executes; submodules (models.pose_model etc.) still resolve via
|
|
__path__. Idempotent.
|
|
"""
|
|
if UPSTREAM not in sys.path:
|
|
sys.path.insert(0, UPSTREAM)
|
|
if "models" not in sys.modules:
|
|
_models_pkg = types.ModuleType("models")
|
|
_models_pkg.__path__ = [os.path.join(UPSTREAM, "models")]
|
|
sys.modules["models"] = _models_pkg
|
|
if mmap_patch:
|
|
install_np_load_mmap_patch()
|
|
return UPSTREAM
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# checkpoint loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# The released checkpoint predates the published code: modules were renamed
|
|
# att -> attention, final_conv -> decoder (param count identical, 2.23M).
|
|
LEGACY_RENAMES = {"att.": "attention.", "final_conv.": "decoder."}
|
|
|
|
|
|
def remap_legacy_keys(state):
|
|
"""Remap pre-rename state_dict keys; no-op for already-new-style keys."""
|
|
return {next((new + k[len(old):] for old, new in LEGACY_RENAMES.items()
|
|
if k.startswith(old)), k): v
|
|
for k, v in state.items()}
|
|
|
|
|
|
def load_remapped_state(path, map_location="cpu"):
|
|
"""torch.load (weights_only) + legacy key remap."""
|
|
state = torch.load(path, map_location=map_location, weights_only=True)
|
|
return remap_legacy_keys(state)
|
|
|
|
|
|
def load_wiflow_model(checkpoint, map_location="cpu", dropout=0.5):
|
|
"""Full-size WiFlowPoseModel from a checkpoint, strict load, eval mode."""
|
|
import_upstream()
|
|
from models.pose_model import WiFlowPoseModel
|
|
model = WiFlowPoseModel(dropout=dropout)
|
|
model.load_state_dict(load_remapped_state(checkpoint, map_location),
|
|
strict=True)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# seeding
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def set_seed(seed=42):
|
|
# mirror upstream run.py exactly
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# THE canonical evaluation loop
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def evaluate(model, loader, device=None, dtype=None, label="",
|
|
thresholds=DEFAULT_THRESHOLDS, progress_every=50):
|
|
"""Batch-weighted PCK/MPJPE over a DataLoader (upstream metrics math).
|
|
|
|
``model`` may be a torch nn.Module (optionally evaluated on ``device``
|
|
with inputs cast to ``dtype``) or an onnxruntime InferenceSession.
|
|
Per-threshold PCK values are independent in upstream calculate_pck, so
|
|
evaluating a superset of thresholds never changes any individual value.
|
|
|
|
Returns {"samples", "mpjpe", "pck@10".."pck@50", "wall_seconds"}.
|
|
"""
|
|
import_upstream()
|
|
from utils.metrics import calculate_mpjpe, calculate_pck
|
|
|
|
is_ort = hasattr(model, "get_inputs") # onnxruntime InferenceSession
|
|
if is_ort:
|
|
inp = model.get_inputs()[0].name
|
|
|
|
def forward(bx):
|
|
return torch.from_numpy(model.run(None, {inp: bx.numpy()})[0])
|
|
else:
|
|
model.eval()
|
|
|
|
def forward(bx):
|
|
if device is not None:
|
|
bx = bx.to(device)
|
|
if dtype is not None:
|
|
bx = bx.to(dtype)
|
|
return model(bx).float()
|
|
|
|
thresholds = list(thresholds)
|
|
totals = {t: 0.0 for t in thresholds}
|
|
total_mpe, n = 0.0, 0
|
|
t0 = time.time()
|
|
with torch.no_grad():
|
|
for batch_idx, (bx, by) in enumerate(loader):
|
|
out = forward(bx)
|
|
if device is not None and not is_ort:
|
|
by = by.to(device)
|
|
mpe = calculate_mpjpe(out, by)
|
|
pck = calculate_pck(out, by, thresholds=thresholds)
|
|
bs = by.size(0)
|
|
total_mpe += mpe * bs
|
|
for t in totals:
|
|
totals[t] += pck[t] * bs
|
|
n += bs
|
|
if batch_idx % progress_every == 0:
|
|
tag = f"[{label}] " if label else ""
|
|
pck20 = totals.get(0.2)
|
|
pck20_str = f"pck20={pck20 / n:.4f} " if pck20 is not None else ""
|
|
print(f" {tag}batch {batch_idx}: n={n} {pck20_str}"
|
|
f"mpjpe={total_mpe / n:.4f} ({time.time() - t0:.0f}s)",
|
|
flush=True)
|
|
return {
|
|
"samples": n,
|
|
"mpjpe": total_mpe / n,
|
|
**{f"pck@{int(t * 100)}": totals[t] / n for t in thresholds},
|
|
"wall_seconds": time.time() - t0,
|
|
}
|