wifi-densepose/benchmarks/wiflow-std/onnx_bench.py

221 lines
8.6 KiB
Python

"""ADR-152 edge optimization: ONNX export + onnxruntime CPU benchmark for the
retrained WiFlow-STD checkpoint.
- Exports fp32 to ONNX. The axial attention reshapes with python ints taken
from tensor.size() (view(N*W, C, H)), so a traced graph bakes the batch
size; we first try a dynamic-batch export and verify it actually works at
batch sizes 1/2/64 -- if not, we fall back to fixed-batch exports.
- Verifies output parity vs torch on the stored fixture
(results/parity_fixture.npz, batch 2, seed 42): max abs diff < 1e-4.
- Measures onnxruntime CPU latency at batch 1 and 64 (median of N runs).
- Supplementary: onnxruntime dynamic int8 quantization of the exported model
(weight size datapoint for the paper's "~2.2 MB int8" claim).
Usage:
.venv/Scripts/python.exe onnx_bench.py
Writes/merges into results/edge_optimization.json under key "onnx".
"""
import json
import os
import platform
import statistics
import time
import traceback
import numpy as np
import torch
from _bench_common import RESULTS, import_upstream, load_wiflow_model
import_upstream() # sys.path + models stub + >1GB np.load mmap patch
CHECKPOINT = os.path.join(RESULTS, "retrained_best_pose_model.pth")
OUT_JSON = os.path.join(RESULTS, "edge_optimization.json")
def load_fp32_model():
return load_wiflow_model(CHECKPOINT)
def try_export(model, path, batch, dynamic, opset=17):
"""Returns (ok, exporter_used, error)."""
x = torch.rand(batch, 540, 20)
attempts = []
if dynamic:
attempts.append(("dynamo", dict(dynamo=True,
dynamic_shapes={"x": {0: "batch"}})))
attempts.append(("torchscript", dict(dynamo=False,
dynamic_axes={"input": {0: "batch"},
"output": {0: "batch"}})))
else:
attempts.append(("torchscript", dict(dynamo=False)))
attempts.append(("dynamo", dict(dynamo=True)))
last_err = None
for name, kw in attempts:
try:
with torch.no_grad():
torch.onnx.export(model, (x,), path, opset_version=opset,
input_names=["input"], output_names=["output"],
**kw)
return True, name, None
except Exception as e: # noqa: BLE001
last_err = f"{name}: {type(e).__name__}: {e}"
traceback.print_exc()
return False, None, last_err
def ort_session(path):
import onnxruntime as ort
return ort.InferenceSession(path, providers=["CPUExecutionProvider"])
def ort_run(sess, x):
inp = sess.get_inputs()[0].name
return sess.run(None, {inp: x})[0]
def bench_ort(sess, batch, n_runs):
rng = np.random.default_rng(123)
x = rng.random((batch, 540, 20), dtype=np.float32)
for _ in range(max(5, n_runs // 10)):
ort_run(sess, x)
times = []
for _ in range(n_runs):
t0 = time.perf_counter()
ort_run(sess, x)
times.append(time.perf_counter() - t0)
med = statistics.median(times)
return {
"batch_size": batch,
"runs": n_runs,
"median_ms_per_batch": med * 1e3,
"median_ms_per_window": med * 1e3 / batch,
"windows_per_second": batch / med,
}
def main():
import argparse
parser = argparse.ArgumentParser(
description="ONNX export + onnxruntime CPU benchmark for the "
"retrained WiFlow-STD checkpoint (no options; see "
"module docstring). NB: the published "
"retrained_fp32_dynamic.onnx came from the TorchScript "
"exporter; on newer torch the dynamo attempt may succeed "
"first and produce a different (external-data) artifact.")
parser.parse_args()
import onnxruntime
model = load_fp32_model()
results = {
"env": {
"torch": torch.__version__,
"onnxruntime": onnxruntime.__version__,
"platform": platform.platform(),
},
}
fixture = np.load(os.path.join(RESULTS, "parity_fixture.npz"))
fx, fy = fixture["input"], fixture["output"] # (2,540,20) -> (2,15,2)
# ---- export: dynamic batch first, fall back to fixed --------------------
dyn_path = os.path.join(RESULTS, "retrained_fp32_dynamic.onnx")
ok, exporter, err = try_export(model, dyn_path, batch=2, dynamic=True)
dynamic_works = False
if ok:
# verify the dynamic graph really runs at other batch sizes
try:
sess = ort_session(dyn_path)
for b in (1, 2, 64):
y = ort_run(sess, np.zeros((b, 540, 20), dtype=np.float32))
assert y.shape == (b, 15, 2), y.shape
dynamic_works = True
except Exception as e: # noqa: BLE001
print(f"dynamic-batch model does not generalize: {e}")
sessions = {}
if dynamic_works:
results["export"] = {"mode": "dynamic-batch", "exporter": exporter,
"file": os.path.basename(dyn_path),
"size_mb": os.path.getsize(dyn_path) / 1e6}
sess = ort_session(dyn_path)
sessions = {1: sess, 2: sess, 64: sess}
print(f"dynamic-batch export OK via {exporter}")
else:
results["export"] = {"mode": "fixed-batch", "fallback_reason": err,
"files": {}}
for b in (1, 2, 64):
p = os.path.join(RESULTS, f"retrained_fp32_b{b}.onnx")
ok, exporter, err = try_export(model, p, batch=b, dynamic=False)
if not ok:
results["export"]["files"][str(b)] = {"error": err}
print(f"EXPORT FAILED at batch {b}: {err}")
continue
results["export"]["files"][str(b)] = {
"exporter": exporter, "file": os.path.basename(p),
"size_mb": os.path.getsize(p) / 1e6}
sessions[b] = ort_session(p)
print(f"fixed-batch {b} export OK via {exporter}")
# ---- parity vs torch on the fixture -------------------------------------
if 2 in sessions:
y_ort = ort_run(sessions[2], fx)
with torch.no_grad():
y_torch = model(torch.from_numpy(fx)).numpy()
results["parity"] = {
"fixture": "results/parity_fixture.npz (batch 2, seed 42)",
"max_abs_diff_vs_stored_fixture": float(np.abs(y_ort - fy).max()),
"max_abs_diff_vs_torch_now": float(np.abs(y_ort - y_torch).max()),
"pass_lt_1e-4": bool(np.abs(y_ort - y_torch).max() < 1e-4),
}
print("parity:", json.dumps(results["parity"], indent=2))
# ---- latency -------------------------------------------------------------
results["latency"] = {}
if 1 in sessions:
results["latency"]["batch1"] = bench_ort(sessions[1], 1, 100)
print(f"ORT batch 1: {results['latency']['batch1']['median_ms_per_window']:.2f} ms/window")
if 64 in sessions:
results["latency"]["batch64"] = bench_ort(sessions[64], 64, 30)
print(f"ORT batch 64: {results['latency']['batch64']['median_ms_per_window']:.3f} ms/window")
# ---- supplementary: ORT dynamic int8 (size datapoint for the 2.2MB claim)
src = (dyn_path if dynamic_works
else os.path.join(RESULTS, "retrained_fp32_b1.onnx"))
if os.path.exists(src):
try:
from onnxruntime.quantization import QuantType, quantize_dynamic
q_path = os.path.join(RESULTS, "retrained_int8_ort_dynamic.onnx")
quantize_dynamic(src, q_path, weight_type=QuantType.QInt8)
entry = {"file": os.path.basename(q_path),
"size_mb": os.path.getsize(q_path) / 1e6}
try:
qs = ort_session(q_path)
yq = ort_run(qs, fx[:1] if not dynamic_works else fx)
ref = fy[:1] if not dynamic_works else fy
entry["runs"] = True
entry["max_abs_diff_vs_fp32_fixture"] = float(np.abs(yq - ref).max())
except Exception as e: # noqa: BLE001
entry["runs"] = False
entry["run_error"] = f"{type(e).__name__}: {e}"
results["ort_int8_dynamic_supplementary"] = entry
print("ORT int8:", json.dumps(entry, indent=2))
except Exception as e: # noqa: BLE001
results["ort_int8_dynamic_supplementary"] = {
"error": f"{type(e).__name__}: {e}"}
merged = {}
if os.path.exists(OUT_JSON):
with open(OUT_JSON) as f:
merged = json.load(f)
merged["onnx"] = results
with open(OUT_JSON, "w") as f:
json.dump(merged, f, indent=2)
print(f"wrote {OUT_JSON}")
if __name__ == "__main__":
main()