wifi-densepose/examples/through-wall/wiflow_ab.py

127 lines
6.4 KiB
Python

#!/usr/bin/env python3
"""Rigorous A/B for WiFlow CSI->pose: is the held-out PCK real signal or split leakage?
For a dataset of {csi:[D], kps:17x[x,y,vis]} pairs, train the SAME small MLP under
several train/val SPLITS and report held-out PCK@0.10 vs the mean-pose baseline:
- chronological_80_20 : last 20% in time (val temporally ADJACENT to train -> leaks
via CSI/pose autocorrelation; this is what gave us +9.4)
- random_80_20 : shuffled (val frames interleaved with train -> MAX leak)
- blocked_gap : hold out a contiguous MIDDLE block with a time GAP buffer on
each side so val is NOT adjacent to any train frame -> the
honest, leakage-controlled test
If the model beats baseline on chronological/random but COLLAPSES to ~baseline on
blocked_gap, the apparent signal was temporal leakage, not generalizable CSI->pose.
Usage (ruvultra venv): python wiflow_ab.py --data ~/wiflow-room/dataset.jsonl
"""
import argparse, json, sys
import numpy as np, torch, torch.nn as nn
def _rec(r, X, Y, V, B):
X.append(r["csi"]); kp=r["kps"]
if kp and isinstance(kp[0], (list,tuple)): # 17 x [x,y(,vis)]
Y.append([c for k in kp for c in (k[0],k[1])]); V.append([(k[2] if len(k)>2 else 1.0) for k in kp])
else: # flat 34 (browser export, no vis)
Y.append(list(kp)); V.append([1.0]*17)
B.append(r.get("bucket"))
def load(path):
X,Y,V,B=[],[],[],[]
txt=open(path).read().strip()
if txt[:1] in "[{": # JSON (browser export: dict{samples:[]} or bare array)
d=json.loads(txt)
rows = d if isinstance(d,list) else d.get("samples", d.get("data", []))
for r in rows: _rec(r,X,Y,V,B)
else: # JSONL (python capture)
for line in txt.splitlines():
if line.strip(): _rec(json.loads(line),X,Y,V,B)
return np.array(X,np.float32), np.array(Y,np.float32), np.array(V,np.float32), B
class Net(nn.Module):
def __init__(s,din,dout):
super().__init__()
s.n=nn.Sequential(nn.Linear(din,384),nn.ReLU(),nn.Dropout(.35),
nn.Linear(384,192),nn.ReLU(),nn.Dropout(.35),
nn.Linear(192,96),nn.ReLU(),nn.Linear(96,dout),nn.Sigmoid())
def forward(s,x): return s.n(x)
def pck(pred,gt,vis,thr=0.10):
p=pred.reshape(-1,17,2); g=gt.reshape(-1,17,2)
d=np.linalg.norm(p-g,axis=2); m=vis>0.5
return float((d[m]<thr).mean()) if m.any() else 0.0
def split_idx(n, kind, B=None):
idx=np.arange(n)
if kind=="chronological_80_20":
c=int(n*.8); return idx[:c], idx[c:]
if kind=="random_80_20":
rng=np.random.default_rng(0); p=rng.permutation(n); c=int(n*.8); return p[:c], p[c:]
if kind=="blocked_gap":
# val = contiguous middle 20%; a WIDE 10% time gap each side guarantees no train
# frame is temporally adjacent to a val frame (kills frame-autocorrelation leakage).
v0=int(n*.4); v1=int(n*.6); gap=int(n*.10)
val=idx[v0:v1]; train=np.concatenate([idx[:max(0,v0-gap)], idx[min(n,v1+gap):]])
return train, val
if kind=="grouped_bucket":
# hold out ENTIRE activity buckets -> val poses/activities never seen in train.
# the strictest leakage-free test (only when bucket labels exist).
b=np.array([x if x is not None else -1 for x in B])
uniq=[u for u in sorted(set(b.tolist())) if u!=-1]
if len(uniq)<3: raise ValueError("too few buckets")
hold=set(uniq[::max(1,len(uniq)//3)][:max(1,len(uniq)//3)]) # ~1/3 of activities held out
val=idx[np.isin(b,list(hold))]; train=idx[~np.isin(b,list(hold))]
return train, val
raise ValueError(kind)
def run(X,Y,V,tr,va,epochs=250,seed=0):
torch.manual_seed(seed); np.random.seed(seed) # seed weight init + batch shuffle
dev="cuda" if torch.cuda.is_available() else "cpu"
mu,sd=X[tr].mean(0),X[tr].std(0)+1e-6
Xtr=torch.tensor((X[tr]-mu)/sd).to(dev); Ytr=torch.tensor(Y[tr]).to(dev)
Xva=torch.tensor((X[va]-mu)/sd).to(dev)
net=Net(X.shape[1],Y.shape[1]).to(dev)
opt=torch.optim.Adam(net.parameters(),lr=1e-3,weight_decay=1e-4); lf=nn.MSELoss()
best=(1e9,None)
for ep in range(epochs):
net.train(); perm=torch.randperm(len(Xtr),device=dev)
for i in range(0,len(Xtr),64):
j=perm[i:i+64]; opt.zero_grad(); loss=lf(net(Xtr[j]),Ytr[j]); loss.backward(); opt.step()
net.eval()
with torch.no_grad(): pv=net(Xva).cpu().numpy()
vl=float(((pv-Y[va])**2).mean())
if vl<best[0]: best=(vl,pv)
base=np.tile(Y[tr].mean(0),(len(va),1))
return pck(best[1],Y[va],V[va]), pck(base,Y[va],V[va])
def main():
ap=argparse.ArgumentParser(); ap.add_argument("--data",required=True)
ap.add_argument("--epochs",type=int,default=250); ap.add_argument("--seeds",type=int,default=3)
a=ap.parse_args()
X,Y,V,B=load(a.data); n=len(X)
has_buckets=any(x is not None for x in B)
print(f"[ab] {n} samples, X={X.shape}, buckets={'yes' if has_buckets else 'no'}, "
f"seeds={a.seeds}, epochs={a.epochs}\n")
print(f"{'split':<22}{'model PCK@0.10':>16}{'baseline':>11}{'delta (mean±sd)':>20} verdict")
print("-"*86)
splits=["chronological_80_20","random_80_20","blocked_gap"]+(["grouped_bucket"] if has_buckets else [])
for kind in splits:
try:
tr,va=split_idx(n,kind,B)
ms=[]; bs=[]
for s in range(a.seeds):
m,b=run(X,Y,V,tr,va,a.epochs,seed=s); ms.append(m); bs.append(b)
ms=np.array(ms)*100; bs=np.array(bs)*100; ds=ms-bs
dm,dsd=ds.mean(),ds.std()
# REAL only if the mean delta minus 1 sd still clears the 1.5pp threshold (robust to seed variance)
verdict = "REAL signal" if dm-dsd>1.5 else ("weak/uncertain" if dm>1.5 else "no signal (==baseline)")
print(f"{kind:<22}{ms.mean():>13.1f}±{ms.std():>3.1f}{bs.mean():>10.1f}%{dm:>+12.1f}±{dsd:>4.1f}pp {verdict}")
except Exception as e:
print(f"{kind:<22} skipped: {e}")
print(f"\nmean±sd over {a.seeds} seeds (weight init + batch order). blocked_gap = 10% time gap each")
print("side; grouped_bucket holds out ENTIRE activities (strictest). If only the LEAKY splits")
print("(chronological/random) beat baseline, the apparent signal is leakage, not generalizable pose.")
if __name__=="__main__": main()