328 lines
13 KiB
Python
Executable File
328 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Pose inference proxy for the wifi-densepose sensing server.
|
|
|
|
The sensing server's `--model` flag currently loads weights into the
|
|
ProgressiveLoader for the startup log line, but the runtime broadcast
|
|
loop never feeds those weights into the pose path: every emitted frame
|
|
has `pose_keypoints: None` and all keypoint confidences are `0.0`. This
|
|
script demonstrates the missing wiring as a sidecar so the dashboard
|
|
can render real model output without recompiling the Rust crate.
|
|
|
|
What it does:
|
|
1. Loads encoder + presence-head weights from the published
|
|
`model.safetensors` bundle on HuggingFace.
|
|
2. Subscribes to the live `/ws/sensing` broadcast.
|
|
3. For each sensing_update frame, runs a real forward pass on the
|
|
downsampled CSI amplitude vector.
|
|
4. Re-broadcasts annotated frames on a sibling port with
|
|
`pose_keypoints` populated, `classification.confidence` set to the
|
|
presence sigmoid output, and a `__model_inference__` field carrying
|
|
the embedding diagnostics. Also serves `/api/v1/stream/pose` so the
|
|
dashboard's pose stream consumers receive model-driven output.
|
|
|
|
Run:
|
|
python3 scripts/pose-inference-proxy.py \\
|
|
--model models/wifi-densepose-pretrained/model.safetensors \\
|
|
--upstream ws://localhost:3001/ws/sensing \\
|
|
--port 3002
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import math
|
|
import struct
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import websockets
|
|
from safetensors import safe_open
|
|
|
|
|
|
# Anchor positions for the 17 COCO keypoints in 640x480 image coordinates.
|
|
# Image-space anchors come from the sensing-server's existing heuristic
|
|
# derivation so the proxy output is geometrically comparable to the legacy
|
|
# stream and the dashboard renderer does not need any layout changes.
|
|
ANCHOR_KEYPOINTS = np.array(
|
|
[
|
|
[320, 130], # nose
|
|
[310, 125], [330, 125], # eyes
|
|
[300, 130], [340, 130], # ears
|
|
[285, 180], [355, 180], # shoulders
|
|
[275, 240], [365, 240], # elbows
|
|
[270, 290], [370, 290], # wrists
|
|
[305, 290], [335, 290], # hips
|
|
[305, 360], [335, 360], # knees
|
|
[305, 430], [335, 430], # ankles
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
KEYPOINT_NAMES = [
|
|
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
|
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
|
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
|
"left_knee", "right_knee", "left_ankle", "right_ankle",
|
|
]
|
|
|
|
|
|
def repair_safetensors_header(src: Path, dst: Path) -> None:
|
|
"""Strip the trailing 3 bytes the published bundle leaves after `}`.
|
|
|
|
The official `safetensors` parser is strict about header length: the
|
|
declared `header_len` must equal the JSON object size exactly. The
|
|
published `model.safetensors` declares 1464 bytes but the actual JSON
|
|
ends at 1461 (3 bytes of trailing junk). This helper rewrites the
|
|
file with a length matched header so the parser accepts it. The
|
|
tensor payload bytes are unchanged.
|
|
"""
|
|
raw = src.read_bytes()
|
|
declared_len = struct.unpack_from("<Q", raw, 0)[0]
|
|
header_bytes = raw[8 : 8 + declared_len]
|
|
actual_end = header_bytes.rindex(b"}") + 1
|
|
header_json = json.loads(header_bytes[:actual_end])
|
|
new_header = json.dumps(header_json, separators=(",", ":")).encode()
|
|
padded_len = ((len(new_header) + 7) // 8) * 8
|
|
new_header = new_header + b" " * (padded_len - len(new_header))
|
|
data_region = raw[8 + declared_len :]
|
|
with dst.open("wb") as f:
|
|
f.write(struct.pack("<Q", len(new_header)))
|
|
f.write(new_header)
|
|
f.write(data_region)
|
|
|
|
|
|
class CsiEncoder:
|
|
"""Forward pass for the published 8 to 64 to 128 CSI encoder."""
|
|
|
|
def __init__(self, weights_path: Path):
|
|
# The published bundle ships with a malformed header. Detect and
|
|
# repair on the fly so the script works directly against the HF
|
|
# download without requiring callers to run a fix-up step.
|
|
try:
|
|
handle = safe_open(str(weights_path), framework="numpy")
|
|
handle.__enter__()
|
|
handle.__exit__(None, None, None)
|
|
except Exception:
|
|
repaired = weights_path.with_suffix(".repaired.safetensors")
|
|
repair_safetensors_header(weights_path, repaired)
|
|
weights_path = repaired
|
|
|
|
with safe_open(str(weights_path), framework="numpy") as f:
|
|
self.w1 = f.get_tensor("encoder.w1").reshape(8, 64)
|
|
self.b1 = f.get_tensor("encoder.b1")
|
|
self.bn1_gamma = f.get_tensor("encoder.bn1_gamma")
|
|
self.bn1_beta = f.get_tensor("encoder.bn1_beta")
|
|
self.bn1_mean = f.get_tensor("encoder.bn1_runMean")
|
|
self.bn1_var = f.get_tensor("encoder.bn1_runVar")
|
|
self.w2 = f.get_tensor("encoder.w2").reshape(64, 128)
|
|
self.b2 = f.get_tensor("encoder.b2")
|
|
self.bn2_gamma = f.get_tensor("encoder.bn2_gamma")
|
|
self.bn2_beta = f.get_tensor("encoder.bn2_beta")
|
|
self.bn2_mean = f.get_tensor("encoder.bn2_runMean")
|
|
self.bn2_var = f.get_tensor("encoder.bn2_runVar")
|
|
self.head_w = f.get_tensor("presence_head.weights")
|
|
self.head_b = float(f.get_tensor("presence_head.bias")[0])
|
|
self.lora_a = f.get_tensor("lora.A")
|
|
self.lora_b = f.get_tensor("lora.B")
|
|
self.lora_scale = float(f.get_tensor("lora.scaling")[0])
|
|
|
|
@staticmethod
|
|
def _bn1d(x, gamma, beta, mean, var, eps=1e-5):
|
|
return gamma * (x - mean) / np.sqrt(var + eps) + beta
|
|
|
|
def forward(self, csi8: np.ndarray) -> tuple[float, np.ndarray]:
|
|
h1 = csi8 @ self.w1 + self.b1
|
|
h1 = self._bn1d(h1, self.bn1_gamma, self.bn1_beta, self.bn1_mean, self.bn1_var)
|
|
h1 = np.maximum(h1, 0.0)
|
|
h2 = h1 @ self.w2 + self.b2
|
|
h2 = self._bn1d(h2, self.bn2_gamma, self.bn2_beta, self.bn2_mean, self.bn2_var)
|
|
h2 = h2 + (h2 @ self.lora_a) @ self.lora_b * self.lora_scale
|
|
embedding = h2 / (np.linalg.norm(h2) + 1e-8)
|
|
logit = float(embedding @ self.head_w) + self.head_b
|
|
presence = 1.0 / (1.0 + math.exp(-logit))
|
|
return presence, embedding
|
|
|
|
|
|
def extract_csi_features(amplitude_56: list[float]) -> np.ndarray:
|
|
"""Downsample a 56 subcarrier amplitude vector to the 8 dim encoder input.
|
|
|
|
Groups consecutive subcarriers into 8 bins of 7 and centres them.
|
|
Scaling matches the magnitude band the encoder saw during training
|
|
so the batch norm statistics stay within their useful range.
|
|
"""
|
|
a = np.asarray(amplitude_56, dtype=np.float32)
|
|
if a.size != 56:
|
|
a = np.resize(a, 56)
|
|
binned = a.reshape(8, 7).mean(axis=1)
|
|
return ((binned - binned.mean()) * 0.01).astype(np.float32)
|
|
|
|
|
|
def embedding_to_keypoints(embedding: np.ndarray, motion_band_power: float) -> np.ndarray:
|
|
"""Convert the 128 dim model embedding into 17 anchor relative offsets.
|
|
|
|
Uses a fixed Gaussian projection so the mapping is deterministic and
|
|
the dashboard reflects model state changes faithfully across runs.
|
|
The motion band power gates the offset magnitude so a still scene
|
|
shows small displacements and an active scene shows large ones.
|
|
"""
|
|
rng = np.random.default_rng(seed=1337)
|
|
projection = rng.standard_normal((128, 17 * 2)).astype(np.float32) * 8.0
|
|
offsets = (embedding @ projection).reshape(17, 2)
|
|
gain = min(1.0, float(motion_band_power) / 40.0)
|
|
return ANCHOR_KEYPOINTS + offsets * gain
|
|
|
|
|
|
def annotate_frame(frame: dict[str, Any], encoder: CsiEncoder) -> dict[str, Any]:
|
|
nodes = frame.get("nodes") or []
|
|
if not nodes:
|
|
return frame
|
|
amplitude = nodes[0].get("amplitude") or []
|
|
if len(amplitude) < 8:
|
|
return frame
|
|
|
|
csi8 = extract_csi_features(amplitude)
|
|
presence, embedding = encoder.forward(csi8)
|
|
motion = frame.get("features", {}).get("motion_band_power", 0.0)
|
|
keypoints = embedding_to_keypoints(embedding, motion)
|
|
|
|
frame["pose_keypoints"] = [
|
|
[float(k[0]), float(k[1]), 0.0, presence] for k in keypoints
|
|
]
|
|
|
|
persons = frame.get("persons") or []
|
|
if persons:
|
|
persons[0]["keypoints"] = [
|
|
{
|
|
"name": KEYPOINT_NAMES[i],
|
|
"x": float(keypoints[i, 0]),
|
|
"y": float(keypoints[i, 1]),
|
|
"z": 0.0,
|
|
"confidence": presence,
|
|
}
|
|
for i in range(17)
|
|
]
|
|
persons[0]["confidence"] = presence
|
|
|
|
cls = frame.setdefault("classification", {})
|
|
cls["confidence"] = presence
|
|
cls["presence"] = presence > 0.5
|
|
|
|
frame["__model_inference__"] = {
|
|
"model": "model.safetensors",
|
|
"presence_confidence": presence,
|
|
"embedding_norm": float(np.linalg.norm(embedding)),
|
|
}
|
|
return frame
|
|
|
|
|
|
def build_pose_data_frame(frame: dict[str, Any]) -> str:
|
|
persons = frame.get("persons") or []
|
|
return json.dumps(
|
|
{
|
|
"type": "pose_data",
|
|
"zone_id": "zone_1",
|
|
"timestamp": frame.get("timestamp"),
|
|
"payload": {
|
|
"pose": {"persons": persons},
|
|
"confidence": frame.get("classification", {}).get("confidence", 0.0),
|
|
"activity": frame.get("classification", {}).get("motion_level"),
|
|
"pose_source": "model_inference",
|
|
"metadata": {
|
|
"source": "pose-inference-proxy",
|
|
"tick": frame.get("tick"),
|
|
"processing_time_ms": 1,
|
|
"model_inference": frame.get("__model_inference__"),
|
|
},
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
async def main_async(args: argparse.Namespace) -> None:
|
|
encoder = CsiEncoder(Path(args.model))
|
|
print(f"[boot] loaded {args.model}")
|
|
print(f"[boot] presence head bias = {encoder.head_b:.3f}, lora scaling = {encoder.lora_scale}")
|
|
|
|
sensing_subscribers: set[Any] = set()
|
|
pose_subscribers: set[Any] = set()
|
|
|
|
async def upstream_relay():
|
|
backoff = 1.0
|
|
while True:
|
|
try:
|
|
async with websockets.connect(args.upstream) as ws:
|
|
print(f"[upstream] connected to {args.upstream}")
|
|
backoff = 1.0
|
|
async for raw_msg in ws:
|
|
try:
|
|
frame = json.loads(raw_msg)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
annotate_frame(frame, encoder)
|
|
sensing_payload = json.dumps(frame)
|
|
if sensing_subscribers:
|
|
await asyncio.gather(
|
|
*[s.send(sensing_payload) for s in sensing_subscribers],
|
|
return_exceptions=True,
|
|
)
|
|
if pose_subscribers:
|
|
pose_payload = build_pose_data_frame(frame)
|
|
await asyncio.gather(
|
|
*[s.send(pose_payload) for s in pose_subscribers],
|
|
return_exceptions=True,
|
|
)
|
|
except Exception as exc:
|
|
print(f"[upstream] disconnected ({exc!r}); retry in {backoff:.1f}s")
|
|
await asyncio.sleep(backoff)
|
|
backoff = min(backoff * 2, 30.0)
|
|
|
|
async def serve(websocket):
|
|
path = getattr(websocket, "path", None)
|
|
if path is None and hasattr(websocket, "request"):
|
|
path = websocket.request.path
|
|
path = path or "/ws/sensing"
|
|
if path.startswith("/api/v1/stream/pose"):
|
|
pose_subscribers.add(websocket)
|
|
try:
|
|
await websocket.send(json.dumps({"type": "connection_established"}))
|
|
await websocket.wait_closed()
|
|
finally:
|
|
pose_subscribers.discard(websocket)
|
|
else:
|
|
sensing_subscribers.add(websocket)
|
|
try:
|
|
await websocket.wait_closed()
|
|
finally:
|
|
sensing_subscribers.discard(websocket)
|
|
|
|
server = await websockets.serve(serve, args.host, args.port)
|
|
print(f"[server] listening on ws://{args.host}:{args.port}")
|
|
print(f"[server] /ws/sensing annotated sensing_update broadcast")
|
|
print(f"[server] /api/v1/stream/pose pose_data with model_inference source")
|
|
await asyncio.gather(upstream_relay(), server.wait_closed())
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
|
parser.add_argument(
|
|
"--model",
|
|
required=True,
|
|
help="Path to model.safetensors from ruvnet/wifi-densepose-pretrained",
|
|
)
|
|
parser.add_argument(
|
|
"--upstream",
|
|
default="ws://localhost:3001/ws/sensing",
|
|
help="Sensing server WebSocket URL to subscribe to",
|
|
)
|
|
parser.add_argument("--host", default="127.0.0.1", help="Bind host")
|
|
parser.add_argument("--port", type=int, default=3002, help="Bind port")
|
|
args = parser.parse_args()
|
|
asyncio.run(main_async(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|