93 lines
4.0 KiB
Python
93 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Live CSI->pose inference bridge (ADR-180).
|
|
|
|
Runs on the box with the live CSI. Loads the camera-supervised model (numpy,
|
|
no torch needed), subscribes to /ws/sensing, runs a forward pass per frame, and
|
|
broadcasts the predicted 17-keypoint pose to HTML clients on ws://:8770/pose.
|
|
|
|
python wiflow_infer.py --model model/model.npz \
|
|
--in ws://localhost:8765/ws/sensing --port 8770
|
|
"""
|
|
import argparse, asyncio, json, os
|
|
import numpy as np
|
|
import websockets
|
|
|
|
# COCO skeleton edges (for the client; sent once in 'meta')
|
|
EDGES = [[5,7],[7,9],[6,8],[8,10],[5,6],[11,12],[5,11],[6,12],
|
|
[11,13],[13,15],[12,14],[14,16],[0,1],[0,2],[1,3],[2,4],[0,5],[0,6]]
|
|
|
|
def csi_vector(frame):
|
|
f = frame.get("features", {}) or {}
|
|
feats = [f.get("mean_rssi",0.0), f.get("variance",0.0),
|
|
f.get("motion_band_power",0.0), f.get("breathing_band_power",0.0)]
|
|
pernode = {nf.get("node_id"): (nf.get("features") or {}) for nf in (frame.get("node_features") or [])}
|
|
for nid in (9,13):
|
|
nf = pernode.get(nid,{}); feats += [nf.get("mean_rssi",0.0), nf.get("variance",0.0), nf.get("motion_band_power",0.0)]
|
|
field = (frame.get("signal_field",{}) or {}).get("values") or []
|
|
field = (field + [0.0]*400)[:400]
|
|
return np.array(feats + field, np.float32)
|
|
|
|
class Model:
|
|
def __init__(self, path):
|
|
z = np.load(path)
|
|
self.mu, self.sd = z["mu"], z["sd"]
|
|
self.W = [z["net_0_weight"], z["net_3_weight"], z["net_6_weight"], z["net_8_weight"]]
|
|
self.b = [z["net_0_bias"], z["net_3_bias"], z["net_6_bias"], z["net_8_bias"]]
|
|
def __call__(self, x):
|
|
h = (x - self.mu) / self.sd
|
|
for i in range(3):
|
|
h = np.maximum(0.0, h @ self.W[i].T + self.b[i]) # Linear+ReLU
|
|
out = 1.0/(1.0+np.exp(-(h @ self.W[3].T + self.b[3]))) # Linear+Sigmoid -> 34
|
|
return out.reshape(17,2)
|
|
|
|
CLIENTS = set()
|
|
LATEST = {"pose": None}
|
|
|
|
async def serve_client(ws):
|
|
CLIENTS.add(ws)
|
|
try:
|
|
await ws.send(json.dumps({"type":"meta","edges":EDGES}))
|
|
async for _ in ws: # client is read-only; just keep alive
|
|
pass
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
CLIENTS.discard(ws)
|
|
|
|
async def infer_loop(model, in_url):
|
|
while True:
|
|
try:
|
|
async with websockets.connect(in_url, open_timeout=8, ping_interval=20) as ws:
|
|
async for msg in ws:
|
|
d = json.loads(msg)
|
|
kp = model(csi_vector(d))
|
|
cls = d.get("classification",{})
|
|
payload = {"type":"pose","src":d.get("source"),
|
|
"presence":bool(cls.get("presence")),
|
|
"motion":(d.get("features",{}) or {}).get("motion_band_power"),
|
|
"kps":[[round(float(x),4),round(float(y),4)] for x,y in kp],
|
|
"nodes":sorted(n.get("node_id") for n in d.get("nodes",[]) if n.get("node_id") is not None)}
|
|
LATEST["pose"]=payload
|
|
if CLIENTS:
|
|
dead=[]
|
|
for c in list(CLIENTS):
|
|
try: await c.send(json.dumps(payload))
|
|
except Exception: dead.append(c)
|
|
for c in dead: CLIENTS.discard(c)
|
|
except Exception as e:
|
|
print(f"[infer] reconnect ({e})", flush=True); await asyncio.sleep(1.0)
|
|
|
|
async def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__),"model","model.npz"))
|
|
ap.add_argument("--in", dest="in_url", default="ws://localhost:8765/ws/sensing")
|
|
ap.add_argument("--port", type=int, default=8770)
|
|
args = ap.parse_args()
|
|
model = Model(args.model)
|
|
print(f"[infer] model {args.model} loaded; serving predicted poses on ws://0.0.0.0:{args.port}/pose")
|
|
async with websockets.serve(serve_client, "0.0.0.0", args.port):
|
|
await infer_loop(model, args.in_url)
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|