#!/usr/bin/env python3 """Live CSI->pose inference bridge (ADR-180). Runs on the box with the live CSI. Loads the camera-supervised model (numpy, no torch needed), subscribes to /ws/sensing, runs a forward pass per frame, and broadcasts the predicted 17-keypoint pose to HTML clients on ws://:8770/pose. python wiflow_infer.py --model model/model.npz \ --in ws://localhost:8765/ws/sensing --port 8770 """ import argparse, asyncio, json, os import numpy as np import websockets # COCO skeleton edges (for the client; sent once in 'meta') EDGES = [[5,7],[7,9],[6,8],[8,10],[5,6],[11,12],[5,11],[6,12], [11,13],[13,15],[12,14],[14,16],[0,1],[0,2],[1,3],[2,4],[0,5],[0,6]] def csi_vector(frame): f = frame.get("features", {}) or {} feats = [f.get("mean_rssi",0.0), f.get("variance",0.0), f.get("motion_band_power",0.0), f.get("breathing_band_power",0.0)] pernode = {nf.get("node_id"): (nf.get("features") or {}) for nf in (frame.get("node_features") or [])} for nid in (9,13): nf = pernode.get(nid,{}); feats += [nf.get("mean_rssi",0.0), nf.get("variance",0.0), nf.get("motion_band_power",0.0)] field = (frame.get("signal_field",{}) or {}).get("values") or [] field = (field + [0.0]*400)[:400] return np.array(feats + field, np.float32) class Model: def __init__(self, path): z = np.load(path) self.mu, self.sd = z["mu"], z["sd"] self.W = [z["net_0_weight"], z["net_3_weight"], z["net_6_weight"], z["net_8_weight"]] self.b = [z["net_0_bias"], z["net_3_bias"], z["net_6_bias"], z["net_8_bias"]] def __call__(self, x): h = (x - self.mu) / self.sd for i in range(3): h = np.maximum(0.0, h @ self.W[i].T + self.b[i]) # Linear+ReLU out = 1.0/(1.0+np.exp(-(h @ self.W[3].T + self.b[3]))) # Linear+Sigmoid -> 34 return out.reshape(17,2) CLIENTS = set() LATEST = {"pose": None} async def serve_client(ws): CLIENTS.add(ws) try: await ws.send(json.dumps({"type":"meta","edges":EDGES})) async for _ in ws: # client is read-only; just keep alive pass except Exception: pass finally: CLIENTS.discard(ws) async def infer_loop(model, in_url): while True: try: async with websockets.connect(in_url, open_timeout=8, ping_interval=20) as ws: async for msg in ws: d = json.loads(msg) kp = model(csi_vector(d)) cls = d.get("classification",{}) payload = {"type":"pose","src":d.get("source"), "presence":bool(cls.get("presence")), "motion":(d.get("features",{}) or {}).get("motion_band_power"), "kps":[[round(float(x),4),round(float(y),4)] for x,y in kp], "nodes":sorted(n.get("node_id") for n in d.get("nodes",[]) if n.get("node_id") is not None)} LATEST["pose"]=payload if CLIENTS: dead=[] for c in list(CLIENTS): try: await c.send(json.dumps(payload)) except Exception: dead.append(c) for c in dead: CLIENTS.discard(c) except Exception as e: print(f"[infer] reconnect ({e})", flush=True); await asyncio.sleep(1.0) async def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__),"model","model.npz")) ap.add_argument("--in", dest="in_url", default="ws://localhost:8765/ws/sensing") ap.add_argument("--port", type=int, default=8770) args = ap.parse_args() model = Model(args.model) print(f"[infer] model {args.model} loaded; serving predicted poses on ws://0.0.0.0:{args.port}/pose") async with websockets.serve(serve_client, "0.0.0.0", args.port): await infer_loop(model, args.in_url) if __name__ == "__main__": asyncio.run(main())