Multi-model dashboard with GQA, W&B integration, and best-loss checkpointing

Dashboard: multi-model support (Stories110M + Qwen3-0.6B) with GQA-aware
text generation and KV cache. Weights & Biases logging (--wandb flag) for
loss, timing, power, and checkpoint events. Top-k=50 sampling to eliminate
garbage tokens from untrained vocab entries. Tokenizer reads any vocab size.

train.m: only save checkpoint when loss improves (best_loss tracking).
This commit is contained in:
maderix 2026-03-07 02:56:27 -08:00
parent 475348ad14
commit 7d61ee4d25
2 changed files with 231 additions and 57 deletions

View File

@ -18,16 +18,48 @@ try:
except ImportError: except ImportError:
HAS_PSUTIL = False HAS_PSUTIL = False
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12 try:
HD = DIM // HEADS import wandb
CKPT_PATH_STATIC = 'ane_stories110M_ckpt.bin' HAS_WANDB = True
CKPT_PATH_DYNAMIC = 'training_dynamic/ane_stories110M_dyn_ckpt.bin' except ImportError:
CKPT_PATH = CKPT_PATH_STATIC # set in main() based on --dynamic HAS_WANDB = False
# Model configs — set at startup based on --model flag
MODEL_CONFIGS = {
'stories110m': {
'dim': 768, 'hidden': 2048, 'heads': 12, 'kv_heads': 12,
'hd': 64, 'seq': 256, 'vocab': 32000, 'nlayers': 12,
'ckpt_static': 'ane_stories110M_ckpt.bin',
'ckpt_dynamic': 'training_dynamic/ane_stories110M_dyn_ckpt.bin',
},
'qwen3_06b': {
'dim': 1024, 'hidden': 3072, 'heads': 16, 'kv_heads': 8,
'hd': 128, 'seq': 256, 'vocab': 151936, 'nlayers': 28,
'ckpt_static': None,
'ckpt_dynamic': 'training_dynamic/ane_qwen3_06b_dyn_ckpt.bin',
},
}
# Active model dims — set in main()
DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 12, 64, 256, 32000, 12
Q_DIM, KV_DIM, GQA_RATIO = DIM, DIM, 1
CKPT_PATH = 'ane_stories110M_ckpt.bin'
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent / 'assets' / 'models' / 'tokenizer.bin') TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent / 'assets' / 'models' / 'tokenizer.bin')
def set_model_config(name):
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
global Q_DIM, KV_DIM, GQA_RATIO
cfg = MODEL_CONFIGS[name]
DIM, HIDDEN, HEADS, KV_HEADS = cfg['dim'], cfg['hidden'], cfg['heads'], cfg['kv_heads']
HD, SEQ, VOCAB, NLAYERS = cfg['hd'], cfg['seq'], cfg['vocab'], cfg['nlayers']
Q_DIM = HEADS * HD
KV_DIM = KV_HEADS * HD
GQA_RATIO = HEADS // KV_HEADS
class State: class State:
def __init__(self): def __init__(self):
self.active_model = 'stories110m'
self.model_config = {} self.model_config = {}
self.params = {} self.params = {}
self.kernels = {} self.kernels = {}
@ -62,6 +94,7 @@ class State:
self.train_start = None # wall clock when first step seen self.train_start = None # wall clock when first step seen
self.compile_ms = 0.0 # total compile time self.compile_ms = 0.0 # total compile time
S = State() S = State()
@ -71,8 +104,12 @@ class Tokenizer:
self.scores = [] self.scores = []
with open(path, 'rb') as f: with open(path, 'rb') as f:
max_len = struct.unpack('i', f.read(4))[0] max_len = struct.unpack('i', f.read(4))[0]
for _ in range(VOCAB): # Read until EOF — works for any vocab size
score = struct.unpack('f', f.read(4))[0] while True:
data = f.read(4)
if len(data) < 4:
break
score = struct.unpack('f', data)[0]
slen = struct.unpack('i', f.read(4))[0] slen = struct.unpack('i', f.read(4))[0]
tok = f.read(slen).decode('utf-8', errors='replace') tok = f.read(slen).decode('utf-8', errors='replace')
self.vocab.append(tok) self.vocab.append(tok)
@ -104,33 +141,32 @@ def get_tokenizer():
def load_weights_from_ckpt(path): def load_weights_from_ckpt(path):
try: try:
with open(path, 'rb') as f: with open(path, 'rb') as f:
# CkptHdr: 96 bytes (verified with sizeof)
hdr = f.read(96) hdr = f.read(96)
if len(hdr) < 96: if len(hdr) < 96:
return None return None
wq_sz = DIM * DIM wq_sz = Q_DIM * DIM
wo_sz = DIM * DIM wk_sz = KV_DIM * DIM
wv_sz = KV_DIM * DIM
wo_sz = DIM * Q_DIM
w1_sz = HIDDEN * DIM w1_sz = HIDDEN * DIM
w2_sz = DIM * HIDDEN w2_sz = DIM * HIDDEN
w3_sz = HIDDEN * DIM w3_sz = HIDDEN * DIM
# Per-layer: weights + adam state (m,v for each) adam_per_layer = (wq_sz*2 + wk_sz*2 + wv_sz*2 + wo_sz*2 +
adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 +
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2) w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
W = {} W = {}
for L in range(NLAYERS): for L in range(NLAYERS):
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(Q_DIM, DIM).copy()
W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() W[f'Wk{L}'] = np.frombuffer(f.read(wk_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() W[f'Wv{L}'] = np.frombuffer(f.read(wv_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, Q_DIM).copy()
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy() W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy() W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy() W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
# Skip adam state for this layer
f.seek(adam_per_layer * 4, 1) f.seek(adam_per_layer * 4, 1)
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
f.seek(DIM * 2 * 4, 1) # skip rms_final adam f.seek(DIM * 2 * 4, 1)
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy() W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
return W return W
except Exception as e: except Exception as e:
@ -151,20 +187,21 @@ def generate_text(W, max_tokens=64, temperature=0.8):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
if tokenizer is None: if tokenizer is None:
return '[no tokenizer]' return '[no tokenizer]'
if len(tokenizer.vocab) < VOCAB:
return f'[tokenizer has {len(tokenizer.vocab)} tokens, model needs {VOCAB}]'
tokens = [1] tokens = [1]
text_parts = [] text_parts = []
# Precompute RoPE frequencies
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32) freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
for pos in range(SEQ): for pos in range(SEQ):
for i in range(HD // 2): for i in range(HD // 2):
freq = 1.0 / (10000.0 ** (2.0 * i / HD)) freq = 1.0 / (10000.0 ** (2.0 * i / HD))
freqs[pos, i] = pos * freq freqs[pos, i] = pos * freq
# KV cache: per-layer, per-head arrays # KV cache: per-layer, per KV head
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)] k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)] v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS) res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS)
@ -177,13 +214,12 @@ def generate_text(W, max_tokens=64, temperature=0.8):
pos = seq_len - 1 pos = seq_len - 1
for L in range(NLAYERS): for L in range(NLAYERS):
# RMSNorm + QKV
xn = rmsnorm(x, W[f'rms1_{L}']) xn = rmsnorm(x, W[f'rms1_{L}'])
q = W[f'Wq{L}'] @ xn q = W[f'Wq{L}'] @ xn # [Q_DIM]
k = W[f'Wk{L}'] @ xn k = W[f'Wk{L}'] @ xn # [KV_DIM]
v = W[f'Wv{L}'] @ xn v = W[f'Wv{L}'] @ xn # [KV_DIM]
# RoPE # RoPE on Q (HEADS heads) and K (KV_HEADS heads)
for h in range(HEADS): for h in range(HEADS):
for i in range(HD // 2): for i in range(HD // 2):
freq = freqs[pos, i] freq = freqs[pos, i]
@ -191,31 +227,37 @@ def generate_text(W, max_tokens=64, temperature=0.8):
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1] qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
for h in range(KV_HEADS):
for i in range(HD // 2):
freq = freqs[pos, i]
cos_v, sin_v = math.cos(freq), math.sin(freq)
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1] ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
# Append to KV cache and compute attention # Append to KV cache (KV_HEADS entries)
o = np.zeros(DIM, dtype=np.float32) for kv in range(KV_HEADS):
for h in range(HEADS): kh = k[kv * HD:(kv + 1) * HD].reshape(1, HD)
qh = q[h * HD:(h + 1) * HD] vh = v[kv * HD:(kv + 1) * HD].reshape(1, HD)
kh = k[h * HD:(h + 1) * HD].reshape(1, HD) k_cache[L][kv] = np.vstack([k_cache[L][kv], kh])
vh = v[h * HD:(h + 1) * HD].reshape(1, HD) v_cache[L][kv] = np.vstack([v_cache[L][kv], vh])
k_cache[L][h] = np.vstack([k_cache[L][h], kh])
v_cache[L][h] = np.vstack([v_cache[L][h], vh])
# scores: (1, HD) @ (HD, seq_len) -> (seq_len,)
scores = k_cache[L][h] @ qh / math.sqrt(HD)
attn = softmax(scores)
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h]
# Residual + output projection (scaled residual, matches training) # GQA attention: each Q head uses its corresponding KV head
o = np.zeros(Q_DIM, dtype=np.float32)
for h in range(HEADS):
kv = h // GQA_RATIO
qh = q[h * HD:(h + 1) * HD]
scores = k_cache[L][kv] @ qh / math.sqrt(HD)
attn = softmax(scores)
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][kv]
# Residual + output projection
x2 = x + res_alpha * (W[f'Wo{L}'] @ o) x2 = x + res_alpha * (W[f'Wo{L}'] @ o)
# FFN # FFN
x2n = rmsnorm(x2, W[f'rms2_{L}']) x2n = rmsnorm(x2, W[f'rms2_{L}'])
h1 = W[f'W1_{L}'] @ x2n h1 = W[f'W1_{L}'] @ x2n
h3 = W[f'W3_{L}'] @ x2n h3 = W[f'W3_{L}'] @ x2n
# SiLU
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3 h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
ffn_out = W[f'W2_{L}'] @ h1 ffn_out = W[f'W2_{L}'] @ h1
@ -230,8 +272,11 @@ def generate_text(W, max_tokens=64, temperature=0.8):
next_tok = int(np.argmax(logits)) next_tok = int(np.argmax(logits))
else: else:
logits = logits / temperature logits = logits / temperature
probs = softmax(logits) top_k = 50
next_tok = int(np.random.choice(VOCAB, p=probs)) top_idx = np.argpartition(logits, -top_k)[-top_k:]
top_logits = logits[top_idx]
probs = softmax(top_logits)
next_tok = int(top_idx[np.random.choice(len(top_idx), p=probs)])
if next_tok == 2: if next_tok == 2:
break break
@ -291,6 +336,8 @@ def sysmetrics_thread():
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)') RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
RE_CONFIG_GQA = re.compile(r'dim=(\d+) q_dim=(\d+) kv_dim=(\d+) hd=(\d+) hidden=(\d+) seq=(\d+) vocab=(\d+)')
RE_MODEL_NAME = re.compile(r'ANE Dynamic Training: (.+?) \((\d+) layers')
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)') RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing') RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing') RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing')
@ -307,10 +354,67 @@ RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%') RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)') RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms') RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms')
RE_CKPT_SAVED = re.compile(r'\[ckpt saved, best_loss=([\d.]+)\]')
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW') RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW') RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW') RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
USE_WANDB = False
def wandb_log_step():
"""Log current state to wandb. Called after each step update."""
if not USE_WANDB:
return
d = {'step': S.step, 'loss': S.loss, 'best_loss': S.best_loss}
if S.ms_per_step > 0:
d['ms_per_step'] = S.ms_per_step
lr = S.training.get('lr')
if lr:
try:
d['lr'] = float(lr)
except ValueError:
pass
ct = S.component_timing
if ct:
for k, v in ct.items():
if k != '_dynamic':
d[f'timing/{k}'] = v
fl = S.flops
if fl.get('ane_tflops'):
d['perf/ane_tflops'] = fl['ane_tflops']
if fl.get('ane_util'):
d['perf/ane_util_pct'] = fl['ane_util']
pw = S.power
if pw['ane'] > 0:
d['power/ane_w'] = pw['ane']
if pw['cpu'] > 0:
d['power/cpu_w'] = pw['cpu']
wandb.log(d, step=S.step)
def _sync_globals_from_parsed(cfg):
"""Sync dashboard globals from parsed binary output so text gen uses correct dims."""
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
global Q_DIM, KV_DIM, GQA_RATIO
if 'dim' in cfg:
DIM = cfg['dim']
if 'hidden' in cfg:
HIDDEN = cfg['hidden']
if 'heads' in cfg:
HEADS = cfg['heads']
if 'kv_heads' in cfg:
KV_HEADS = cfg['kv_heads']
if 'hd' in cfg:
HD = cfg['hd']
if 'seq' in cfg:
SEQ = cfg['seq']
if 'vocab' in cfg:
VOCAB = cfg['vocab']
if 'layers' in cfg:
NLAYERS = cfg['layers']
Q_DIM = HEADS * HD
KV_DIM = KV_HEADS * HD
GQA_RATIO = HEADS // KV_HEADS if KV_HEADS else 1
def parse_line(line): def parse_line(line):
S.logs.append(line) S.logs.append(line)
# Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...}) # Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...})
@ -339,6 +443,7 @@ def parse_line(line):
ct[k[2:]] = j[k] # strip 't_' prefix ct[k[2:]] = j[k] # strip 't_' prefix
if ct: if ct:
S.component_timing = ct S.component_timing = ct
wandb_log_step()
return return
elif jt == 'batch': elif jt == 'batch':
S.batch_num = j.get('batch', S.batch_num) S.batch_num = j.get('batch', S.batch_num)
@ -356,9 +461,21 @@ def parse_line(line):
return return
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
pass pass
m = RE_MODEL_NAME.search(line)
if m:
S.model_config['name'] = m[1]
S.model_config['layers'] = int(m[2])
m = RE_CONFIG_GQA.search(line)
if m:
d, qd, kvd, hd, hid, seq, voc = map(int, m.groups())
S.model_config.update(dim=d, q_dim=qd, kv_dim=kvd, hd=hd, hidden=hid, seq=seq, vocab=voc,
heads=qd//hd, kv_heads=kvd//hd)
_sync_globals_from_parsed(S.model_config)
return
m = RE_CONFIG.search(line) m = RE_CONFIG.search(line)
if m: if m:
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups()))) S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
_sync_globals_from_parsed(S.model_config)
return return
m = RE_PARAMS.search(line) m = RE_PARAMS.search(line)
if m: if m:
@ -398,6 +515,7 @@ def parse_line(line):
S.ms_per_step = dt * 1000 S.ms_per_step = dt * 1000
S.loss_history.append((S.step, S.loss)) S.loss_history.append((S.step, S.loss))
S.best_loss = min(S.best_loss, S.loss) S.best_loss = min(S.best_loss, S.loss)
wandb_log_step()
return return
m = RE_BATCH.search(line) m = RE_BATCH.search(line)
if m: if m:
@ -434,6 +552,11 @@ def parse_line(line):
S.compiles = int(m[1]) S.compiles = int(m[1])
S.compile_ms += float(m[2]) S.compile_ms += float(m[2])
return return
m = RE_CKPT_SAVED.search(line)
if m:
if USE_WANDB:
wandb.log({'checkpoint/best_loss': float(m[1]), 'checkpoint/saved': True}, step=S.step)
return
m = RE_EFFICIENCY.search(line) m = RE_EFFICIENCY.search(line)
if m: if m:
S.efficiency[m[1].strip()] = m[2].strip() S.efficiency[m[1].strip()] = m[2].strip()
@ -553,14 +676,17 @@ def draw(term):
row = 0 row = 0
# Model Config header # Model Config header — use parsed name from binary if available, else CLI arg
hdr = '\u2500 Model Config ' model_label = S.model_config.get('name', S.active_model)
put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan) keys_hint = '[r]estart [g]en [q]uit'
hdr_text = f'\u2500 {model_label} \u2500\u2500 {keys_hint} '
put(row, 0, '\u250c' + hdr_text + '\u2500' * max(0, w - len(hdr_text) - 2) + '\u2510', term.cyan)
row += 1 row += 1
cfg = S.model_config cfg = S.model_config
if cfg: if cfg:
line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}" gqa_str = f" kv_heads={cfg.get('kv_heads', '')}" if cfg.get('kv_heads', cfg.get('heads', 0)) != cfg.get('heads', 0) else ''
line1 = f"dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')}{gqa_str} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
put(row, 0, '\u2502', term.cyan) put(row, 0, '\u2502', term.cyan)
put(row, 2, line1) put(row, 2, line1)
put(row, w - 1, '\u2502', term.cyan) put(row, w - 1, '\u2502', term.cyan)
@ -778,9 +904,10 @@ def set_nonblock(fd):
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False, def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False,
lr=None, accum=None, no_ane_extras=False, data=None): lr=None, accum=None, no_ane_extras=False, data=None, model=None):
if dynamic: if dynamic:
cmd = 'cd training_dynamic && make 2>&1 && ./train' model_arg = f' MODEL={model}' if model else ''
cmd = f'cd training_dynamic && make{model_arg} 2>&1 && ./train'
elif ane: elif ane:
cmd = 'make train_large_ane 2>&1 && ./train_large_ane' cmd = 'make train_large_ane 2>&1 && ./train_large_ane'
else: else:
@ -818,9 +945,12 @@ def spawn_powermetrics():
return None return None
def main(): def main():
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)') parser = argparse.ArgumentParser(description='ANE Training Dashboard')
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint') parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)') parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)')
parser.add_argument('--model', type=str, default=None,
choices=list(MODEL_CONFIGS.keys()),
help='Model config (default: stories110m for static, qwen3_06b for dynamic)')
parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd') parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd')
parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)') parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)')
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)') parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
@ -831,14 +961,53 @@ def main():
parser.add_argument('--no-generate', action='store_true', help='Disable text generation') parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)') parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
parser.add_argument('--data', type=str, default=None, help='Path to training data shard (.bin)') parser.add_argument('--data', type=str, default=None, help='Path to training data shard (.bin)')
parser.add_argument('--wandb', action='store_true', help='Log to Weights & Biases')
parser.add_argument('--wandb-project', type=str, default='ane-training', help='W&B project name')
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name')
args = parser.parse_args() args = parser.parse_args()
if args.infinite: if args.infinite:
args.steps = 999999999 args.steps = 999999999
S.total_steps = args.steps S.total_steps = args.steps
global CKPT_PATH # Select model
CKPT_PATH = CKPT_PATH_DYNAMIC if args.dynamic else CKPT_PATH_STATIC if args.model is None:
args.model = 'qwen3_06b' if args.dynamic else 'stories110m'
cfg = MODEL_CONFIGS[args.model]
# Auto-enable dynamic for models without a static pipeline
if cfg['ckpt_static'] is None:
args.dynamic = True
set_model_config(args.model)
S.active_model = args.model
# For dynamic: default to --scratch when --resume not given
if args.dynamic and not args.resume:
args.scratch = True
global CKPT_PATH, USE_WANDB
CKPT_PATH = cfg['ckpt_dynamic'] if args.dynamic else cfg['ckpt_static']
# Weights & Biases
if args.wandb:
if not HAS_WANDB:
print('pip install wandb')
sys.exit(1)
run_name = args.wandb_name or f'{args.model}-{"resume" if args.resume else "scratch"}'
wandb.init(
project=args.wandb_project,
name=run_name,
config={
'model': args.model,
'dim': DIM, 'hidden': HIDDEN, 'heads': HEADS,
'kv_heads': KV_HEADS, 'hd': HD, 'seq': SEQ,
'vocab': VOCAB, 'nlayers': NLAYERS,
'q_dim': Q_DIM, 'kv_dim': KV_DIM,
'pipeline': 'dynamic' if args.dynamic else 'static',
'resume': args.resume,
'lr': args.lr, 'accum': args.accum,
'steps': args.steps,
},
)
USE_WANDB = True
term = Terminal() term = Terminal()
procs = [] procs = []
@ -846,7 +1015,7 @@ def main():
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic, train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
scratch=args.scratch, lr=args.lr, accum=args.accum, scratch=args.scratch, lr=args.lr, accum=args.accum,
ane=args.ane, no_ane_extras=args.no_ane_extras, ane=args.ane, no_ane_extras=args.no_ane_extras,
data=args.data) data=args.data, model=args.model)
S.train_pid = train_proc.pid S.train_pid = train_proc.pid
procs.append(train_proc) procs.append(train_proc)
@ -874,6 +1043,8 @@ def main():
p.terminate() p.terminate()
except Exception: except Exception:
pass pass
if USE_WANDB:
wandb.finish()
signal.signal(signal.SIGINT, lambda *a: cleanup()) signal.signal(signal.SIGINT, lambda *a: cleanup())
signal.signal(signal.SIGTERM, lambda *a: cleanup()) signal.signal(signal.SIGTERM, lambda *a: cleanup())
@ -989,11 +1160,11 @@ def main():
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic, train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
lr=args.lr, accum=args.accum, lr=args.lr, accum=args.accum,
ane=args.ane, no_ane_extras=args.no_ane_extras, ane=args.ane, no_ane_extras=args.no_ane_extras,
data=args.data) data=args.data, model=S.active_model)
S.train_pid = train_proc.pid S.train_pid = train_proc.pid
procs = [p for p in procs if p.poll() is None] procs = [p for p in procs if p.poll() is None]
procs.append(train_proc) procs.append(train_proc)
S.logs.append('[dashboard] Restarted with --resume') S.logs.append(f'[dashboard] Restarted {S.active_model} with --resume')
need_draw = True need_draw = True
elif key == 'g': elif key == 'g':
with S.gen_lock: with S.gen_lock:

