221 lines
8.6 KiB
Python
221 lines
8.6 KiB
Python
"""ADR-152 edge optimization: ONNX export + onnxruntime CPU benchmark for the
|
|
retrained WiFlow-STD checkpoint.
|
|
|
|
- Exports fp32 to ONNX. The axial attention reshapes with python ints taken
|
|
from tensor.size() (view(N*W, C, H)), so a traced graph bakes the batch
|
|
size; we first try a dynamic-batch export and verify it actually works at
|
|
batch sizes 1/2/64 -- if not, we fall back to fixed-batch exports.
|
|
- Verifies output parity vs torch on the stored fixture
|
|
(results/parity_fixture.npz, batch 2, seed 42): max abs diff < 1e-4.
|
|
- Measures onnxruntime CPU latency at batch 1 and 64 (median of N runs).
|
|
- Supplementary: onnxruntime dynamic int8 quantization of the exported model
|
|
(weight size datapoint for the paper's "~2.2 MB int8" claim).
|
|
|
|
Usage:
|
|
.venv/Scripts/python.exe onnx_bench.py
|
|
|
|
Writes/merges into results/edge_optimization.json under key "onnx".
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import platform
|
|
import statistics
|
|
import time
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from _bench_common import RESULTS, import_upstream, load_wiflow_model
|
|
|
|
import_upstream() # sys.path + models stub + >1GB np.load mmap patch
|
|
|
|
CHECKPOINT = os.path.join(RESULTS, "retrained_best_pose_model.pth")
|
|
OUT_JSON = os.path.join(RESULTS, "edge_optimization.json")
|
|
|
|
|
|
def load_fp32_model():
|
|
return load_wiflow_model(CHECKPOINT)
|
|
|
|
|
|
def try_export(model, path, batch, dynamic, opset=17):
|
|
"""Returns (ok, exporter_used, error)."""
|
|
x = torch.rand(batch, 540, 20)
|
|
attempts = []
|
|
if dynamic:
|
|
attempts.append(("dynamo", dict(dynamo=True,
|
|
dynamic_shapes={"x": {0: "batch"}})))
|
|
attempts.append(("torchscript", dict(dynamo=False,
|
|
dynamic_axes={"input": {0: "batch"},
|
|
"output": {0: "batch"}})))
|
|
else:
|
|
attempts.append(("torchscript", dict(dynamo=False)))
|
|
attempts.append(("dynamo", dict(dynamo=True)))
|
|
last_err = None
|
|
for name, kw in attempts:
|
|
try:
|
|
with torch.no_grad():
|
|
torch.onnx.export(model, (x,), path, opset_version=opset,
|
|
input_names=["input"], output_names=["output"],
|
|
**kw)
|
|
return True, name, None
|
|
except Exception as e: # noqa: BLE001
|
|
last_err = f"{name}: {type(e).__name__}: {e}"
|
|
traceback.print_exc()
|
|
return False, None, last_err
|
|
|
|
|
|
def ort_session(path):
|
|
import onnxruntime as ort
|
|
return ort.InferenceSession(path, providers=["CPUExecutionProvider"])
|
|
|
|
|
|
def ort_run(sess, x):
|
|
inp = sess.get_inputs()[0].name
|
|
return sess.run(None, {inp: x})[0]
|
|
|
|
|
|
def bench_ort(sess, batch, n_runs):
|
|
rng = np.random.default_rng(123)
|
|
x = rng.random((batch, 540, 20), dtype=np.float32)
|
|
for _ in range(max(5, n_runs // 10)):
|
|
ort_run(sess, x)
|
|
times = []
|
|
for _ in range(n_runs):
|
|
t0 = time.perf_counter()
|
|
ort_run(sess, x)
|
|
times.append(time.perf_counter() - t0)
|
|
med = statistics.median(times)
|
|
return {
|
|
"batch_size": batch,
|
|
"runs": n_runs,
|
|
"median_ms_per_batch": med * 1e3,
|
|
"median_ms_per_window": med * 1e3 / batch,
|
|
"windows_per_second": batch / med,
|
|
}
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(
|
|
description="ONNX export + onnxruntime CPU benchmark for the "
|
|
"retrained WiFlow-STD checkpoint (no options; see "
|
|
"module docstring). NB: the published "
|
|
"retrained_fp32_dynamic.onnx came from the TorchScript "
|
|
"exporter; on newer torch the dynamo attempt may succeed "
|
|
"first and produce a different (external-data) artifact.")
|
|
parser.parse_args()
|
|
|
|
import onnxruntime
|
|
model = load_fp32_model()
|
|
results = {
|
|
"env": {
|
|
"torch": torch.__version__,
|
|
"onnxruntime": onnxruntime.__version__,
|
|
"platform": platform.platform(),
|
|
},
|
|
}
|
|
|
|
fixture = np.load(os.path.join(RESULTS, "parity_fixture.npz"))
|
|
fx, fy = fixture["input"], fixture["output"] # (2,540,20) -> (2,15,2)
|
|
|
|
# ---- export: dynamic batch first, fall back to fixed --------------------
|
|
dyn_path = os.path.join(RESULTS, "retrained_fp32_dynamic.onnx")
|
|
ok, exporter, err = try_export(model, dyn_path, batch=2, dynamic=True)
|
|
dynamic_works = False
|
|
if ok:
|
|
# verify the dynamic graph really runs at other batch sizes
|
|
try:
|
|
sess = ort_session(dyn_path)
|
|
for b in (1, 2, 64):
|
|
y = ort_run(sess, np.zeros((b, 540, 20), dtype=np.float32))
|
|
assert y.shape == (b, 15, 2), y.shape
|
|
dynamic_works = True
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"dynamic-batch model does not generalize: {e}")
|
|
|
|
sessions = {}
|
|
if dynamic_works:
|
|
results["export"] = {"mode": "dynamic-batch", "exporter": exporter,
|
|
"file": os.path.basename(dyn_path),
|
|
"size_mb": os.path.getsize(dyn_path) / 1e6}
|
|
sess = ort_session(dyn_path)
|
|
sessions = {1: sess, 2: sess, 64: sess}
|
|
print(f"dynamic-batch export OK via {exporter}")
|
|
else:
|
|
results["export"] = {"mode": "fixed-batch", "fallback_reason": err,
|
|
"files": {}}
|
|
for b in (1, 2, 64):
|
|
p = os.path.join(RESULTS, f"retrained_fp32_b{b}.onnx")
|
|
ok, exporter, err = try_export(model, p, batch=b, dynamic=False)
|
|
if not ok:
|
|
results["export"]["files"][str(b)] = {"error": err}
|
|
print(f"EXPORT FAILED at batch {b}: {err}")
|
|
continue
|
|
results["export"]["files"][str(b)] = {
|
|
"exporter": exporter, "file": os.path.basename(p),
|
|
"size_mb": os.path.getsize(p) / 1e6}
|
|
sessions[b] = ort_session(p)
|
|
print(f"fixed-batch {b} export OK via {exporter}")
|
|
|
|
# ---- parity vs torch on the fixture -------------------------------------
|
|
if 2 in sessions:
|
|
y_ort = ort_run(sessions[2], fx)
|
|
with torch.no_grad():
|
|
y_torch = model(torch.from_numpy(fx)).numpy()
|
|
results["parity"] = {
|
|
"fixture": "results/parity_fixture.npz (batch 2, seed 42)",
|
|
"max_abs_diff_vs_stored_fixture": float(np.abs(y_ort - fy).max()),
|
|
"max_abs_diff_vs_torch_now": float(np.abs(y_ort - y_torch).max()),
|
|
"pass_lt_1e-4": bool(np.abs(y_ort - y_torch).max() < 1e-4),
|
|
}
|
|
print("parity:", json.dumps(results["parity"], indent=2))
|
|
|
|
# ---- latency -------------------------------------------------------------
|
|
results["latency"] = {}
|
|
if 1 in sessions:
|
|
results["latency"]["batch1"] = bench_ort(sessions[1], 1, 100)
|
|
print(f"ORT batch 1: {results['latency']['batch1']['median_ms_per_window']:.2f} ms/window")
|
|
if 64 in sessions:
|
|
results["latency"]["batch64"] = bench_ort(sessions[64], 64, 30)
|
|
print(f"ORT batch 64: {results['latency']['batch64']['median_ms_per_window']:.3f} ms/window")
|
|
|
|
# ---- supplementary: ORT dynamic int8 (size datapoint for the 2.2MB claim)
|
|
src = (dyn_path if dynamic_works
|
|
else os.path.join(RESULTS, "retrained_fp32_b1.onnx"))
|
|
if os.path.exists(src):
|
|
try:
|
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
q_path = os.path.join(RESULTS, "retrained_int8_ort_dynamic.onnx")
|
|
quantize_dynamic(src, q_path, weight_type=QuantType.QInt8)
|
|
entry = {"file": os.path.basename(q_path),
|
|
"size_mb": os.path.getsize(q_path) / 1e6}
|
|
try:
|
|
qs = ort_session(q_path)
|
|
yq = ort_run(qs, fx[:1] if not dynamic_works else fx)
|
|
ref = fy[:1] if not dynamic_works else fy
|
|
entry["runs"] = True
|
|
entry["max_abs_diff_vs_fp32_fixture"] = float(np.abs(yq - ref).max())
|
|
except Exception as e: # noqa: BLE001
|
|
entry["runs"] = False
|
|
entry["run_error"] = f"{type(e).__name__}: {e}"
|
|
results["ort_int8_dynamic_supplementary"] = entry
|
|
print("ORT int8:", json.dumps(entry, indent=2))
|
|
except Exception as e: # noqa: BLE001
|
|
results["ort_int8_dynamic_supplementary"] = {
|
|
"error": f"{type(e).__name__}: {e}"}
|
|
|
|
merged = {}
|
|
if os.path.exists(OUT_JSON):
|
|
with open(OUT_JSON) as f:
|
|
merged = json.load(f)
|
|
merged["onnx"] = results
|
|
with open(OUT_JSON, "w") as f:
|
|
json.dump(merged, f, indent=2)
|
|
print(f"wrote {OUT_JSON}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|