295 lines
13 KiB
Python
295 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""ADR-175: int8 quantization of the WiFlow-STD "half" pose model + MEASURED accuracy/size trade-off.
|
|
|
|
Sub-deliverable 8.2 of the benchmark/optimization milestone. Quantizes the 843,834-param
|
|
"half" WiFlow-STD pose model to int8 (QAT primary, static-PTQ fallback) and MEASURES the
|
|
accuracy delta against the fp32 baseline under ONE locked PCK normalization.
|
|
|
|
LOCKED NORMALIZATION (ADR-173): torso-diameter PCK — neck(idx 2)->pelvis(idx 12) distance,
|
|
exactly the default `use_torso_norm=True` path of upstream `utils/metrics.calculate_pck`,
|
|
which is the standard MM-Fi/GraphPose-Fi convention. The SAME `calculate_pck` /
|
|
`calculate_mpjpe` from the upstream harness scores BOTH fp32 and int8 so the comparison is
|
|
metric-locked. The test split is the seed-42 file-level 70/15/15 test partition (54,000
|
|
windows full / 52,560 NaN-free) produced by the SAME loader that produced half_best.pth.
|
|
|
|
int8 backend: FX graph-mode quantization, fbgemm engine (server x86 int8). Quantized int8
|
|
kernels execute on CPU, so int8 eval is CPU; an fp32-CPU baseline is also measured so the
|
|
accuracy delta is device-matched (CPU fp32 vs CPU int8), and an fp32-GPU number is reported
|
|
for continuity with the sweep's recorded numbers.
|
|
|
|
REPRODUCE (exact command run for ADR-175, run date 2026-06-15, on host ruvultra / RTX 5080):
|
|
ssh ruvultra 'cd ~/wiflow-std-bench && source venv/bin/activate && \
|
|
python ~/quantize_half_int8.py --mode both --qat-epochs 3 2>&1'
|
|
|
|
(the script lives in-repo at v2/crates/wifi-densepose-train/scripts/quantize_half_int8.py;
|
|
it was scp'd to ~/quantize_half_int8.py on ruvultra and invoked as above. It is read-only
|
|
to everything under ~/wiflow-std-bench except that it WRITES its int8 artifacts + a JSON
|
|
results file into ~/wiflow-std-bench/sweep/int8/ — it never modifies half_best.pth or any
|
|
upstream file.)
|
|
|
|
Everything this script prints to stdout is MEASURED. Nothing is estimated.
|
|
"""
|
|
import argparse
|
|
import copy
|
|
import json
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader, Subset
|
|
|
|
BENCH = os.path.expanduser('~/wiflow-std-bench')
|
|
SWEEP = os.path.join(BENCH, 'sweep')
|
|
OUTDIR = os.path.join(SWEEP, 'int8')
|
|
sys.path.insert(0, os.path.join(BENCH, 'upstream'))
|
|
sys.path.insert(0, SWEEP)
|
|
|
|
from dataset import (PreprocessedCSIKeypointsDataset, # noqa: E402
|
|
create_preprocessed_train_val_test_loaders)
|
|
from losses.pose_loss import PoseLoss # noqa: E402
|
|
from utils.metrics import calculate_pck, calculate_mpjpe # noqa: E402 LOCKED metric (torso norm)
|
|
from model_compact import CompactWiFlowPoseModel, describe # noqa: E402
|
|
|
|
# half variant config — IDENTICAL to sweep/run_sweep.py VARIANTS[0] that produced half_best.pth
|
|
HALF = dict(tcn=[270, 220, 170, 120], conv=[4, 8, 16, 32], attn_groups=4,
|
|
groups_mode='gcd20', input_pw_groups=1)
|
|
HALF_CKPT = os.path.join(SWEEP, 'half_best.pth')
|
|
CORRUPT_FILE_START = 487 # files 487-499 were zero-filled by clean_nan.py (same as sweep)
|
|
SEED = 42
|
|
THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5) # PCK@10..50
|
|
|
|
|
|
def set_seed(seed=SEED):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def build_half(dropout=0.5):
|
|
return CompactWiFlowPoseModel(
|
|
tcn_channels=HALF['tcn'], conv_channels=HALF['conv'],
|
|
attn_groups=HALF['attn_groups'], groups_mode=HALF['groups_mode'],
|
|
input_pw_groups=HALF['input_pw_groups'], dropout=dropout)
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, loader, device):
|
|
"""MEASURED PCK@10..50 + MPJPE under the LOCKED torso-diameter normalization."""
|
|
model.eval()
|
|
totals = {t: 0.0 for t in THRESHOLDS}
|
|
total_mpe, n = 0.0, 0
|
|
for bx, by in loader:
|
|
bx, by = bx.to(device), by.to(device)
|
|
out = model(bx)
|
|
bs = by.size(0)
|
|
total_mpe += calculate_mpjpe(out, by) * bs
|
|
pck = calculate_pck(out, by, thresholds=list(totals)) # use_torso_norm=True default
|
|
for t in totals:
|
|
totals[t] += pck[t] * bs
|
|
n += bs
|
|
return {'samples': n, 'mpjpe': total_mpe / n,
|
|
**{f'pck@{int(t * 100)}': totals[t] / n for t in totals}}
|
|
|
|
|
|
def file_size_mb(path):
|
|
return os.path.getsize(path) / (1024 * 1024)
|
|
|
|
|
|
def state_dict_size_mb(model, path):
|
|
"""On-disk size of the *quantized* checkpoint (int8 weights are packed by fbgemm)."""
|
|
torch.save(model.state_dict(), path)
|
|
return file_size_mb(path)
|
|
|
|
|
|
def loaders():
|
|
set_seed(SEED)
|
|
data_dir = os.path.join(BENCH, 'preprocessed_csi_data')
|
|
dataset = PreprocessedCSIKeypointsDataset(data_dir=data_dir, keypoint_scale=1000.0,
|
|
enable_temporal_clean=True)
|
|
train_loader, val_loader, test_loader = create_preprocessed_train_val_test_loaders(
|
|
dataset=dataset, batch_size=64, num_workers=2, random_seed=SEED)
|
|
return dataset, train_loader, val_loader, test_loader
|
|
|
|
|
|
def clean_loader_from(dataset, test_loader, bs=256):
|
|
w2f = dataset.window_to_file
|
|
clean_idx = [i for i in test_loader.dataset.indices if w2f[i] < CORRUPT_FILE_START]
|
|
return DataLoader(Subset(dataset, clean_idx), batch_size=bs, shuffle=False, num_workers=2)
|
|
|
|
|
|
def eval_loaders(dataset, test_loader, bs=256):
|
|
full = DataLoader(test_loader.dataset, batch_size=bs, shuffle=False, num_workers=2)
|
|
clean = clean_loader_from(dataset, test_loader, bs=bs)
|
|
return full, clean
|
|
|
|
|
|
# --------------------------------------------------------------- int8 paths (FX graph mode)
|
|
def ptq_static(fp32_model, train_loader, calib_batches=64):
|
|
"""Static post-training quantization, FX graph mode, fbgemm. CPU int8."""
|
|
from torch.ao.quantization import get_default_qconfig, QConfigMapping
|
|
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
|
|
torch.backends.quantized.engine = 'fbgemm'
|
|
m = copy.deepcopy(fp32_model).cpu().eval()
|
|
qconfig = get_default_qconfig('fbgemm')
|
|
qmap = QConfigMapping().set_global(qconfig)
|
|
example = torch.randn(1, 540, 20)
|
|
prepared = prepare_fx(m, qmap, example_inputs=(example,))
|
|
prepared.eval()
|
|
with torch.no_grad():
|
|
for i, (bx, _) in enumerate(train_loader):
|
|
prepared(bx.cpu())
|
|
if i + 1 >= calib_batches:
|
|
break
|
|
return convert_fx(prepared)
|
|
|
|
|
|
def qat(fp32_model, train_loader, val_loader, device, epochs=3, lr=2e-5):
|
|
"""Quantization-aware training, FX graph mode, fbgemm. Fine-tune fake-quant from fp32, convert. CPU int8."""
|
|
from torch.ao.quantization import get_default_qat_qconfig, QConfigMapping
|
|
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
|
|
torch.backends.quantized.engine = 'fbgemm'
|
|
set_seed(SEED)
|
|
m = copy.deepcopy(fp32_model).to(device).train()
|
|
qconfig = get_default_qat_qconfig('fbgemm')
|
|
qmap = QConfigMapping().set_global(qconfig)
|
|
example = torch.randn(1, 540, 20).to(device)
|
|
prepared = prepare_qat_fx(m, qmap, example_inputs=(example,))
|
|
prepared.to(device)
|
|
|
|
criterion = PoseLoss(position_weight=1.0, bone_weight=0.2, loss_type='smooth_l1')
|
|
opt = torch.optim.AdamW(prepared.parameters(), lr=lr, weight_decay=5e-5, betas=(0.9, 0.999))
|
|
|
|
best_val = float('inf')
|
|
best_state = None
|
|
for ep in range(1, epochs + 1):
|
|
prepared.train()
|
|
t0 = time.time()
|
|
ep_loss, nb = 0.0, 0
|
|
for bx, by in train_loader:
|
|
bx, by = bx.to(device), by.to(device)
|
|
opt.zero_grad(set_to_none=True)
|
|
out = prepared(bx)
|
|
loss, _ = criterion(out, by)
|
|
if not torch.isfinite(loss):
|
|
continue
|
|
loss.backward()
|
|
opt.step()
|
|
ep_loss += loss.item()
|
|
nb += 1
|
|
# eval the fake-quant model on GPU (proxy for int8) to pick the best epoch
|
|
prepared.eval()
|
|
v = evaluate(prepared, val_loader, device)
|
|
print(f"[qat] epoch {ep}/{epochs} train_loss={ep_loss / max(nb,1):.5f} "
|
|
f"val_mpjpe(fakequant)={v['mpjpe']:.5f} val_pck20={v['pck@20']*100:.2f}% "
|
|
f"({time.time()-t0:.0f}s)", flush=True)
|
|
if v['mpjpe'] < best_val:
|
|
best_val = v['mpjpe']
|
|
best_state = copy.deepcopy(prepared.state_dict())
|
|
if best_state is not None:
|
|
prepared.load_state_dict(best_state)
|
|
prepared.cpu().eval()
|
|
return convert_fx(prepared)
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument('--mode', choices=['ptq', 'qat', 'both'], default='both')
|
|
ap.add_argument('--qat-epochs', type=int, default=3)
|
|
ap.add_argument('--calib-batches', type=int, default=64)
|
|
args = ap.parse_args()
|
|
os.makedirs(OUTDIR, exist_ok=True)
|
|
|
|
cuda = torch.device('cuda')
|
|
cpu = torch.device('cpu')
|
|
print(f"torch {torch.__version__} | cuda {torch.cuda.get_device_name(0)} | "
|
|
f"quantized.engine candidates {torch.backends.quantized.supported_engines}", flush=True)
|
|
|
|
dataset, train_loader, val_loader, test_loader = loaders()
|
|
test_full, test_clean = eval_loaders(dataset, test_loader)
|
|
|
|
# ---------- fp32 baseline (loads half_best.pth strict; same arch as sweep) ----------
|
|
fp32 = build_half().eval()
|
|
state = torch.load(HALF_CKPT, map_location='cpu', weights_only=True)
|
|
fp32.load_state_dict(state, strict=True)
|
|
fp32_size = file_size_mb(HALF_CKPT)
|
|
params = describe(fp32)['params']
|
|
print(f"\n=== fp32 baseline: half_best.pth | params={params:,} | "
|
|
f"on-disk={fp32_size:.3f} MB ===", flush=True)
|
|
|
|
results = {
|
|
'host': os.uname().nodename, 'gpu': torch.cuda.get_device_name(0),
|
|
'torch': torch.__version__, 'date_utc': time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()),
|
|
'locked_normalization': 'torso-diameter (neck idx2 -> pelvis idx12), '
|
|
'upstream calculate_pck use_torso_norm=True (ADR-173 standard)',
|
|
'checkpoint': HALF_CKPT, 'params': params, 'fp32_size_mb': fp32_size,
|
|
'test_split': 'seed-42 file-level 70/15/15 test (full 54000 / clean 52560)',
|
|
'fp32': {}, 'int8': {},
|
|
}
|
|
|
|
fp32_gpu = build_half().to(cuda).eval()
|
|
fp32_gpu.load_state_dict(state, strict=True)
|
|
print('[fp32/gpu] full ...', flush=True)
|
|
results['fp32']['gpu_full'] = evaluate(fp32_gpu, test_full, cuda)
|
|
print(json.dumps(results['fp32']['gpu_full']), flush=True)
|
|
print('[fp32/gpu] clean ...', flush=True)
|
|
results['fp32']['gpu_clean'] = evaluate(fp32_gpu, test_clean, cuda)
|
|
print(json.dumps(results['fp32']['gpu_clean']), flush=True)
|
|
|
|
print('[fp32/cpu] full (device-matched ref for int8) ...', flush=True)
|
|
results['fp32']['cpu_full'] = evaluate(fp32.to(cpu), test_full, cpu)
|
|
print(json.dumps(results['fp32']['cpu_full']), flush=True)
|
|
print('[fp32/cpu] clean ...', flush=True)
|
|
results['fp32']['cpu_clean'] = evaluate(fp32.to(cpu), test_clean, cpu)
|
|
print(json.dumps(results['fp32']['cpu_clean']), flush=True)
|
|
|
|
# ---------- int8 ----------
|
|
def measure_int8(label, qmodel):
|
|
path = os.path.join(OUTDIR, f'half_int8_{label}.pth')
|
|
size = state_dict_size_mb(qmodel, path)
|
|
print(f"[int8/{label}] on-disk={size:.3f} MB | full ...", flush=True)
|
|
full = evaluate(qmodel, test_full, cpu)
|
|
print(json.dumps(full), flush=True)
|
|
print(f"[int8/{label}] clean ...", flush=True)
|
|
clean = evaluate(qmodel, test_clean, cpu)
|
|
print(json.dumps(clean), flush=True)
|
|
results['int8'][label] = {'size_mb': size, 'checkpoint': path,
|
|
'cpu_full': full, 'cpu_clean': clean}
|
|
|
|
if args.mode in ('ptq', 'both'):
|
|
print("\n=== int8 PTQ (static, FX, fbgemm) ===", flush=True)
|
|
qp = ptq_static(fp32.to(cpu).eval(), train_loader, calib_batches=args.calib_batches)
|
|
measure_int8('ptq_static', qp)
|
|
|
|
if args.mode in ('qat', 'both'):
|
|
print(f"\n=== int8 QAT (FX, fbgemm, {args.qat_epochs} epochs from half_best) ===", flush=True)
|
|
qq = qat(fp32, train_loader, val_loader, cuda, epochs=args.qat_epochs)
|
|
measure_int8('qat', qq)
|
|
|
|
out = os.path.join(OUTDIR, 'int8_results.json')
|
|
with open(out, 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
print('\nwrote', out, flush=True)
|
|
|
|
# ---------- comparison table (MEASURED) ----------
|
|
print("\n================= MEASURED COMPARISON (clean test subset, torso-PCK) =================", flush=True)
|
|
base = results['fp32']['cpu_clean']
|
|
print(f"{'model':16s} {'size_MB':>8s} {'pck@20':>8s} {'pck@50':>8s} {'mpjpe':>9s}", flush=True)
|
|
print(f"{'fp32 (cpu)':16s} {fp32_size:8.3f} {base['pck@20']*100:7.2f}% {base['pck@50']*100:7.2f}% {base['mpjpe']:9.6f}", flush=True)
|
|
for label, r in results['int8'].items():
|
|
c = r['cpu_clean']
|
|
d20 = (c['pck@20'] - base['pck@20']) * 100
|
|
d50 = (c['pck@50'] - base['pck@50']) * 100
|
|
print(f"{'int8 '+label:16s} {r['size_mb']:8.3f} {c['pck@20']*100:7.2f}% {c['pck@50']*100:7.2f}% {c['mpjpe']:9.6f} "
|
|
f"(d_pck20={d20:+.2f}pp d_pck50={d50:+.2f}pp size={fp32_size/r['size_mb']:.2f}x smaller)", flush=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|