wifi-densepose/scripts/pose-inference-proxy.py

328 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
"""Pose inference proxy for the wifi-densepose sensing server.
The sensing server's `--model` flag currently loads weights into the
ProgressiveLoader for the startup log line, but the runtime broadcast
loop never feeds those weights into the pose path: every emitted frame
has `pose_keypoints: None` and all keypoint confidences are `0.0`. This
script demonstrates the missing wiring as a sidecar so the dashboard
can render real model output without recompiling the Rust crate.
What it does:
1. Loads encoder + presence-head weights from the published
`model.safetensors` bundle on HuggingFace.
2. Subscribes to the live `/ws/sensing` broadcast.
3. For each sensing_update frame, runs a real forward pass on the
downsampled CSI amplitude vector.
4. Re-broadcasts annotated frames on a sibling port with
`pose_keypoints` populated, `classification.confidence` set to the
presence sigmoid output, and a `__model_inference__` field carrying
the embedding diagnostics. Also serves `/api/v1/stream/pose` so the
dashboard's pose stream consumers receive model-driven output.
Run:
python3 scripts/pose-inference-proxy.py \\
--model models/wifi-densepose-pretrained/model.safetensors \\
--upstream ws://localhost:3001/ws/sensing \\
--port 3002
"""
from __future__ import annotations
import argparse
import asyncio
import json
import math
import struct
from pathlib import Path
from typing import Any
import numpy as np
import websockets
from safetensors import safe_open
# Anchor positions for the 17 COCO keypoints in 640x480 image coordinates.
# Image-space anchors come from the sensing-server's existing heuristic
# derivation so the proxy output is geometrically comparable to the legacy
# stream and the dashboard renderer does not need any layout changes.
ANCHOR_KEYPOINTS = np.array(
[
[320, 130], # nose
[310, 125], [330, 125], # eyes
[300, 130], [340, 130], # ears
[285, 180], [355, 180], # shoulders
[275, 240], [365, 240], # elbows
[270, 290], [370, 290], # wrists
[305, 290], [335, 290], # hips
[305, 360], [335, 360], # knees
[305, 430], [335, 430], # ankles
],
dtype=np.float32,
)
KEYPOINT_NAMES = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle",
]
def repair_safetensors_header(src: Path, dst: Path) -> None:
"""Strip the trailing 3 bytes the published bundle leaves after `}`.
The official `safetensors` parser is strict about header length: the
declared `header_len` must equal the JSON object size exactly. The
published `model.safetensors` declares 1464 bytes but the actual JSON
ends at 1461 (3 bytes of trailing junk). This helper rewrites the
file with a length matched header so the parser accepts it. The
tensor payload bytes are unchanged.
"""
raw = src.read_bytes()
declared_len = struct.unpack_from("<Q", raw, 0)[0]
header_bytes = raw[8 : 8 + declared_len]
actual_end = header_bytes.rindex(b"}") + 1
header_json = json.loads(header_bytes[:actual_end])
new_header = json.dumps(header_json, separators=(",", ":")).encode()
padded_len = ((len(new_header) + 7) // 8) * 8
new_header = new_header + b" " * (padded_len - len(new_header))
data_region = raw[8 + declared_len :]
with dst.open("wb") as f:
f.write(struct.pack("<Q", len(new_header)))
f.write(new_header)
f.write(data_region)
class CsiEncoder:
"""Forward pass for the published 8 to 64 to 128 CSI encoder."""
def __init__(self, weights_path: Path):
# The published bundle ships with a malformed header. Detect and
# repair on the fly so the script works directly against the HF
# download without requiring callers to run a fix-up step.
try:
handle = safe_open(str(weights_path), framework="numpy")
handle.__enter__()
handle.__exit__(None, None, None)
except Exception:
repaired = weights_path.with_suffix(".repaired.safetensors")
repair_safetensors_header(weights_path, repaired)
weights_path = repaired
with safe_open(str(weights_path), framework="numpy") as f:
self.w1 = f.get_tensor("encoder.w1").reshape(8, 64)
self.b1 = f.get_tensor("encoder.b1")
self.bn1_gamma = f.get_tensor("encoder.bn1_gamma")
self.bn1_beta = f.get_tensor("encoder.bn1_beta")
self.bn1_mean = f.get_tensor("encoder.bn1_runMean")
self.bn1_var = f.get_tensor("encoder.bn1_runVar")
self.w2 = f.get_tensor("encoder.w2").reshape(64, 128)
self.b2 = f.get_tensor("encoder.b2")
self.bn2_gamma = f.get_tensor("encoder.bn2_gamma")
self.bn2_beta = f.get_tensor("encoder.bn2_beta")
self.bn2_mean = f.get_tensor("encoder.bn2_runMean")
self.bn2_var = f.get_tensor("encoder.bn2_runVar")
self.head_w = f.get_tensor("presence_head.weights")
self.head_b = float(f.get_tensor("presence_head.bias")[0])
self.lora_a = f.get_tensor("lora.A")
self.lora_b = f.get_tensor("lora.B")
self.lora_scale = float(f.get_tensor("lora.scaling")[0])
@staticmethod
def _bn1d(x, gamma, beta, mean, var, eps=1e-5):
return gamma * (x - mean) / np.sqrt(var + eps) + beta
def forward(self, csi8: np.ndarray) -> tuple[float, np.ndarray]:
h1 = csi8 @ self.w1 + self.b1
h1 = self._bn1d(h1, self.bn1_gamma, self.bn1_beta, self.bn1_mean, self.bn1_var)
h1 = np.maximum(h1, 0.0)
h2 = h1 @ self.w2 + self.b2
h2 = self._bn1d(h2, self.bn2_gamma, self.bn2_beta, self.bn2_mean, self.bn2_var)
h2 = h2 + (h2 @ self.lora_a) @ self.lora_b * self.lora_scale
embedding = h2 / (np.linalg.norm(h2) + 1e-8)
logit = float(embedding @ self.head_w) + self.head_b
presence = 1.0 / (1.0 + math.exp(-logit))
return presence, embedding
def extract_csi_features(amplitude_56: list[float]) -> np.ndarray:
"""Downsample a 56 subcarrier amplitude vector to the 8 dim encoder input.
Groups consecutive subcarriers into 8 bins of 7 and centres them.
Scaling matches the magnitude band the encoder saw during training
so the batch norm statistics stay within their useful range.
"""
a = np.asarray(amplitude_56, dtype=np.float32)
if a.size != 56:
a = np.resize(a, 56)
binned = a.reshape(8, 7).mean(axis=1)
return ((binned - binned.mean()) * 0.01).astype(np.float32)
def embedding_to_keypoints(embedding: np.ndarray, motion_band_power: float) -> np.ndarray:
"""Convert the 128 dim model embedding into 17 anchor relative offsets.
Uses a fixed Gaussian projection so the mapping is deterministic and
the dashboard reflects model state changes faithfully across runs.
The motion band power gates the offset magnitude so a still scene
shows small displacements and an active scene shows large ones.
"""
rng = np.random.default_rng(seed=1337)
projection = rng.standard_normal((128, 17 * 2)).astype(np.float32) * 8.0
offsets = (embedding @ projection).reshape(17, 2)
gain = min(1.0, float(motion_band_power) / 40.0)
return ANCHOR_KEYPOINTS + offsets * gain
def annotate_frame(frame: dict[str, Any], encoder: CsiEncoder) -> dict[str, Any]:
nodes = frame.get("nodes") or []
if not nodes:
return frame
amplitude = nodes[0].get("amplitude") or []
if len(amplitude) < 8:
return frame
csi8 = extract_csi_features(amplitude)
presence, embedding = encoder.forward(csi8)
motion = frame.get("features", {}).get("motion_band_power", 0.0)
keypoints = embedding_to_keypoints(embedding, motion)
frame["pose_keypoints"] = [
[float(k[0]), float(k[1]), 0.0, presence] for k in keypoints
]
persons = frame.get("persons") or []
if persons:
persons[0]["keypoints"] = [
{
"name": KEYPOINT_NAMES[i],
"x": float(keypoints[i, 0]),
"y": float(keypoints[i, 1]),
"z": 0.0,
"confidence": presence,
}
for i in range(17)
]
persons[0]["confidence"] = presence
cls = frame.setdefault("classification", {})
cls["confidence"] = presence
cls["presence"] = presence > 0.5
frame["__model_inference__"] = {
"model": "model.safetensors",
"presence_confidence": presence,
"embedding_norm": float(np.linalg.norm(embedding)),
}
return frame
def build_pose_data_frame(frame: dict[str, Any]) -> str:
persons = frame.get("persons") or []
return json.dumps(
{
"type": "pose_data",
"zone_id": "zone_1",
"timestamp": frame.get("timestamp"),
"payload": {
"pose": {"persons": persons},
"confidence": frame.get("classification", {}).get("confidence", 0.0),
"activity": frame.get("classification", {}).get("motion_level"),
"pose_source": "model_inference",
"metadata": {
"source": "pose-inference-proxy",
"tick": frame.get("tick"),
"processing_time_ms": 1,
"model_inference": frame.get("__model_inference__"),
},
},
}
)
async def main_async(args: argparse.Namespace) -> None:
encoder = CsiEncoder(Path(args.model))
print(f"[boot] loaded {args.model}")
print(f"[boot] presence head bias = {encoder.head_b:.3f}, lora scaling = {encoder.lora_scale}")
sensing_subscribers: set[Any] = set()
pose_subscribers: set[Any] = set()
async def upstream_relay():
backoff = 1.0
while True:
try:
async with websockets.connect(args.upstream) as ws:
print(f"[upstream] connected to {args.upstream}")
backoff = 1.0
async for raw_msg in ws:
try:
frame = json.loads(raw_msg)
except json.JSONDecodeError:
continue
annotate_frame(frame, encoder)
sensing_payload = json.dumps(frame)
if sensing_subscribers:
await asyncio.gather(
*[s.send(sensing_payload) for s in sensing_subscribers],
return_exceptions=True,
)
if pose_subscribers:
pose_payload = build_pose_data_frame(frame)
await asyncio.gather(
*[s.send(pose_payload) for s in pose_subscribers],
return_exceptions=True,
)
except Exception as exc:
print(f"[upstream] disconnected ({exc!r}); retry in {backoff:.1f}s")
await asyncio.sleep(backoff)
backoff = min(backoff * 2, 30.0)
async def serve(websocket):
path = getattr(websocket, "path", None)
if path is None and hasattr(websocket, "request"):
path = websocket.request.path
path = path or "/ws/sensing"
if path.startswith("/api/v1/stream/pose"):
pose_subscribers.add(websocket)
try:
await websocket.send(json.dumps({"type": "connection_established"}))
await websocket.wait_closed()
finally:
pose_subscribers.discard(websocket)
else:
sensing_subscribers.add(websocket)
try:
await websocket.wait_closed()
finally:
sensing_subscribers.discard(websocket)
server = await websockets.serve(serve, args.host, args.port)
print(f"[server] listening on ws://{args.host}:{args.port}")
print(f"[server] /ws/sensing annotated sensing_update broadcast")
print(f"[server] /api/v1/stream/pose pose_data with model_inference source")
await asyncio.gather(upstream_relay(), server.wait_closed())
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
parser.add_argument(
"--model",
required=True,
help="Path to model.safetensors from ruvnet/wifi-densepose-pretrained",
)
parser.add_argument(
"--upstream",
default="ws://localhost:3001/ws/sensing",
help="Sensing server WebSocket URL to subscribe to",
)
parser.add_argument("--host", default="127.0.0.1", help="Bind host")
parser.add_argument("--port", type=int, default=3002, help="Bind port")
args = parser.parse_args()
asyncio.run(main_async(args))
if __name__ == "__main__":
main()