286 lines
11 KiB
Python
286 lines
11 KiB
Python
"""
|
|
Phase 5 — OccWorld VQVAE + Transformer retraining on RuView indoor occupancy.
|
|
|
|
Two-stage training pipeline:
|
|
Stage 1: Retrain VQVAE tokenizer on RuView snapshots
|
|
Stage 2: Retrain autoregressive transformer on tokenized sequences
|
|
|
|
Usage:
|
|
# Stage 1: VQVAE
|
|
python3 scripts/occworld_retrain.py vqvae \
|
|
--snapshots /tmp/snapshots/ \
|
|
--work-dir out/ruview_vqvae \
|
|
--epochs 200
|
|
|
|
# Stage 2: Transformer (requires Stage 1 checkpoint)
|
|
python3 scripts/occworld_retrain.py transformer \
|
|
--snapshots /tmp/snapshots/ \
|
|
--vqvae-checkpoint out/ruview_vqvae/latest.pth \
|
|
--work-dir out/ruview_occworld \
|
|
--epochs 200
|
|
|
|
# Generate training snapshots from the live sensing server
|
|
python3 scripts/occworld_retrain.py record \
|
|
--server http://localhost:8080 \
|
|
--out-dir /tmp/snapshots/scene_live \
|
|
--duration 3600
|
|
|
|
Requirements:
|
|
ml-env with OccWorld installed (see ADR-147 §3)
|
|
At least 16 GB VRAM for training (RTX 5080 sufficient at batch=1)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# ── Stage 0: Record snapshots from the live sensing server ───────────────────
|
|
|
|
def cmd_record(args: argparse.Namespace) -> None:
|
|
"""Stream WorldGraph snapshots from the sensing server REST API."""
|
|
import json
|
|
import urllib.request
|
|
|
|
out_dir = Path(args.out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
url = f"{args.server.rstrip('/')}/api/v1/worldgraph/snapshot"
|
|
end_time = time.time() + args.duration
|
|
frame_idx = 0
|
|
interval = args.interval
|
|
|
|
log.info("Recording snapshots from %s → %s for %ds", url, out_dir, args.duration)
|
|
|
|
while time.time() < end_time:
|
|
try:
|
|
with urllib.request.urlopen(url, timeout=5) as resp:
|
|
snap = json.loads(resp.read())
|
|
out_path = out_dir / f"frame_{frame_idx:06d}.json"
|
|
out_path.write_text(json.dumps(snap))
|
|
frame_idx += 1
|
|
if frame_idx % 100 == 0:
|
|
log.info("Recorded %d frames", frame_idx)
|
|
except Exception as exc:
|
|
log.warning("Snapshot fetch failed: %s", exc)
|
|
time.sleep(interval)
|
|
|
|
log.info("Done — recorded %d frames to %s", frame_idx, out_dir)
|
|
|
|
|
|
# ── Stage 1: VQVAE retraining ────────────────────────────────────────────────
|
|
|
|
def cmd_vqvae(args: argparse.Namespace) -> None:
|
|
"""Retrain the OccWorld VQVAE tokenizer on RuView indoor occupancy."""
|
|
sys.path.insert(0, str(Path(args.occworld_dir).resolve()))
|
|
|
|
import torch
|
|
from mmengine.config import Config
|
|
from mmengine.registry import MODELS
|
|
|
|
try:
|
|
import model as occmodel # noqa: F401 — registers custom MODELS
|
|
except ImportError:
|
|
log.error("Could not import OccWorld model package. Set --occworld-dir correctly.")
|
|
sys.exit(1)
|
|
|
|
from ruview_occ_dataset import RuViewOccDataset
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
work_dir = Path(args.work_dir)
|
|
work_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Build VQVAE only
|
|
vae = MODELS.build(cfg.model.vae).cuda()
|
|
log.info("VQVAE params: %.1fM", sum(p.numel() for p in vae.parameters()) / 1e6)
|
|
|
|
ds = RuViewOccDataset(
|
|
args.snapshots,
|
|
return_len=cfg.model.get("num_frames", 15) + 1,
|
|
voxel_m=args.voxel_m,
|
|
x_min=args.x_min,
|
|
y_min=args.y_min,
|
|
)
|
|
log.info("Dataset: %d windows from %s", len(ds), args.snapshots)
|
|
|
|
if len(ds) == 0:
|
|
log.error("No training windows found in %s — record snapshots first.", args.snapshots)
|
|
sys.exit(1)
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
ds, batch_size=1, shuffle=not args.no_shuffle, num_workers=0,
|
|
collate_fn=lambda b: b[0], # dict passthrough
|
|
)
|
|
|
|
opt = torch.optim.AdamW(vae.parameters(), lr=1e-3, weight_decay=0.01)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
|
|
|
|
best_loss = float("inf")
|
|
for epoch in range(args.epochs):
|
|
vae.train()
|
|
epoch_loss = 0.0
|
|
for batch in loader:
|
|
occ = torch.from_numpy(batch["target_occs"]).long().unsqueeze(0).cuda() # (1,F,H,W,D)
|
|
# VQVAE forward: encode + quantize + decode, returns reconstruction loss
|
|
z, shape = vae.forward_encoder(occ)
|
|
z = vae.vqvae.quant_conv(z)
|
|
z_q, vq_loss, _ = vae.vqvae.forward_quantizer(z, is_voxel=False)
|
|
z_q = vae.vqvae.post_quant_conv(z_q)
|
|
recon = vae.forward_decoder(z_q, shape, occ.shape)
|
|
recon_loss = torch.nn.functional.cross_entropy(
|
|
recon.flatten(0, -2),
|
|
occ.flatten(),
|
|
)
|
|
loss = recon_loss + vq_loss
|
|
opt.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
|
|
opt.step()
|
|
epoch_loss += loss.item()
|
|
|
|
scheduler.step()
|
|
avg = epoch_loss / max(len(loader), 1)
|
|
if epoch % 10 == 0:
|
|
log.info("Epoch %d/%d loss=%.4f lr=%.2e", epoch + 1, args.epochs, avg, scheduler.get_last_lr()[0])
|
|
|
|
if avg < best_loss:
|
|
best_loss = avg
|
|
torch.save({"epoch": epoch, "state_dict": vae.state_dict(), "loss": avg},
|
|
work_dir / "latest.pth")
|
|
|
|
log.info("VQVAE training complete. Best loss=%.4f checkpoint: %s/latest.pth",
|
|
best_loss, work_dir)
|
|
|
|
|
|
# ── Stage 2: Transformer retraining ─────────────────────────────────────────
|
|
|
|
def cmd_transformer(args: argparse.Namespace) -> None:
|
|
"""Retrain the OccWorld autoregressive transformer on tokenized RuView sequences."""
|
|
sys.path.insert(0, str(Path(args.occworld_dir).resolve()))
|
|
|
|
import torch
|
|
from copy import deepcopy
|
|
from einops import rearrange
|
|
from mmengine.config import Config
|
|
from mmengine.registry import MODELS
|
|
|
|
try:
|
|
import model as occmodel # noqa: F401
|
|
except ImportError:
|
|
log.error("OccWorld model package not found.")
|
|
sys.exit(1)
|
|
|
|
from ruview_occ_dataset import RuViewOccDataset
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
work_dir = Path(args.work_dir)
|
|
work_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
full_model = MODELS.build(cfg.model).cuda()
|
|
|
|
# Load VQVAE checkpoint if provided
|
|
if args.vqvae_checkpoint:
|
|
ck = torch.load(args.vqvae_checkpoint, map_location="cuda")
|
|
full_model.vae.load_state_dict(ck["state_dict"])
|
|
log.info("Loaded VQVAE checkpoint: %s", args.vqvae_checkpoint)
|
|
full_model.vae.eval()
|
|
for p in full_model.vae.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
log.info("Transformer params: %.1fM",
|
|
sum(p.numel() for p in full_model.transformer.parameters()) / 1e6)
|
|
|
|
ds = RuViewOccDataset(args.snapshots, return_len=cfg.model.get("num_frames", 15) + 1)
|
|
loader = torch.utils.data.DataLoader(
|
|
ds, batch_size=1, shuffle=True, num_workers=0,
|
|
collate_fn=lambda b: b[0],
|
|
)
|
|
|
|
opt = torch.optim.AdamW(full_model.transformer.parameters(), lr=1e-3, weight_decay=0.01)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
|
|
|
|
for epoch in range(args.epochs):
|
|
full_model.transformer.train()
|
|
epoch_loss = 0.0
|
|
for batch in loader:
|
|
occ = torch.from_numpy(batch["target_occs"]).long().unsqueeze(0).cuda()
|
|
with torch.no_grad():
|
|
z, shape = full_model.vae.forward_encoder(occ)
|
|
z = full_model.vae.vqvae.quant_conv(z)
|
|
z_q, _, (_, _, indices) = full_model.vae.vqvae.forward_quantizer(z, is_voxel=False)
|
|
z_q = rearrange(z_q, "(b f) c h w -> b f c h w", b=1)
|
|
|
|
bs, F, C, H, W = z_q.shape
|
|
pose_tokens = torch.zeros(bs, full_model.num_frames, C, device=z_q.device)
|
|
pred_tokens, _ = full_model.transformer(z_q[:, :full_model.num_frames], pose_tokens)
|
|
indices_target = rearrange(indices, "(b f) h w -> b f h w", b=bs)[:, full_model.offset:]
|
|
loss = torch.nn.functional.cross_entropy(
|
|
pred_tokens.flatten(0, 1),
|
|
indices_target.flatten(0, 1).flatten(1),
|
|
)
|
|
opt.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(full_model.transformer.parameters(), 1.0)
|
|
opt.step()
|
|
epoch_loss += loss.item()
|
|
|
|
scheduler.step()
|
|
if epoch % 10 == 0:
|
|
avg = epoch_loss / max(len(loader), 1)
|
|
log.info("Epoch %d/%d loss=%.4f", epoch + 1, args.epochs, avg)
|
|
torch.save({"epoch": epoch, "state_dict": full_model.state_dict(), "loss": avg},
|
|
work_dir / "latest.pth")
|
|
|
|
log.info("Transformer training complete. Checkpoint: %s/latest.pth", work_dir)
|
|
|
|
|
|
# ── CLI ──────────────────────────────────────────────────────────────────────
|
|
|
|
def _build_parser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(description="OccWorld retraining pipeline for RuView (ADR-147 Phase 5)")
|
|
p.add_argument("--occworld-dir", default=os.path.expanduser("~/projects/OccWorld"),
|
|
help="Path to OccWorld repo root")
|
|
p.add_argument("--config", default=os.path.expanduser("~/projects/OccWorld/config/occworld.py"),
|
|
help="OccWorld config file")
|
|
|
|
sub = p.add_subparsers(dest="cmd", required=True)
|
|
|
|
# record
|
|
rec = sub.add_parser("record", help="Record WorldGraph snapshots from sensing server")
|
|
rec.add_argument("--server", default="http://localhost:8080")
|
|
rec.add_argument("--out-dir", required=True)
|
|
rec.add_argument("--duration", type=int, default=3600, help="Recording duration (s)")
|
|
rec.add_argument("--interval", type=float, default=0.5, help="Poll interval (s)")
|
|
|
|
# vqvae
|
|
vae = sub.add_parser("vqvae", help="Retrain VQVAE tokenizer")
|
|
vae.add_argument("--snapshots", required=True)
|
|
vae.add_argument("--work-dir", default="out/ruview_vqvae")
|
|
vae.add_argument("--epochs", type=int, default=200)
|
|
vae.add_argument("--voxel-m", type=float, dest="voxel_m", default=0.4)
|
|
vae.add_argument("--x-min", type=float, dest="x_min", default=-40.0)
|
|
vae.add_argument("--y-min", type=float, dest="y_min", default=-40.0)
|
|
vae.add_argument("--no-shuffle", action="store_true")
|
|
|
|
# transformer
|
|
xfm = sub.add_parser("transformer", help="Retrain autoregressive transformer")
|
|
xfm.add_argument("--snapshots", required=True)
|
|
xfm.add_argument("--vqvae-checkpoint", default=None)
|
|
xfm.add_argument("--work-dir", default="out/ruview_occworld")
|
|
xfm.add_argument("--epochs", type=int, default=200)
|
|
|
|
return p
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
args = _build_parser().parse_args()
|
|
{"record": cmd_record, "vqvae": cmd_vqvae, "transformer": cmd_transformer}[args.cmd](args)
|