#!/usr/bin/env python3 """Train the person-count head — ADR-103 v0.0.1. Mirrors the Conv1d encoder architecture from cog-person-count's `src/inference.rs::CountNet` exactly, so the learned weights load into the Rust cog without translation. Trains on data/paired/wiflow-p7-1779210883.paired.jsonl (1,077 samples with n_persons_mode labels in {0, 1}). Output: count_v1.safetensors + count_v1.onnx + train_results.json. """ from __future__ import annotations import argparse import json import struct import time from collections import Counter from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # Architecture constants — MUST match cog-person-count's src/inference.rs. N_SUB = 56 N_FRAMES = 20 COUNT_CLASSES = 8 class CountNet(nn.Module): """Mirrors cog_person_count::inference::CountNet bit-for-bit.""" def __init__(self) -> None: super().__init__() # Encoder — identical to the pose cog's encoder so future joint # training can share weights. self.enc_c1 = nn.Conv1d(N_SUB, 64, kernel_size=3, padding=1, dilation=1) self.enc_c2 = nn.Conv1d(64, 128, kernel_size=3, padding=2, dilation=2) self.enc_c3 = nn.Conv1d(128, 128, kernel_size=3, padding=4, dilation=4) # Count head self.count_head_fc1 = nn.Linear(128, 64) self.count_head_fc2 = nn.Linear(64, COUNT_CLASSES) # Confidence head self.conf_head_fc1 = nn.Linear(128, 32) self.conf_head_fc2 = nn.Linear(32, 1) def forward(self, x: torch.Tensor): # x: [B, 56, 20] h = F.relu(self.enc_c1(x)) h = F.relu(self.enc_c2(h)) h = F.relu(self.enc_c3(h)) h = h.mean(dim=2) # [B, 128] # Logits (un-normalised); softmax at inference + cross-entropy training. c = F.relu(self.count_head_fc1(h)) count_logits = self.count_head_fc2(c) # Confidence head — sigmoid at inference; BCE-with-logits at training. cf = F.relu(self.conf_head_fc1(h)) conf_logits = self.conf_head_fc2(cf) return count_logits, conf_logits def load_paired(path: Path) -> tuple[np.ndarray, np.ndarray]: """Return (X, y) where X is [N, 56, 20] CSI and y is [N] integer counts.""" csis, ys = [], [] with path.open(encoding="utf-8") as f: for line in f: if not line.strip(): continue d = json.loads(line) shape = d.get("csi_shape", [N_SUB, N_FRAMES]) if shape != [N_SUB, N_FRAMES]: continue csi = np.asarray(d["csi"], dtype=np.float32).reshape(N_SUB, N_FRAMES) csis.append(csi) ys.append(int(d.get("n_persons_mode", 0))) X = np.stack(csis, axis=0) y = np.asarray(ys, dtype=np.int64) return X, y def temporal_split(X: np.ndarray, y: np.ndarray, eval_frac: float = 0.2): """Held-out time-window eval (last `eval_frac` of samples, by index).""" n = X.shape[0] n_eval = int(round(n * eval_frac)) n_train = n - n_eval return ( X[:n_train], y[:n_train], X[n_train:], y[n_train:], ) def stratified_k_fold(X: np.ndarray, y: np.ndarray, k: int = 5): """Stratified k-fold cross-validation splits — hand-rolled, no sklearn. Per class: shuffle the indices (deterministic seed 42), split into k near-equal chunks, then assemble fold i by taking chunk i from every class. Yields (X_train, y_train, X_val, y_val) per fold, with class distribution preserved within ±1. """ rng = np.random.default_rng(seed=42) classes = np.unique(y) per_class_folds = {} for c in classes: idx = np.where(y == c)[0] rng.shuffle(idx) per_class_folds[c] = np.array_split(idx, k) for fold in range(k): val_idx = np.concatenate([per_class_folds[c][fold] for c in classes]) train_idx = np.concatenate( [per_class_folds[c][f] for c in classes for f in range(k) if f != fold] ) yield X[train_idx], y[train_idx], X[val_idx], y[val_idx] def standardise(X_train: np.ndarray, X_eval: np.ndarray): """Z-score by subcarrier across the time axis. Eval uses train stats.""" mu = X_train.mean(axis=(0, 2), keepdims=True) sd = X_train.std(axis=(0, 2), keepdims=True) + 1e-6 return (X_train - mu) / sd, (X_eval - mu) / sd def write_safetensors(model: CountNet, path: Path): """Write the model's state in the same on-disk layout the Rust cog expects.""" state = model.state_dict() # Map PyTorch param names → cog-person-count's VarBuilder paths. rename = { "enc_c1.weight": "enc.c1.weight", "enc_c1.bias": "enc.c1.bias", "enc_c2.weight": "enc.c2.weight", "enc_c2.bias": "enc.c2.bias", "enc_c3.weight": "enc.c3.weight", "enc_c3.bias": "enc.c3.bias", "count_head_fc1.weight": "count_head.fc1.weight", "count_head_fc1.bias": "count_head.fc1.bias", "count_head_fc2.weight": "count_head.fc2.weight", "count_head_fc2.bias": "count_head.fc2.bias", "conf_head_fc1.weight": "conf_head.fc1.weight", "conf_head_fc1.bias": "conf_head.fc1.bias", "conf_head_fc2.weight": "conf_head.fc2.weight", "conf_head_fc2.bias": "conf_head.fc2.bias", } header = {} payload = bytearray() offset = 0 for torch_name, cog_name in rename.items(): t = state[torch_name].detach().cpu().numpy().astype(np.float32) n_bytes = t.nbytes header[cog_name] = { "dtype": "F32", "shape": list(t.shape), "data_offsets": [offset, offset + n_bytes], } payload.extend(t.tobytes()) offset += n_bytes header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") with path.open("wb") as f: f.write(struct.pack(" 0, cls_counts, 1.0) cls_weight = (1.0 / cls_counts) / (1.0 / cls_counts).sum() * COUNT_CLASSES cls_weight_t = torch.from_numpy(cls_weight).to(device) Xt = torch.from_numpy(X_train).to(device) yt = torch.from_numpy(y_train).to(device) Xv = torch.from_numpy(X_val).to(device) yv = torch.from_numpy(y_val).to(device) model = CountNet().to(device) opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1) n_train = X_train.shape[0] best_eval_acc = 0.0 best_state = None for epoch in range(args.epochs): model.train() perm = torch.randperm(n_train, device=device) train_loss = 0.0 train_correct = 0 n_batches = 0 for i in range(0, n_train, args.batch_size): idx = perm[i : i + args.batch_size] xb = Xt[idx] yb = yt[idx] opt.zero_grad() count_logits, conf_logits = model(xb) ce = F.cross_entropy(count_logits, yb, weight=cls_weight_t) with torch.no_grad(): pred = count_logits.argmax(dim=1) correct_indicator = (pred == yb).float().unsqueeze(1) bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator) with torch.no_grad(): conf_sigm = torch.sigmoid(conf_logits) brier = ((conf_sigm - correct_indicator) ** 2).mean() loss = ce + 0.3 * bce + 0.1 * brier loss.backward() opt.step() train_loss += loss.item() train_correct += (pred == yb).sum().item() n_batches += 1 sched.step() model.eval() with torch.no_grad(): cl_v, _ = model(Xv) eval_pred = cl_v.argmax(dim=1) eval_acc = (eval_pred == yv).float().mean().item() if eval_acc > best_eval_acc: best_eval_acc = eval_acc best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} # Restore best checkpoint and final eval if best_state is not None: model.load_state_dict(best_state) model.eval() with torch.no_grad(): cl_v, conf_v = model(Xv) pred_v = cl_v.argmax(dim=1) acc = (pred_v == yv).float().mean().item() within1 = ((pred_v - yv).abs() <= 1).float().mean().item() mae = (pred_v - yv).abs().float().mean().item() # Per-class accuracy per_class = {} for k in range(COUNT_CLASSES): mask = yv == k n = mask.sum().item() if n > 0: per_class[k] = { "support": int(n), "accuracy": ((pred_v == yv) & mask).sum().item() / n, } # Spearman conf_sigm = torch.sigmoid(conf_v).squeeze(-1) correct = (pred_v == yv).float() c_rank = conf_sigm.argsort().argsort().float() r_rank = correct.argsort().argsort().float() c_centered = c_rank - c_rank.mean() r_centered = r_rank - r_rank.mean() denom = (c_centered.norm() * r_centered.norm()).item() spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0 fold_results.append({ "fold": fold_idx + 1, "accuracy": acc, "within_pm1": within1, "mae": mae, "spearman": spearman, "per_class_accuracy": per_class, }) print(f" accuracy={acc:.3f} within±1={within1:.3f} mae={mae:.3f} spearman={spearman:.3f}") # K-fold summary total_time = time.perf_counter() - overall_t0 accs = [r["accuracy"] for r in fold_results] within1s = [r["within_pm1"] for r in fold_results] maes = [r["mae"] for r in fold_results] spears = [r["spearman"] for r in fold_results] print(f"\n=== {args.k_fold}-fold summary ({total_time:.1f} s) ===") print(f" accuracy: {np.mean(accs):.3f} ± {np.std(accs):.3f}") print(f" within ±1: {np.mean(within1s):.3f} ± {np.std(within1s):.3f}") print(f" MAE: {np.mean(maes):.3f} ± {np.std(maes):.3f}") print(f" conf↔correct Spearman: {np.mean(spears):.3f} ± {np.std(spears):.3f}") # Per-class summary across folds for k in range(COUNT_CLASSES): accs_k = [r["per_class_accuracy"].get(k, {}).get("accuracy", 0.0) for r in fold_results] n_k = [r["per_class_accuracy"].get(k, {}).get("support", 0) for r in fold_results] if any(n > 0 for n in n_k): print(f" class {k}: {np.mean(accs_k):.3f} mean accuracy (support: {n_k})") # Write k-fold results to JSON results = { "mode": "k_fold_cv", "k": args.k_fold, "backend": "pytorch-cuda" if device.type == "cuda" else "pytorch-cpu", "total_time_s": total_time, "fold_results": fold_results, "summary": { "mean_accuracy": float(np.mean(accs)), "std_accuracy": float(np.std(accs)), "mean_within_pm1": float(np.mean(within1s)), "std_within_pm1": float(np.std(within1s)), "mean_mae": float(np.mean(maes)), "std_mae": float(np.std(maes)), "mean_spearman": float(np.mean(spears)), "std_spearman": float(np.std(spears)), }, "hyperparameters": { "optimizer": "AdamW", "lr": args.lr, "weight_decay": args.weight_decay, "batch_size": args.batch_size, "schedule": "cosine_warm_restarts", "epochs": args.epochs, }, } Path(args.out_results).write_text(json.dumps(results, indent=2)) print(f"\nwrote {args.out_results}") return # --------------------------------------------------------------- # v0.0.2 training path: random 80/20 + label smoothing + early # stopping + class-balanced batch sampling + temperature scaling. # --------------------------------------------------------------- if args.v2: rng = np.random.default_rng(seed=42) idx = np.arange(X.shape[0]) rng.shuffle(idx) n_eval = int(round(0.2 * X.shape[0])) eval_idx, train_idx = idx[:n_eval], idx[n_eval:] X_train, X_eval = X[train_idx], X[eval_idx] y_train, y_eval = y[train_idx], y[eval_idx] X_train, X_eval = standardise(X_train, X_eval) print(f"v0.0.2 mode — random 80/20 split: train={len(y_train)} eval={len(y_eval)}") print(f" train class dist: {dict(Counter(y_train.tolist()).most_common())}") print(f" eval class dist: {dict(Counter(y_eval.tolist()).most_common())}") Xt = torch.from_numpy(X_train).to(device) yt = torch.from_numpy(y_train).to(device) Xe = torch.from_numpy(X_eval).to(device) ye = torch.from_numpy(y_eval).to(device) # Class-balanced sampler: for each batch, sample with replacement # so each class has equal expected count regardless of dataset # distribution. With our ~533/544 split this is nearly a no-op # but it generalises to imbalanced multi-room data later. cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32) cls_counts = np.where(cls_counts > 0, cls_counts, 1.0) per_sample_weight = (1.0 / cls_counts[y_train]) per_sample_weight_t = torch.from_numpy(per_sample_weight.astype(np.float32)).to(device) model = CountNet().to(device) opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1) n_train = X_train.shape[0] batches_per_epoch = max(1, n_train // args.batch_size) epoch_losses = [] t0 = time.perf_counter() best_eval_acc = 0.0 best_state = None epochs_without_improvement = 0 for epoch in range(args.epochs): model.train() train_loss = 0.0; train_correct = 0; n_batches = 0 for _ in range(batches_per_epoch): # Balanced sample with replacement idx_t = torch.multinomial(per_sample_weight_t, args.batch_size, replacement=True) xb = Xt[idx_t]; yb = yt[idx_t] opt.zero_grad() count_logits, conf_logits = model(xb) ce = F.cross_entropy(count_logits, yb, label_smoothing=args.label_smoothing) with torch.no_grad(): pred = count_logits.argmax(dim=1) correct_indicator = (pred == yb).float().unsqueeze(1) bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator) with torch.no_grad(): conf_sigm = torch.sigmoid(conf_logits) brier = ((conf_sigm - correct_indicator) ** 2).mean() loss = ce + 0.3 * bce + 0.1 * brier loss.backward() opt.step() train_loss += loss.item() train_correct += (pred == yb).sum().item() n_batches += 1 sched.step() model.eval() with torch.no_grad(): cl_e, _ = model(Xe) eval_loss = F.cross_entropy(cl_e, ye).item() eval_pred = cl_e.argmax(dim=1) eval_acc = (eval_pred == ye).float().mean().item() epoch_losses.append({ "epoch": epoch, "train_loss": train_loss / max(1, n_batches), "train_acc": train_correct / max(1, n_batches * args.batch_size), "eval_loss": eval_loss, "eval_acc": eval_acc, }) if eval_acc > best_eval_acc: best_eval_acc = eval_acc best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} epochs_without_improvement = 0 else: epochs_without_improvement += 1 if epoch < 5 or epoch % 25 == 0: print(f"epoch {epoch:3d} train_loss={train_loss/n_batches:.4f} " f"train_acc={train_correct/(n_batches*args.batch_size):.3f} " f"eval_loss={eval_loss:.4f} eval_acc={eval_acc:.3f} " f"epochs_no_improve={epochs_without_improvement}") if epochs_without_improvement >= args.patience: print(f"early stopping at epoch {epoch} (no improvement for {args.patience} epochs)") break train_time = time.perf_counter() - t0 print(f"\ntrained {epoch + 1} epochs in {train_time:.1f} s (best eval_acc {best_eval_acc:.3f})") if best_state is not None: model.load_state_dict(best_state) # Temperature scaling on the confidence head — fit a scalar T s.t. # sigmoid(conf_logits / T) is best-calibrated on the eval set. model.eval() with torch.no_grad(): cl_e, conf_e = model(Xe) pred_e = cl_e.argmax(dim=1) correct_indicator = (pred_e == ye).float() # 1D optimisation over T via LBFGS. T = torch.nn.Parameter(torch.ones(1, device=device)) opt_t = torch.optim.LBFGS([T], lr=0.1, max_iter=50) def eval_t(): opt_t.zero_grad() scaled = conf_e.squeeze(-1) / T loss_t = F.binary_cross_entropy_with_logits(scaled, correct_indicator) loss_t.backward() return loss_t opt_t.step(eval_t) T_val = float(T.detach().cpu().item()) print(f" temperature scale T = {T_val:.4f}") # Final eval with temperature applied. with torch.no_grad(): cl_e, conf_e = model(Xe) probs_e = F.softmax(cl_e, dim=1) pred_e = cl_e.argmax(dim=1) acc = (pred_e == ye).float().mean().item() within1 = ((pred_e - ye).abs() <= 1).float().mean().item() mae = (pred_e - ye).abs().float().mean().item() per_class = {} for k in range(COUNT_CLASSES): mask = ye == k n = mask.sum().item() if n > 0: per_class[k] = { "support": int(n), "accuracy": ((pred_e == ye) & mask).sum().item() / n, } conf_sigm = torch.sigmoid(conf_e.squeeze(-1) / T_val) correct = (pred_e == ye).float() c_rank = conf_sigm.argsort().argsort().float() r_rank = correct.argsort().argsort().float() c_centered = c_rank - c_rank.mean() r_centered = r_rank - r_rank.mean() denom = (c_centered.norm() * r_centered.norm()).item() spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0 print(f"\n=== v0.0.2 final eval ===") print(f" accuracy: {acc:.3f}") print(f" within ±1: {within1:.3f}") print(f" MAE: {mae:.3f}") print(f" conf↔correct Spearman (post-temp): {spearman:.3f}") for k, v in per_class.items(): print(f" class {k}: {v['accuracy']:.3f} accuracy on {v['support']} samples") write_safetensors(model, Path(args.out_safetensors)) # Also append the temperature scalar so the cog can apply it. # We add it by appending to the safetensors file using the # write_safetensors helper but with the temperature recorded # as a separate file alongside (count_v1.temperature.txt) for # consumption by the Rust cog inference path. Path(args.out_safetensors + ".temperature").write_text(f"{T_val}\n") print(f"wrote {args.out_safetensors} ({Path(args.out_safetensors).stat().st_size} bytes)") print(f"wrote {args.out_safetensors}.temperature ({T_val})") # ONNX dummy = torch.zeros(1, N_SUB, N_FRAMES, device=device) try: torch.onnx.export(model, dummy, args.out_onnx, opset_version=18, input_names=["csi_window"], output_names=["count_logits", "conf_logits"], dynamic_axes={"csi_window": {0: "batch"}, "count_logits": {0: "batch"}, "conf_logits": {0: "batch"}}, export_params=True, do_constant_folding=True) print(f"wrote {args.out_onnx} ({Path(args.out_onnx).stat().st_size} bytes)") except Exception as e: print(f"WARN: ONNX export failed: {e}") results = { "mode": "v0.0.2", "backend": "pytorch-cuda" if device.type == "cuda" else "pytorch-cpu", "epochs_trained": epoch + 1, "train_time_s": train_time, "best_eval_acc": best_eval_acc, "final_eval_acc": acc, "final_eval_within_pm1": within1, "final_eval_mae": mae, "temperature_scale": T_val, "conf_correctness_spearman_post_temp": spearman, "per_class_accuracy": per_class, "hyperparameters": { "optimizer": "AdamW", "lr": args.lr, "weight_decay": args.weight_decay, "batch_size": args.batch_size, "schedule": "cosine_warm_restarts", "epochs_max": args.epochs, "label_smoothing": args.label_smoothing, "patience": args.patience, "split": "random_80_20_seed_42", "balanced_sampler": True, "temperature_scaling": True, }, "epoch_losses": epoch_losses, } Path(args.out_results).write_text(json.dumps(results, indent=2)) print(f"wrote {args.out_results}") return # Original temporal-split mode (kept for v0.0.1 reproducibility). X_train, y_train, X_eval, y_eval = temporal_split(X, y, eval_frac=0.2) X_train, X_eval = standardise(X_train, X_eval) # Re-balance via class weights — handles the 50/50 split fine # but also makes the loss correct under future imbalanced data. cls_counts = np.bincount(y_train, minlength=COUNT_CLASSES).astype(np.float32) cls_counts = np.where(cls_counts > 0, cls_counts, 1.0) cls_weight = (1.0 / cls_counts) / (1.0 / cls_counts).sum() * COUNT_CLASSES cls_weight_t = torch.from_numpy(cls_weight).to(device) print(f"class weights: {cls_weight.tolist()}") Xt = torch.from_numpy(X_train).to(device) yt = torch.from_numpy(y_train).to(device) Xe = torch.from_numpy(X_eval).to(device) ye = torch.from_numpy(y_eval).to(device) model = CountNet().to(device) opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=50, T_mult=1) n_train = X_train.shape[0] epoch_losses = [] t0 = time.perf_counter() best_eval_acc = 0.0 best_state = None for epoch in range(args.epochs): model.train() perm = torch.randperm(n_train, device=device) train_loss = 0.0 train_correct = 0 n_batches = 0 for i in range(0, n_train, args.batch_size): idx = perm[i : i + args.batch_size] xb = Xt[idx] yb = yt[idx] opt.zero_grad() count_logits, conf_logits = model(xb) # Categorical cross-entropy for count. ce = F.cross_entropy(count_logits, yb, weight=cls_weight_t) # Confidence head: train against `argmax == truth` indicator. with torch.no_grad(): pred = count_logits.argmax(dim=1) correct_indicator = (pred == yb).float().unsqueeze(1) bce = F.binary_cross_entropy_with_logits(conf_logits, correct_indicator) # Brier-score uncertainty calibration on the conf head — sharpens # the calibration so the sigmoid output is a real probability. with torch.no_grad(): conf_sigm = torch.sigmoid(conf_logits) brier = ((conf_sigm - correct_indicator) ** 2).mean() loss = ce + 0.3 * bce + 0.1 * brier loss.backward() opt.step() train_loss += loss.item() train_correct += (pred == yb).sum().item() n_batches += 1 sched.step() model.eval() with torch.no_grad(): cl_e, _ = model(Xe) eval_loss = F.cross_entropy(cl_e, ye, weight=cls_weight_t).item() eval_pred = cl_e.argmax(dim=1) eval_acc = (eval_pred == ye).float().mean().item() eval_within1 = ((eval_pred - ye).abs() <= 1).float().mean().item() epoch_losses.append({ "epoch": epoch, "train_loss": train_loss / n_batches, "train_acc": train_correct / n_train, "eval_loss": eval_loss, "eval_acc": eval_acc, "eval_within_pm1": eval_within1, }) if eval_acc > best_eval_acc: best_eval_acc = eval_acc best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} if epoch < 5 or epoch % 50 == 0 or epoch == args.epochs - 1: print(f"epoch {epoch:3d} train_loss={train_loss/n_batches:.4f} " f"train_acc={train_correct/n_train:.3f} " f"eval_loss={eval_loss:.4f} eval_acc={eval_acc:.3f} " f"within±1={eval_within1:.3f}") train_time = time.perf_counter() - t0 print(f"\ntrained {args.epochs} epochs in {train_time:.1f} s") print(f"best eval_acc: {best_eval_acc:.3f}") # Restore best checkpoint if best_state is not None: model.load_state_dict(best_state) # Eval breakdown model.eval() with torch.no_grad(): cl_e, conf_e = model(Xe) probs_e = torch.softmax(cl_e, dim=1) pred_e = cl_e.argmax(dim=1) acc = (pred_e == ye).float().mean().item() within1 = ((pred_e - ye).abs() <= 1).float().mean().item() mae = (pred_e - ye).abs().float().mean().item() # Per-class accuracy per_class = {} for k in range(COUNT_CLASSES): mask = ye == k n = mask.sum().item() if n > 0: per_class[k] = { "support": int(n), "accuracy": ((pred_e == ye) & mask).sum().item() / n, } # Confidence-accuracy calibration: Spearman over (predicted-correct, confidence) conf_sigm = torch.sigmoid(conf_e).squeeze(-1) correct = (pred_e == ye).float() # Spearman = Pearson over ranks c_rank = conf_sigm.argsort().argsort().float() r_rank = correct.argsort().argsort().float() c_centered = c_rank - c_rank.mean() r_centered = r_rank - r_rank.mean() denom = (c_centered.norm() * r_centered.norm()).item() spearman = (c_centered * r_centered).sum().item() / denom if denom > 0 else 0.0 print(f"\n=== final eval ===") print(f" accuracy: {acc:.3f}") print(f" within ±1: {within1:.3f}") print(f" MAE: {mae:.3f}") print(f" conf↔correct Spearman: {spearman:.3f}") for k, v in per_class.items(): print(f" class {k}: {v['accuracy']:.3f} accuracy on {v['support']} samples") # Save safetensors write_safetensors(model, Path(args.out_safetensors)) print(f"\nwrote {args.out_safetensors} ({Path(args.out_safetensors).stat().st_size} bytes)") # ONNX export dummy = torch.zeros(1, N_SUB, N_FRAMES, device=device) try: torch.onnx.export( model, dummy, args.out_onnx, opset_version=18, input_names=["csi_window"], output_names=["count_logits", "conf_logits"], dynamic_axes={ "csi_window": {0: "batch"}, "count_logits": {0: "batch"}, "conf_logits": {0: "batch"}, }, export_params=True, do_constant_folding=True, ) print(f"wrote {args.out_onnx} ({Path(args.out_onnx).stat().st_size} bytes)") except Exception as e: print(f"WARN: ONNX export failed: {e}") # Results JSON results = { "backend": "candle-cuda" if device.type == "cuda" else "candle-cpu", "device": str(device), "epochs": args.epochs, "train_time_s": train_time, "best_eval_acc": best_eval_acc, "final_eval_acc": acc, "final_eval_within_pm1": within1, "final_eval_mae": mae, "conf_correctness_spearman": spearman, "per_class_accuracy": per_class, "hyperparameters": { "optimizer": "AdamW", "lr": args.lr, "weight_decay": args.weight_decay, "batch_size": args.batch_size, "schedule": "cosine_warm_restarts", "epochs": args.epochs, "loss": "cross_entropy(count) + 0.3*bce(conf) + 0.1*brier(conf)", "z_score_normalisation": True, "class_weights": cls_weight.tolist(), }, "epoch_losses": epoch_losses, } Path(args.out_results).write_text(json.dumps(results, indent=2)) print(f"wrote {args.out_results} ({Path(args.out_results).stat().st_size} bytes)") if __name__ == "__main__": main()