wifi-densepose/benchmarks/wiflow-std/_bench_common.py

201 lines
7.4 KiB
Python

"""Shared infrastructure for the LOCAL wiflow-std benchmark scripts (ADR-152).
This module is the single canonical implementation of the helpers that were
previously copy-pasted across eval_repro.py / quantize_bench.py /
onnx_bench.py / eval_ort_accuracy.py / export_to_safetensors.py:
- ``import_upstream()`` -- sys.path setup + the models-package stub that
works around the upstream import bug, plus the >1GB np.load mmap patch
- ``install_np_load_mmap_patch()`` -- the mmap patch on its own
- ``remap_legacy_keys()`` / ``load_remapped_state()`` -- checkpoint
key remap for the pre-rename released checkpoint
- ``load_wiflow_model()`` -- WiFlowPoseModel from a checkpoint, eval mode
- ``set_seed()`` -- mirrors upstream run.py seeding exactly
- ``evaluate()`` -- THE canonical batch-weighted PCK/MPJPE evaluation loop
(thresholds 0.1-0.5, upstream utils/metrics.py math); accepts either a
torch nn.Module or an onnxruntime InferenceSession
The scripts under remote/ deploy to ruvultra as standalone single files and
therefore intentionally inline private copies of these helpers; when editing
them, treat this module as the reference implementation and keep the copies
in sync.
"""
import os
import random
import sys
import time
import types
import numpy as np
import torch
HERE = os.path.dirname(os.path.abspath(__file__))
UPSTREAM = os.path.join(HERE, "upstream")
RESULTS = os.path.join(HERE, "results")
DEFAULT_THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5)
# ---------------------------------------------------------------------------
# >1GB np.load mmap patch
# ---------------------------------------------------------------------------
# csi_windows.npy is ~13 GB; mmap large arrays instead of loading into RAM
# (loading it eagerly needs ~15 GB).
_np_load = np.load
def _np_load_mmap(path, *a, **kw):
if (isinstance(path, str) and path.endswith(".npy")
and os.path.getsize(path) > 1 << 30 and "mmap_mode" not in kw):
kw["mmap_mode"] = "r"
return _np_load(path, *a, **kw)
def install_np_load_mmap_patch():
"""Globally patch np.load so .npy files >1GB are mmap'd read-only.
Idempotent. Patching the numpy module attribute is equivalent to the
historical ``upstream_dataset.np.load = _np_load_mmap`` (dataset.np IS
the numpy module), but works regardless of import order.
"""
np.load = _np_load_mmap
# ---------------------------------------------------------------------------
# upstream import shim
# ---------------------------------------------------------------------------
def import_upstream(mmap_patch=True):
"""Make the upstream WiFlow-STD clone importable; returns its path.
Upstream bug: models/__init__.py imports TemporalConvNet, which
models/tcn.py does not define -- the package fails to import as
published. Register a stub package so the broken __init__ never
executes; submodules (models.pose_model etc.) still resolve via
__path__. Idempotent.
"""
if UPSTREAM not in sys.path:
sys.path.insert(0, UPSTREAM)
if "models" not in sys.modules:
_models_pkg = types.ModuleType("models")
_models_pkg.__path__ = [os.path.join(UPSTREAM, "models")]
sys.modules["models"] = _models_pkg
if mmap_patch:
install_np_load_mmap_patch()
return UPSTREAM
# ---------------------------------------------------------------------------
# checkpoint loading
# ---------------------------------------------------------------------------
# The released checkpoint predates the published code: modules were renamed
# att -> attention, final_conv -> decoder (param count identical, 2.23M).
LEGACY_RENAMES = {"att.": "attention.", "final_conv.": "decoder."}
def remap_legacy_keys(state):
"""Remap pre-rename state_dict keys; no-op for already-new-style keys."""
return {next((new + k[len(old):] for old, new in LEGACY_RENAMES.items()
if k.startswith(old)), k): v
for k, v in state.items()}
def load_remapped_state(path, map_location="cpu"):
"""torch.load (weights_only) + legacy key remap."""
state = torch.load(path, map_location=map_location, weights_only=True)
return remap_legacy_keys(state)
def load_wiflow_model(checkpoint, map_location="cpu", dropout=0.5):
"""Full-size WiFlowPoseModel from a checkpoint, strict load, eval mode."""
import_upstream()
from models.pose_model import WiFlowPoseModel
model = WiFlowPoseModel(dropout=dropout)
model.load_state_dict(load_remapped_state(checkpoint, map_location),
strict=True)
model.eval()
return model
# ---------------------------------------------------------------------------
# seeding
# ---------------------------------------------------------------------------
def set_seed(seed=42):
# mirror upstream run.py exactly
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ---------------------------------------------------------------------------
# THE canonical evaluation loop
# ---------------------------------------------------------------------------
def evaluate(model, loader, device=None, dtype=None, label="",
thresholds=DEFAULT_THRESHOLDS, progress_every=50):
"""Batch-weighted PCK/MPJPE over a DataLoader (upstream metrics math).
``model`` may be a torch nn.Module (optionally evaluated on ``device``
with inputs cast to ``dtype``) or an onnxruntime InferenceSession.
Per-threshold PCK values are independent in upstream calculate_pck, so
evaluating a superset of thresholds never changes any individual value.
Returns {"samples", "mpjpe", "pck@10".."pck@50", "wall_seconds"}.
"""
import_upstream()
from utils.metrics import calculate_mpjpe, calculate_pck
is_ort = hasattr(model, "get_inputs") # onnxruntime InferenceSession
if is_ort:
inp = model.get_inputs()[0].name
def forward(bx):
return torch.from_numpy(model.run(None, {inp: bx.numpy()})[0])
else:
model.eval()
def forward(bx):
if device is not None:
bx = bx.to(device)
if dtype is not None:
bx = bx.to(dtype)
return model(bx).float()
thresholds = list(thresholds)
totals = {t: 0.0 for t in thresholds}
total_mpe, n = 0.0, 0
t0 = time.time()
with torch.no_grad():
for batch_idx, (bx, by) in enumerate(loader):
out = forward(bx)
if device is not None and not is_ort:
by = by.to(device)
mpe = calculate_mpjpe(out, by)
pck = calculate_pck(out, by, thresholds=thresholds)
bs = by.size(0)
total_mpe += mpe * bs
for t in totals:
totals[t] += pck[t] * bs
n += bs
if batch_idx % progress_every == 0:
tag = f"[{label}] " if label else ""
pck20 = totals.get(0.2)
pck20_str = f"pck20={pck20 / n:.4f} " if pck20 is not None else ""
print(f" {tag}batch {batch_idx}: n={n} {pck20_str}"
f"mpjpe={total_mpe / n:.4f} ({time.time() - t0:.0f}s)",
flush=True)
return {
"samples": n,
"mpjpe": total_mpe / n,
**{f"pck@{int(t * 100)}": totals[t] / n for t in thresholds},
"wall_seconds": time.time() - t0,
}