wifi-densepose/examples/through-wall/wiflow_capture.py

162 lines
7.7 KiB
Python

#!/usr/bin/env python3
"""WiFlow-style camera-supervised capture (ADR-079 / ADR-180).
Runs on a box with BOTH a camera (ground truth) and reachable live CSI:
- opens a camera, runs MediaPipe Pose -> 17 COCO keypoints (the LABEL),
- subscribes to the sensing-server /ws/sensing (the INPUT: CSI features +
20x20 signal-field),
- writes timestamp-aligned (csi -> pose) pairs to a JSONL dataset.
This is the *collect* phase of camera-supervised CSI->pose training. The camera
and the CSI nodes MUST see the same person in the same space at the same time,
or the pairs are meaningless. Honest by construction: we only emit a pair when
BOTH a confident camera pose AND a live (source=esp32) CSI frame are present in
the same ~100 ms window.
Usage (on ruvultra, with the CSI tunneled to localhost:8765):
python3 wiflow_capture.py --ws ws://localhost:8765/ws/sensing \
--cam 0 --out ~/wiflow-room/dataset.jsonl --seconds 180
"""
import argparse, asyncio, json, time, threading, sys, os
from collections import deque
import urllib.request
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks.python import BaseOptions
from mediapipe.tasks.python.vision import PoseLandmarker, PoseLandmarkerOptions, RunningMode
import websockets
_MODEL_URL = ("https://storage.googleapis.com/mediapipe-models/pose_landmarker/"
"pose_landmarker_lite/float16/latest/pose_landmarker_lite.task")
def ensure_model(path: str) -> str:
if not os.path.exists(path):
os.makedirs(os.path.dirname(path), exist_ok=True)
print(f"[capture] downloading pose model -> {path}", flush=True)
urllib.request.urlretrieve(_MODEL_URL, path)
return path
# MediaPipe Pose (33 landmarks) -> 17 COCO keypoints (same mapping as
# scripts/collect-ground-truth.py, ADR-079).
COCO_FROM_MP = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]
COCO_NAMES = ["nose","l_eye","r_eye","l_ear","r_ear","l_sho","r_sho","l_elb",
"r_elb","l_wri","r_wri","l_hip","r_hip","l_knee","r_knee","l_ank","r_ank"]
# ---- shared state between the CSI (async) thread and the camera (sync) loop ----
_latest_csi = {"t": 0.0, "frame": None}
_csi_lock = threading.Lock()
_stop = threading.Event()
def csi_thread(ws_url: str):
"""Background thread: keep the most recent LIVE csi frame in _latest_csi."""
async def run():
while not _stop.is_set():
try:
async with websockets.connect(ws_url, open_timeout=8, ping_interval=20) as ws:
while not _stop.is_set():
msg = await asyncio.wait_for(ws.recv(), timeout=8)
d = json.loads(msg)
with _csi_lock:
_latest_csi["t"] = time.time()
_latest_csi["frame"] = d
except Exception as e:
print(f"[csi] reconnect ({e})", flush=True)
await asyncio.sleep(1.0)
asyncio.new_event_loop().run_until_complete(run())
def csi_vector(frame: dict):
"""Flatten a csi frame to a fixed-length input vector: features + field."""
f = frame.get("features", {}) or {}
feats = [f.get("mean_rssi", 0.0), f.get("variance", 0.0),
f.get("motion_band_power", 0.0), f.get("breathing_band_power", 0.0)]
# per-node mean_rssi/variance/motion for up to the 2 nodes (9, 13)
pernode = {nf.get("node_id"): (nf.get("features") or {}) for nf in (frame.get("node_features") or [])}
for nid in (9, 13):
nf = pernode.get(nid, {})
feats += [nf.get("mean_rssi", 0.0), nf.get("variance", 0.0), nf.get("motion_band_power", 0.0)]
field = (frame.get("signal_field", {}) or {}).get("values") or []
field = (field + [0.0] * 400)[:400]
return feats + field # 4 + 6 + 400 = 410-d
def main():
ap = argparse.ArgumentParser(description="WiFlow camera-supervised CSI<->pose capture (ADR-180).")
ap.add_argument("--ws", default="ws://localhost:8765/ws/sensing")
ap.add_argument("--cam", type=int, default=0)
ap.add_argument("--out", default=os.path.expanduser("~/wiflow-room/dataset.jsonl"))
ap.add_argument("--seconds", type=int, default=180)
ap.add_argument("--min-vis", type=float, default=0.5, help="min mean landmark visibility to accept a pose label")
ap.add_argument("--max-skew-ms", type=float, default=150, help="max csi/pose time skew to pair")
ap.add_argument("--require-esp32", action="store_true", default=True,
help="only pair when csi source==esp32 (real). Default on.")
args = ap.parse_args()
os.makedirs(os.path.dirname(args.out), exist_ok=True)
th = threading.Thread(target=csi_thread, args=(args.ws,), daemon=True)
th.start()
cap = cv2.VideoCapture(args.cam)
if not cap.isOpened():
print(f"ERROR: cannot open camera {args.cam}", file=sys.stderr); sys.exit(2)
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 480
model_path = ensure_model(os.path.expanduser("~/wiflow-room/pose_landmarker_lite.task"))
landmarker = PoseLandmarker.create_from_options(PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=model_path),
running_mode=RunningMode.IMAGE, min_pose_detection_confidence=0.5))
n_pairs = 0; n_nopose = 0; n_nocsi = 0; n_skew = 0; n_sim = 0
t0 = time.time()
print(f"[capture] camera {args.cam} {W}x{H} -> {args.out} for {args.seconds}s")
print("[capture] stand in view AND in the CSI field; move/walk so poses vary. Ctrl-C to stop.")
with open(args.out, "a") as out:
try:
while time.time() - t0 < args.seconds:
ok, frame = cap.read()
if not ok:
continue
now = time.time()
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
res = landmarker.detect(mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb))
if not res.pose_landmarks:
n_nopose += 1; continue
lm = res.pose_landmarks[0]
kps = [[lm[i].x, lm[i].y, lm[i].visibility] for i in COCO_FROM_MP]
vis = float(np.mean([k[2] for k in kps]))
if vis < args.min_vis:
n_nopose += 1; continue
with _csi_lock:
ct = _latest_csi["t"]; cf = _latest_csi["frame"]
if cf is None:
n_nocsi += 1; continue
if (now - ct) * 1000.0 > args.max_skew_ms:
n_skew += 1; continue
if args.require_esp32 and cf.get("source") != "esp32":
n_sim += 1; continue
rec = {"t": now, "vis": round(vis, 3),
"kps": [[round(x, 4), round(y, 4), round(v, 3)] for x, y, v in kps],
"csi": csi_vector(cf),
"src": cf.get("source"),
"nodes": sorted(n.get("node_id") for n in cf.get("nodes", []) if n.get("node_id") is not None)}
out.write(json.dumps(rec) + "\n")
n_pairs += 1
if n_pairs % 30 == 0:
out.flush()
el = int(now - t0)
print(f"[capture] t+{el:3d}s pairs={n_pairs} (skip: nopose={n_nopose} nocsi={n_nocsi} skew={n_skew} sim={n_sim})", flush=True)
except KeyboardInterrupt:
print("\n[capture] stopped by user")
_stop.set(); cap.release()
print(f"[capture] DONE. wrote {n_pairs} paired samples to {args.out}")
print(f"[capture] skipped: no-pose={n_nopose} no-csi={n_nocsi} skew={n_skew} simulated={n_sim}")
if n_pairs == 0:
print("[capture] WARNING: 0 pairs — check camera sees you AND csi source==esp32 (live).")
if __name__ == "__main__":
main()