wifi-densepose/benchmarks/wiflow-std/remote/eval_retrained.py

113 lines
4.6 KiB
Python

"""Evaluate the retrained WiFlow-STD checkpoint (ADR-152 §2.2a fallback).
Scores the model produced by run.py (train_output/best_pose_model.pth or similar)
on the seed-42 test split: full test set AND NaN-free subset (excluding windows
that were zero-filled by clean_nan.py — file indices 487-499).
NOTE: deployed to ruvultra (~/wiflow-std-bench) as a standalone single file,
so it deliberately inlines its helpers. The reference implementations (upstream
import shim, >1GB np.load mmap patch, key-remap loader, canonical evaluate
loop) live in benchmarks/wiflow-std/_bench_common.py — keep copies in sync.
"""
import json, os, random, sys
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
# csi_windows.npy is ~13 GB; mmap large arrays instead of eagerly loading
# ~15 GB into RAM (same patch as _bench_common._np_load_mmap).
_np_load = np.load
def _np_load_mmap(path, *a, **kw):
if (isinstance(path, str) and path.endswith('.npy')
and os.path.getsize(path) > 1 << 30 and 'mmap_mode' not in kw):
kw['mmap_mode'] = 'r'
return _np_load(path, *a, **kw)
np.load = _np_load_mmap
sys.path.insert(0, os.path.expanduser('~/wiflow-std-bench/upstream'))
from dataset import PreprocessedCSIKeypointsDataset, create_preprocessed_train_val_test_loaders
from models.pose_model import WiFlowPoseModel
from utils.metrics import calculate_pck, calculate_mpjpe
def find_checkpoint():
cands = []
for root, _, files in os.walk(os.path.expanduser('~/wiflow-std-bench/train_output')):
for f in files:
if f.endswith('.pth'):
cands.append(os.path.join(root, f))
# also upstream/test default output dir
for root, _, files in os.walk(os.path.expanduser('~/wiflow-std-bench/upstream')):
for f in files:
if f.endswith('.pth') and 'best' in f and 'cross_dataset' not in root:
p = os.path.join(root, f)
if os.path.getmtime(p) > os.path.getmtime(os.path.expanduser('~/wiflow-std-bench/train.log')) - 86400 * 2:
cands.append(p)
cands = [c for c in cands if not c.endswith('upstream/best_pose_model.pth')]
if not cands:
sys.exit('no retrained checkpoint found')
return max(cands, key=os.path.getmtime)
def evaluate(model, loader, device):
model.eval()
totals = {t: 0.0 for t in (0.1, 0.2, 0.3, 0.4, 0.5)}
total_mpe, n = 0.0, 0
with torch.no_grad():
for bx, by in loader:
bx, by = bx.to(device), by.to(device)
out = model(bx)
bs = by.size(0)
total_mpe += calculate_mpjpe(out, by) * bs
pck = calculate_pck(out, by, thresholds=list(totals))
for t in totals:
totals[t] += pck[t] * bs
n += bs
return {'samples': n, 'mpjpe': total_mpe / n,
**{f'pck@{int(t*100)}': totals[t] / n for t in totals}}
random.seed(42); np.random.seed(42); torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
d = os.path.expanduser('~/wiflow-std-bench/preprocessed_csi_data')
dataset = PreprocessedCSIKeypointsDataset(data_dir=d, keypoint_scale=1000.0,
enable_temporal_clean=True)
_, _, test_loader = create_preprocessed_train_val_test_loaders(
dataset=dataset, batch_size=256, num_workers=2, random_seed=42)
device = torch.device('cuda')
ckpt = find_checkpoint()
print('checkpoint:', ckpt)
model = WiFlowPoseModel(dropout=0.5).to(device)
state = torch.load(ckpt, map_location=device, weights_only=True)
renames = {'att.': 'attention.', 'final_conv.': 'decoder.'}
state = {next((new + k[len(old):] for old, new in renames.items()
if k.startswith(old)), k): v for k, v in state.items()}
model.load_state_dict(state, strict=True)
results = {'checkpoint': ckpt}
print('=== full test set ===')
results['test_full'] = evaluate(model, test_loader, device)
print(json.dumps(results['test_full'], indent=2))
# NaN-free subset: exclude windows from corrupted files 487-499
test_subset = test_loader.dataset # Subset(dataset, test_indices)
w2f = dataset.window_to_file
clean_idx = [i for i in test_subset.indices if w2f[i] < 487]
print(f'=== NaN-free test subset ({len(clean_idx)} of {len(test_subset.indices)}) ===')
clean_loader = DataLoader(Subset(dataset, clean_idx), batch_size=256, shuffle=False)
results['test_clean'] = evaluate(model, clean_loader, device)
print(json.dumps(results['test_clean'], indent=2))
out = os.path.expanduser('~/wiflow-std-bench/eval_retrained.json')
with open(out, 'w') as f:
json.dump(results, f, indent=2)
print('wrote', out)