feat: GCloud GPU training pipeline + data collection + benchmarking
- gcloud-train.sh: L4/A100/H100 VM provisioning, Rust build, training with --cuda, artifact download, auto-cleanup ($0.80-$8.50/hr) - training-config-sweep.json: 10 hyperparameter configs (LR, batch, backbone, windows, loss weights, warmup) - collect-training-data.py: UDP listener for 2-node ESP32 CSI recording to .csi.jsonl with interactive/batch labeling and manifest generation - benchmark-model.py: ONNX latency/throughput/PCK/FLOPs profiling with multi-model sweep comparison Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
9a2bc1839a
commit
c63cf2ee77
|
|
@ -0,0 +1,550 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
WiFi-DensePose Model Benchmarking
|
||||||
|
|
||||||
|
Loads trained ONNX models, runs inference on test data, and reports
|
||||||
|
performance metrics: latency, throughput, PCK@0.2, model size, and
|
||||||
|
estimated FLOPs.
|
||||||
|
|
||||||
|
Can compare multiple models from a hyperparameter sweep.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Benchmark a single model
|
||||||
|
python scripts/benchmark-model.py --model checkpoints/best.onnx
|
||||||
|
|
||||||
|
# Benchmark with recorded test data
|
||||||
|
python scripts/benchmark-model.py --model best.onnx --test-data data/recordings/test.csi.jsonl
|
||||||
|
|
||||||
|
# Compare models from a sweep
|
||||||
|
python scripts/benchmark-model.py --sweep-dir training-results/wdp-train-a100-*/checkpoints/
|
||||||
|
|
||||||
|
# Benchmark with synthetic data (no recordings needed)
|
||||||
|
python scripts/benchmark-model.py --model best.onnx --synthetic --num-samples 200
|
||||||
|
|
||||||
|
# Export results as JSON
|
||||||
|
python scripts/benchmark-model.py --model best.onnx --output results.json
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
pip install onnxruntime numpy
|
||||||
|
Optional: pip install onnx (for FLOPs estimation)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
except ImportError:
|
||||||
|
print("ERROR: onnxruntime not installed. Run: pip install onnxruntime")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Configuration ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Default model input shape (must match TrainingConfig defaults)
|
||||||
|
NUM_SUBCARRIERS = 56
|
||||||
|
NUM_ANTENNAS_TX = 3
|
||||||
|
NUM_ANTENNAS_RX = 3
|
||||||
|
WINDOW_FRAMES = 100
|
||||||
|
NUM_KEYPOINTS = 17
|
||||||
|
HEATMAP_SIZE = 56
|
||||||
|
|
||||||
|
# PCK threshold
|
||||||
|
PCK_THRESHOLD = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
# ── Data classes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
model_path: str
|
||||||
|
model_size_mb: float
|
||||||
|
num_parameters: Optional[int] = None
|
||||||
|
estimated_flops: Optional[int] = None
|
||||||
|
|
||||||
|
# Latency
|
||||||
|
warmup_runs: int = 10
|
||||||
|
benchmark_runs: int = 100
|
||||||
|
latency_mean_ms: float = 0.0
|
||||||
|
latency_std_ms: float = 0.0
|
||||||
|
latency_p50_ms: float = 0.0
|
||||||
|
latency_p95_ms: float = 0.0
|
||||||
|
latency_p99_ms: float = 0.0
|
||||||
|
throughput_fps: float = 0.0
|
||||||
|
|
||||||
|
# Accuracy (if ground truth available)
|
||||||
|
pck_at_02: Optional[float] = None
|
||||||
|
mean_per_joint_error: Optional[float] = None
|
||||||
|
num_test_samples: int = 0
|
||||||
|
|
||||||
|
# Input shape
|
||||||
|
input_shape: list = field(default_factory=list)
|
||||||
|
provider: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── ONNX model loading ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def load_model(model_path: str) -> ort.InferenceSession:
|
||||||
|
"""Load an ONNX model with the best available execution provider."""
|
||||||
|
providers = []
|
||||||
|
if "CUDAExecutionProvider" in ort.get_available_providers():
|
||||||
|
providers.append("CUDAExecutionProvider")
|
||||||
|
providers.append("CPUExecutionProvider")
|
||||||
|
|
||||||
|
sess_opts = ort.SessionOptions()
|
||||||
|
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
sess_opts.intra_op_num_threads = os.cpu_count() or 4
|
||||||
|
|
||||||
|
session = ort.InferenceSession(model_path, sess_opts, providers=providers)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_info(model_path: str) -> dict:
|
||||||
|
"""Extract model metadata: size, parameter count, FLOPs estimate."""
|
||||||
|
path = Path(model_path)
|
||||||
|
size_mb = path.stat().st_size / (1024 * 1024)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"size_mb": round(size_mb, 2),
|
||||||
|
"num_parameters": None,
|
||||||
|
"estimated_flops": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to count parameters via onnx
|
||||||
|
try:
|
||||||
|
import onnx
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
total_params = 0
|
||||||
|
for initializer in model.graph.initializer:
|
||||||
|
shape = list(initializer.dims)
|
||||||
|
if shape:
|
||||||
|
total_params += int(np.prod(shape))
|
||||||
|
info["num_parameters"] = total_params
|
||||||
|
|
||||||
|
# Rough FLOPs estimate: ~2 * params (multiply-accumulate)
|
||||||
|
info["estimated_flops"] = total_params * 2
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not extract parameter count: {e}")
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
# ── Synthetic data generation ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def generate_synthetic_input(
|
||||||
|
batch_size: int = 1,
|
||||||
|
num_subcarriers: int = NUM_SUBCARRIERS,
|
||||||
|
num_tx: int = NUM_ANTENNAS_TX,
|
||||||
|
num_rx: int = NUM_ANTENNAS_RX,
|
||||||
|
window_frames: int = WINDOW_FRAMES,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate synthetic CSI input tensor matching the model's expected shape.
|
||||||
|
|
||||||
|
The WiFi-DensePose model expects input shape:
|
||||||
|
[batch, channels, height, width]
|
||||||
|
where channels = num_tx * num_rx, height = window_frames, width = num_subcarriers.
|
||||||
|
"""
|
||||||
|
channels = num_tx * num_rx # 3x3 = 9 MIMO streams
|
||||||
|
# Simulate CSI amplitude data with realistic distribution
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
data = rng.normal(loc=0.0, scale=1.0, size=(batch_size, channels, window_frames, num_subcarriers))
|
||||||
|
return data.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_synthetic_keypoints(
|
||||||
|
num_samples: int,
|
||||||
|
num_keypoints: int = NUM_KEYPOINTS,
|
||||||
|
heatmap_size: int = HEATMAP_SIZE,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate synthetic ground truth keypoint coordinates for PCK evaluation."""
|
||||||
|
rng = np.random.default_rng(123)
|
||||||
|
# Keypoints as (x, y) in [0, heatmap_size) range
|
||||||
|
return rng.uniform(0, heatmap_size, size=(num_samples, num_keypoints, 2)).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Load test data from .csi.jsonl ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def load_test_data(
|
||||||
|
jsonl_path: str,
|
||||||
|
window_frames: int = WINDOW_FRAMES,
|
||||||
|
num_subcarriers: int = NUM_SUBCARRIERS,
|
||||||
|
max_samples: int = 500,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Load CSI frames from a .csi.jsonl file and window them into model inputs."""
|
||||||
|
frames = []
|
||||||
|
path = Path(jsonl_path)
|
||||||
|
|
||||||
|
with open(path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
record = json.loads(line)
|
||||||
|
subs = record.get("subcarriers", [])
|
||||||
|
if len(subs) > 0:
|
||||||
|
frames.append(subs)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(frames) < window_frames:
|
||||||
|
print(f" Warning: Only {len(frames)} frames, need {window_frames}. Padding with zeros.")
|
||||||
|
while len(frames) < window_frames:
|
||||||
|
frames.append([0.0] * num_subcarriers)
|
||||||
|
|
||||||
|
# Normalize subcarrier count
|
||||||
|
normalized = []
|
||||||
|
for frame in frames:
|
||||||
|
if len(frame) < num_subcarriers:
|
||||||
|
frame = frame + [0.0] * (num_subcarriers - len(frame))
|
||||||
|
elif len(frame) > num_subcarriers:
|
||||||
|
# Downsample via linear interpolation
|
||||||
|
indices = np.linspace(0, len(frame) - 1, num_subcarriers)
|
||||||
|
frame = np.interp(indices, range(len(frame)), frame).tolist()
|
||||||
|
normalized.append(frame)
|
||||||
|
|
||||||
|
frames = normalized
|
||||||
|
|
||||||
|
# Create sliding windows
|
||||||
|
samples = []
|
||||||
|
stride = max(1, window_frames // 2)
|
||||||
|
for i in range(0, len(frames) - window_frames + 1, stride):
|
||||||
|
window = frames[i : i + window_frames]
|
||||||
|
# Shape: [channels=1, window_frames, num_subcarriers]
|
||||||
|
# Expand single stream to 9 channels (repeat for MIMO)
|
||||||
|
arr = np.array(window, dtype=np.float32)
|
||||||
|
arr = np.expand_dims(arr, axis=0) # [1, window_frames, num_subcarriers]
|
||||||
|
arr = np.repeat(arr, NUM_ANTENNAS_TX * NUM_ANTENNAS_RX, axis=0) # [9, window, subs]
|
||||||
|
samples.append(arr)
|
||||||
|
|
||||||
|
if len(samples) >= max_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not samples:
|
||||||
|
return generate_synthetic_input(1)
|
||||||
|
|
||||||
|
return np.stack(samples, axis=0) # [N, 9, window_frames, num_subcarriers]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Benchmarking ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def benchmark_latency(
|
||||||
|
session: ort.InferenceSession,
|
||||||
|
input_data: np.ndarray,
|
||||||
|
warmup: int = 10,
|
||||||
|
runs: int = 100,
|
||||||
|
) -> dict:
|
||||||
|
"""Measure inference latency over multiple runs."""
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(warmup):
|
||||||
|
session.run(None, {input_name: input_data[:1]})
|
||||||
|
|
||||||
|
# Timed runs
|
||||||
|
latencies = []
|
||||||
|
for _ in range(runs):
|
||||||
|
start = time.perf_counter()
|
||||||
|
session.run(None, {input_name: input_data[:1]})
|
||||||
|
end = time.perf_counter()
|
||||||
|
latencies.append((end - start) * 1000) # ms
|
||||||
|
|
||||||
|
latencies = np.array(latencies)
|
||||||
|
return {
|
||||||
|
"mean_ms": float(np.mean(latencies)),
|
||||||
|
"std_ms": float(np.std(latencies)),
|
||||||
|
"p50_ms": float(np.percentile(latencies, 50)),
|
||||||
|
"p95_ms": float(np.percentile(latencies, 95)),
|
||||||
|
"p99_ms": float(np.percentile(latencies, 99)),
|
||||||
|
"throughput_fps": 1000.0 / float(np.mean(latencies)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_pck(
|
||||||
|
predictions: np.ndarray,
|
||||||
|
ground_truth: np.ndarray,
|
||||||
|
threshold: float = PCK_THRESHOLD,
|
||||||
|
normalize_by: float = HEATMAP_SIZE,
|
||||||
|
) -> float:
|
||||||
|
"""Compute Percentage of Correct Keypoints at a given threshold.
|
||||||
|
|
||||||
|
PCK@t = fraction of predicted keypoints within t * normalize_by of ground truth.
|
||||||
|
"""
|
||||||
|
if predictions.shape != ground_truth.shape:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Euclidean distance per keypoint
|
||||||
|
distances = np.linalg.norm(predictions - ground_truth, axis=-1) # [N, K]
|
||||||
|
threshold_pixels = threshold * normalize_by
|
||||||
|
correct = (distances < threshold_pixels).astype(float)
|
||||||
|
return float(np.mean(correct))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_keypoints_from_heatmaps(heatmaps: np.ndarray) -> np.ndarray:
|
||||||
|
"""Convert heatmap outputs [N, K, H, W] to keypoint coordinates [N, K, 2]."""
|
||||||
|
n, k, h, w = heatmaps.shape
|
||||||
|
flat = heatmaps.reshape(n, k, -1)
|
||||||
|
max_idx = np.argmax(flat, axis=-1) # [N, K]
|
||||||
|
y = max_idx // w
|
||||||
|
x = max_idx % w
|
||||||
|
return np.stack([x, y], axis=-1).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_model(
|
||||||
|
model_path: str,
|
||||||
|
test_data: Optional[np.ndarray] = None,
|
||||||
|
gt_keypoints: Optional[np.ndarray] = None,
|
||||||
|
warmup: int = 10,
|
||||||
|
runs: int = 100,
|
||||||
|
) -> BenchmarkResult:
|
||||||
|
"""Run full benchmark on a single model."""
|
||||||
|
print(f"\nBenchmarking: {model_path}")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
session = load_model(model_path)
|
||||||
|
provider = session.get_providers()[0]
|
||||||
|
print(f" Provider: {provider}")
|
||||||
|
|
||||||
|
# Model info
|
||||||
|
model_info = get_model_info(model_path)
|
||||||
|
print(f" Size: {model_info['size_mb']} MB")
|
||||||
|
if model_info["num_parameters"]:
|
||||||
|
print(f" Parameters: {model_info['num_parameters']:,}")
|
||||||
|
if model_info["estimated_flops"]:
|
||||||
|
print(f" Estimated FLOPs: {model_info['estimated_flops']:,}")
|
||||||
|
|
||||||
|
# Input shape
|
||||||
|
input_meta = session.get_inputs()[0]
|
||||||
|
input_shape = input_meta.shape
|
||||||
|
print(f" Input: {input_meta.name} {input_shape} ({input_meta.type})")
|
||||||
|
|
||||||
|
# Output shapes
|
||||||
|
for out in session.get_outputs():
|
||||||
|
print(f" Output: {out.name} {out.shape}")
|
||||||
|
|
||||||
|
# Generate or use provided test data
|
||||||
|
if test_data is None:
|
||||||
|
# Infer shape from model
|
||||||
|
if input_shape and all(isinstance(d, int) for d in input_shape):
|
||||||
|
batch = max(1, input_shape[0] if input_shape[0] > 0 else 1)
|
||||||
|
test_data = np.random.randn(*[batch if d <= 0 else d for d in input_shape]).astype(np.float32)
|
||||||
|
else:
|
||||||
|
test_data = generate_synthetic_input(1)
|
||||||
|
|
||||||
|
# Latency benchmark
|
||||||
|
print(f" Running {warmup} warmup + {runs} benchmark iterations...")
|
||||||
|
latency = benchmark_latency(session, test_data, warmup=warmup, runs=runs)
|
||||||
|
print(f" Latency: {latency['mean_ms']:.2f} +/- {latency['std_ms']:.2f} ms")
|
||||||
|
print(f" P50/P95/P99: {latency['p50_ms']:.2f} / {latency['p95_ms']:.2f} / {latency['p99_ms']:.2f} ms")
|
||||||
|
print(f" Throughput: {latency['throughput_fps']:.1f} fps")
|
||||||
|
|
||||||
|
# Accuracy (if ground truth provided or we can do synthetic evaluation)
|
||||||
|
pck = None
|
||||||
|
mpjpe = None
|
||||||
|
num_samples = 0
|
||||||
|
|
||||||
|
if gt_keypoints is not None and test_data is not None:
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
all_preds = []
|
||||||
|
|
||||||
|
for i in range(len(test_data)):
|
||||||
|
outputs = session.run(None, {input_name: test_data[i : i + 1]})
|
||||||
|
# Assume first output is keypoint heatmaps [1, K, H, W]
|
||||||
|
heatmaps = outputs[0]
|
||||||
|
if heatmaps.ndim == 4:
|
||||||
|
kp = extract_keypoints_from_heatmaps(heatmaps)
|
||||||
|
all_preds.append(kp[0])
|
||||||
|
|
||||||
|
if all_preds:
|
||||||
|
predictions = np.stack(all_preds)
|
||||||
|
gt = gt_keypoints[: len(predictions)]
|
||||||
|
pck = compute_pck(predictions, gt)
|
||||||
|
distances = np.linalg.norm(predictions - gt, axis=-1)
|
||||||
|
mpjpe = float(np.mean(distances))
|
||||||
|
num_samples = len(predictions)
|
||||||
|
print(f" PCK@{PCK_THRESHOLD}: {pck:.4f}")
|
||||||
|
print(f" MPJPE: {mpjpe:.2f} px")
|
||||||
|
print(f" Samples: {num_samples}")
|
||||||
|
|
||||||
|
result = BenchmarkResult(
|
||||||
|
model_path=model_path,
|
||||||
|
model_size_mb=model_info["size_mb"],
|
||||||
|
num_parameters=model_info["num_parameters"],
|
||||||
|
estimated_flops=model_info["estimated_flops"],
|
||||||
|
warmup_runs=warmup,
|
||||||
|
benchmark_runs=runs,
|
||||||
|
latency_mean_ms=round(latency["mean_ms"], 3),
|
||||||
|
latency_std_ms=round(latency["std_ms"], 3),
|
||||||
|
latency_p50_ms=round(latency["p50_ms"], 3),
|
||||||
|
latency_p95_ms=round(latency["p95_ms"], 3),
|
||||||
|
latency_p99_ms=round(latency["p99_ms"], 3),
|
||||||
|
throughput_fps=round(latency["throughput_fps"], 1),
|
||||||
|
pck_at_02=round(pck, 4) if pck is not None else None,
|
||||||
|
mean_per_joint_error=round(mpjpe, 2) if mpjpe is not None else None,
|
||||||
|
num_test_samples=num_samples,
|
||||||
|
input_shape=list(input_shape) if input_shape else [],
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Comparison table ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def print_comparison_table(results: list[BenchmarkResult]):
|
||||||
|
"""Print a formatted comparison table of multiple models."""
|
||||||
|
if not results:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 100)
|
||||||
|
print(" Model Comparison")
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
# Header
|
||||||
|
print(
|
||||||
|
f"{'Model':<35} {'Size(MB)':>8} {'Params':>10} "
|
||||||
|
f"{'Lat(ms)':>8} {'P95(ms)':>8} {'FPS':>7} {'PCK@0.2':>8}"
|
||||||
|
)
|
||||||
|
print("-" * 100)
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
name = Path(r.model_path).stem[:33]
|
||||||
|
params = f"{r.num_parameters:,}" if r.num_parameters else "?"
|
||||||
|
pck = f"{r.pck_at_02:.4f}" if r.pck_at_02 is not None else "N/A"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{name:<35} {r.model_size_mb:>8.2f} {params:>10} "
|
||||||
|
f"{r.latency_mean_ms:>8.2f} {r.latency_p95_ms:>8.2f} "
|
||||||
|
f"{r.throughput_fps:>7.1f} {pck:>8}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
# Best model by latency
|
||||||
|
best_latency = min(results, key=lambda r: r.latency_mean_ms)
|
||||||
|
print(f"\n Fastest: {Path(best_latency.model_path).stem} ({best_latency.latency_mean_ms:.2f} ms)")
|
||||||
|
|
||||||
|
# Best by PCK (if available)
|
||||||
|
pck_results = [r for r in results if r.pck_at_02 is not None]
|
||||||
|
if pck_results:
|
||||||
|
best_pck = max(pck_results, key=lambda r: r.pck_at_02)
|
||||||
|
print(f" Best accuracy: {Path(best_pck.model_path).stem} (PCK@0.2={best_pck.pck_at_02:.4f})")
|
||||||
|
|
||||||
|
# Smallest model
|
||||||
|
smallest = min(results, key=lambda r: r.model_size_mb)
|
||||||
|
print(f" Smallest: {Path(smallest.model_path).stem} ({smallest.model_size_mb:.2f} MB)")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark WiFi-DensePose ONNX models",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--model", type=str, help="Path to a single ONNX model")
|
||||||
|
parser.add_argument("--sweep-dir", type=str, help="Directory containing multiple ONNX models to compare")
|
||||||
|
parser.add_argument("--test-data", type=str, help="Path to .csi.jsonl test data file")
|
||||||
|
parser.add_argument("--synthetic", action="store_true", help="Use synthetic test data")
|
||||||
|
parser.add_argument("--num-samples", type=int, default=100, help="Number of synthetic samples (default: 100)")
|
||||||
|
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations (default: 10)")
|
||||||
|
parser.add_argument("--runs", type=int, default=100, help="Benchmark iterations (default: 100)")
|
||||||
|
parser.add_argument("--output", type=str, help="Save results to JSON file")
|
||||||
|
parser.add_argument("--gpu", action="store_true", help="Force GPU execution provider")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.model and not args.sweep_dir:
|
||||||
|
parser.error("Specify --model or --sweep-dir")
|
||||||
|
|
||||||
|
# Prepare test data
|
||||||
|
test_data = None
|
||||||
|
gt_keypoints = None
|
||||||
|
|
||||||
|
if args.test_data:
|
||||||
|
print(f"Loading test data from: {args.test_data}")
|
||||||
|
test_data = load_test_data(args.test_data)
|
||||||
|
print(f" Loaded {len(test_data)} windowed samples")
|
||||||
|
elif args.synthetic:
|
||||||
|
print(f"Generating {args.num_samples} synthetic samples...")
|
||||||
|
test_data = generate_synthetic_input(args.num_samples)
|
||||||
|
gt_keypoints = generate_synthetic_keypoints(args.num_samples)
|
||||||
|
print(f" Input shape: {test_data.shape}")
|
||||||
|
|
||||||
|
# Collect models
|
||||||
|
model_paths = []
|
||||||
|
if args.model:
|
||||||
|
model_paths.append(args.model)
|
||||||
|
if args.sweep_dir:
|
||||||
|
sweep = Path(args.sweep_dir)
|
||||||
|
if sweep.is_dir():
|
||||||
|
model_paths.extend(sorted(str(p) for p in sweep.glob("**/*.onnx")))
|
||||||
|
else:
|
||||||
|
# Glob pattern
|
||||||
|
from glob import glob
|
||||||
|
model_paths.extend(sorted(glob(str(sweep))))
|
||||||
|
|
||||||
|
if not model_paths:
|
||||||
|
print("ERROR: No ONNX models found.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Found {len(model_paths)} model(s) to benchmark.")
|
||||||
|
|
||||||
|
# Benchmark each model
|
||||||
|
results = []
|
||||||
|
for path in model_paths:
|
||||||
|
if not Path(path).exists():
|
||||||
|
print(f" Skipping (not found): {path}")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result = benchmark_model(
|
||||||
|
path,
|
||||||
|
test_data=test_data,
|
||||||
|
gt_keypoints=gt_keypoints,
|
||||||
|
warmup=args.warmup,
|
||||||
|
runs=args.runs,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR benchmarking {path}: {e}")
|
||||||
|
|
||||||
|
# Comparison table
|
||||||
|
if len(results) > 1:
|
||||||
|
print_comparison_table(results)
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
if args.output:
|
||||||
|
output_path = Path(args.output)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"benchmark_results": [asdict(r) for r in results],
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"num_models": len(results),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
print(f"\nResults saved to: {output_path}")
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
print("No models were successfully benchmarked.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,483 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
WiFi-DensePose Training Data Collector
|
||||||
|
|
||||||
|
Listens on UDP for CSI data from ESP32 nodes and records to .csi.jsonl
|
||||||
|
files compatible with the Rust training pipeline (MmFiDataset / CsiDataset).
|
||||||
|
|
||||||
|
Supports two packet formats:
|
||||||
|
- ADR-069 feature vectors (magic 0xC5110003, 48 bytes) — 8-dim pre-extracted
|
||||||
|
- ADR-018 raw CSI frames (magic 0xC5110001, variable) — full subcarrier data
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Interactive — prompts for scenario labels
|
||||||
|
python scripts/collect-training-data.py --port 5006
|
||||||
|
|
||||||
|
# Scripted — fixed label, 60s per recording
|
||||||
|
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
|
||||||
|
|
||||||
|
# Multiple scenarios in sequence
|
||||||
|
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
|
||||||
|
|
||||||
|
# Dual-node collection (two ESP32s on different ports)
|
||||||
|
python scripts/collect-training-data.py --port 5005 --port2 5006 --label walking
|
||||||
|
|
||||||
|
# Generate manifest only from existing recordings
|
||||||
|
python scripts/collect-training-data.py --manifest-only --output-dir data/recordings
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
- ESP32 nodes streaming CSI on UDP (see firmware/esp32-csi-node)
|
||||||
|
- Python 3.9+
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import signal
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
log = logging.getLogger("collect-data")
|
||||||
|
|
||||||
|
# ── Packet formats (must match firmware) ─────────────────────────────────────
|
||||||
|
|
||||||
|
# ADR-018 raw CSI frame header
|
||||||
|
MAGIC_CSI_RAW = 0xC5110001
|
||||||
|
# ADR-069 feature vector packet
|
||||||
|
MAGIC_FEATURES = 0xC5110003
|
||||||
|
FEATURE_PKT_FMT = "<IBBHq8f"
|
||||||
|
FEATURE_PKT_SIZE = struct.calcsize(FEATURE_PKT_FMT) # 48 bytes
|
||||||
|
|
||||||
|
# Raw CSI header: magic(4) + node_id(1) + antenna_cfg(1) + n_sub(2) + rssi(1) + noise(1) + channel(1) + reserved(1) + timestamp_ms(4)
|
||||||
|
RAW_CSI_HDR_FMT = "<IBBHbbBxI"
|
||||||
|
RAW_CSI_HDR_SIZE = struct.calcsize(RAW_CSI_HDR_FMT) # 16 bytes
|
||||||
|
|
||||||
|
|
||||||
|
# ── Packet parsing ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def parse_packet(data: bytes) -> Optional[dict]:
|
||||||
|
"""Parse a UDP packet into a frame dict, or None if unrecognized."""
|
||||||
|
if len(data) < 4:
|
||||||
|
return None
|
||||||
|
|
||||||
|
magic = struct.unpack_from("<I", data)[0]
|
||||||
|
|
||||||
|
if magic == MAGIC_FEATURES and len(data) >= FEATURE_PKT_SIZE:
|
||||||
|
return _parse_feature_packet(data)
|
||||||
|
elif magic == MAGIC_CSI_RAW and len(data) >= RAW_CSI_HDR_SIZE:
|
||||||
|
return _parse_raw_csi_packet(data)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_feature_packet(data: bytes) -> Optional[dict]:
|
||||||
|
"""Parse ADR-069 feature vector packet (48 bytes)."""
|
||||||
|
try:
|
||||||
|
magic, node_id, _, seq, ts_us, *features = struct.unpack_from(FEATURE_PKT_FMT, data)
|
||||||
|
except struct.error:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if magic != MAGIC_FEATURES:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Reject NaN/inf
|
||||||
|
import math
|
||||||
|
if any(math.isnan(f) or math.isinf(f) for f in features):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "features",
|
||||||
|
"node_id": node_id,
|
||||||
|
"seq": seq,
|
||||||
|
"timestamp_us": ts_us,
|
||||||
|
"timestamp": ts_us / 1_000_000.0,
|
||||||
|
"features": features,
|
||||||
|
"subcarriers": features, # Use features as subcarrier proxy for training
|
||||||
|
"rssi": 0.0,
|
||||||
|
"noise_floor": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_raw_csi_packet(data: bytes) -> Optional[dict]:
|
||||||
|
"""Parse ADR-018 raw CSI frame with full subcarrier data."""
|
||||||
|
try:
|
||||||
|
magic, node_id, ant_cfg, n_sub, rssi, noise, channel, ts_ms = struct.unpack_from(
|
||||||
|
RAW_CSI_HDR_FMT, data
|
||||||
|
)
|
||||||
|
except struct.error:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if magic != MAGIC_CSI_RAW:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Subcarrier data follows header as int16 I/Q pairs
|
||||||
|
payload_offset = RAW_CSI_HDR_SIZE
|
||||||
|
expected_bytes = n_sub * 2 * 2 # n_sub * (I + Q) * int16
|
||||||
|
if len(data) < payload_offset + expected_bytes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
iq_data = struct.unpack_from(f"<{n_sub * 2}h", data, payload_offset)
|
||||||
|
# Convert I/Q pairs to amplitude
|
||||||
|
subcarriers = []
|
||||||
|
for i in range(0, len(iq_data), 2):
|
||||||
|
real, imag = iq_data[i], iq_data[i + 1]
|
||||||
|
amplitude = (real ** 2 + imag ** 2) ** 0.5
|
||||||
|
subcarriers.append(amplitude)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "raw_csi",
|
||||||
|
"node_id": node_id,
|
||||||
|
"antenna_config": ant_cfg,
|
||||||
|
"n_subcarriers": n_sub,
|
||||||
|
"channel": channel,
|
||||||
|
"timestamp": ts_ms / 1000.0,
|
||||||
|
"subcarriers": subcarriers,
|
||||||
|
"rssi": float(rssi),
|
||||||
|
"noise_floor": float(noise),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── JSONL recording ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CsiRecorder:
|
||||||
|
"""Records CSI frames to .csi.jsonl files compatible with the Rust pipeline."""
|
||||||
|
|
||||||
|
def __init__(self, output_dir: str, session_name: str, label: Optional[str] = None):
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||||
|
safe_name = session_name.replace(" ", "_").replace("/", "_")
|
||||||
|
self.session_id = f"{safe_name}-{ts}"
|
||||||
|
self.label = label
|
||||||
|
self.file_path = self.output_dir / f"{self.session_id}.csi.jsonl"
|
||||||
|
self.meta_path = self.output_dir / f"{self.session_id}.csi.meta.json"
|
||||||
|
self.frame_count = 0
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.started_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
self._file = None
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
self._file = open(self.file_path, "a", encoding="utf-8")
|
||||||
|
log.info(f"Recording to: {self.file_path}")
|
||||||
|
|
||||||
|
def write_frame(self, frame: dict):
|
||||||
|
"""Write a single frame as a JSONL line."""
|
||||||
|
if self._file is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"timestamp": frame.get("timestamp", time.time()),
|
||||||
|
"subcarriers": frame.get("subcarriers", []),
|
||||||
|
"rssi": frame.get("rssi", 0.0),
|
||||||
|
"noise_floor": frame.get("noise_floor", 0.0),
|
||||||
|
"features": {
|
||||||
|
k: v for k, v in frame.items()
|
||||||
|
if k not in ("timestamp", "subcarriers", "rssi", "noise_floor", "type")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
line = json.dumps(record, separators=(",", ":"))
|
||||||
|
self._file.write(line + "\n")
|
||||||
|
self.frame_count += 1
|
||||||
|
|
||||||
|
if self.frame_count % 500 == 0:
|
||||||
|
self._file.flush()
|
||||||
|
|
||||||
|
def close(self) -> dict:
|
||||||
|
"""Close the recording and write metadata. Returns session info."""
|
||||||
|
if self._file:
|
||||||
|
self._file.flush()
|
||||||
|
self._file.close()
|
||||||
|
self._file = None
|
||||||
|
|
||||||
|
ended_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
elapsed = time.time() - self.start_time
|
||||||
|
file_size = self.file_path.stat().st_size if self.file_path.exists() else 0
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"id": self.session_id,
|
||||||
|
"name": self.session_id,
|
||||||
|
"label": self.label,
|
||||||
|
"started_at": self.started_at,
|
||||||
|
"ended_at": ended_at,
|
||||||
|
"duration_secs": round(elapsed, 2),
|
||||||
|
"frame_count": self.frame_count,
|
||||||
|
"file_size_bytes": file_size,
|
||||||
|
"file_path": str(self.file_path),
|
||||||
|
"fps": round(self.frame_count / elapsed, 1) if elapsed > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Recording stopped: {self.frame_count} frames in {elapsed:.1f}s "
|
||||||
|
f"({meta['fps']} fps, {file_size / 1024:.1f} KB)"
|
||||||
|
)
|
||||||
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
# ── Manifest generation ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def generate_manifest(output_dir: str) -> dict:
|
||||||
|
"""Scan recordings directory and generate a dataset manifest JSON."""
|
||||||
|
rec_dir = Path(output_dir)
|
||||||
|
sessions = []
|
||||||
|
|
||||||
|
for meta_file in sorted(rec_dir.glob("*.csi.meta.json")):
|
||||||
|
try:
|
||||||
|
with open(meta_file, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
sessions.append(meta)
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
log.warning(f"Skipping {meta_file}: {e}")
|
||||||
|
|
||||||
|
# Aggregate stats
|
||||||
|
total_frames = sum(s.get("frame_count", 0) for s in sessions)
|
||||||
|
total_bytes = sum(s.get("file_size_bytes", 0) for s in sessions)
|
||||||
|
labels = sorted(set(s.get("label", "unlabeled") or "unlabeled" for s in sessions))
|
||||||
|
|
||||||
|
manifest = {
|
||||||
|
"dataset": "wifi-densepose-csi",
|
||||||
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"directory": str(rec_dir),
|
||||||
|
"num_sessions": len(sessions),
|
||||||
|
"total_frames": total_frames,
|
||||||
|
"total_size_bytes": total_bytes,
|
||||||
|
"total_size_mb": round(total_bytes / (1024 * 1024), 2),
|
||||||
|
"labels": labels,
|
||||||
|
"sessions": sessions,
|
||||||
|
}
|
||||||
|
|
||||||
|
manifest_path = rec_dir / "manifest.json"
|
||||||
|
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(manifest, f, indent=2)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Manifest: {len(sessions)} sessions, {total_frames} frames, "
|
||||||
|
f"{manifest['total_size_mb']} MB, labels={labels}"
|
||||||
|
)
|
||||||
|
log.info(f"Written to: {manifest_path}")
|
||||||
|
return manifest
|
||||||
|
|
||||||
|
|
||||||
|
# ── UDP listener ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def collect_session(
|
||||||
|
port: int,
|
||||||
|
port2: Optional[int],
|
||||||
|
output_dir: str,
|
||||||
|
label: str,
|
||||||
|
duration: float,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Run a single collection session. Returns session metadata."""
|
||||||
|
name = session_name or label or "session"
|
||||||
|
recorder = CsiRecorder(output_dir, name, label)
|
||||||
|
recorder.open()
|
||||||
|
|
||||||
|
# Bind primary socket
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
sock.bind(("0.0.0.0", port))
|
||||||
|
sock.settimeout(1.0)
|
||||||
|
sockets = [sock]
|
||||||
|
|
||||||
|
# Bind secondary socket if specified
|
||||||
|
if port2:
|
||||||
|
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
sock2.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
sock2.bind(("0.0.0.0", port2))
|
||||||
|
sock2.settimeout(0.1)
|
||||||
|
sockets.append(sock2)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Collecting '{label}' for {duration}s on port(s) "
|
||||||
|
f"{port}{f', {port2}' if port2 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
dropped = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while time.time() - start < duration:
|
||||||
|
for s in sockets:
|
||||||
|
try:
|
||||||
|
data, addr = s.recvfrom(4096)
|
||||||
|
except socket.timeout:
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame = parse_packet(data)
|
||||||
|
if frame:
|
||||||
|
recorder.write_frame(frame)
|
||||||
|
else:
|
||||||
|
dropped += 1
|
||||||
|
|
||||||
|
# Progress update every 5s
|
||||||
|
elapsed = time.time() - start
|
||||||
|
if recorder.frame_count > 0 and int(elapsed) % 5 == 0 and int(elapsed) > 0:
|
||||||
|
remaining = duration - elapsed
|
||||||
|
if remaining > 0 and int(elapsed * 10) % 50 == 0:
|
||||||
|
log.info(
|
||||||
|
f" {recorder.frame_count} frames collected, "
|
||||||
|
f"{remaining:.0f}s remaining..."
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
log.info("Interrupted by user.")
|
||||||
|
finally:
|
||||||
|
for s in sockets:
|
||||||
|
s.close()
|
||||||
|
|
||||||
|
if dropped > 0:
|
||||||
|
log.warning(f" {dropped} unrecognized packets dropped")
|
||||||
|
|
||||||
|
return recorder.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Collect CSI training data from ESP32 nodes via UDP",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
# Interactive label input
|
||||||
|
python scripts/collect-training-data.py --port 5006
|
||||||
|
|
||||||
|
# Fixed label, 60 seconds
|
||||||
|
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
|
||||||
|
|
||||||
|
# Multiple scenarios
|
||||||
|
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
|
||||||
|
|
||||||
|
# Dual ESP32 nodes
|
||||||
|
python scripts/collect-training-data.py --port 5005 --port2 5006 --label test
|
||||||
|
|
||||||
|
# Generate manifest from existing recordings
|
||||||
|
python scripts/collect-training-data.py --manifest-only
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--port", type=int, default=5006, help="Primary UDP port (default: 5006)")
|
||||||
|
parser.add_argument("--port2", type=int, default=None, help="Secondary UDP port for dual-node")
|
||||||
|
parser.add_argument("--output-dir", default="data/recordings", help="Output directory (default: data/recordings)")
|
||||||
|
parser.add_argument("--label", default=None, help="Activity label for the recording")
|
||||||
|
parser.add_argument("--duration", type=float, default=30.0, help="Recording duration in seconds (default: 30)")
|
||||||
|
parser.add_argument("--scenarios", default=None, help="Comma-separated list of scenarios to record sequentially")
|
||||||
|
parser.add_argument("--pause", type=float, default=5.0, help="Pause between scenarios in seconds (default: 5)")
|
||||||
|
parser.add_argument("--manifest-only", action="store_true", help="Only generate manifest from existing recordings")
|
||||||
|
parser.add_argument("--repeats", type=int, default=1, help="Number of repeats per scenario (default: 1)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Manifest-only mode
|
||||||
|
if args.manifest_only:
|
||||||
|
generate_manifest(args.output_dir)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Collect scenarios
|
||||||
|
all_sessions = []
|
||||||
|
|
||||||
|
if args.scenarios:
|
||||||
|
# Multi-scenario sequential collection
|
||||||
|
scenarios = [s.strip() for s in args.scenarios.split(",") if s.strip()]
|
||||||
|
total = len(scenarios) * args.repeats
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
for repeat in range(args.repeats):
|
||||||
|
for scenario in scenarios:
|
||||||
|
idx += 1
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" Scenario {idx}/{total}: '{scenario}' (repeat {repeat+1}/{args.repeats})")
|
||||||
|
print(f" Duration: {args.duration}s")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
if idx > 1:
|
||||||
|
print(f" Starting in {args.pause}s... (get into position)")
|
||||||
|
time.sleep(args.pause)
|
||||||
|
|
||||||
|
meta = collect_session(
|
||||||
|
port=args.port,
|
||||||
|
port2=args.port2,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
label=scenario,
|
||||||
|
duration=args.duration,
|
||||||
|
session_name=f"{scenario}_r{repeat+1:02d}",
|
||||||
|
)
|
||||||
|
all_sessions.append(meta)
|
||||||
|
|
||||||
|
elif args.label:
|
||||||
|
# Single labeled recording
|
||||||
|
meta = collect_session(
|
||||||
|
port=args.port,
|
||||||
|
port2=args.port2,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
label=args.label,
|
||||||
|
duration=args.duration,
|
||||||
|
)
|
||||||
|
all_sessions.append(meta)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Interactive mode — prompt for labels
|
||||||
|
print("\nInteractive data collection mode.")
|
||||||
|
print("Type a label for each recording, or 'q' to quit.\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
label = input("Label (or 'q' to quit): ").strip()
|
||||||
|
if label.lower() in ("q", "quit", "exit"):
|
||||||
|
break
|
||||||
|
if not label:
|
||||||
|
print(" Empty label. Try again.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
duration = args.duration
|
||||||
|
try:
|
||||||
|
dur_input = input(f"Duration in seconds [{duration}]: ").strip()
|
||||||
|
if dur_input:
|
||||||
|
duration = float(dur_input)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print(f" Recording '{label}' for {duration}s — starting now...")
|
||||||
|
meta = collect_session(
|
||||||
|
port=args.port,
|
||||||
|
port2=args.port2,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
label=label,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
all_sessions.append(meta)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Generate manifest
|
||||||
|
if all_sessions:
|
||||||
|
print(f"\nCollected {len(all_sessions)} session(s).")
|
||||||
|
manifest = generate_manifest(args.output_dir)
|
||||||
|
|
||||||
|
total_frames = sum(s.get("frame_count", 0) for s in all_sessions)
|
||||||
|
print(f"\nSummary:")
|
||||||
|
print(f" Sessions: {len(all_sessions)}")
|
||||||
|
print(f" Total frames: {total_frames}")
|
||||||
|
print(f" Output: {args.output_dir}/")
|
||||||
|
print(f" Manifest: {args.output_dir}/manifest.json")
|
||||||
|
else:
|
||||||
|
print("No sessions recorded.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,469 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# ==============================================================================
|
||||||
|
# GCloud GPU Training Script for WiFi-DensePose
|
||||||
|
# ==============================================================================
|
||||||
|
#
|
||||||
|
# Creates a GCloud VM with GPU, runs the Rust training pipeline, downloads
|
||||||
|
# the trained model artifacts, and tears down the VM to avoid ongoing costs.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# bash scripts/gcloud-train.sh [OPTIONS]
|
||||||
|
#
|
||||||
|
# Options:
|
||||||
|
# --gpu l4|a100|h100 GPU type (default: l4)
|
||||||
|
# --zone ZONE GCloud zone (default: us-central1-a)
|
||||||
|
# --hours N Max VM lifetime in hours (default: 2)
|
||||||
|
# --config FILE Training config JSON (default: scripts/training-config-sweep.json entry 0)
|
||||||
|
# --data-dir DIR Local data directory to upload (default: data/recordings)
|
||||||
|
# --dry-run Run smoke test with synthetic data
|
||||||
|
# --sweep Run full hyperparameter sweep (all configs)
|
||||||
|
# --keep-vm Do not delete VM after training
|
||||||
|
# --instance NAME Custom VM instance name
|
||||||
|
#
|
||||||
|
# Prerequisites:
|
||||||
|
# - gcloud CLI authenticated: gcloud auth login
|
||||||
|
# - Project set: gcloud config set project cognitum-20260110
|
||||||
|
# - Quota for GPUs in the selected zone
|
||||||
|
#
|
||||||
|
# Cost estimates:
|
||||||
|
# L4 (~$0.80/hr) — good for prototyping and small sweeps
|
||||||
|
# A100 40GB (~$3.60/hr) — full training runs
|
||||||
|
# H100 80GB (~$11.00/hr) — large batch / fast iteration
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ── Defaults ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
PROJECT="cognitum-20260110"
|
||||||
|
GPU_TYPE="l4"
|
||||||
|
ZONE="us-central1-a"
|
||||||
|
MAX_HOURS=2
|
||||||
|
CONFIG_FILE=""
|
||||||
|
DATA_DIR="data/recordings"
|
||||||
|
DRY_RUN=false
|
||||||
|
SWEEP=false
|
||||||
|
KEEP_VM=false
|
||||||
|
INSTANCE_NAME=""
|
||||||
|
REPO_URL="https://github.com/ruvnet/wifi-densepose.git"
|
||||||
|
BRANCH="main"
|
||||||
|
|
||||||
|
# ── Parse arguments ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--gpu) GPU_TYPE="$2"; shift 2 ;;
|
||||||
|
--zone) ZONE="$2"; shift 2 ;;
|
||||||
|
--hours) MAX_HOURS="$2"; shift 2 ;;
|
||||||
|
--config) CONFIG_FILE="$2"; shift 2 ;;
|
||||||
|
--data-dir) DATA_DIR="$2"; shift 2 ;;
|
||||||
|
--dry-run) DRY_RUN=true; shift ;;
|
||||||
|
--sweep) SWEEP=true; shift ;;
|
||||||
|
--keep-vm) KEEP_VM=true; shift ;;
|
||||||
|
--instance) INSTANCE_NAME="$2"; shift 2 ;;
|
||||||
|
--branch) BRANCH="$2"; shift 2 ;;
|
||||||
|
-h|--help)
|
||||||
|
head -35 "$0" | tail -30
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "ERROR: Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# ── GPU configuration map ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
declare -A GPU_ACCELERATOR=(
|
||||||
|
[l4]="nvidia-l4"
|
||||||
|
[a100]="nvidia-tesla-a100"
|
||||||
|
[h100]="nvidia-h100-80gb"
|
||||||
|
)
|
||||||
|
|
||||||
|
declare -A GPU_MACHINE_TYPE=(
|
||||||
|
[l4]="g2-standard-8"
|
||||||
|
[a100]="a2-highgpu-1g"
|
||||||
|
[h100]="a3-highgpu-1g"
|
||||||
|
)
|
||||||
|
|
||||||
|
declare -A GPU_BOOT_DISK=(
|
||||||
|
[l4]="200"
|
||||||
|
[a100]="300"
|
||||||
|
[h100]="300"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ -z "${GPU_ACCELERATOR[$GPU_TYPE]+x}" ]]; then
|
||||||
|
echo "ERROR: Unknown GPU type '$GPU_TYPE'. Choose: l4, a100, h100"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ACCELERATOR="${GPU_ACCELERATOR[$GPU_TYPE]}"
|
||||||
|
MACHINE_TYPE="${GPU_MACHINE_TYPE[$GPU_TYPE]}"
|
||||||
|
BOOT_DISK_GB="${GPU_BOOT_DISK[$GPU_TYPE]}"
|
||||||
|
|
||||||
|
# ── Instance naming ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TIMESTAMP=$(date +%Y%m%d-%H%M%S)
|
||||||
|
if [[ -z "$INSTANCE_NAME" ]]; then
|
||||||
|
INSTANCE_NAME="wdp-train-${GPU_TYPE}-${TIMESTAMP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Announce plan ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo " WiFi-DensePose GCloud GPU Training"
|
||||||
|
echo "============================================================"
|
||||||
|
echo " Project: $PROJECT"
|
||||||
|
echo " Instance: $INSTANCE_NAME"
|
||||||
|
echo " Zone: $ZONE"
|
||||||
|
echo " GPU: $GPU_TYPE ($ACCELERATOR)"
|
||||||
|
echo " Machine: $MACHINE_TYPE"
|
||||||
|
echo " Boot disk: ${BOOT_DISK_GB}GB"
|
||||||
|
echo " Max runtime: ${MAX_HOURS}h"
|
||||||
|
echo " Data dir: $DATA_DIR"
|
||||||
|
echo " Dry run: $DRY_RUN"
|
||||||
|
echo " Sweep: $SWEEP"
|
||||||
|
echo " Branch: $BRANCH"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ── Verify gcloud auth ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
if ! gcloud auth list --filter=status:ACTIVE --format="value(account)" 2>/dev/null | head -1 | grep -q '@'; then
|
||||||
|
echo "ERROR: No active gcloud account. Run: gcloud auth login"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
gcloud config set project "$PROJECT" --quiet
|
||||||
|
|
||||||
|
# ── Build startup script ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
STARTUP_SCRIPT=$(cat <<'STARTUP_EOF'
|
||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
exec > /var/log/wdp-setup.log 2>&1
|
||||||
|
|
||||||
|
echo "=== WiFi-DensePose GPU VM Setup ==="
|
||||||
|
echo "Started: $(date)"
|
||||||
|
|
||||||
|
# Wait for GPU driver
|
||||||
|
echo "Waiting for NVIDIA driver..."
|
||||||
|
for i in $(seq 1 60); do
|
||||||
|
if nvidia-smi &>/dev/null; then
|
||||||
|
echo "GPU ready after ${i}s"
|
||||||
|
nvidia-smi
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 5
|
||||||
|
done
|
||||||
|
|
||||||
|
if ! nvidia-smi &>/dev/null; then
|
||||||
|
echo "ERROR: GPU driver not available after 300s"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install Rust toolchain
|
||||||
|
echo "Installing Rust toolchain..."
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
|
||||||
|
source "$HOME/.cargo/env"
|
||||||
|
rustc --version
|
||||||
|
cargo --version
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
echo "Installing system dependencies..."
|
||||||
|
apt-get update -qq
|
||||||
|
apt-get install -y -qq pkg-config libssl-dev cmake clang
|
||||||
|
|
||||||
|
# Find libtorch from the Deep Learning VM's PyTorch installation
|
||||||
|
echo "Locating libtorch..."
|
||||||
|
PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')" 2>/dev/null || echo "")
|
||||||
|
if [[ -n "$PYTORCH_LIB" && -d "$PYTORCH_LIB" ]]; then
|
||||||
|
export LIBTORCH="$PYTORCH_LIB"
|
||||||
|
export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}"
|
||||||
|
echo "Found libtorch at: $LIBTORCH"
|
||||||
|
else
|
||||||
|
echo "WARNING: PyTorch not found in system Python. Installing via pip..."
|
||||||
|
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')")
|
||||||
|
export LIBTORCH="$PYTORCH_LIB"
|
||||||
|
export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Persist env vars
|
||||||
|
cat >> /etc/environment <<ENV_VARS
|
||||||
|
LIBTORCH=$LIBTORCH
|
||||||
|
LD_LIBRARY_PATH=$LIBTORCH:\$LD_LIBRARY_PATH
|
||||||
|
PATH=$HOME/.cargo/bin:\$PATH
|
||||||
|
ENV_VARS
|
||||||
|
|
||||||
|
echo "=== Setup complete: $(date) ==="
|
||||||
|
touch /tmp/wdp-setup-done
|
||||||
|
STARTUP_EOF
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Step 1: Create the VM ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
echo "[1/7] Creating VM instance: $INSTANCE_NAME ..."
|
||||||
|
|
||||||
|
gcloud compute instances create "$INSTANCE_NAME" \
|
||||||
|
--project="$PROJECT" \
|
||||||
|
--zone="$ZONE" \
|
||||||
|
--machine-type="$MACHINE_TYPE" \
|
||||||
|
--accelerator="type=$ACCELERATOR,count=1" \
|
||||||
|
--image-family="common-cu121-ubuntu-2204" \
|
||||||
|
--image-project="deeplearning-platform-release" \
|
||||||
|
--boot-disk-size="${BOOT_DISK_GB}GB" \
|
||||||
|
--boot-disk-type="pd-ssd" \
|
||||||
|
--maintenance-policy=TERMINATE \
|
||||||
|
--metadata="install-nvidia-driver=True" \
|
||||||
|
--metadata-from-file="startup-script=<(echo "$STARTUP_SCRIPT")" \
|
||||||
|
--scopes="default,storage-rw" \
|
||||||
|
--labels="purpose=wdp-training,gpu=${GPU_TYPE}" \
|
||||||
|
--quiet
|
||||||
|
|
||||||
|
echo " VM created. Waiting for startup script to complete..."
|
||||||
|
|
||||||
|
# ── Step 2: Wait for setup ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
echo "[2/7] Waiting for setup to complete (GPU driver + Rust toolchain)..."
|
||||||
|
|
||||||
|
for i in $(seq 1 60); do
|
||||||
|
if gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="test -f /tmp/wdp-setup-done" --quiet 2>/dev/null; then
|
||||||
|
echo " Setup complete after $((i * 15))s"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if [[ $i -eq 60 ]]; then
|
||||||
|
echo "ERROR: Setup timed out after 15 minutes."
|
||||||
|
echo "Check logs: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE --command='cat /var/log/wdp-setup.log'"
|
||||||
|
if [[ "$KEEP_VM" == "false" ]]; then
|
||||||
|
echo "Cleaning up VM..."
|
||||||
|
gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet
|
||||||
|
fi
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
sleep 15
|
||||||
|
done
|
||||||
|
|
||||||
|
# ── Step 3: Clone repo and build ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
echo "[3/7] Cloning repository and building training binary..."
|
||||||
|
|
||||||
|
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<CLONE_EOF
|
||||||
|
set -euo pipefail
|
||||||
|
source \$HOME/.cargo/env
|
||||||
|
|
||||||
|
# Clone the repo
|
||||||
|
if [[ ! -d ~/wifi-densepose ]]; then
|
||||||
|
git clone --depth 1 --branch "$BRANCH" "$REPO_URL" ~/wifi-densepose
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set libtorch environment
|
||||||
|
export LIBTORCH=\$(python3 -c "import torch; print(torch.__path__[0] + '/lib')")
|
||||||
|
export LD_LIBRARY_PATH="\${LIBTORCH}:\${LD_LIBRARY_PATH:-}"
|
||||||
|
|
||||||
|
# Build the training binary with tch-backend
|
||||||
|
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
|
||||||
|
echo "Building with LIBTORCH=\$LIBTORCH ..."
|
||||||
|
cargo build --release --features tch-backend --bin train 2>&1 | tail -5
|
||||||
|
|
||||||
|
echo "Build complete."
|
||||||
|
ls -lh target/release/train
|
||||||
|
CLONE_EOF
|
||||||
|
)"
|
||||||
|
|
||||||
|
# ── Step 4: Upload training data ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
echo "[4/7] Uploading training data..."
|
||||||
|
|
||||||
|
if [[ -d "$DATA_DIR" ]] && [[ "$(ls -A "$DATA_DIR" 2>/dev/null)" ]]; then
|
||||||
|
# Create a tarball of the data directory
|
||||||
|
DATA_TAR="/tmp/wdp-training-data-${TIMESTAMP}.tar.gz"
|
||||||
|
tar czf "$DATA_TAR" -C "$(dirname "$DATA_DIR")" "$(basename "$DATA_DIR")"
|
||||||
|
DATA_SIZE=$(du -h "$DATA_TAR" | cut -f1)
|
||||||
|
echo " Uploading ${DATA_SIZE} of training data..."
|
||||||
|
|
||||||
|
gcloud compute scp "$DATA_TAR" "${INSTANCE_NAME}:~/training-data.tar.gz" --zone="$ZONE" --quiet
|
||||||
|
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="
|
||||||
|
mkdir -p ~/wifi-densepose/data
|
||||||
|
tar xzf ~/training-data.tar.gz -C ~/wifi-densepose/data/
|
||||||
|
echo 'Data extracted:'
|
||||||
|
find ~/wifi-densepose/data -name '*.jsonl' -o -name '*.csi.jsonl' | head -20
|
||||||
|
"
|
||||||
|
rm -f "$DATA_TAR"
|
||||||
|
else
|
||||||
|
echo " No local data at '$DATA_DIR'. Training will use --dry-run or MM-Fi."
|
||||||
|
if [[ "$DRY_RUN" == "false" && "$SWEEP" == "false" ]]; then
|
||||||
|
echo " WARNING: No data and --dry-run not set. Forcing --dry-run."
|
||||||
|
DRY_RUN=true
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Step 5: Upload config and run training ────────────────────────────────────
|
||||||
|
|
||||||
|
echo "[5/7] Running training..."
|
||||||
|
|
||||||
|
# Upload sweep config if doing a sweep
|
||||||
|
if [[ "$SWEEP" == "true" ]]; then
|
||||||
|
SWEEP_FILE="scripts/training-config-sweep.json"
|
||||||
|
if [[ -f "$SWEEP_FILE" ]]; then
|
||||||
|
gcloud compute scp "$SWEEP_FILE" "${INSTANCE_NAME}:~/sweep-configs.json" --zone="$ZONE" --quiet
|
||||||
|
else
|
||||||
|
echo "ERROR: Sweep config not found at $SWEEP_FILE"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Upload single config if specified
|
||||||
|
if [[ -n "$CONFIG_FILE" ]]; then
|
||||||
|
gcloud compute scp "$CONFIG_FILE" "${INSTANCE_NAME}:~/train-config.json" --zone="$ZONE" --quiet
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build the training command
|
||||||
|
TRAIN_CMD_BASE="
|
||||||
|
set -euo pipefail
|
||||||
|
source \$HOME/.cargo/env
|
||||||
|
export LIBTORCH=\$(python3 -c \"import torch; print(torch.__path__[0] + '/lib')\")
|
||||||
|
export LD_LIBRARY_PATH=\"\${LIBTORCH}:\${LD_LIBRARY_PATH:-}\"
|
||||||
|
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
|
||||||
|
|
||||||
|
# Set auto-shutdown timer (safety net)
|
||||||
|
sudo shutdown -P +$((MAX_HOURS * 60)) &
|
||||||
|
|
||||||
|
TRAIN_BIN=./target/release/train
|
||||||
|
"
|
||||||
|
|
||||||
|
if [[ "$SWEEP" == "true" ]]; then
|
||||||
|
# Run all configs in the sweep file
|
||||||
|
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<SWEEP_EOF
|
||||||
|
$TRAIN_CMD_BASE
|
||||||
|
|
||||||
|
echo "=== Hyperparameter Sweep ==="
|
||||||
|
SWEEP_FILE=~/sweep-configs.json
|
||||||
|
NUM_CONFIGS=\$(python3 -c "import json; print(len(json.load(open('\$SWEEP_FILE'))['configs']))")
|
||||||
|
echo "Running \$NUM_CONFIGS configurations..."
|
||||||
|
|
||||||
|
mkdir -p ~/results
|
||||||
|
|
||||||
|
for i in \$(seq 0 \$((NUM_CONFIGS - 1))); do
|
||||||
|
echo ""
|
||||||
|
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||||
|
echo " Config \$((i+1)) / \$NUM_CONFIGS"
|
||||||
|
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||||
|
|
||||||
|
# Extract single config to temp file
|
||||||
|
python3 -c "
|
||||||
|
import json, sys
|
||||||
|
sweep = json.load(open('\$SWEEP_FILE'))
|
||||||
|
cfg = sweep['configs'][\$i]
|
||||||
|
# Merge with base config
|
||||||
|
base = sweep.get('base', {})
|
||||||
|
merged = {**base, **cfg}
|
||||||
|
# Set checkpoint dir per config
|
||||||
|
merged['checkpoint_dir'] = f'checkpoints/sweep_{i:02d}'
|
||||||
|
merged['log_dir'] = f'logs/sweep_{i:02d}'
|
||||||
|
json.dump(merged, open('/tmp/sweep_config_\${i}.json', 'w'), indent=2)
|
||||||
|
print(f\"Config \${i}: lr={merged.get('learning_rate', '?')}, bs={merged.get('batch_size', '?')}, bb={merged.get('backbone_channels', '?')}\")
|
||||||
|
"
|
||||||
|
|
||||||
|
START_TIME=\$(date +%s)
|
||||||
|
|
||||||
|
\$TRAIN_BIN --config /tmp/sweep_config_\${i}.json --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run" ) 2>&1 | tee ~/results/sweep_\${i}.log || true
|
||||||
|
|
||||||
|
END_TIME=\$(date +%s)
|
||||||
|
ELAPSED=\$(( END_TIME - START_TIME ))
|
||||||
|
echo " Completed in \${ELAPSED}s"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== Sweep Complete ==="
|
||||||
|
echo "Results in ~/results/"
|
||||||
|
ls -lh ~/results/
|
||||||
|
SWEEP_EOF
|
||||||
|
)"
|
||||||
|
elif [[ -n "$CONFIG_FILE" ]]; then
|
||||||
|
# Single config run
|
||||||
|
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<SINGLE_EOF
|
||||||
|
$TRAIN_CMD_BASE
|
||||||
|
echo "=== Training with custom config ==="
|
||||||
|
\$TRAIN_BIN --config ~/train-config.json --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run" ) 2>&1 | tee ~/train.log
|
||||||
|
SINGLE_EOF
|
||||||
|
)"
|
||||||
|
else
|
||||||
|
# Default config run
|
||||||
|
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<DEFAULT_EOF
|
||||||
|
$TRAIN_CMD_BASE
|
||||||
|
echo "=== Training with default config ==="
|
||||||
|
\$TRAIN_BIN --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run --dry-run-samples 256" ) 2>&1 | tee ~/train.log
|
||||||
|
DEFAULT_EOF
|
||||||
|
)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Step 6: Download results ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
echo "[6/7] Downloading trained model artifacts..."
|
||||||
|
|
||||||
|
LOCAL_RESULTS="training-results/${INSTANCE_NAME}"
|
||||||
|
mkdir -p "$LOCAL_RESULTS"
|
||||||
|
|
||||||
|
# Package results on the VM
|
||||||
|
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="
|
||||||
|
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
|
||||||
|
tar czf ~/training-artifacts.tar.gz \
|
||||||
|
checkpoints/ \
|
||||||
|
logs/ \
|
||||||
|
2>/dev/null || true
|
||||||
|
|
||||||
|
# Also grab sweep results if they exist
|
||||||
|
if [[ -d ~/results ]]; then
|
||||||
|
tar czf ~/sweep-results.tar.gz -C ~ results/ 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
ls -lh ~/training-artifacts.tar.gz ~/sweep-results.tar.gz 2>/dev/null || true
|
||||||
|
"
|
||||||
|
|
||||||
|
# Download artifacts
|
||||||
|
gcloud compute scp "${INSTANCE_NAME}:~/training-artifacts.tar.gz" \
|
||||||
|
"${LOCAL_RESULTS}/training-artifacts.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true
|
||||||
|
|
||||||
|
if [[ "$SWEEP" == "true" ]]; then
|
||||||
|
gcloud compute scp "${INSTANCE_NAME}:~/sweep-results.tar.gz" \
|
||||||
|
"${LOCAL_RESULTS}/sweep-results.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Download training log
|
||||||
|
gcloud compute scp "${INSTANCE_NAME}:~/train.log" \
|
||||||
|
"${LOCAL_RESULTS}/train.log" --zone="$ZONE" --quiet 2>/dev/null || true
|
||||||
|
|
||||||
|
# Extract locally
|
||||||
|
if [[ -f "${LOCAL_RESULTS}/training-artifacts.tar.gz" ]]; then
|
||||||
|
tar xzf "${LOCAL_RESULTS}/training-artifacts.tar.gz" -C "$LOCAL_RESULTS/"
|
||||||
|
echo " Artifacts extracted to: $LOCAL_RESULTS/"
|
||||||
|
find "$LOCAL_RESULTS" -name "*.pt" -o -name "*.onnx" -o -name "*.rvf" 2>/dev/null | head -20
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Step 7: Cleanup ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
if [[ "$KEEP_VM" == "true" ]]; then
|
||||||
|
echo "[7/7] Keeping VM alive (--keep-vm). Remember to delete it manually:"
|
||||||
|
echo " gcloud compute instances delete $INSTANCE_NAME --zone=$ZONE --quiet"
|
||||||
|
echo " SSH: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE"
|
||||||
|
else
|
||||||
|
echo "[7/7] Deleting VM to avoid ongoing costs..."
|
||||||
|
gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet
|
||||||
|
echo " VM deleted."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Summary ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "============================================================"
|
||||||
|
echo " Training Complete"
|
||||||
|
echo "============================================================"
|
||||||
|
echo " Results: $LOCAL_RESULTS/"
|
||||||
|
echo " GPU: $GPU_TYPE ($ZONE)"
|
||||||
|
echo " Instance: $INSTANCE_NAME"
|
||||||
|
if [[ "$KEEP_VM" == "true" ]]; then
|
||||||
|
echo " VM: STILL RUNNING (delete manually!)"
|
||||||
|
fi
|
||||||
|
echo "============================================================"
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
{
|
||||||
|
"description": "WiFi-DensePose hyperparameter sweep — 10 configurations exploring learning rate, batch size, backbone width, window length, loss ratios, and warmup schedules.",
|
||||||
|
"base": {
|
||||||
|
"num_subcarriers": 56,
|
||||||
|
"native_subcarriers": 114,
|
||||||
|
"num_antennas_tx": 3,
|
||||||
|
"num_antennas_rx": 3,
|
||||||
|
"heatmap_size": 56,
|
||||||
|
"num_keypoints": 17,
|
||||||
|
"num_body_parts": 24,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"num_epochs": 50,
|
||||||
|
"lr_gamma": 0.1,
|
||||||
|
"grad_clip_norm": 1.0,
|
||||||
|
"val_every_epochs": 1,
|
||||||
|
"early_stopping_patience": 10,
|
||||||
|
"save_top_k": 3,
|
||||||
|
"use_gpu": true,
|
||||||
|
"gpu_device_id": 0,
|
||||||
|
"num_workers": 4,
|
||||||
|
"seed": 42
|
||||||
|
},
|
||||||
|
"configs": [
|
||||||
|
{
|
||||||
|
"_name": "baseline",
|
||||||
|
"_description": "Default config — reference baseline",
|
||||||
|
"learning_rate": 1e-3,
|
||||||
|
"batch_size": 8,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr_milestones": [30, 45],
|
||||||
|
"lambda_kp": 0.3,
|
||||||
|
"lambda_dp": 0.6,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "low_lr_large_batch",
|
||||||
|
"_description": "Lower LR with larger batch — stable convergence",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size": 16,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 10,
|
||||||
|
"lr_milestones": [30, 45],
|
||||||
|
"lambda_kp": 0.3,
|
||||||
|
"lambda_dp": 0.6,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "high_lr_small_batch",
|
||||||
|
"_description": "Higher LR with small batch — fast exploration",
|
||||||
|
"learning_rate": 2e-3,
|
||||||
|
"batch_size": 4,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 3,
|
||||||
|
"lr_milestones": [20, 40],
|
||||||
|
"lambda_kp": 0.3,
|
||||||
|
"lambda_dp": 0.6,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "narrow_backbone",
|
||||||
|
"_description": "128-channel backbone — faster training, lower VRAM",
|
||||||
|
"learning_rate": 1e-3,
|
||||||
|
"batch_size": 16,
|
||||||
|
"backbone_channels": 128,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr_milestones": [30, 45],
|
||||||
|
"lambda_kp": 0.3,
|
||||||
|
"lambda_dp": 0.6,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "short_window",
|
||||||
|
"_description": "50-frame window — lower latency, tests temporal sensitivity",
|
||||||
|
"learning_rate": 5e-4,
|
||||||
|
"batch_size": 16,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 50,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr_milestones": [30, 45],
|
||||||
|
"lambda_kp": 0.3,
|
||||||
|
"lambda_dp": 0.6,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "keypoint_heavy",
|
||||||
|
"_description": "Heavier keypoint loss — prioritize skeleton accuracy",
|
||||||
|
"learning_rate": 5e-4,
|
||||||
|
"batch_size": 8,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr_milestones": [30, 45],
|
||||||
|
"lambda_kp": 0.5,
|
||||||
|
"lambda_dp": 0.4,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "contrastive_heavy",
|
||||||
|
"_description": "Strong contrastive/transfer loss — self-supervised pretraining focus",
|
||||||
|
"learning_rate": 5e-4,
|
||||||
|
"batch_size": 8,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 10,
|
||||||
|
"lr_milestones": [30, 45],
|
||||||
|
"lambda_kp": 0.2,
|
||||||
|
"lambda_dp": 0.3,
|
||||||
|
"lambda_tr": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "wide_backbone_long_warmup",
|
||||||
|
"_description": "256-ch backbone + long warmup + moderate LR",
|
||||||
|
"learning_rate": 5e-4,
|
||||||
|
"batch_size": 8,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 10,
|
||||||
|
"lr_milestones": [35, 48],
|
||||||
|
"lambda_kp": 0.3,
|
||||||
|
"lambda_dp": 0.6,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "narrow_short_aggressive",
|
||||||
|
"_description": "128-ch + 50-frame + high LR — fast cheap exploration",
|
||||||
|
"learning_rate": 2e-3,
|
||||||
|
"batch_size": 16,
|
||||||
|
"backbone_channels": 128,
|
||||||
|
"window_frames": 50,
|
||||||
|
"warmup_epochs": 3,
|
||||||
|
"lr_milestones": [20, 40],
|
||||||
|
"lambda_kp": 0.4,
|
||||||
|
"lambda_dp": 0.5,
|
||||||
|
"lambda_tr": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_name": "balanced_medium",
|
||||||
|
"_description": "Balanced loss, medium LR, medium batch — robust default",
|
||||||
|
"learning_rate": 5e-4,
|
||||||
|
"batch_size": 8,
|
||||||
|
"backbone_channels": 256,
|
||||||
|
"window_frames": 100,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr_milestones": [25, 40],
|
||||||
|
"lambda_kp": 0.35,
|
||||||
|
"lambda_dp": 0.45,
|
||||||
|
"lambda_tr": 0.2
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue