mirror of https://github.com/maderix/ANE.git
Multi-model dashboard with GQA, W&B integration, and best-loss checkpointing
Dashboard: multi-model support (Stories110M + Qwen3-0.6B) with GQA-aware text generation and KV cache. Weights & Biases logging (--wandb flag) for loss, timing, power, and checkpoint events. Top-k=50 sampling to eliminate garbage tokens from untrained vocab entries. Tokenizer reads any vocab size. train.m: only save checkpoint when loss improves (best_loss tracking).
This commit is contained in:
parent
475348ad14
commit
7d61ee4d25
|
|
@ -18,16 +18,48 @@ try:
|
|||
except ImportError:
|
||||
HAS_PSUTIL = False
|
||||
|
||||
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12
|
||||
HD = DIM // HEADS
|
||||
CKPT_PATH_STATIC = 'ane_stories110M_ckpt.bin'
|
||||
CKPT_PATH_DYNAMIC = 'training_dynamic/ane_stories110M_dyn_ckpt.bin'
|
||||
CKPT_PATH = CKPT_PATH_STATIC # set in main() based on --dynamic
|
||||
try:
|
||||
import wandb
|
||||
HAS_WANDB = True
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
|
||||
# Model configs — set at startup based on --model flag
|
||||
MODEL_CONFIGS = {
|
||||
'stories110m': {
|
||||
'dim': 768, 'hidden': 2048, 'heads': 12, 'kv_heads': 12,
|
||||
'hd': 64, 'seq': 256, 'vocab': 32000, 'nlayers': 12,
|
||||
'ckpt_static': 'ane_stories110M_ckpt.bin',
|
||||
'ckpt_dynamic': 'training_dynamic/ane_stories110M_dyn_ckpt.bin',
|
||||
},
|
||||
'qwen3_06b': {
|
||||
'dim': 1024, 'hidden': 3072, 'heads': 16, 'kv_heads': 8,
|
||||
'hd': 128, 'seq': 256, 'vocab': 151936, 'nlayers': 28,
|
||||
'ckpt_static': None,
|
||||
'ckpt_dynamic': 'training_dynamic/ane_qwen3_06b_dyn_ckpt.bin',
|
||||
},
|
||||
}
|
||||
|
||||
# Active model dims — set in main()
|
||||
DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 12, 64, 256, 32000, 12
|
||||
Q_DIM, KV_DIM, GQA_RATIO = DIM, DIM, 1
|
||||
CKPT_PATH = 'ane_stories110M_ckpt.bin'
|
||||
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent / 'assets' / 'models' / 'tokenizer.bin')
|
||||
|
||||
def set_model_config(name):
|
||||
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
|
||||
global Q_DIM, KV_DIM, GQA_RATIO
|
||||
cfg = MODEL_CONFIGS[name]
|
||||
DIM, HIDDEN, HEADS, KV_HEADS = cfg['dim'], cfg['hidden'], cfg['heads'], cfg['kv_heads']
|
||||
HD, SEQ, VOCAB, NLAYERS = cfg['hd'], cfg['seq'], cfg['vocab'], cfg['nlayers']
|
||||
Q_DIM = HEADS * HD
|
||||
KV_DIM = KV_HEADS * HD
|
||||
GQA_RATIO = HEADS // KV_HEADS
|
||||
|
||||
|
||||
class State:
|
||||
def __init__(self):
|
||||
self.active_model = 'stories110m'
|
||||
self.model_config = {}
|
||||
self.params = {}
|
||||
self.kernels = {}
|
||||
|
|
@ -62,6 +94,7 @@ class State:
|
|||
self.train_start = None # wall clock when first step seen
|
||||
self.compile_ms = 0.0 # total compile time
|
||||
|
||||
|
||||
S = State()
|
||||
|
||||
|
||||
|
|
@ -71,8 +104,12 @@ class Tokenizer:
|
|||
self.scores = []
|
||||
with open(path, 'rb') as f:
|
||||
max_len = struct.unpack('i', f.read(4))[0]
|
||||
for _ in range(VOCAB):
|
||||
score = struct.unpack('f', f.read(4))[0]
|
||||
# Read until EOF — works for any vocab size
|
||||
while True:
|
||||
data = f.read(4)
|
||||
if len(data) < 4:
|
||||
break
|
||||
score = struct.unpack('f', data)[0]
|
||||
slen = struct.unpack('i', f.read(4))[0]
|
||||
tok = f.read(slen).decode('utf-8', errors='replace')
|
||||
self.vocab.append(tok)
|
||||
|
|
@ -104,33 +141,32 @@ def get_tokenizer():
|
|||
def load_weights_from_ckpt(path):
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
# CkptHdr: 96 bytes (verified with sizeof)
|
||||
hdr = f.read(96)
|
||||
if len(hdr) < 96:
|
||||
return None
|
||||
wq_sz = DIM * DIM
|
||||
wo_sz = DIM * DIM
|
||||
wq_sz = Q_DIM * DIM
|
||||
wk_sz = KV_DIM * DIM
|
||||
wv_sz = KV_DIM * DIM
|
||||
wo_sz = DIM * Q_DIM
|
||||
w1_sz = HIDDEN * DIM
|
||||
w2_sz = DIM * HIDDEN
|
||||
w3_sz = HIDDEN * DIM
|
||||
# Per-layer: weights + adam state (m,v for each)
|
||||
adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 +
|
||||
adam_per_layer = (wq_sz*2 + wk_sz*2 + wv_sz*2 + wo_sz*2 +
|
||||
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
|
||||
W = {}
|
||||
for L in range(NLAYERS):
|
||||
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(Q_DIM, DIM).copy()
|
||||
W[f'Wk{L}'] = np.frombuffer(f.read(wk_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
|
||||
W[f'Wv{L}'] = np.frombuffer(f.read(wv_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
|
||||
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, Q_DIM).copy()
|
||||
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
||||
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
|
||||
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
||||
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||
# Skip adam state for this layer
|
||||
f.seek(adam_per_layer * 4, 1)
|
||||
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||
f.seek(DIM * 2 * 4, 1) # skip rms_final adam
|
||||
f.seek(DIM * 2 * 4, 1)
|
||||
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
|
||||
return W
|
||||
except Exception as e:
|
||||
|
|
@ -151,20 +187,21 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
tokenizer = get_tokenizer()
|
||||
if tokenizer is None:
|
||||
return '[no tokenizer]'
|
||||
if len(tokenizer.vocab) < VOCAB:
|
||||
return f'[tokenizer has {len(tokenizer.vocab)} tokens, model needs {VOCAB}]'
|
||||
|
||||
tokens = [1]
|
||||
text_parts = []
|
||||
|
||||
# Precompute RoPE frequencies
|
||||
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
|
||||
for pos in range(SEQ):
|
||||
for i in range(HD // 2):
|
||||
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
|
||||
freqs[pos, i] = pos * freq
|
||||
|
||||
# KV cache: per-layer, per-head arrays
|
||||
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||
# KV cache: per-layer, per KV head
|
||||
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
|
||||
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
|
||||
|
||||
res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS)
|
||||
|
||||
|
|
@ -177,13 +214,12 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
pos = seq_len - 1
|
||||
|
||||
for L in range(NLAYERS):
|
||||
# RMSNorm + QKV
|
||||
xn = rmsnorm(x, W[f'rms1_{L}'])
|
||||
q = W[f'Wq{L}'] @ xn
|
||||
k = W[f'Wk{L}'] @ xn
|
||||
v = W[f'Wv{L}'] @ xn
|
||||
q = W[f'Wq{L}'] @ xn # [Q_DIM]
|
||||
k = W[f'Wk{L}'] @ xn # [KV_DIM]
|
||||
v = W[f'Wv{L}'] @ xn # [KV_DIM]
|
||||
|
||||
# RoPE
|
||||
# RoPE on Q (HEADS heads) and K (KV_HEADS heads)
|
||||
for h in range(HEADS):
|
||||
for i in range(HD // 2):
|
||||
freq = freqs[pos, i]
|
||||
|
|
@ -191,31 +227,37 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
|
||||
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
|
||||
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
|
||||
for h in range(KV_HEADS):
|
||||
for i in range(HD // 2):
|
||||
freq = freqs[pos, i]
|
||||
cos_v, sin_v = math.cos(freq), math.sin(freq)
|
||||
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
|
||||
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
|
||||
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
|
||||
|
||||
# Append to KV cache and compute attention
|
||||
o = np.zeros(DIM, dtype=np.float32)
|
||||
for h in range(HEADS):
|
||||
qh = q[h * HD:(h + 1) * HD]
|
||||
kh = k[h * HD:(h + 1) * HD].reshape(1, HD)
|
||||
vh = v[h * HD:(h + 1) * HD].reshape(1, HD)
|
||||
k_cache[L][h] = np.vstack([k_cache[L][h], kh])
|
||||
v_cache[L][h] = np.vstack([v_cache[L][h], vh])
|
||||
# scores: (1, HD) @ (HD, seq_len) -> (seq_len,)
|
||||
scores = k_cache[L][h] @ qh / math.sqrt(HD)
|
||||
attn = softmax(scores)
|
||||
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h]
|
||||
# Append to KV cache (KV_HEADS entries)
|
||||
for kv in range(KV_HEADS):
|
||||
kh = k[kv * HD:(kv + 1) * HD].reshape(1, HD)
|
||||
vh = v[kv * HD:(kv + 1) * HD].reshape(1, HD)
|
||||
k_cache[L][kv] = np.vstack([k_cache[L][kv], kh])
|
||||
v_cache[L][kv] = np.vstack([v_cache[L][kv], vh])
|
||||
|
||||
# Residual + output projection (scaled residual, matches training)
|
||||
# GQA attention: each Q head uses its corresponding KV head
|
||||
o = np.zeros(Q_DIM, dtype=np.float32)
|
||||
for h in range(HEADS):
|
||||
kv = h // GQA_RATIO
|
||||
qh = q[h * HD:(h + 1) * HD]
|
||||
scores = k_cache[L][kv] @ qh / math.sqrt(HD)
|
||||
attn = softmax(scores)
|
||||
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][kv]
|
||||
|
||||
# Residual + output projection
|
||||
x2 = x + res_alpha * (W[f'Wo{L}'] @ o)
|
||||
|
||||
# FFN
|
||||
x2n = rmsnorm(x2, W[f'rms2_{L}'])
|
||||
h1 = W[f'W1_{L}'] @ x2n
|
||||
h3 = W[f'W3_{L}'] @ x2n
|
||||
# SiLU
|
||||
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
|
||||
ffn_out = W[f'W2_{L}'] @ h1
|
||||
|
||||
|
|
@ -230,8 +272,11 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
next_tok = int(np.argmax(logits))
|
||||
else:
|
||||
logits = logits / temperature
|
||||
probs = softmax(logits)
|
||||
next_tok = int(np.random.choice(VOCAB, p=probs))
|
||||
top_k = 50
|
||||
top_idx = np.argpartition(logits, -top_k)[-top_k:]
|
||||
top_logits = logits[top_idx]
|
||||
probs = softmax(top_logits)
|
||||
next_tok = int(top_idx[np.random.choice(len(top_idx), p=probs)])
|
||||
|
||||
if next_tok == 2:
|
||||
break
|
||||
|
|
@ -291,6 +336,8 @@ def sysmetrics_thread():
|
|||
|
||||
|
||||
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
|
||||
RE_CONFIG_GQA = re.compile(r'dim=(\d+) q_dim=(\d+) kv_dim=(\d+) hd=(\d+) hidden=(\d+) seq=(\d+) vocab=(\d+)')
|
||||
RE_MODEL_NAME = re.compile(r'ANE Dynamic Training: (.+?) \((\d+) layers')
|
||||
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
|
||||
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
|
||||
RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing')
|
||||
|
|
@ -307,10 +354,67 @@ RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
|
|||
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
|
||||
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
|
||||
RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms')
|
||||
RE_CKPT_SAVED = re.compile(r'\[ckpt saved, best_loss=([\d.]+)\]')
|
||||
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
|
||||
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
|
||||
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
|
||||
|
||||
USE_WANDB = False
|
||||
|
||||
def wandb_log_step():
|
||||
"""Log current state to wandb. Called after each step update."""
|
||||
if not USE_WANDB:
|
||||
return
|
||||
d = {'step': S.step, 'loss': S.loss, 'best_loss': S.best_loss}
|
||||
if S.ms_per_step > 0:
|
||||
d['ms_per_step'] = S.ms_per_step
|
||||
lr = S.training.get('lr')
|
||||
if lr:
|
||||
try:
|
||||
d['lr'] = float(lr)
|
||||
except ValueError:
|
||||
pass
|
||||
ct = S.component_timing
|
||||
if ct:
|
||||
for k, v in ct.items():
|
||||
if k != '_dynamic':
|
||||
d[f'timing/{k}'] = v
|
||||
fl = S.flops
|
||||
if fl.get('ane_tflops'):
|
||||
d['perf/ane_tflops'] = fl['ane_tflops']
|
||||
if fl.get('ane_util'):
|
||||
d['perf/ane_util_pct'] = fl['ane_util']
|
||||
pw = S.power
|
||||
if pw['ane'] > 0:
|
||||
d['power/ane_w'] = pw['ane']
|
||||
if pw['cpu'] > 0:
|
||||
d['power/cpu_w'] = pw['cpu']
|
||||
wandb.log(d, step=S.step)
|
||||
|
||||
def _sync_globals_from_parsed(cfg):
|
||||
"""Sync dashboard globals from parsed binary output so text gen uses correct dims."""
|
||||
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
|
||||
global Q_DIM, KV_DIM, GQA_RATIO
|
||||
if 'dim' in cfg:
|
||||
DIM = cfg['dim']
|
||||
if 'hidden' in cfg:
|
||||
HIDDEN = cfg['hidden']
|
||||
if 'heads' in cfg:
|
||||
HEADS = cfg['heads']
|
||||
if 'kv_heads' in cfg:
|
||||
KV_HEADS = cfg['kv_heads']
|
||||
if 'hd' in cfg:
|
||||
HD = cfg['hd']
|
||||
if 'seq' in cfg:
|
||||
SEQ = cfg['seq']
|
||||
if 'vocab' in cfg:
|
||||
VOCAB = cfg['vocab']
|
||||
if 'layers' in cfg:
|
||||
NLAYERS = cfg['layers']
|
||||
Q_DIM = HEADS * HD
|
||||
KV_DIM = KV_HEADS * HD
|
||||
GQA_RATIO = HEADS // KV_HEADS if KV_HEADS else 1
|
||||
|
||||
def parse_line(line):
|
||||
S.logs.append(line)
|
||||
# Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...})
|
||||
|
|
@ -339,6 +443,7 @@ def parse_line(line):
|
|||
ct[k[2:]] = j[k] # strip 't_' prefix
|
||||
if ct:
|
||||
S.component_timing = ct
|
||||
wandb_log_step()
|
||||
return
|
||||
elif jt == 'batch':
|
||||
S.batch_num = j.get('batch', S.batch_num)
|
||||
|
|
@ -356,9 +461,21 @@ def parse_line(line):
|
|||
return
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
m = RE_MODEL_NAME.search(line)
|
||||
if m:
|
||||
S.model_config['name'] = m[1]
|
||||
S.model_config['layers'] = int(m[2])
|
||||
m = RE_CONFIG_GQA.search(line)
|
||||
if m:
|
||||
d, qd, kvd, hd, hid, seq, voc = map(int, m.groups())
|
||||
S.model_config.update(dim=d, q_dim=qd, kv_dim=kvd, hd=hd, hidden=hid, seq=seq, vocab=voc,
|
||||
heads=qd//hd, kv_heads=kvd//hd)
|
||||
_sync_globals_from_parsed(S.model_config)
|
||||
return
|
||||
m = RE_CONFIG.search(line)
|
||||
if m:
|
||||
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
|
||||
_sync_globals_from_parsed(S.model_config)
|
||||
return
|
||||
m = RE_PARAMS.search(line)
|
||||
if m:
|
||||
|
|
@ -398,6 +515,7 @@ def parse_line(line):
|
|||
S.ms_per_step = dt * 1000
|
||||
S.loss_history.append((S.step, S.loss))
|
||||
S.best_loss = min(S.best_loss, S.loss)
|
||||
wandb_log_step()
|
||||
return
|
||||
m = RE_BATCH.search(line)
|
||||
if m:
|
||||
|
|
@ -434,6 +552,11 @@ def parse_line(line):
|
|||
S.compiles = int(m[1])
|
||||
S.compile_ms += float(m[2])
|
||||
return
|
||||
m = RE_CKPT_SAVED.search(line)
|
||||
if m:
|
||||
if USE_WANDB:
|
||||
wandb.log({'checkpoint/best_loss': float(m[1]), 'checkpoint/saved': True}, step=S.step)
|
||||
return
|
||||
m = RE_EFFICIENCY.search(line)
|
||||
if m:
|
||||
S.efficiency[m[1].strip()] = m[2].strip()
|
||||
|
|
@ -553,14 +676,17 @@ def draw(term):
|
|||
|
||||
row = 0
|
||||
|
||||
# Model Config header
|
||||
hdr = '\u2500 Model Config '
|
||||
put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan)
|
||||
# Model Config header — use parsed name from binary if available, else CLI arg
|
||||
model_label = S.model_config.get('name', S.active_model)
|
||||
keys_hint = '[r]estart [g]en [q]uit'
|
||||
hdr_text = f'\u2500 {model_label} \u2500\u2500 {keys_hint} '
|
||||
put(row, 0, '\u250c' + hdr_text + '\u2500' * max(0, w - len(hdr_text) - 2) + '\u2510', term.cyan)
|
||||
row += 1
|
||||
|
||||
cfg = S.model_config
|
||||
if cfg:
|
||||
line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
|
||||
gqa_str = f" kv_heads={cfg.get('kv_heads', '')}" if cfg.get('kv_heads', cfg.get('heads', 0)) != cfg.get('heads', 0) else ''
|
||||
line1 = f"dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')}{gqa_str} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
|
||||
put(row, 0, '\u2502', term.cyan)
|
||||
put(row, 2, line1)
|
||||
put(row, w - 1, '\u2502', term.cyan)
|
||||
|
|
@ -778,9 +904,10 @@ def set_nonblock(fd):
|
|||
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
|
||||
|
||||
def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False,
|
||||
lr=None, accum=None, no_ane_extras=False, data=None):
|
||||
lr=None, accum=None, no_ane_extras=False, data=None, model=None):
|
||||
if dynamic:
|
||||
cmd = 'cd training_dynamic && make 2>&1 && ./train'
|
||||
model_arg = f' MODEL={model}' if model else ''
|
||||
cmd = f'cd training_dynamic && make{model_arg} 2>&1 && ./train'
|
||||
elif ane:
|
||||
cmd = 'make train_large_ane 2>&1 && ./train_large_ane'
|
||||
else:
|
||||
|
|
@ -818,9 +945,12 @@ def spawn_powermetrics():
|
|||
return None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
|
||||
parser = argparse.ArgumentParser(description='ANE Training Dashboard')
|
||||
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
|
||||
parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)')
|
||||
parser.add_argument('--model', type=str, default=None,
|
||||
choices=list(MODEL_CONFIGS.keys()),
|
||||
help='Model config (default: stories110m for static, qwen3_06b for dynamic)')
|
||||
parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd')
|
||||
parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)')
|
||||
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
|
||||
|
|
@ -831,14 +961,53 @@ def main():
|
|||
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
|
||||
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
|
||||
parser.add_argument('--data', type=str, default=None, help='Path to training data shard (.bin)')
|
||||
parser.add_argument('--wandb', action='store_true', help='Log to Weights & Biases')
|
||||
parser.add_argument('--wandb-project', type=str, default='ane-training', help='W&B project name')
|
||||
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.infinite:
|
||||
args.steps = 999999999
|
||||
S.total_steps = args.steps
|
||||
|
||||
global CKPT_PATH
|
||||
CKPT_PATH = CKPT_PATH_DYNAMIC if args.dynamic else CKPT_PATH_STATIC
|
||||
# Select model
|
||||
if args.model is None:
|
||||
args.model = 'qwen3_06b' if args.dynamic else 'stories110m'
|
||||
cfg = MODEL_CONFIGS[args.model]
|
||||
# Auto-enable dynamic for models without a static pipeline
|
||||
if cfg['ckpt_static'] is None:
|
||||
args.dynamic = True
|
||||
set_model_config(args.model)
|
||||
S.active_model = args.model
|
||||
# For dynamic: default to --scratch when --resume not given
|
||||
if args.dynamic and not args.resume:
|
||||
args.scratch = True
|
||||
|
||||
global CKPT_PATH, USE_WANDB
|
||||
CKPT_PATH = cfg['ckpt_dynamic'] if args.dynamic else cfg['ckpt_static']
|
||||
|
||||
# Weights & Biases
|
||||
if args.wandb:
|
||||
if not HAS_WANDB:
|
||||
print('pip install wandb')
|
||||
sys.exit(1)
|
||||
run_name = args.wandb_name or f'{args.model}-{"resume" if args.resume else "scratch"}'
|
||||
wandb.init(
|
||||
project=args.wandb_project,
|
||||
name=run_name,
|
||||
config={
|
||||
'model': args.model,
|
||||
'dim': DIM, 'hidden': HIDDEN, 'heads': HEADS,
|
||||
'kv_heads': KV_HEADS, 'hd': HD, 'seq': SEQ,
|
||||
'vocab': VOCAB, 'nlayers': NLAYERS,
|
||||
'q_dim': Q_DIM, 'kv_dim': KV_DIM,
|
||||
'pipeline': 'dynamic' if args.dynamic else 'static',
|
||||
'resume': args.resume,
|
||||
'lr': args.lr, 'accum': args.accum,
|
||||
'steps': args.steps,
|
||||
},
|
||||
)
|
||||
USE_WANDB = True
|
||||
|
||||
term = Terminal()
|
||||
procs = []
|
||||
|
|
@ -846,7 +1015,7 @@ def main():
|
|||
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
|
||||
scratch=args.scratch, lr=args.lr, accum=args.accum,
|
||||
ane=args.ane, no_ane_extras=args.no_ane_extras,
|
||||
data=args.data)
|
||||
data=args.data, model=args.model)
|
||||
S.train_pid = train_proc.pid
|
||||
procs.append(train_proc)
|
||||
|
||||
|
|
@ -874,6 +1043,8 @@ def main():
|
|||
p.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
if USE_WANDB:
|
||||
wandb.finish()
|
||||
|
||||
signal.signal(signal.SIGINT, lambda *a: cleanup())
|
||||
signal.signal(signal.SIGTERM, lambda *a: cleanup())
|
||||
|
|
@ -989,11 +1160,11 @@ def main():
|
|||
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
|
||||
lr=args.lr, accum=args.accum,
|
||||
ane=args.ane, no_ane_extras=args.no_ane_extras,
|
||||
data=args.data)
|
||||
data=args.data, model=S.active_model)
|
||||
S.train_pid = train_proc.pid
|
||||
procs = [p for p in procs if p.poll() is None]
|
||||
procs.append(train_proc)
|
||||
S.logs.append('[dashboard] Restarted with --resume')
|
||||
S.logs.append(f'[dashboard] Restarted {S.active_model} with --resume')
|
||||
need_draw = True
|
||||
elif key == 'g':
|
||||
with S.gen_lock:
|
||||
|
|
|
|||
|
|
@ -384,6 +384,7 @@ int main(int argc, char *argv[]) {
|
|||
dispatch_group_t dw_grp = dispatch_group_create();
|
||||
|
||||
float last_loss = 999.0f;
|
||||
float best_loss = resume_loss > 0 ? resume_loss : 999.0f;
|
||||
double total_train_ms = 0;
|
||||
int total_steps_done = 0;
|
||||
uint64_t t_wall_start = mach_absolute_time();
|
||||
|
|
@ -875,12 +876,14 @@ int main(int argc, char *argv[]) {
|
|||
memset(gembed, 0, (size_t)VOCAB*DIM*4);
|
||||
memset(gcembed, 0, (size_t)CV*DIM*4);
|
||||
|
||||
// Checkpoint
|
||||
if ((step+1) % 100 == 0) {
|
||||
// Checkpoint — only save on best loss
|
||||
if ((step+1) % 100 == 0 && last_loss < best_loss) {
|
||||
best_loss = last_loss;
|
||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||
save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss,
|
||||
total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t,
|
||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||
printf(" [ckpt saved, best_loss=%.4f]\n", best_loss);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue