334 lines
14 KiB
Python
334 lines
14 KiB
Python
"""ADR-152 edge optimization follow-up: ONNX Runtime STATIC post-training
|
|
quantization (calibration-based QDQ) of the retrained WiFlow-STD model, to
|
|
improve on the dynamic-int8 result (2.44 MB, PCK@20 96.52%, 6.5 ms/win b1).
|
|
|
|
Static PTQ pre-computes activation ranges from calibration data, so inference
|
|
uses QLinearConv/QDQ kernels instead of dynamic ConvInteger -- typically both
|
|
faster and (with good calibration) closer to fp32 accuracy.
|
|
|
|
Method:
|
|
- Calibration set: corruption-free windows drawn ONLY from the seed-42
|
|
file-level TRAINING split (same split as eval_repro.py; corrupted windows
|
|
excluded via results/nan_windows_mask.npy | big_windows_mask.npy), chosen
|
|
with np.random.default_rng(42). Never test windows.
|
|
- quantize_static, QuantFormat.QDQ, per-channel int8 weights, int8
|
|
activations; calibration methods MinMax / Entropy / Percentile(99.99);
|
|
scopes "all" (ORT default op set) vs "conv" (op_types_to_quantize=
|
|
["Conv"] -- leaves the attention path, which exports as Einsum/Softmax
|
|
and elementwise ops, in fp32).
|
|
- Model is pre-processed first (quant_pre_process: symbolic shape
|
|
inference + ORT graph optimization, folds BatchNormalization into Conv).
|
|
- Accuracy: identical protocol to eval_ort_accuracy.py -- the 10,000-window
|
|
seed-42 subset of the corruption-free test split (PCK@20/50, MPJPE).
|
|
- Latency: median ms/window at batch 1 (100 runs) and batch 64 (30 runs),
|
|
3 interleaved repetitions across all variants (fp32 and dynamic-int8
|
|
sessions included as same-session reference points).
|
|
|
|
Usage:
|
|
PYTHONUTF8=1 .venv/Scripts/python.exe static_ptq_bench.py \
|
|
[--data-dir <preprocessed_csi_data>] [--subset 10000]
|
|
[--calib-minmax 1000] [--calib-hist 512] [--skip-accuracy]
|
|
|
|
Writes/merges into results/edge_optimization.json under key "onnx_static_ptq".
|
|
"""
|
|
|
|
import argparse
|
|
import collections
|
|
import json
|
|
import os
|
|
import platform
|
|
import statistics
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
HERE = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.insert(0, HERE)
|
|
|
|
from _bench_common import RESULTS # noqa: E402
|
|
# quantize_bench sets up upstream imports + the np.load mmap patch
|
|
# (both via _bench_common.import_upstream)
|
|
from quantize_bench import build_test_subset # noqa: E402
|
|
import quantize_bench as qb # noqa: E402
|
|
from eval_ort_accuracy import evaluate_ort # noqa: E402
|
|
|
|
FP32_ONNX = os.path.join(RESULTS, "retrained_fp32_dynamic.onnx")
|
|
DYN_INT8_ONNX = os.path.join(RESULTS, "retrained_int8_ort_dynamic.onnx")
|
|
PREPROC_ONNX = os.path.join(RESULTS, "retrained_fp32_preproc.onnx")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# calibration data: corruption-free TRAINING-split windows only
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def build_calibration_windows(data_dir, n_windows):
|
|
"""Seed-42 file-level 70/15/15 TRAIN split (exactly as eval_repro.py),
|
|
minus corrupted windows, then a seed-42 random draw of n_windows."""
|
|
dataset = qb.PreprocessedCSIKeypointsDataset(
|
|
data_dir=data_dir, keypoint_scale=1000.0, enable_temporal_clean=True)
|
|
train_loader, _va, _te = qb.create_preprocessed_train_val_test_loaders(
|
|
dataset=dataset, batch_size=64, num_workers=0, random_seed=42)
|
|
train_indices = np.asarray(train_loader.dataset.indices)
|
|
|
|
corrupted = (np.load(os.path.join(RESULTS, "nan_windows_mask.npy"))
|
|
| np.load(os.path.join(RESULTS, "big_windows_mask.npy")))
|
|
clean = train_indices[~corrupted[train_indices]]
|
|
print(f"train split: {len(train_indices)} windows, "
|
|
f"{len(train_indices) - len(clean)} corrupted excluded, "
|
|
f"{len(clean)} clean")
|
|
|
|
rng = np.random.default_rng(42)
|
|
sel = np.sort(rng.choice(clean, size=n_windows, replace=False))
|
|
xs = np.stack([dataset[int(i)][0].numpy() for i in sel]).astype(np.float32)
|
|
print(f"calibration tensor: {xs.shape} from {n_windows} clean TRAIN windows")
|
|
return xs
|
|
|
|
|
|
def make_reader(windows, batch_size=64):
|
|
from onnxruntime.quantization import CalibrationDataReader
|
|
|
|
class WindowReader(CalibrationDataReader):
|
|
def __init__(self):
|
|
self._batches = [windows[i:i + batch_size]
|
|
for i in range(0, len(windows), batch_size)]
|
|
self._it = iter(self._batches)
|
|
|
|
def get_next(self):
|
|
b = next(self._it, None)
|
|
return None if b is None else {"input": b}
|
|
|
|
def rewind(self):
|
|
self._it = iter(self._batches)
|
|
|
|
def __len__(self):
|
|
return len(self._batches)
|
|
|
|
return WindowReader()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# quantization variants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def preprocess_model():
|
|
from onnxruntime.quantization.shape_inference import quant_pre_process
|
|
quant_pre_process(FP32_ONNX, PREPROC_ONNX)
|
|
return PREPROC_ONNX
|
|
|
|
|
|
def quantize_variant(src, dst, method, scope, calib_windows):
|
|
from onnxruntime.quantization import (CalibrationMethod, QuantFormat,
|
|
QuantType, quantize_static)
|
|
methods = {
|
|
"minmax": CalibrationMethod.MinMax,
|
|
"entropy": CalibrationMethod.Entropy,
|
|
"percentile": CalibrationMethod.Percentile,
|
|
}
|
|
# NB: do NOT pass CalibMaxIntermediateOutputs -- in ORT 1.26 the MinMax
|
|
# calibrater clears its buffer every N batches and then raises
|
|
# "No data is collected" if the batch count is divisible by N.
|
|
extra = {}
|
|
if method == "percentile":
|
|
extra["CalibPercentile"] = 99.99
|
|
op_types = ["Conv"] if scope == "conv" else None
|
|
|
|
t0 = time.time()
|
|
quantize_static(
|
|
src, dst, make_reader(calib_windows),
|
|
quant_format=QuantFormat.QDQ,
|
|
op_types_to_quantize=op_types,
|
|
per_channel=True,
|
|
activation_type=QuantType.QInt8,
|
|
weight_type=QuantType.QInt8,
|
|
calibrate_method=methods[method],
|
|
extra_options=extra,
|
|
)
|
|
secs = time.time() - t0
|
|
|
|
import onnx
|
|
ops = collections.Counter(n.op_type for n in onnx.load(dst).graph.node)
|
|
return {
|
|
"file": os.path.basename(dst),
|
|
"size_bytes": os.path.getsize(dst),
|
|
"size_mb": os.path.getsize(dst) / 1e6,
|
|
"calibration": {"method": method,
|
|
"windows": int(len(calib_windows)),
|
|
"percentile": extra.get("CalibPercentile"),
|
|
"seconds": secs},
|
|
"scope": scope,
|
|
"per_channel": True,
|
|
"activation_type": "QInt8",
|
|
"weight_type": "QInt8",
|
|
"node_counts": {k: v for k, v in sorted(ops.items())},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# latency (3 interleaved reps, like the latency_controlled_rerun)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def ort_session(path):
|
|
import onnxruntime as ort
|
|
return ort.InferenceSession(path, providers=["CPUExecutionProvider"])
|
|
|
|
|
|
def bench_ort(sess, batch, n_runs):
|
|
rng = np.random.default_rng(123)
|
|
x = rng.random((batch, 540, 20), dtype=np.float32)
|
|
inp = sess.get_inputs()[0].name
|
|
for _ in range(max(5, n_runs // 10)):
|
|
sess.run(None, {inp: x})
|
|
times = []
|
|
for _ in range(n_runs):
|
|
t0 = time.perf_counter()
|
|
sess.run(None, {inp: x})
|
|
times.append(time.perf_counter() - t0)
|
|
return statistics.median(times) * 1e3 / batch # ms/window
|
|
|
|
|
|
def interleaved_latency(sessions, reps=3, runs_b1=100, runs_b64=30):
|
|
lat = {name: {"batch1_reps": [], "batch64_reps": []} for name in sessions}
|
|
for rep in range(reps):
|
|
for name, sess in sessions.items():
|
|
lat[name]["batch1_reps"].append(bench_ort(sess, 1, runs_b1))
|
|
lat[name]["batch64_reps"].append(bench_ort(sess, 64, runs_b64))
|
|
print(f" rep {rep + 1}/{reps} {name}: "
|
|
f"b1={lat[name]['batch1_reps'][-1]:.2f} "
|
|
f"b64={lat[name]['batch64_reps'][-1]:.3f} ms/win", flush=True)
|
|
for name in lat:
|
|
lat[name]["batch1_ms_per_window_median"] = statistics.median(
|
|
lat[name]["batch1_reps"])
|
|
lat[name]["batch64_ms_per_window_median"] = statistics.median(
|
|
lat[name]["batch64_reps"])
|
|
return lat
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
import onnxruntime
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--data-dir", default=os.path.join(
|
|
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
|
|
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"))
|
|
parser.add_argument("--subset", type=int, default=10000)
|
|
parser.add_argument("--calib-minmax", type=int, default=1000)
|
|
parser.add_argument("--calib-hist", type=int, default=512,
|
|
help="calibration windows for Entropy/Percentile "
|
|
"(histogram calibraters hold all intermediate "
|
|
"activations in RAM)")
|
|
parser.add_argument("--skip-accuracy", action="store_true")
|
|
parser.add_argument("--methods", default="minmax,entropy,percentile",
|
|
help="comma list of calibration methods to (re)run; "
|
|
"results merge into existing onnx_static_ptq")
|
|
parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json"))
|
|
args = parser.parse_args()
|
|
|
|
results = {
|
|
"env": {
|
|
"onnxruntime": onnxruntime.__version__,
|
|
"torch": torch.__version__,
|
|
"platform": platform.platform(),
|
|
"source_model": os.path.basename(FP32_ONNX),
|
|
},
|
|
"variants": {},
|
|
}
|
|
|
|
# ---- calibration data (TRAIN split only) -------------------------------
|
|
calib_mm = build_calibration_windows(args.data_dir, args.calib_minmax)
|
|
calib_hist = calib_mm[:args.calib_hist]
|
|
|
|
# ---- preprocess + quantize ---------------------------------------------
|
|
print("\n=== quant_pre_process (shape inference + graph optimization) ===")
|
|
src = preprocess_model()
|
|
results["env"]["preprocessed_model"] = {
|
|
"file": os.path.basename(src),
|
|
"size_mb": os.path.getsize(src) / 1e6,
|
|
}
|
|
|
|
matrix = [(m, s) for m in args.methods.split(",")
|
|
for s in ("all", "conv")]
|
|
for method, scope in matrix:
|
|
name = f"{method}_{scope}"
|
|
dst = os.path.join(RESULTS, f"retrained_int8_static_{name}.onnx")
|
|
calib = calib_mm if method == "minmax" else calib_hist
|
|
print(f"\n=== quantize_static: {name} "
|
|
f"({len(calib)} calib windows) ===", flush=True)
|
|
try:
|
|
results["variants"][name] = quantize_variant(
|
|
src, dst, method, scope, calib)
|
|
print(f" {results['variants'][name]['size_mb']:.3f} MB")
|
|
except Exception as e: # noqa: BLE001
|
|
results["variants"][name] = {"error": f"{type(e).__name__}: {e}"}
|
|
print(f" FAILED: {e}")
|
|
|
|
# ---- fixture parity (sanity, batch 2) ----------------------------------
|
|
fixture = np.load(os.path.join(RESULTS, "parity_fixture.npz"))
|
|
fx, fy = fixture["input"], fixture["output"]
|
|
sessions = {}
|
|
for name, info in results["variants"].items():
|
|
if "error" in info:
|
|
continue
|
|
path = os.path.join(RESULTS, info["file"])
|
|
try:
|
|
sess = ort_session(path)
|
|
yq = sess.run(None, {sess.get_inputs()[0].name: fx})[0]
|
|
info["max_abs_diff_vs_fp32_fixture"] = float(np.abs(yq - fy).max())
|
|
sessions[name] = sess
|
|
except Exception as e: # noqa: BLE001
|
|
info["run_error"] = f"{type(e).__name__}: {e}"
|
|
print("\nfixture max-abs-diff vs fp32:",
|
|
{n: round(results["variants"][n].get("max_abs_diff_vs_fp32_fixture",
|
|
float("nan")), 5)
|
|
for n in results["variants"]})
|
|
|
|
# ---- latency: 3 interleaved reps incl. fp32 + dynamic-int8 reference ----
|
|
print("\n=== latency (3 interleaved reps) ===")
|
|
lat_sessions = {"onnx_fp32": ort_session(FP32_ONNX),
|
|
"onnx_int8_ort_dynamic": ort_session(DYN_INT8_ONNX)}
|
|
lat_sessions.update(sessions)
|
|
results["latency"] = {
|
|
"note": "3 interleaved repetitions per variant, median ms/window; "
|
|
"onnx_fp32 / onnx_int8_ort_dynamic are same-session references",
|
|
**interleaved_latency(lat_sessions),
|
|
}
|
|
|
|
# ---- accuracy on the standard 10k corruption-free test subset ----------
|
|
if not args.skip_accuracy:
|
|
loader, n_clean = build_test_subset(args.data_dir, args.subset)
|
|
results["accuracy_subset"] = {
|
|
"description": "seed-42 file-level 70/15/15 test split, corrupted "
|
|
"windows excluded, seed-42 random subset (same as "
|
|
"quantize_bench/eval_ort_accuracy)",
|
|
"subset_size": min(args.subset, n_clean) if args.subset else n_clean,
|
|
}
|
|
for name, sess in sessions.items():
|
|
print(f"\n=== accuracy: {name} ===")
|
|
results["variants"][name]["accuracy"] = evaluate_ort(
|
|
sess, loader, name)
|
|
print(json.dumps(results["variants"][name]["accuracy"], indent=2))
|
|
|
|
# ---- merge into edge_optimization.json ----------------------------------
|
|
merged = {}
|
|
if os.path.exists(args.out):
|
|
with open(args.out) as f:
|
|
merged = json.load(f)
|
|
prev = merged.get("onnx_static_ptq")
|
|
if prev: # nested merge so partial --methods reruns don't clobber
|
|
prev["env"] = results["env"]
|
|
prev["variants"].update(results["variants"])
|
|
prev.setdefault("latency", {}).update(results["latency"])
|
|
if "accuracy_subset" in results:
|
|
prev["accuracy_subset"] = results["accuracy_subset"]
|
|
else:
|
|
merged["onnx_static_ptq"] = results
|
|
with open(args.out, "w") as f:
|
|
json.dump(merged, f, indent=2)
|
|
print(f"\nwrote {args.out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|