View File

@ -384,6 +384,7 @@ int main(int argc, char *argv[]) {
dispatch_group_t dw_grp = dispatch_group_create(); dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f; float last_loss = 999.0f;
float best_loss = resume_loss > 0 ? resume_loss : 999.0f;
double total_train_ms = 0; double total_train_ms = 0;
int total_steps_done = 0; int total_steps_done = 0;
uint64_t t_wall_start = mach_absolute_time(); uint64_t t_wall_start = mach_absolute_time();
@ -875,12 +876,14 @@ int main(int argc, char *argv[]) {
memset(gembed, 0, (size_t)VOCAB*DIM*4); memset(gembed, 0, (size_t)VOCAB*DIM*4);
memset(gcembed, 0, (size_t)CV*DIM*4); memset(gcembed, 0, (size_t)CV*DIM*4);
// Checkpoint // Checkpoint only save on best loss
if ((step+1) % 100 == 0) { if ((step+1) % 100 == 0 && last_loss < best_loss) {
best_loss = last_loss;
double wall = tb_ms(mach_absolute_time() - t_wall_start); double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss, save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss,
total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t, total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed); lw, la, rms_final, &arms_final, embed, &aembed);
printf(" [ckpt saved, best_loss=%.4f]\n", best_loss);
} }
} }
} }