229 lines
9.1 KiB
Python
229 lines
9.1 KiB
Python
"""ADR-152 "optimize beyond SOTA": edge-optimization benchmark for the
|
|
retrained WiFlow-STD checkpoint (results/retrained_best_pose_model.pth,
|
|
~96% PCK@20, fp32 params 2,225,042).
|
|
|
|
Measures, for fp32 / fp16 / dynamic-int8 torch variants:
|
|
(a) serialized state_dict size on disk,
|
|
(b) CPU inference latency per window at batch 1 and batch 64
|
|
(median of repeated runs, this Windows box),
|
|
(c) accuracy (PCK@20/50 + MPJPE, upstream metrics) on a corruption-free
|
|
random subset of the seed-42 file-level 70/15/15 test split
|
|
(same split as eval_repro.py; corrupted windows 487-499 excluded via
|
|
results/nan_windows_mask.npy | results/big_windows_mask.npy).
|
|
|
|
Also verifies the paper's "~2.2 MB int8" size claim: reports which layer
|
|
types torch dynamic quantization actually converts (the model contains NO
|
|
nn.Linear -- it is Conv1d/Conv2d/BatchNorm only) and the real on-disk size.
|
|
|
|
Usage:
|
|
.venv/Scripts/python.exe quantize_bench.py \
|
|
--data-dir C:/Users/ruv/.cache/kagglehub/datasets/kaka2434/wiflow-dataset/versions/1/preprocessed_csi_data \
|
|
[--subset 10000] [--skip-accuracy]
|
|
|
|
Writes/merges into results/edge_optimization.json under key "torch".
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import platform
|
|
import statistics
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
from _bench_common import HERE, RESULTS, evaluate, import_upstream, load_wiflow_model
|
|
|
|
import_upstream() # sys.path + models stub + >1GB np.load mmap patch
|
|
|
|
from dataset import ( # noqa: E402
|
|
PreprocessedCSIKeypointsDataset,
|
|
create_preprocessed_train_val_test_loaders,
|
|
)
|
|
|
|
CHECKPOINT = os.path.join(RESULTS, "retrained_best_pose_model.pth")
|
|
|
|
|
|
def load_fp32_model():
|
|
# legacy upstream key remap inside is a harmless no-op on this checkpoint
|
|
return load_wiflow_model(CHECKPOINT)
|
|
|
|
|
|
def state_dict_size_bytes(model, path):
|
|
torch.save(model.state_dict(), path)
|
|
return os.path.getsize(path)
|
|
|
|
|
|
def bench_latency(model, batch_size, n_runs, dtype=torch.float32):
|
|
gen = torch.Generator().manual_seed(123)
|
|
x = torch.rand(batch_size, 540, 20, generator=gen).to(dtype)
|
|
with torch.no_grad():
|
|
for _ in range(max(5, n_runs // 10)): # warmup
|
|
model(x)
|
|
times = []
|
|
for _ in range(n_runs):
|
|
t0 = time.perf_counter()
|
|
model(x)
|
|
times.append(time.perf_counter() - t0)
|
|
med = statistics.median(times)
|
|
return {
|
|
"batch_size": batch_size,
|
|
"runs": n_runs,
|
|
"median_ms_per_batch": med * 1e3,
|
|
"median_ms_per_window": med * 1e3 / batch_size,
|
|
"windows_per_second": batch_size / med,
|
|
}
|
|
|
|
|
|
def build_test_subset(data_dir, subset_size, batch_size=64):
|
|
"""Seed-42 file-level 70/15/15 test split (exactly as eval_repro.py),
|
|
minus corrupted windows, then a seed-42 random subset."""
|
|
dataset = PreprocessedCSIKeypointsDataset(
|
|
data_dir=data_dir, keypoint_scale=1000.0, enable_temporal_clean=True)
|
|
_tr, _va, test_loader = create_preprocessed_train_val_test_loaders(
|
|
dataset=dataset, batch_size=batch_size, num_workers=0, random_seed=42)
|
|
test_indices = np.asarray(test_loader.dataset.indices)
|
|
|
|
corrupted = (np.load(os.path.join(RESULTS, "nan_windows_mask.npy"))
|
|
| np.load(os.path.join(RESULTS, "big_windows_mask.npy")))
|
|
clean = test_indices[~corrupted[test_indices]]
|
|
print(f"test split: {len(test_indices)} windows, "
|
|
f"{len(test_indices) - len(clean)} corrupted excluded, "
|
|
f"{len(clean)} clean")
|
|
|
|
if subset_size and subset_size < len(clean):
|
|
rng = np.random.default_rng(42)
|
|
clean = np.sort(rng.choice(clean, size=subset_size, replace=False))
|
|
subset = torch.utils.data.Subset(dataset, clean.tolist())
|
|
loader = DataLoader(subset, batch_size=batch_size, shuffle=False,
|
|
num_workers=0)
|
|
return loader, len(clean)
|
|
|
|
|
|
def quantize_int8_dynamic(fp32_model):
|
|
"""torch.ao.quantization.quantize_dynamic on Linear/Conv where supported.
|
|
Returns (model, report) where report documents what actually quantized."""
|
|
qmodel = torch.ao.quantization.quantize_dynamic(
|
|
fp32_model, {nn.Linear, nn.Conv1d, nn.Conv2d}, dtype=torch.qint8)
|
|
|
|
quantized, total_params, quant_params = [], 0, 0
|
|
for name, mod in qmodel.named_modules():
|
|
cls = type(mod).__module__ + "." + type(mod).__name__
|
|
if "quantized" in cls:
|
|
w = mod.weight() if callable(getattr(mod, "weight", None)) else None
|
|
numel = w.numel() if w is not None else 0
|
|
quant_params += numel
|
|
quantized.append({"module": name, "class": cls, "params": numel})
|
|
for p in fp32_model.parameters():
|
|
total_params += p.numel()
|
|
|
|
n_linear = sum(isinstance(m, nn.Linear) for m in fp32_model.modules())
|
|
n_conv1d = sum(isinstance(m, nn.Conv1d) for m in fp32_model.modules())
|
|
n_conv2d = sum(isinstance(m, nn.Conv2d) for m in fp32_model.modules())
|
|
report = {
|
|
"eligible_module_counts": {
|
|
"nn.Linear": n_linear, "nn.Conv1d": n_conv1d, "nn.Conv2d": n_conv2d},
|
|
"modules_actually_quantized": quantized,
|
|
"n_modules_quantized": len(quantized),
|
|
"params_total": total_params,
|
|
"params_quantized": quant_params,
|
|
"params_quantized_fraction": quant_params / total_params,
|
|
}
|
|
return qmodel, report
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--data-dir", default=os.path.join(
|
|
os.path.expanduser("~"), ".cache", "kagglehub", "datasets", "kaka2434",
|
|
"wiflow-dataset", "versions", "1", "preprocessed_csi_data"))
|
|
parser.add_argument("--subset", type=int, default=10000)
|
|
parser.add_argument("--runs-b1", type=int, default=100)
|
|
parser.add_argument("--runs-b64", type=int, default=30)
|
|
parser.add_argument("--skip-accuracy", action="store_true")
|
|
parser.add_argument("--out", default=os.path.join(RESULTS, "edge_optimization.json"))
|
|
args = parser.parse_args()
|
|
|
|
torch.manual_seed(42)
|
|
results = {
|
|
"env": {
|
|
"torch": torch.__version__,
|
|
"platform": platform.platform(),
|
|
"processor": platform.processor(),
|
|
"num_threads": torch.get_num_threads(),
|
|
"checkpoint": os.path.relpath(CHECKPOINT, HERE),
|
|
},
|
|
"variants": {},
|
|
}
|
|
|
|
# ---- build variants ---------------------------------------------------
|
|
fp32 = load_fp32_model()
|
|
n_params = sum(p.numel() for p in fp32.parameters())
|
|
results["env"]["params"] = n_params
|
|
print(f"fp32 model: {n_params:,} params")
|
|
|
|
fp16 = load_fp32_model().half()
|
|
|
|
int8, q_report = quantize_int8_dynamic(load_fp32_model())
|
|
results["int8_dynamic_quant_report"] = q_report
|
|
print(f"int8 dynamic: {q_report['n_modules_quantized']} modules quantized, "
|
|
f"{q_report['params_quantized_fraction']*100:.1f}% of params")
|
|
|
|
variants = {
|
|
"fp32": (fp32, torch.float32, "retrained_fp32_resaved.pth"),
|
|
"fp16": (fp16, torch.float16, "retrained_fp16.pth"),
|
|
"int8_dynamic": (int8, torch.float32, "retrained_int8_dynamic.pth"),
|
|
}
|
|
|
|
# ---- (a) size + (b) latency -------------------------------------------
|
|
for name, (model, dtype, fname) in variants.items():
|
|
path = os.path.join(RESULTS, fname)
|
|
size = state_dict_size_bytes(model, path)
|
|
print(f"\n=== {name}: {size/1e6:.3f} MB on disk ({fname}) ===")
|
|
lat1 = bench_latency(model, 1, args.runs_b1, dtype)
|
|
lat64 = bench_latency(model, 64, args.runs_b64, dtype)
|
|
print(f" batch 1: {lat1['median_ms_per_window']:.2f} ms/window "
|
|
f"({lat1['windows_per_second']:.0f}/s)")
|
|
print(f" batch 64: {lat64['median_ms_per_window']:.3f} ms/window "
|
|
f"({lat64['windows_per_second']:.0f}/s)")
|
|
results["variants"][name] = {
|
|
"file": fname,
|
|
"size_bytes": size,
|
|
"size_mb": size / 1e6,
|
|
"latency_batch1": lat1,
|
|
"latency_batch64": lat64,
|
|
}
|
|
|
|
# ---- (c) accuracy ------------------------------------------------------
|
|
if not args.skip_accuracy:
|
|
loader, n_clean = build_test_subset(args.data_dir, args.subset)
|
|
results["accuracy_subset"] = {
|
|
"description": "seed-42 file-level 70/15/15 test split, corrupted "
|
|
"windows (files 487-499) excluded, seed-42 random "
|
|
"subset",
|
|
"subset_size": min(args.subset, n_clean) if args.subset else n_clean,
|
|
"clean_test_total": n_clean,
|
|
}
|
|
for name, (model, dtype, _f) in variants.items():
|
|
print(f"\n=== accuracy: {name} ===")
|
|
results["variants"][name]["accuracy"] = evaluate(
|
|
model, loader, dtype=dtype, label=name)
|
|
print(json.dumps(results["variants"][name]["accuracy"], indent=2))
|
|
|
|
# ---- merge into edge_optimization.json ---------------------------------
|
|
merged = {}
|
|
if os.path.exists(args.out):
|
|
with open(args.out) as f:
|
|
merged = json.load(f)
|
|
merged["torch"] = results
|
|
with open(args.out, "w") as f:
|
|
json.dump(merged, f, indent=2)
|
|
print(f"\nwrote {args.out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|