"""ADR-152 efficiency-sweep follow-up: edge pipeline for the TINY compact WiFlow-STD variant (56,290 params, results/tiny_best.pth, trained overnight 2026-06-10/11 -- see RESULTS.md "Efficiency sweep"). Headline question: what does the smallest deployable WiFlow-class model look like (KB + ms + PCK)? Reuses the onnx_bench.py / static_ptq_bench.py machinery on the tiny checkpoint: 1. Load tiny_best.pth with remote/sweep/model_compact.py (depthwise TCN groups, input_pw_groups=4, conv [2,4,8,16], attn groups 2). 2. Export ONNX: dynamic batch, opset 17, TorchScript exporter (dynamo=False) -- same recipe that worked for the full model; verified at batch 1/2/64. One forced deviation: tiny's stride schedule [2,1,1,1] leaves final_width 16, and the TorchScript exporter cannot export AdaptiveAvgPool2d((15,1)) when 15 is not a factor of the input height (the full model never hit this -- its width was exactly 15). The adaptive pool over a fixed-size feature map is a fixed linear map, so the export wrapper replaces it with an exact matmul equivalent (PyTorch adaptive-pool bin semantics: bin i averages rows floor(i*H/K)..ceil((i+1)*H/K)); the W axis (20->1, a factor) becomes mean(-1). Exactness is proven by the parity check below, which compares against the ORIGINAL torch model with the real AdaptiveAvgPool2d. 3. Torch-vs-ORT parity on the stored fixture input (results/parity_fixture.npz, batch 2, seed 42 -- same 540x20 input layout; reference output recomputed with the tiny torch model). PASS < 1e-4. 4. Static QDQ conv-only int8 (quant_pre_process + quantize_static, per-channel QInt8 weights+activations, Percentile(99.99) calibration on 512 corruption-free TRAIN-split windows -- the winning recipe and calibration count from static_ptq_bench.py. 512, not "about 500": ORT 1.26's histogram collector np.asarray()'s the per-batch maxima, so the calibration count must be a multiple of the batch size 64 or the ragged last batch crashes it). 5. Disk size + CPU latency b1/b64 (3 interleaved reps, median ms/window) for tiny fp32 + tiny int8, with the full-model ONNX fp32 + static-int8 sessions interleaved as same-session references. 6. Accuracy (PCK@20/50 + MPJPE) on the identical 10k-window seed-42 corruption-free test subset for tiny fp32 + tiny int8. Usage: PYTHONUTF8=1 .venv/Scripts/python.exe tiny_edge_bench.py \ [--data-dir ] [--subset 10000] [--calib 512] (--calib must be a multiple of 64; see step 4 above) Writes/merges into results/edge_optimization.json under key "tiny_variant". """ import argparse import json import os import platform import sys import time import numpy as np import torch HERE = os.path.dirname(os.path.abspath(__file__)) RESULTS = os.path.join(HERE, "results") sys.path.insert(0, HERE) sys.path.insert(0, os.path.join(HERE, "remote", "sweep")) # quantize_bench sets up upstream imports + the np.load mmap patch from quantize_bench import build_test_subset # noqa: E402 from eval_ort_accuracy import evaluate_ort # noqa: E402 from static_ptq_bench import ( # noqa: E402 build_calibration_windows, interleaved_latency, make_reader, ort_session, ) from model_compact import CompactWiFlowPoseModel, describe # noqa: E402 TINY_CKPT = os.path.join(RESULTS, "tiny_best.pth") TINY_FP32_ONNX = os.path.join(RESULTS, "tiny_fp32_dynamic.onnx") TINY_PREPROC_ONNX = os.path.join(RESULTS, "tiny_fp32_preproc.onnx") TINY_INT8_ONNX = os.path.join(RESULTS, "tiny_int8_static_percentile_conv.onnx") FULL_FP32_ONNX = os.path.join(RESULTS, "retrained_fp32_dynamic.onnx") FULL_INT8_ONNX = os.path.join(RESULTS, "retrained_int8_static_percentile_conv.onnx") # Exact tiny config from remote/sweep/run_sweep.py VARIANTS (measured 56,290 # params, clean-test PCK@20 94.11% -- results/efficiency_sweep.jsonl). TINY = dict(tcn=[68, 56, 44, 32], conv=[2, 4, 8, 16], attn_groups=2, groups_mode="depthwise", input_pw_groups=4) def load_tiny_model(): model = CompactWiFlowPoseModel( tcn_channels=TINY["tcn"], conv_channels=TINY["conv"], attn_groups=TINY["attn_groups"], groups_mode=TINY["groups_mode"], input_pw_groups=TINY["input_pw_groups"], dropout=0.5) state = torch.load(TINY_CKPT, map_location="cpu", weights_only=True) model.load_state_dict(state, strict=True) model.eval() return model def adaptive_pool_matrix(h_in, h_out): """Exact AdaptiveAvgPool1d as a (h_out, h_in) averaging matrix, using PyTorch's bin rule: bin i covers rows floor(i*h_in/h_out) .. ceil((i+1)*h_in/h_out).""" w = torch.zeros(h_out, h_in) for i in range(h_out): s = (i * h_in) // h_out e = -((-(i + 1) * h_in) // h_out) # ceil division w[i, s:e] = 1.0 / (e - s) return w class ExportWrapper(torch.nn.Module): """CompactWiFlowPoseModel forward with the AdaptiveAvgPool2d((K,1)) replaced by an exact fixed linear map (mean over the factor W axis, then a constant averaging matmul over the non-factor H axis) so the TorchScript ONNX exporter accepts it. Bit-equivalent up to float round-off; proven by the parity check against the original model.""" def __init__(self, m, num_keypoints=15): super().__init__() self.m = m self.register_buffer( "pool_w_t", adaptive_pool_matrix(m.final_width, num_keypoints).t()) def forward(self, x): m = self.m x = m.tcn(x) x = x.transpose(1, 2).unsqueeze(1) x = m.up(x) for block in m.residual_blocks: x = block(x) x = x.permute(0, 1, 3, 2) x = m.attention(x) x = m.decoder(x) # [B, 2, H=final_width, T=20] x = x.mean(-1) # W-axis pool (20 -> 1, a factor) x = x.matmul(self.pool_w_t) # exact adaptive H pool: [B, 2, K] return x.transpose(1, 2) # [B, K, 2] def export_onnx(model): """Dynamic-batch TorchScript export (the recipe that worked for the full model in onnx_bench.py), verified at batch 1/2/64. Uses ExportWrapper (see docstring) because final_width 16 is not a multiple of 15.""" wrapper = ExportWrapper(model).eval() x = torch.rand(2, 540, 20) with torch.no_grad(): torch.onnx.export( wrapper, (x,), TINY_FP32_ONNX, opset_version=17, input_names=["input"], output_names=["output"], dynamo=False, dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}) sess = ort_session(TINY_FP32_ONNX) inp = sess.get_inputs()[0].name for b in (1, 2, 64): y = sess.run(None, {inp: np.zeros((b, 540, 20), dtype=np.float32)})[0] assert y.shape == (b, 15, 2), y.shape return { "mode": "dynamic-batch", "exporter": "torchscript", "opset": 17, "file": os.path.basename(TINY_FP32_ONNX), "size_bytes": os.path.getsize(TINY_FP32_ONNX), "size_mb": os.path.getsize(TINY_FP32_ONNX) / 1e6, "verified_batches": [1, 2, 64], "note": "AdaptiveAvgPool2d((15,1)) replaced at export by an exact " "mean(-1) + constant averaging matmul (final_width 16 is not " "a multiple of 15, which the TorchScript exporter rejects); " "exactness proven by the parity check vs the original torch " "model", } def quantize_tiny(calib_windows): """quant_pre_process + static QDQ conv-only Percentile(99.99) int8 -- the winning recipe from static_ptq_bench.py.""" from onnxruntime.quantization import (CalibrationMethod, QuantFormat, QuantType, quantize_static) from onnxruntime.quantization.shape_inference import quant_pre_process quant_pre_process(TINY_FP32_ONNX, TINY_PREPROC_ONNX) t0 = time.time() quantize_static( TINY_PREPROC_ONNX, TINY_INT8_ONNX, make_reader(calib_windows), quant_format=QuantFormat.QDQ, op_types_to_quantize=["Conv"], per_channel=True, activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, calibrate_method=CalibrationMethod.Percentile, extra_options={"CalibPercentile": 99.99}, ) return { "file": os.path.basename(TINY_INT8_ONNX), "size_bytes": os.path.getsize(TINY_INT8_ONNX), "size_mb": os.path.getsize(TINY_INT8_ONNX) / 1e6, "calibration": {"method": "percentile", "percentile": 99.99, "windows": int(len(calib_windows)), "scope": "conv-only TRAIN-split corruption-free", "seconds": time.time() - t0}, "per_channel": True, "activation_type": "QInt8", "weight_type": "QInt8", } def main(): import onnxruntime parser = argparse.ArgumentParser() parser.add_argument("--data-dir", default=os.path.join( os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434", "wiflow-dataset", "versions", "1", "preprocessed_csi_data")) parser.add_argument("--subset", type=int, default=10000) parser.add_argument("--calib", type=int, default=512, help="calibration windows; must be a multiple of the " "64-window calibration batch (ORT histogram " "collector rejects ragged batches)") parser.add_argument("--skip-accuracy", action="store_true") parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json")) args = parser.parse_args() if args.calib % 64 != 0: parser.error( f"--calib must be a multiple of 64 (got {args.calib}): ORT 1.26's " f"histogram calibration collector np.asarray()'s the per-batch " f"maxima and crashes on a ragged final batch (calibration batch " f"size is 64)") model = load_tiny_model() info = describe(model) print(f"tiny model: {info['params']:,} params, tcn_groups={info['tcn_groups_per_block']}, " f"strides={info['conv_strides']}, final_width={info['final_width']}") assert info["params"] == 56290, info["params"] results = { "env": { "torch": torch.__version__, "onnxruntime": onnxruntime.__version__, "platform": platform.platform(), "num_threads": torch.get_num_threads(), "checkpoint": os.path.relpath(TINY_CKPT, HERE), "checkpoint_size_bytes": os.path.getsize(TINY_CKPT), "params": info["params"], "variant_config": TINY, }, } # ---- export + parity ---------------------------------------------------- print("\n=== ONNX export (dynamic batch, opset 17, torchscript) ===") results["export"] = export_onnx(model) print(f" {results['export']['size_mb']:.3f} MB, batches {results['export']['verified_batches']} OK") fixture = np.load(os.path.join(RESULTS, "parity_fixture.npz")) fx = fixture["input"] # (2, 540, 20), seed 42 -- same input layout as full model sess_fp32 = ort_session(TINY_FP32_ONNX) y_ort = sess_fp32.run(None, {sess_fp32.get_inputs()[0].name: fx})[0] with torch.no_grad(): y_torch = model(torch.from_numpy(fx)).numpy() results["parity"] = { "fixture": "results/parity_fixture.npz input (batch 2, seed 42); " "reference output recomputed with the tiny torch model", "max_abs_diff_vs_torch": float(np.abs(y_ort - y_torch).max()), "pass_lt_1e-4": bool(np.abs(y_ort - y_torch).max() < 1e-4), } print("parity:", json.dumps(results["parity"], indent=2)) assert results["parity"]["pass_lt_1e-4"], "torch-vs-ORT parity FAILED" # ---- static PTQ int8 ------------------------------------------------------ print(f"\n=== static QDQ int8 (Percentile conv-only, {args.calib} calib windows) ===") calib = build_calibration_windows(args.data_dir, args.calib) results["int8_static_percentile_conv"] = quantize_tiny(calib) print(f" {results['int8_static_percentile_conv']['size_mb']:.3f} MB") sess_int8 = ort_session(TINY_INT8_ONNX) yq = sess_int8.run(None, {sess_int8.get_inputs()[0].name: fx})[0] results["int8_static_percentile_conv"]["max_abs_diff_vs_fp32_fixture"] = float( np.abs(yq - y_torch).max()) # ---- latency (3 interleaved reps, full-model sessions as references) ----- print("\n=== latency (3 interleaved reps) ===") lat_sessions = { "tiny_onnx_fp32": sess_fp32, "tiny_onnx_int8_static_percentile_conv": sess_int8, "full_onnx_fp32_reference": ort_session(FULL_FP32_ONNX), "full_onnx_int8_static_percentile_conv_reference": ort_session(FULL_INT8_ONNX), } results["latency"] = { "note": "3 interleaved repetitions per variant, median ms/window; " "full-model sessions are same-session references", **interleaved_latency(lat_sessions), } # ---- accuracy on the standard 10k corruption-free test subset ------------ if not args.skip_accuracy: loader, n_clean = build_test_subset(args.data_dir, args.subset) results["accuracy_subset"] = { "description": "seed-42 file-level 70/15/15 test split, corrupted " "windows excluded, seed-42 random subset (same as " "quantize_bench/eval_ort_accuracy/static_ptq_bench)", "subset_size": min(args.subset, n_clean) if args.subset else n_clean, } results["accuracy"] = {} for name, sess in (("tiny_onnx_fp32", sess_fp32), ("tiny_onnx_int8_static_percentile_conv", sess_int8)): print(f"\n=== accuracy: {name} ===") results["accuracy"][name] = evaluate_ort(sess, loader, name) print(json.dumps(results["accuracy"][name], indent=2)) # ---- merge into edge_optimization.json ----------------------------------- merged = {} if os.path.exists(args.out): with open(args.out) as f: merged = json.load(f) merged["tiny_variant"] = results with open(args.out, "w") as f: json.dump(merged, f, indent=2) print(f"\nwrote {args.out}") if __name__ == "__main__": main()