113 lines
4.6 KiB
Python
113 lines
4.6 KiB
Python
"""Evaluate the retrained WiFlow-STD checkpoint (ADR-152 §2.2a fallback).
|
|
|
|
Scores the model produced by run.py (train_output/best_pose_model.pth or similar)
|
|
on the seed-42 test split: full test set AND NaN-free subset (excluding windows
|
|
that were zero-filled by clean_nan.py — file indices 487-499).
|
|
|
|
NOTE: deployed to ruvultra (~/wiflow-std-bench) as a standalone single file,
|
|
so it deliberately inlines its helpers. The reference implementations (upstream
|
|
import shim, >1GB np.load mmap patch, key-remap loader, canonical evaluate
|
|
loop) live in benchmarks/wiflow-std/_bench_common.py — keep copies in sync.
|
|
"""
|
|
import json, os, random, sys
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader, Subset
|
|
|
|
# csi_windows.npy is ~13 GB; mmap large arrays instead of eagerly loading
|
|
# ~15 GB into RAM (same patch as _bench_common._np_load_mmap).
|
|
_np_load = np.load
|
|
|
|
|
|
def _np_load_mmap(path, *a, **kw):
|
|
if (isinstance(path, str) and path.endswith('.npy')
|
|
and os.path.getsize(path) > 1 << 30 and 'mmap_mode' not in kw):
|
|
kw['mmap_mode'] = 'r'
|
|
return _np_load(path, *a, **kw)
|
|
|
|
|
|
np.load = _np_load_mmap
|
|
|
|
sys.path.insert(0, os.path.expanduser('~/wiflow-std-bench/upstream'))
|
|
from dataset import PreprocessedCSIKeypointsDataset, create_preprocessed_train_val_test_loaders
|
|
from models.pose_model import WiFlowPoseModel
|
|
from utils.metrics import calculate_pck, calculate_mpjpe
|
|
|
|
|
|
def find_checkpoint():
|
|
cands = []
|
|
for root, _, files in os.walk(os.path.expanduser('~/wiflow-std-bench/train_output')):
|
|
for f in files:
|
|
if f.endswith('.pth'):
|
|
cands.append(os.path.join(root, f))
|
|
# also upstream/test default output dir
|
|
for root, _, files in os.walk(os.path.expanduser('~/wiflow-std-bench/upstream')):
|
|
for f in files:
|
|
if f.endswith('.pth') and 'best' in f and 'cross_dataset' not in root:
|
|
p = os.path.join(root, f)
|
|
if os.path.getmtime(p) > os.path.getmtime(os.path.expanduser('~/wiflow-std-bench/train.log')) - 86400 * 2:
|
|
cands.append(p)
|
|
cands = [c for c in cands if not c.endswith('upstream/best_pose_model.pth')]
|
|
if not cands:
|
|
sys.exit('no retrained checkpoint found')
|
|
return max(cands, key=os.path.getmtime)
|
|
|
|
|
|
def evaluate(model, loader, device):
|
|
model.eval()
|
|
totals = {t: 0.0 for t in (0.1, 0.2, 0.3, 0.4, 0.5)}
|
|
total_mpe, n = 0.0, 0
|
|
with torch.no_grad():
|
|
for bx, by in loader:
|
|
bx, by = bx.to(device), by.to(device)
|
|
out = model(bx)
|
|
bs = by.size(0)
|
|
total_mpe += calculate_mpjpe(out, by) * bs
|
|
pck = calculate_pck(out, by, thresholds=list(totals))
|
|
for t in totals:
|
|
totals[t] += pck[t] * bs
|
|
n += bs
|
|
return {'samples': n, 'mpjpe': total_mpe / n,
|
|
**{f'pck@{int(t*100)}': totals[t] / n for t in totals}}
|
|
|
|
|
|
random.seed(42); np.random.seed(42); torch.manual_seed(42)
|
|
torch.cuda.manual_seed_all(42)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
d = os.path.expanduser('~/wiflow-std-bench/preprocessed_csi_data')
|
|
dataset = PreprocessedCSIKeypointsDataset(data_dir=d, keypoint_scale=1000.0,
|
|
enable_temporal_clean=True)
|
|
_, _, test_loader = create_preprocessed_train_val_test_loaders(
|
|
dataset=dataset, batch_size=256, num_workers=2, random_seed=42)
|
|
|
|
device = torch.device('cuda')
|
|
ckpt = find_checkpoint()
|
|
print('checkpoint:', ckpt)
|
|
model = WiFlowPoseModel(dropout=0.5).to(device)
|
|
state = torch.load(ckpt, map_location=device, weights_only=True)
|
|
renames = {'att.': 'attention.', 'final_conv.': 'decoder.'}
|
|
state = {next((new + k[len(old):] for old, new in renames.items()
|
|
if k.startswith(old)), k): v for k, v in state.items()}
|
|
model.load_state_dict(state, strict=True)
|
|
|
|
results = {'checkpoint': ckpt}
|
|
print('=== full test set ===')
|
|
results['test_full'] = evaluate(model, test_loader, device)
|
|
print(json.dumps(results['test_full'], indent=2))
|
|
|
|
# NaN-free subset: exclude windows from corrupted files 487-499
|
|
test_subset = test_loader.dataset # Subset(dataset, test_indices)
|
|
w2f = dataset.window_to_file
|
|
clean_idx = [i for i in test_subset.indices if w2f[i] < 487]
|
|
print(f'=== NaN-free test subset ({len(clean_idx)} of {len(test_subset.indices)}) ===')
|
|
clean_loader = DataLoader(Subset(dataset, clean_idx), batch_size=256, shuffle=False)
|
|
results['test_clean'] = evaluate(model, clean_loader, device)
|
|
print(json.dumps(results['test_clean'], indent=2))
|
|
|
|
out = os.path.expanduser('~/wiflow-std-bench/eval_retrained.json')
|
|
with open(out, 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
print('wrote', out)
|