wifi-densepose/benchmarks/wiflow-std/eval_ort_accuracy.py

68 lines
2.4 KiB
Python

"""ADR-152 edge optimization: accuracy of the ONNX fp32 and ORT-dynamic-int8
models on the same corruption-free 10k test subset used by quantize_bench.py.
The torch dynamic-int8 path quantizes nothing (no nn.Linear in the model), so
the only real int8 datapoint for the paper's "~2.2 MB int8" claim is the
onnxruntime dynamically quantized model -- this script measures what that
quantization costs in PCK/MPJPE.
Usage:
.venv/Scripts/python.exe eval_ort_accuracy.py \
--data-dir <preprocessed_csi_data> [--subset 10000]
Writes/merges into results/edge_optimization.json under key "onnx_accuracy".
"""
import argparse
import json
import os
import sys
HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)
from _bench_common import RESULTS, evaluate # noqa: E402
from quantize_bench import build_test_subset # noqa: E402 (sets up upstream imports)
def evaluate_ort(sess, loader, label):
"""ORT-session evaluation via the canonical _bench_common.evaluate loop."""
return evaluate(sess, loader, label=label)
def main():
import onnxruntime as ort
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", default=os.path.join(
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"))
parser.add_argument("--subset", type=int, default=10000)
parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json"))
args = parser.parse_args()
loader, _n_clean = build_test_subset(args.data_dir, args.subset)
results = {}
for label, fname in (("onnx_fp32", "retrained_fp32_dynamic.onnx"),
("onnx_int8_ort_dynamic", "retrained_int8_ort_dynamic.onnx")):
path = os.path.join(RESULTS, fname)
if not os.path.exists(path):
results[label] = {"error": f"{fname} not found; run onnx_bench.py first"}
continue
sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
print(f"=== accuracy: {label} ({fname}) ===")
results[label] = evaluate_ort(sess, loader, label)
print(json.dumps(results[label], indent=2))
merged = {}
if os.path.exists(args.out):
with open(args.out) as f:
merged = json.load(f)
merged["onnx_accuracy"] = results
with open(args.out, "w") as f:
json.dump(merged, f, indent=2)
print(f"wrote {args.out}")
if __name__ == "__main__":
main()