mirror of https://github.com/maderix/ANE.git
Multi-model dashboard with GQA, W&B integration, and best-loss checkpointing
Dashboard: multi-model support (Stories110M + Qwen3-0.6B) with GQA-aware text generation and KV cache. Weights & Biases logging (--wandb flag) for loss, timing, power, and checkpoint events. Top-k=50 sampling to eliminate garbage tokens from untrained vocab entries. Tokenizer reads any vocab size. train.m: only save checkpoint when loss improves (best_loss tracking).
This commit is contained in:
parent
475348ad14
commit
7d61ee4d25
|
|
@ -18,16 +18,48 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_PSUTIL = False
|
HAS_PSUTIL = False
|
||||||
|
|
||||||
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12
|
try:
|
||||||
HD = DIM // HEADS
|
import wandb
|
||||||
CKPT_PATH_STATIC = 'ane_stories110M_ckpt.bin'
|
HAS_WANDB = True
|
||||||
CKPT_PATH_DYNAMIC = 'training_dynamic/ane_stories110M_dyn_ckpt.bin'
|
except ImportError:
|
||||||
CKPT_PATH = CKPT_PATH_STATIC # set in main() based on --dynamic
|
HAS_WANDB = False
|
||||||
|
|
||||||
|
# Model configs — set at startup based on --model flag
|
||||||
|
MODEL_CONFIGS = {
|
||||||
|
'stories110m': {
|
||||||
|
'dim': 768, 'hidden': 2048, 'heads': 12, 'kv_heads': 12,
|
||||||
|
'hd': 64, 'seq': 256, 'vocab': 32000, 'nlayers': 12,
|
||||||
|
'ckpt_static': 'ane_stories110M_ckpt.bin',
|
||||||
|
'ckpt_dynamic': 'training_dynamic/ane_stories110M_dyn_ckpt.bin',
|
||||||
|
},
|
||||||
|
'qwen3_06b': {
|
||||||
|
'dim': 1024, 'hidden': 3072, 'heads': 16, 'kv_heads': 8,
|
||||||
|
'hd': 128, 'seq': 256, 'vocab': 151936, 'nlayers': 28,
|
||||||
|
'ckpt_static': None,
|
||||||
|
'ckpt_dynamic': 'training_dynamic/ane_qwen3_06b_dyn_ckpt.bin',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Active model dims — set in main()
|
||||||
|
DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 12, 64, 256, 32000, 12
|
||||||
|
Q_DIM, KV_DIM, GQA_RATIO = DIM, DIM, 1
|
||||||
|
CKPT_PATH = 'ane_stories110M_ckpt.bin'
|
||||||
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent / 'assets' / 'models' / 'tokenizer.bin')
|
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent / 'assets' / 'models' / 'tokenizer.bin')
|
||||||
|
|
||||||
|
def set_model_config(name):
|
||||||
|
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
|
||||||
|
global Q_DIM, KV_DIM, GQA_RATIO
|
||||||
|
cfg = MODEL_CONFIGS[name]
|
||||||
|
DIM, HIDDEN, HEADS, KV_HEADS = cfg['dim'], cfg['hidden'], cfg['heads'], cfg['kv_heads']
|
||||||
|
HD, SEQ, VOCAB, NLAYERS = cfg['hd'], cfg['seq'], cfg['vocab'], cfg['nlayers']
|
||||||
|
Q_DIM = HEADS * HD
|
||||||
|
KV_DIM = KV_HEADS * HD
|
||||||
|
GQA_RATIO = HEADS // KV_HEADS
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.active_model = 'stories110m'
|
||||||
self.model_config = {}
|
self.model_config = {}
|
||||||
self.params = {}
|
self.params = {}
|
||||||
self.kernels = {}
|
self.kernels = {}
|
||||||
|
|
@ -62,6 +94,7 @@ class State:
|
||||||
self.train_start = None # wall clock when first step seen
|
self.train_start = None # wall clock when first step seen
|
||||||
self.compile_ms = 0.0 # total compile time
|
self.compile_ms = 0.0 # total compile time
|
||||||
|
|
||||||
|
|
||||||
S = State()
|
S = State()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -71,8 +104,12 @@ class Tokenizer:
|
||||||
self.scores = []
|
self.scores = []
|
||||||
with open(path, 'rb') as f:
|
with open(path, 'rb') as f:
|
||||||
max_len = struct.unpack('i', f.read(4))[0]
|
max_len = struct.unpack('i', f.read(4))[0]
|
||||||
for _ in range(VOCAB):
|
# Read until EOF — works for any vocab size
|
||||||
score = struct.unpack('f', f.read(4))[0]
|
while True:
|
||||||
|
data = f.read(4)
|
||||||
|
if len(data) < 4:
|
||||||
|
break
|
||||||
|
score = struct.unpack('f', data)[0]
|
||||||
slen = struct.unpack('i', f.read(4))[0]
|
slen = struct.unpack('i', f.read(4))[0]
|
||||||
tok = f.read(slen).decode('utf-8', errors='replace')
|
tok = f.read(slen).decode('utf-8', errors='replace')
|
||||||
self.vocab.append(tok)
|
self.vocab.append(tok)
|
||||||
|
|
@ -104,33 +141,32 @@ def get_tokenizer():
|
||||||
def load_weights_from_ckpt(path):
|
def load_weights_from_ckpt(path):
|
||||||
try:
|
try:
|
||||||
with open(path, 'rb') as f:
|
with open(path, 'rb') as f:
|
||||||
# CkptHdr: 96 bytes (verified with sizeof)
|
|
||||||
hdr = f.read(96)
|
hdr = f.read(96)
|
||||||
if len(hdr) < 96:
|
if len(hdr) < 96:
|
||||||
return None
|
return None
|
||||||
wq_sz = DIM * DIM
|
wq_sz = Q_DIM * DIM
|
||||||
wo_sz = DIM * DIM
|
wk_sz = KV_DIM * DIM
|
||||||
|
wv_sz = KV_DIM * DIM
|
||||||
|
wo_sz = DIM * Q_DIM
|
||||||
w1_sz = HIDDEN * DIM
|
w1_sz = HIDDEN * DIM
|
||||||
w2_sz = DIM * HIDDEN
|
w2_sz = DIM * HIDDEN
|
||||||
w3_sz = HIDDEN * DIM
|
w3_sz = HIDDEN * DIM
|
||||||
# Per-layer: weights + adam state (m,v for each)
|
adam_per_layer = (wq_sz*2 + wk_sz*2 + wv_sz*2 + wo_sz*2 +
|
||||||
adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 +
|
|
||||||
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
|
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
|
||||||
W = {}
|
W = {}
|
||||||
for L in range(NLAYERS):
|
for L in range(NLAYERS):
|
||||||
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(Q_DIM, DIM).copy()
|
||||||
W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
W[f'Wk{L}'] = np.frombuffer(f.read(wk_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
|
||||||
W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
W[f'Wv{L}'] = np.frombuffer(f.read(wv_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
|
||||||
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, Q_DIM).copy()
|
||||||
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
||||||
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
|
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
|
||||||
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
||||||
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||||
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||||
# Skip adam state for this layer
|
|
||||||
f.seek(adam_per_layer * 4, 1)
|
f.seek(adam_per_layer * 4, 1)
|
||||||
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||||
f.seek(DIM * 2 * 4, 1) # skip rms_final adam
|
f.seek(DIM * 2 * 4, 1)
|
||||||
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
|
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
|
||||||
return W
|
return W
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -151,20 +187,21 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
||||||
tokenizer = get_tokenizer()
|
tokenizer = get_tokenizer()
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
return '[no tokenizer]'
|
return '[no tokenizer]'
|
||||||
|
if len(tokenizer.vocab) < VOCAB:
|
||||||
|
return f'[tokenizer has {len(tokenizer.vocab)} tokens, model needs {VOCAB}]'
|
||||||
|
|
||||||
tokens = [1]
|
tokens = [1]
|
||||||
text_parts = []
|
text_parts = []
|
||||||
|
|
||||||
# Precompute RoPE frequencies
|
|
||||||
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
|
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
|
||||||
for pos in range(SEQ):
|
for pos in range(SEQ):
|
||||||
for i in range(HD // 2):
|
for i in range(HD // 2):
|
||||||
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
|
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
|
||||||
freqs[pos, i] = pos * freq
|
freqs[pos, i] = pos * freq
|
||||||
|
|
||||||
# KV cache: per-layer, per-head arrays
|
# KV cache: per-layer, per KV head
|
||||||
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
|
||||||
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
|
||||||
|
|
||||||
res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS)
|
res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS)
|
||||||
|
|
||||||
|
|
@ -177,13 +214,12 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
||||||
pos = seq_len - 1
|
pos = seq_len - 1
|
||||||
|
|
||||||
for L in range(NLAYERS):
|
for L in range(NLAYERS):
|
||||||
# RMSNorm + QKV
|
|
||||||
xn = rmsnorm(x, W[f'rms1_{L}'])
|
xn = rmsnorm(x, W[f'rms1_{L}'])
|
||||||
q = W[f'Wq{L}'] @ xn
|
q = W[f'Wq{L}'] @ xn # [Q_DIM]
|
||||||
k = W[f'Wk{L}'] @ xn
|
k = W[f'Wk{L}'] @ xn # [KV_DIM]
|
||||||
v = W[f'Wv{L}'] @ xn
|
v = W[f'Wv{L}'] @ xn # [KV_DIM]
|
||||||
|
|
||||||
# RoPE
|
# RoPE on Q (HEADS heads) and K (KV_HEADS heads)
|
||||||
for h in range(HEADS):
|
for h in range(HEADS):
|
||||||
for i in range(HD // 2):
|
for i in range(HD // 2):
|
||||||
freq = freqs[pos, i]
|
freq = freqs[pos, i]
|
||||||
|
|
@ -191,31 +227,37 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
||||||
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
|
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
|
||||||
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
|
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
|
||||||
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
|
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
|
||||||
|
for h in range(KV_HEADS):
|
||||||
|
for i in range(HD // 2):
|
||||||
|
freq = freqs[pos, i]
|
||||||
|
cos_v, sin_v = math.cos(freq), math.sin(freq)
|
||||||
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
|
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
|
||||||
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
|
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
|
||||||
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
|
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
|
||||||
|
|
||||||
# Append to KV cache and compute attention
|
# Append to KV cache (KV_HEADS entries)
|
||||||
o = np.zeros(DIM, dtype=np.float32)
|
for kv in range(KV_HEADS):
|
||||||
for h in range(HEADS):
|
kh = k[kv * HD:(kv + 1) * HD].reshape(1, HD)
|
||||||
qh = q[h * HD:(h + 1) * HD]
|
vh = v[kv * HD:(kv + 1) * HD].reshape(1, HD)
|
||||||
kh = k[h * HD:(h + 1) * HD].reshape(1, HD)
|
k_cache[L][kv] = np.vstack([k_cache[L][kv], kh])
|
||||||
vh = v[h * HD:(h + 1) * HD].reshape(1, HD)
|
v_cache[L][kv] = np.vstack([v_cache[L][kv], vh])
|
||||||
k_cache[L][h] = np.vstack([k_cache[L][h], kh])
|
|
||||||
v_cache[L][h] = np.vstack([v_cache[L][h], vh])
|
|
||||||
# scores: (1, HD) @ (HD, seq_len) -> (seq_len,)
|
|
||||||
scores = k_cache[L][h] @ qh / math.sqrt(HD)
|
|
||||||
attn = softmax(scores)
|
|
||||||
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h]
|
|
||||||
|
|
||||||
# Residual + output projection (scaled residual, matches training)
|
# GQA attention: each Q head uses its corresponding KV head
|
||||||
|
o = np.zeros(Q_DIM, dtype=np.float32)
|
||||||
|
for h in range(HEADS):
|
||||||
|
kv = h // GQA_RATIO
|
||||||
|
qh = q[h * HD:(h + 1) * HD]
|
||||||
|
scores = k_cache[L][kv] @ qh / math.sqrt(HD)
|
||||||
|
attn = softmax(scores)
|
||||||
|
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][kv]
|
||||||
|
|
||||||
|
# Residual + output projection
|
||||||
x2 = x + res_alpha * (W[f'Wo{L}'] @ o)
|
x2 = x + res_alpha * (W[f'Wo{L}'] @ o)
|
||||||
|
|
||||||
# FFN
|
# FFN
|
||||||
x2n = rmsnorm(x2, W[f'rms2_{L}'])
|
x2n = rmsnorm(x2, W[f'rms2_{L}'])
|
||||||
h1 = W[f'W1_{L}'] @ x2n
|
h1 = W[f'W1_{L}'] @ x2n
|
||||||
h3 = W[f'W3_{L}'] @ x2n
|
h3 = W[f'W3_{L}'] @ x2n
|
||||||
# SiLU
|
|
||||||
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
|
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
|
||||||
ffn_out = W[f'W2_{L}'] @ h1
|
ffn_out = W[f'W2_{L}'] @ h1
|
||||||
|
|
||||||
|
|
@ -230,8 +272,11 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
||||||
next_tok = int(np.argmax(logits))
|
next_tok = int(np.argmax(logits))
|
||||||
else:
|
else:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
probs = softmax(logits)
|
top_k = 50
|
||||||
next_tok = int(np.random.choice(VOCAB, p=probs))
|
top_idx = np.argpartition(logits, -top_k)[-top_k:]
|
||||||
|
top_logits = logits[top_idx]
|
||||||
|
probs = softmax(top_logits)
|
||||||
|
next_tok = int(top_idx[np.random.choice(len(top_idx), p=probs)])
|
||||||
|
|
||||||
if next_tok == 2:
|
if next_tok == 2:
|
||||||
break
|
break
|
||||||
|
|
@ -291,6 +336,8 @@ def sysmetrics_thread():
|
||||||
|
|
||||||
|
|
||||||
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
|
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
|
||||||
|
RE_CONFIG_GQA = re.compile(r'dim=(\d+) q_dim=(\d+) kv_dim=(\d+) hd=(\d+) hidden=(\d+) seq=(\d+) vocab=(\d+)')
|
||||||
|
RE_MODEL_NAME = re.compile(r'ANE Dynamic Training: (.+?) \((\d+) layers')
|
||||||
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
|
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
|
||||||
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
|
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
|
||||||
RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing')
|
RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing')
|
||||||
|
|
@ -307,10 +354,67 @@ RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
|
||||||
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
|
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
|
||||||
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
|
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
|
||||||
RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms')
|
RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms')
|
||||||
|
RE_CKPT_SAVED = re.compile(r'\[ckpt saved, best_loss=([\d.]+)\]')
|
||||||
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
|
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
|
||||||
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
|
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
|
||||||
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
|
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
|
||||||
|
|
||||||
|
USE_WANDB = False
|
||||||
|
|
||||||
|
def wandb_log_step():
|
||||||
|
"""Log current state to wandb. Called after each step update."""
|
||||||
|
if not USE_WANDB:
|
||||||
|
return
|
||||||
|
d = {'step': S.step, 'loss': S.loss, 'best_loss': S.best_loss}
|
||||||
|
if S.ms_per_step > 0:
|
||||||
|
d['ms_per_step'] = S.ms_per_step
|
||||||
|
lr = S.training.get('lr')
|
||||||
|
if lr:
|
||||||
|
try:
|
||||||
|
d['lr'] = float(lr)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
ct = S.component_timing
|
||||||
|
if ct:
|
||||||
|
for k, v in ct.items():
|
||||||
|
if k != '_dynamic':
|
||||||
|
d[f'timing/{k}'] = v
|
||||||
|
fl = S.flops
|
||||||
|
if fl.get('ane_tflops'):
|
||||||
|
d['perf/ane_tflops'] = fl['ane_tflops']
|
||||||
|
if fl.get('ane_util'):
|
||||||
|
d['perf/ane_util_pct'] = fl['ane_util']
|
||||||
|
pw = S.power
|
||||||
|
if pw['ane'] > 0:
|
||||||
|
d['power/ane_w'] = pw['ane']
|
||||||
|
if pw['cpu'] > 0:
|
||||||
|
d['power/cpu_w'] = pw['cpu']
|
||||||
|
wandb.log(d, step=S.step)
|
||||||
|
|
||||||
|
def _sync_globals_from_parsed(cfg):
|
||||||
|
"""Sync dashboard globals from parsed binary output so text gen uses correct dims."""
|
||||||
|
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
|
||||||
|
global Q_DIM, KV_DIM, GQA_RATIO
|
||||||
|
if 'dim' in cfg:
|
||||||
|
DIM = cfg['dim']
|
||||||
|
if 'hidden' in cfg:
|
||||||
|
HIDDEN = cfg['hidden']
|
||||||
|
if 'heads' in cfg:
|
||||||
|
HEADS = cfg['heads']
|
||||||
|
if 'kv_heads' in cfg:
|
||||||
|
KV_HEADS = cfg['kv_heads']
|
||||||
|
if 'hd' in cfg:
|
||||||
|
HD = cfg['hd']
|
||||||
|
if 'seq' in cfg:
|
||||||
|
SEQ = cfg['seq']
|
||||||
|
if 'vocab' in cfg:
|
||||||
|
VOCAB = cfg['vocab']
|
||||||
|
if 'layers' in cfg:
|
||||||
|
NLAYERS = cfg['layers']
|
||||||
|
Q_DIM = HEADS * HD
|
||||||
|
KV_DIM = KV_HEADS * HD
|
||||||
|
GQA_RATIO = HEADS // KV_HEADS if KV_HEADS else 1
|
||||||
|
|
||||||
def parse_line(line):
|
def parse_line(line):
|
||||||
S.logs.append(line)
|
S.logs.append(line)
|
||||||
# Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...})
|
# Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...})
|
||||||
|
|
@ -339,6 +443,7 @@ def parse_line(line):
|
||||||
ct[k[2:]] = j[k] # strip 't_' prefix
|
ct[k[2:]] = j[k] # strip 't_' prefix
|
||||||
if ct:
|
if ct:
|
||||||
S.component_timing = ct
|
S.component_timing = ct
|
||||||
|
wandb_log_step()
|
||||||
return
|
return
|
||||||
elif jt == 'batch':
|
elif jt == 'batch':
|
||||||
S.batch_num = j.get('batch', S.batch_num)
|
S.batch_num = j.get('batch', S.batch_num)
|
||||||
|
|
@ -356,9 +461,21 @@ def parse_line(line):
|
||||||
return
|
return
|
||||||
except (json.JSONDecodeError, KeyError):
|
except (json.JSONDecodeError, KeyError):
|
||||||
pass
|
pass
|
||||||
|
m = RE_MODEL_NAME.search(line)
|
||||||
|
if m:
|
||||||
|
S.model_config['name'] = m[1]
|
||||||
|
S.model_config['layers'] = int(m[2])
|
||||||
|
m = RE_CONFIG_GQA.search(line)
|
||||||
|
if m:
|
||||||
|
d, qd, kvd, hd, hid, seq, voc = map(int, m.groups())
|
||||||
|
S.model_config.update(dim=d, q_dim=qd, kv_dim=kvd, hd=hd, hidden=hid, seq=seq, vocab=voc,
|
||||||
|
heads=qd//hd, kv_heads=kvd//hd)
|
||||||
|
_sync_globals_from_parsed(S.model_config)
|
||||||
|
return
|
||||||
m = RE_CONFIG.search(line)
|
m = RE_CONFIG.search(line)
|
||||||
if m:
|
if m:
|
||||||
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
|
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
|
||||||
|
_sync_globals_from_parsed(S.model_config)
|
||||||
return
|
return
|
||||||
m = RE_PARAMS.search(line)
|
m = RE_PARAMS.search(line)
|
||||||
if m:
|
if m:
|
||||||
|
|
@ -398,6 +515,7 @@ def parse_line(line):
|
||||||
S.ms_per_step = dt * 1000
|
S.ms_per_step = dt * 1000
|
||||||
S.loss_history.append((S.step, S.loss))
|
S.loss_history.append((S.step, S.loss))
|
||||||
S.best_loss = min(S.best_loss, S.loss)
|
S.best_loss = min(S.best_loss, S.loss)
|
||||||
|
wandb_log_step()
|
||||||
return
|
return
|
||||||
m = RE_BATCH.search(line)
|
m = RE_BATCH.search(line)
|
||||||
if m:
|
if m:
|
||||||
|
|
@ -434,6 +552,11 @@ def parse_line(line):
|
||||||
S.compiles = int(m[1])
|
S.compiles = int(m[1])
|
||||||
S.compile_ms += float(m[2])
|
S.compile_ms += float(m[2])
|
||||||
return
|
return
|
||||||
|
m = RE_CKPT_SAVED.search(line)
|
||||||
|
if m:
|
||||||
|
if USE_WANDB:
|
||||||
|
wandb.log({'checkpoint/best_loss': float(m[1]), 'checkpoint/saved': True}, step=S.step)
|
||||||
|
return
|
||||||
m = RE_EFFICIENCY.search(line)
|
m = RE_EFFICIENCY.search(line)
|
||||||
if m:
|
if m:
|
||||||
S.efficiency[m[1].strip()] = m[2].strip()
|
S.efficiency[m[1].strip()] = m[2].strip()
|
||||||
|
|
@ -553,14 +676,17 @@ def draw(term):
|
||||||
|
|
||||||
row = 0
|
row = 0
|
||||||
|
|
||||||
# Model Config header
|
# Model Config header — use parsed name from binary if available, else CLI arg
|
||||||
hdr = '\u2500 Model Config '
|
model_label = S.model_config.get('name', S.active_model)
|
||||||
put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan)
|
keys_hint = '[r]estart [g]en [q]uit'
|
||||||
|
hdr_text = f'\u2500 {model_label} \u2500\u2500 {keys_hint} '
|
||||||
|
put(row, 0, '\u250c' + hdr_text + '\u2500' * max(0, w - len(hdr_text) - 2) + '\u2510', term.cyan)
|
||||||
row += 1
|
row += 1
|
||||||
|
|
||||||
cfg = S.model_config
|
cfg = S.model_config
|
||||||
if cfg:
|
if cfg:
|
||||||
line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
|
gqa_str = f" kv_heads={cfg.get('kv_heads', '')}" if cfg.get('kv_heads', cfg.get('heads', 0)) != cfg.get('heads', 0) else ''
|
||||||
|
line1 = f"dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')}{gqa_str} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
|
||||||
put(row, 0, '\u2502', term.cyan)
|
put(row, 0, '\u2502', term.cyan)
|
||||||
put(row, 2, line1)
|
put(row, 2, line1)
|
||||||
put(row, w - 1, '\u2502', term.cyan)
|
put(row, w - 1, '\u2502', term.cyan)
|
||||||
|
|
@ -778,9 +904,10 @@ def set_nonblock(fd):
|
||||||
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
|
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
|
||||||
|
|
||||||
def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False,
|
def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False,
|
||||||
lr=None, accum=None, no_ane_extras=False, data=None):
|
lr=None, accum=None, no_ane_extras=False, data=None, model=None):
|
||||||
if dynamic:
|
if dynamic:
|
||||||
cmd = 'cd training_dynamic && make 2>&1 && ./train'
|
model_arg = f' MODEL={model}' if model else ''
|
||||||
|
cmd = f'cd training_dynamic && make{model_arg} 2>&1 && ./train'
|
||||||
elif ane:
|
elif ane:
|
||||||
cmd = 'make train_large_ane 2>&1 && ./train_large_ane'
|
cmd = 'make train_large_ane 2>&1 && ./train_large_ane'
|
||||||
else:
|
else:
|
||||||
|
|
@ -818,9 +945,12 @@ def spawn_powermetrics():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
|
parser = argparse.ArgumentParser(description='ANE Training Dashboard')
|
||||||
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
|
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
|
||||||
parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)')
|
parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)')
|
||||||
|
parser.add_argument('--model', type=str, default=None,
|
||||||
|
choices=list(MODEL_CONFIGS.keys()),
|
||||||
|
help='Model config (default: stories110m for static, qwen3_06b for dynamic)')
|
||||||
parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd')
|
parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd')
|
||||||
parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)')
|
parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)')
|
||||||
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
|
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
|
||||||
|
|
@ -831,14 +961,53 @@ def main():
|
||||||
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
|
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
|
||||||
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
|
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
|
||||||
parser.add_argument('--data', type=str, default=None, help='Path to training data shard (.bin)')
|
parser.add_argument('--data', type=str, default=None, help='Path to training data shard (.bin)')
|
||||||
|
parser.add_argument('--wandb', action='store_true', help='Log to Weights & Biases')
|
||||||
|
parser.add_argument('--wandb-project', type=str, default='ane-training', help='W&B project name')
|
||||||
|
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.infinite:
|
if args.infinite:
|
||||||
args.steps = 999999999
|
args.steps = 999999999
|
||||||
S.total_steps = args.steps
|
S.total_steps = args.steps
|
||||||
|
|
||||||
global CKPT_PATH
|
# Select model
|
||||||
CKPT_PATH = CKPT_PATH_DYNAMIC if args.dynamic else CKPT_PATH_STATIC
|
if args.model is None:
|
||||||
|
args.model = 'qwen3_06b' if args.dynamic else 'stories110m'
|
||||||
|
cfg = MODEL_CONFIGS[args.model]
|
||||||
|
# Auto-enable dynamic for models without a static pipeline
|
||||||
|
if cfg['ckpt_static'] is None:
|
||||||
|
args.dynamic = True
|
||||||
|
set_model_config(args.model)
|
||||||
|
S.active_model = args.model
|
||||||
|
# For dynamic: default to --scratch when --resume not given
|
||||||
|
if args.dynamic and not args.resume:
|
||||||
|
args.scratch = True
|
||||||
|
|
||||||
|
global CKPT_PATH, USE_WANDB
|
||||||
|
CKPT_PATH = cfg['ckpt_dynamic'] if args.dynamic else cfg['ckpt_static']
|
||||||
|
|
||||||
|
# Weights & Biases
|
||||||
|
if args.wandb:
|
||||||
|
if not HAS_WANDB:
|
||||||
|
print('pip install wandb')
|
||||||
|
sys.exit(1)
|
||||||
|
run_name = args.wandb_name or f'{args.model}-{"resume" if args.resume else "scratch"}'
|
||||||
|
wandb.init(
|
||||||
|
project=args.wandb_project,
|
||||||
|
name=run_name,
|
||||||
|
config={
|
||||||
|
'model': args.model,
|
||||||
|
'dim': DIM, 'hidden': HIDDEN, 'heads': HEADS,
|
||||||
|
'kv_heads': KV_HEADS, 'hd': HD, 'seq': SEQ,
|
||||||
|
'vocab': VOCAB, 'nlayers': NLAYERS,
|
||||||
|
'q_dim': Q_DIM, 'kv_dim': KV_DIM,
|
||||||
|
'pipeline': 'dynamic' if args.dynamic else 'static',
|
||||||
|
'resume': args.resume,
|
||||||
|
'lr': args.lr, 'accum': args.accum,
|
||||||
|
'steps': args.steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
USE_WANDB = True
|
||||||
|
|
||||||
term = Terminal()
|
term = Terminal()
|
||||||
procs = []
|
procs = []
|
||||||
|
|
@ -846,7 +1015,7 @@ def main():
|
||||||
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
|
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
|
||||||
scratch=args.scratch, lr=args.lr, accum=args.accum,
|
scratch=args.scratch, lr=args.lr, accum=args.accum,
|
||||||
ane=args.ane, no_ane_extras=args.no_ane_extras,
|
ane=args.ane, no_ane_extras=args.no_ane_extras,
|
||||||
data=args.data)
|
data=args.data, model=args.model)
|
||||||
S.train_pid = train_proc.pid
|
S.train_pid = train_proc.pid
|
||||||
procs.append(train_proc)
|
procs.append(train_proc)
|
||||||
|
|
||||||
|
|
@ -874,6 +1043,8 @@ def main():
|
||||||
p.terminate()
|
p.terminate()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
if USE_WANDB:
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, lambda *a: cleanup())
|
signal.signal(signal.SIGINT, lambda *a: cleanup())
|
||||||
signal.signal(signal.SIGTERM, lambda *a: cleanup())
|
signal.signal(signal.SIGTERM, lambda *a: cleanup())
|
||||||
|
|
@ -989,11 +1160,11 @@ def main():
|
||||||
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
|
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
|
||||||
lr=args.lr, accum=args.accum,
|
lr=args.lr, accum=args.accum,
|
||||||
ane=args.ane, no_ane_extras=args.no_ane_extras,
|
ane=args.ane, no_ane_extras=args.no_ane_extras,
|
||||||
data=args.data)
|
data=args.data, model=S.active_model)
|
||||||
S.train_pid = train_proc.pid
|
S.train_pid = train_proc.pid
|
||||||
procs = [p for p in procs if p.poll() is None]
|
procs = [p for p in procs if p.poll() is None]
|
||||||
procs.append(train_proc)
|
procs.append(train_proc)
|
||||||
S.logs.append('[dashboard] Restarted with --resume')
|
S.logs.append(f'[dashboard] Restarted {S.active_model} with --resume')
|
||||||
need_draw = True
|
need_draw = True
|
||||||
elif key == 'g':
|
elif key == 'g':
|
||||||
with S.gen_lock:
|
with S.gen_lock:
|
||||||
|
|
|
||||||
|
|
@ -384,6 +384,7 @@ int main(int argc, char *argv[]) {
|
||||||
dispatch_group_t dw_grp = dispatch_group_create();
|
dispatch_group_t dw_grp = dispatch_group_create();
|
||||||
|
|
||||||
float last_loss = 999.0f;
|
float last_loss = 999.0f;
|
||||||
|
float best_loss = resume_loss > 0 ? resume_loss : 999.0f;
|
||||||
double total_train_ms = 0;
|
double total_train_ms = 0;
|
||||||
int total_steps_done = 0;
|
int total_steps_done = 0;
|
||||||
uint64_t t_wall_start = mach_absolute_time();
|
uint64_t t_wall_start = mach_absolute_time();
|
||||||
|
|
@ -875,12 +876,14 @@ int main(int argc, char *argv[]) {
|
||||||
memset(gembed, 0, (size_t)VOCAB*DIM*4);
|
memset(gembed, 0, (size_t)VOCAB*DIM*4);
|
||||||
memset(gcembed, 0, (size_t)CV*DIM*4);
|
memset(gcembed, 0, (size_t)CV*DIM*4);
|
||||||
|
|
||||||
// Checkpoint
|
// Checkpoint — only save on best loss
|
||||||
if ((step+1) % 100 == 0) {
|
if ((step+1) % 100 == 0 && last_loss < best_loss) {
|
||||||
|
best_loss = last_loss;
|
||||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||||
save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss,
|
save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss,
|
||||||
total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t,
|
total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t,
|
||||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||||
|
printf(" [ckpt saved, best_loss=%.4f]\n", best_loss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue