wifi-densepose/scripts/occworld_retrain.py

286 lines
11 KiB
Python

"""
Phase 5 — OccWorld VQVAE + Transformer retraining on RuView indoor occupancy.
Two-stage training pipeline:
Stage 1: Retrain VQVAE tokenizer on RuView snapshots
Stage 2: Retrain autoregressive transformer on tokenized sequences
Usage:
# Stage 1: VQVAE
python3 scripts/occworld_retrain.py vqvae \
--snapshots /tmp/snapshots/ \
--work-dir out/ruview_vqvae \
--epochs 200
# Stage 2: Transformer (requires Stage 1 checkpoint)
python3 scripts/occworld_retrain.py transformer \
--snapshots /tmp/snapshots/ \
--vqvae-checkpoint out/ruview_vqvae/latest.pth \
--work-dir out/ruview_occworld \
--epochs 200
# Generate training snapshots from the live sensing server
python3 scripts/occworld_retrain.py record \
--server http://localhost:8080 \
--out-dir /tmp/snapshots/scene_live \
--duration 3600
Requirements:
ml-env with OccWorld installed (see ADR-147 §3)
At least 16 GB VRAM for training (RTX 5080 sufficient at batch=1)
"""
from __future__ import annotations
import argparse
import logging
import os
import sys
import time
from pathlib import Path
log = logging.getLogger(__name__)
# ── Stage 0: Record snapshots from the live sensing server ───────────────────
def cmd_record(args: argparse.Namespace) -> None:
"""Stream WorldGraph snapshots from the sensing server REST API."""
import json
import urllib.request
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
url = f"{args.server.rstrip('/')}/api/v1/worldgraph/snapshot"
end_time = time.time() + args.duration
frame_idx = 0
interval = args.interval
log.info("Recording snapshots from %s%s for %ds", url, out_dir, args.duration)
while time.time() < end_time:
try:
with urllib.request.urlopen(url, timeout=5) as resp:
snap = json.loads(resp.read())
out_path = out_dir / f"frame_{frame_idx:06d}.json"
out_path.write_text(json.dumps(snap))
frame_idx += 1
if frame_idx % 100 == 0:
log.info("Recorded %d frames", frame_idx)
except Exception as exc:
log.warning("Snapshot fetch failed: %s", exc)
time.sleep(interval)
log.info("Done — recorded %d frames to %s", frame_idx, out_dir)
# ── Stage 1: VQVAE retraining ────────────────────────────────────────────────
def cmd_vqvae(args: argparse.Namespace) -> None:
"""Retrain the OccWorld VQVAE tokenizer on RuView indoor occupancy."""
sys.path.insert(0, str(Path(args.occworld_dir).resolve()))
import torch
from mmengine.config import Config
from mmengine.registry import MODELS
try:
import model as occmodel # noqa: F401 — registers custom MODELS
except ImportError:
log.error("Could not import OccWorld model package. Set --occworld-dir correctly.")
sys.exit(1)
from ruview_occ_dataset import RuViewOccDataset
cfg = Config.fromfile(args.config)
work_dir = Path(args.work_dir)
work_dir.mkdir(parents=True, exist_ok=True)
# Build VQVAE only
vae = MODELS.build(cfg.model.vae).cuda()
log.info("VQVAE params: %.1fM", sum(p.numel() for p in vae.parameters()) / 1e6)
ds = RuViewOccDataset(
args.snapshots,
return_len=cfg.model.get("num_frames", 15) + 1,
voxel_m=args.voxel_m,
x_min=args.x_min,
y_min=args.y_min,
)
log.info("Dataset: %d windows from %s", len(ds), args.snapshots)
if len(ds) == 0:
log.error("No training windows found in %s — record snapshots first.", args.snapshots)
sys.exit(1)
loader = torch.utils.data.DataLoader(
ds, batch_size=1, shuffle=not args.no_shuffle, num_workers=0,
collate_fn=lambda b: b[0], # dict passthrough
)
opt = torch.optim.AdamW(vae.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
best_loss = float("inf")
for epoch in range(args.epochs):
vae.train()
epoch_loss = 0.0
for batch in loader:
occ = torch.from_numpy(batch["target_occs"]).long().unsqueeze(0).cuda() # (1,F,H,W,D)
# VQVAE forward: encode + quantize + decode, returns reconstruction loss
z, shape = vae.forward_encoder(occ)
z = vae.vqvae.quant_conv(z)
z_q, vq_loss, _ = vae.vqvae.forward_quantizer(z, is_voxel=False)
z_q = vae.vqvae.post_quant_conv(z_q)
recon = vae.forward_decoder(z_q, shape, occ.shape)
recon_loss = torch.nn.functional.cross_entropy(
recon.flatten(0, -2),
occ.flatten(),
)
loss = recon_loss + vq_loss
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
opt.step()
epoch_loss += loss.item()
scheduler.step()
avg = epoch_loss / max(len(loader), 1)
if epoch % 10 == 0:
log.info("Epoch %d/%d loss=%.4f lr=%.2e", epoch + 1, args.epochs, avg, scheduler.get_last_lr()[0])
if avg < best_loss:
best_loss = avg
torch.save({"epoch": epoch, "state_dict": vae.state_dict(), "loss": avg},
work_dir / "latest.pth")
log.info("VQVAE training complete. Best loss=%.4f checkpoint: %s/latest.pth",
best_loss, work_dir)
# ── Stage 2: Transformer retraining ─────────────────────────────────────────
def cmd_transformer(args: argparse.Namespace) -> None:
"""Retrain the OccWorld autoregressive transformer on tokenized RuView sequences."""
sys.path.insert(0, str(Path(args.occworld_dir).resolve()))
import torch
from copy import deepcopy
from einops import rearrange
from mmengine.config import Config
from mmengine.registry import MODELS
try:
import model as occmodel # noqa: F401
except ImportError:
log.error("OccWorld model package not found.")
sys.exit(1)
from ruview_occ_dataset import RuViewOccDataset
cfg = Config.fromfile(args.config)
work_dir = Path(args.work_dir)
work_dir.mkdir(parents=True, exist_ok=True)
full_model = MODELS.build(cfg.model).cuda()
# Load VQVAE checkpoint if provided
if args.vqvae_checkpoint:
ck = torch.load(args.vqvae_checkpoint, map_location="cuda")
full_model.vae.load_state_dict(ck["state_dict"])
log.info("Loaded VQVAE checkpoint: %s", args.vqvae_checkpoint)
full_model.vae.eval()
for p in full_model.vae.parameters():
p.requires_grad_(False)
log.info("Transformer params: %.1fM",
sum(p.numel() for p in full_model.transformer.parameters()) / 1e6)
ds = RuViewOccDataset(args.snapshots, return_len=cfg.model.get("num_frames", 15) + 1)
loader = torch.utils.data.DataLoader(
ds, batch_size=1, shuffle=True, num_workers=0,
collate_fn=lambda b: b[0],
)
opt = torch.optim.AdamW(full_model.transformer.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
for epoch in range(args.epochs):
full_model.transformer.train()
epoch_loss = 0.0
for batch in loader:
occ = torch.from_numpy(batch["target_occs"]).long().unsqueeze(0).cuda()
with torch.no_grad():
z, shape = full_model.vae.forward_encoder(occ)
z = full_model.vae.vqvae.quant_conv(z)
z_q, _, (_, _, indices) = full_model.vae.vqvae.forward_quantizer(z, is_voxel=False)
z_q = rearrange(z_q, "(b f) c h w -> b f c h w", b=1)
bs, F, C, H, W = z_q.shape
pose_tokens = torch.zeros(bs, full_model.num_frames, C, device=z_q.device)
pred_tokens, _ = full_model.transformer(z_q[:, :full_model.num_frames], pose_tokens)
indices_target = rearrange(indices, "(b f) h w -> b f h w", b=bs)[:, full_model.offset:]
loss = torch.nn.functional.cross_entropy(
pred_tokens.flatten(0, 1),
indices_target.flatten(0, 1).flatten(1),
)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(full_model.transformer.parameters(), 1.0)
opt.step()
epoch_loss += loss.item()
scheduler.step()
if epoch % 10 == 0:
avg = epoch_loss / max(len(loader), 1)
log.info("Epoch %d/%d loss=%.4f", epoch + 1, args.epochs, avg)
torch.save({"epoch": epoch, "state_dict": full_model.state_dict(), "loss": avg},
work_dir / "latest.pth")
log.info("Transformer training complete. Checkpoint: %s/latest.pth", work_dir)
# ── CLI ──────────────────────────────────────────────────────────────────────
def _build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="OccWorld retraining pipeline for RuView (ADR-147 Phase 5)")
p.add_argument("--occworld-dir", default=os.path.expanduser("~/projects/OccWorld"),
help="Path to OccWorld repo root")
p.add_argument("--config", default=os.path.expanduser("~/projects/OccWorld/config/occworld.py"),
help="OccWorld config file")
sub = p.add_subparsers(dest="cmd", required=True)
# record
rec = sub.add_parser("record", help="Record WorldGraph snapshots from sensing server")
rec.add_argument("--server", default="http://localhost:8080")
rec.add_argument("--out-dir", required=True)
rec.add_argument("--duration", type=int, default=3600, help="Recording duration (s)")
rec.add_argument("--interval", type=float, default=0.5, help="Poll interval (s)")
# vqvae
vae = sub.add_parser("vqvae", help="Retrain VQVAE tokenizer")
vae.add_argument("--snapshots", required=True)
vae.add_argument("--work-dir", default="out/ruview_vqvae")
vae.add_argument("--epochs", type=int, default=200)
vae.add_argument("--voxel-m", type=float, dest="voxel_m", default=0.4)
vae.add_argument("--x-min", type=float, dest="x_min", default=-40.0)
vae.add_argument("--y-min", type=float, dest="y_min", default=-40.0)
vae.add_argument("--no-shuffle", action="store_true")
# transformer
xfm = sub.add_parser("transformer", help="Retrain autoregressive transformer")
xfm.add_argument("--snapshots", required=True)
xfm.add_argument("--vqvae-checkpoint", default=None)
xfm.add_argument("--work-dir", default="out/ruview_occworld")
xfm.add_argument("--epochs", type=int, default=200)
return p
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
args = _build_parser().parse_args()
{"record": cmd_record, "vqvae": cmd_vqvae, "transformer": cmd_transformer}[args.cmd](args)