ANE/training/dashboard.py

902 lines
34 KiB
Python

"""TUI dashboard for ANE training (train_large). Uses blessed for terminal UI."""
import argparse, fcntl, math, os, re, select, signal, struct, subprocess, sys, time, threading
from collections import deque
from pathlib import Path
import numpy as np
try:
from blessed import Terminal
except ImportError:
print('pip install blessed')
sys.exit(1)
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12
HD = DIM // HEADS
CKPT_PATH = 'ane_stories110M_ckpt.bin'
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent.parent / 'assets' / 'models' / 'tokenizer.bin')
class State:
def __init__(self):
self.model_config = {}
self.params = {}
self.kernels = {}
self.training = {}
self.flops = {}
self.step = 0
self.total_steps = 0
self.loss = 0.0
self.best_loss = float('inf')
self.loss_history = []
self.ms_per_step = 0.0
self.compile_pct = 0.0
self.compiles = 0
self.component_timing = {}
self.power = {'ane': 0.0, 'cpu': 0.0, 'gpu': 0.0}
self.power_history_ane = deque(maxlen=300)
self.power_history_cpu = deque(maxlen=300)
self.logs = deque(maxlen=2000)
self.log_scroll = 0
self.auto_scroll = True
self.batch_num = 0
self.efficiency = {}
self.gen_text = ''
self.gen_step = 0
self.gen_status = 'idle'
self.gen_lock = threading.Lock()
self.cpu_pct_history = deque(maxlen=300)
self.mem_mb_history = deque(maxlen=300)
self.proc_mem_mb_history = deque(maxlen=300)
self.train_pid = None
S = State()
class Tokenizer:
def __init__(self, path):
self.vocab = []
self.scores = []
with open(path, 'rb') as f:
max_len = struct.unpack('i', f.read(4))[0]
for _ in range(VOCAB):
score = struct.unpack('f', f.read(4))[0]
slen = struct.unpack('i', f.read(4))[0]
tok = f.read(slen).decode('utf-8', errors='replace')
self.vocab.append(tok)
self.scores.append(score)
def decode(self, token_id):
if 0 <= token_id < len(self.vocab):
s = self.vocab[token_id]
if s.startswith('<0x') and s.endswith('>'):
try:
return chr(int(s[3:-1], 16))
except:
return s
return s
return ''
_tokenizer = None
def get_tokenizer():
global _tokenizer
if _tokenizer is None:
try:
_tokenizer = Tokenizer(TOKENIZER_PATH)
except Exception as e:
S.logs.append(f'[gen] tokenizer load failed: {e}')
return None
return _tokenizer
def load_weights_from_ckpt(path):
try:
with open(path, 'rb') as f:
# CkptHdr: 96 bytes (verified with sizeof)
hdr = f.read(96)
if len(hdr) < 96:
return None
wq_sz = DIM * DIM
wo_sz = DIM * DIM
w1_sz = HIDDEN * DIM
w2_sz = DIM * HIDDEN
w3_sz = HIDDEN * DIM
# Per-layer: weights + adam state (m,v for each)
adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 +
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
W = {}
for L in range(NLAYERS):
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
# Skip adam state for this layer
f.seek(adam_per_layer * 4, 1)
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
f.seek(DIM * 2 * 4, 1) # skip rms_final adam
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
return W
except Exception as e:
S.logs.append(f'[gen] ckpt load failed: {e}')
return None
def rmsnorm(x, w):
ss = np.mean(x * x) + 1e-5
return x * (1.0 / math.sqrt(ss)) * w
def softmax(x):
x = x - np.max(x)
e = np.exp(x)
return e / np.sum(e)
def generate_text(W, tok, max_tokens=64, temperature=0.8):
tokenizer = get_tokenizer()
if tokenizer is None:
return '[no tokenizer]'
tokens = [1]
text_parts = []
# Precompute RoPE frequencies
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
for pos in range(SEQ):
for i in range(HD // 2):
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
freqs[pos, i] = pos * freq
for step in range(max_tokens):
seq_len = len(tokens)
if seq_len > SEQ:
break
x = W['embed'][tokens[-1]].copy()
for L in range(NLAYERS):
# RMSNorm + QKV
xn = rmsnorm(x, W[f'rms1_{L}'])
q = W[f'Wq{L}'] @ xn
k = W[f'Wk{L}'] @ xn
v = W[f'Wv{L}'] @ xn
# RoPE
pos = seq_len - 1
for h in range(HEADS):
for i in range(HD // 2):
freq = freqs[pos, i]
cos_v, sin_v = math.cos(freq), math.sin(freq)
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
# Attention (single token)
o = np.zeros(DIM, dtype=np.float32)
for h in range(HEADS):
qh = q[h * HD:(h + 1) * HD]
kh = k[h * HD:(h + 1) * HD]
vh = v[h * HD:(h + 1) * HD]
score = np.dot(qh, kh) / math.sqrt(HD)
o[h * HD:(h + 1) * HD] = vh
# Residual + output projection
x2 = x + W[f'Wo{L}'] @ o
# FFN
x2n = rmsnorm(x2, W[f'rms2_{L}'])
h1 = W[f'W1_{L}'] @ x2n
h3 = W[f'W3_{L}'] @ x2n
# SiLU
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
ffn_out = W[f'W2_{L}'] @ h1
x = x2 + ffn_out
x = rmsnorm(x, W['rms_final'])
# Logits
logits = W['embed'] @ x
if temperature < 0.01:
next_tok = int(np.argmax(logits))
else:
logits = logits / temperature
probs = softmax(logits)
next_tok = int(np.random.choice(VOCAB, p=probs))
if next_tok == 2:
break
tokens.append(next_tok)
piece = tokenizer.decode(next_tok)
text_parts.append(piece)
return ''.join(text_parts)
def generation_thread():
last_gen_step = -1
while True:
time.sleep(5)
if S.step <= last_gen_step + 99:
continue
if not os.path.exists(CKPT_PATH):
continue
with S.gen_lock:
S.gen_status = 'generating'
S.gen_step = S.step
try:
W = load_weights_from_ckpt(CKPT_PATH)
if W is None:
with S.gen_lock:
S.gen_status = 'idle'
continue
text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8)
with S.gen_lock:
S.gen_text = text
S.gen_step = S.step
S.gen_status = 'done'
S.step # just to reference
except Exception as e:
with S.gen_lock:
S.gen_text = f'[error: {e}]'
S.gen_status = 'done'
last_gen_step = S.step
def sysmetrics_thread():
while True:
time.sleep(1)
if not HAS_PSUTIL:
continue
now = time.monotonic()
S.cpu_pct_history.append(psutil.cpu_percent(interval=None))
mem = psutil.virtual_memory()
S.mem_mb_history.append(mem.used / (1024 * 1024))
pid = S.train_pid
if pid:
try:
p = psutil.Process(pid)
S.proc_mem_mb_history.append(p.memory_info().rss / (1024 * 1024))
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
RE_ACCUM = re.compile(r'Accum (\d+).*LR=([\d.e+-]+)')
RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)(?:\s+lr=([\d.e+-]+))?(?:\s+([\d.]+)ms/step)?')
RE_BATCH = re.compile(r'\[batch (\d+): compile=([\d.]+)ms train=([\d.]+)ms \(([\d.]+)ms/step\) compiles=(\d+)\]')
RE_TIMING = re.compile(r'ane=([\d.]+) io=([\d.]+) cls=([\d.]+) elem=([\d.]+) rms=([\d.]+) cblas_wait=([\d.]+)')
RE_RESTART = re.compile(r'\[exec\(\) restart step (\d+)')
RE_RESUME = re.compile(r'\[RESUMED step (\d+), loss=([\d.]+)\]')
RE_FLOPS = re.compile(r'FLOPs/step: fwd=([\d.]+)M bwd_dx=([\d.]+)M bwd_dW=([\d.]+)M sdpa_bwd=([\d.]+)M total=([\d.]+)M')
RE_ANE_FLOPS = re.compile(r'ANE FLOPs/step: ([\d.]+)M')
RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
def parse_line(line):
S.logs.append(line)
m = RE_CONFIG.search(line)
if m:
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
return
m = RE_PARAMS.search(line)
if m:
S.params = {'total': float(m[1]), 'transformer': float(m[2]), 'embed': float(m[3])}
return
m = RE_KERNELS.search(line)
if m:
S.kernels = {'total': int(m[1]), 'weight_bearing': int(m[2])}
return
m = RE_ACCUM.search(line)
if m:
S.training = {'accum': int(m[1]), 'lr': m[2]}
return
m = RE_FLOPS.search(line)
if m:
S.flops.update(fwd=float(m[1]), bwd_dx=float(m[2]), bwd_dw=float(m[3]),
sdpa_bwd=float(m[4]), total=float(m[5]))
return
m = RE_ANE_FLOPS.search(line)
if m:
S.flops['ane'] = float(m[1])
return
m = RE_STEP.search(line)
if m:
S.step, S.loss = int(m[1]), float(m[2])
if m[3]:
S.training['lr'] = m[3]
if m[4]:
S.ms_per_step = float(m[4])
S.loss_history.append((S.step, S.loss))
S.best_loss = min(S.best_loss, S.loss)
return
m = RE_BATCH.search(line)
if m:
S.batch_num = int(m[1])
compile_ms, train_ms = float(m[2]), float(m[3])
S.ms_per_step = float(m[4])
S.compiles = int(m[5])
S.compile_pct = 100 * compile_ms / (compile_ms + train_ms) if compile_ms + train_ms > 0 else 0
return
m = RE_TIMING.search(line)
if m:
S.component_timing = dict(zip(['ane', 'io', 'cls', 'elem', 'rms', 'cblas_wait'], map(float, m.groups())))
return
m = RE_ANE_TFLOPS.search(line)
if m:
S.flops['ane_tflops'] = float(m[1])
return
m = RE_ANE_UTIL.search(line)
if m:
S.flops['ane_util'] = float(m[1])
return
m = RE_EFFICIENCY.search(line)
if m:
S.efficiency[m[1].strip()] = m[2].strip()
return
def parse_powermetrics_text(text):
now = time.monotonic()
m = RE_ANE_POWER.search(text)
if m:
S.power['ane'] = float(m[1]) / 1000.0
S.power_history_ane.append((now, S.power['ane']))
m = RE_CPU_POWER.search(text)
if m:
S.power['cpu'] = float(m[1]) / 1000.0
S.power_history_cpu.append((now, S.power['cpu']))
m = RE_GPU_POWER.search(text)
if m:
S.power['gpu'] = float(m[1]) / 1000.0
BRAILLE_BASE = 0x2800
BRAILLE_MAP = [
[1, 8],
[2, 16],
[4, 32],
[64, 128],
]
def braille_chart(values, width, height, label_fmt='{:.1f}', y_range=None):
if not values or width < 8 or height < 2:
return ['(no data)'] * max(1, height)
chart_w = width - 6
if chart_w < 2:
return ['(no data)'] * max(1, height)
points_x = chart_w * 2
points_y = height * 4
data = values[-points_x:] if len(values) > points_x else values
lo, hi = min(data), max(data)
if y_range:
lo, hi = y_range
if hi - lo < 0.001:
lo, hi = lo - 0.5, hi + 0.5
margin = (hi - lo) * 0.05
lo -= margin
hi += margin
grid = [[0] * chart_w for _ in range(height)]
def plot(px, py):
px = max(0, min(points_x - 1, px))
py = max(0, min(points_y - 1, py))
grid[py // 4][px // 2] |= BRAILLE_MAP[py % 4][px % 2]
def val_to_y(v):
return int((1 - (v - lo) / (hi - lo)) * (points_y - 1))
for i in range(len(data)):
if i >= points_x:
break
y0 = val_to_y(data[i])
plot(i, y0)
if i > 0:
y_prev = val_to_y(data[i - 1])
y_lo, y_hi = min(y_prev, y0), max(y_prev, y0)
for yy in range(y_lo, y_hi + 1):
if y_hi != y_lo:
t = (yy - y_prev) / (y0 - y_prev)
xx = int(i - 1 + t)
else:
xx = i
plot(xx, yy)
lines = []
for r in range(height):
if r == 0:
label = label_fmt.format(hi)[:5].rjust(5)
elif r == height - 1:
label = label_fmt.format(lo)[:5].rjust(5)
elif r == height // 2:
label = label_fmt.format((hi + lo) / 2)[:5].rjust(5)
else:
label = ' '
row_str = ''.join(chr(BRAILLE_BASE | grid[r][c]) for c in range(chart_w))
lines.append(f'{label}\u2502{row_str}')
lines.append(' \u2514' + '\u2500' * chart_w)
return lines
def draw(term):
w, h = term.width, term.height
if w < 40 or h < 15:
print(term.home + term.clear + 'Terminal too small', end='', flush=True)
return
buf = []
def put(y, x, text, style=''):
if 0 <= y < h and x < w:
text = text[:w - x]
if style:
buf.append(term.move(y, x) + style + text + term.normal)
return
buf.append(term.move(y, x) + text)
buf.append(term.home + term.clear)
mid_x = w // 2
right_w = w - mid_x - 1
left_w = mid_x - 1
row = 0
# Model Config header
hdr = '\u2500 Model Config '
put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan)
row += 1
cfg = S.model_config
if cfg:
line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
put(row, 0, '\u2502', term.cyan)
put(row, 2, line1)
put(row, w - 1, '\u2502', term.cyan)
row += 1
p, k, t = S.params, S.kernels, S.training
line2 = f"{p.get('total', '?')}M params ({p.get('transformer', '?')}M xfmr + {p.get('embed', '?')}M embed)"
put(row, 0, '\u2502', term.cyan)
put(row, 2, line2)
put(row, w - 1, '\u2502', term.cyan)
row += 1
line3 = f"{k.get('total', '?')} kernels ({k.get('weight_bearing', '?')} wt-bearing) | Accum {t.get('accum', '?')} | Adam LR={t.get('lr', '?')}"
put(row, 0, '\u2502', term.cyan)
put(row, 2, line3)
put(row, w - 1, '\u2502', term.cyan)
row += 1
else:
put(row, 0, '\u2502', term.cyan)
put(row, 2, 'Waiting for model config...')
put(row, w - 1, '\u2502', term.cyan)
row += 1
remaining = h - row - 1
# Allocate: loss curve ~40%, logs ~30%, power/cpu/mem/gen share rest
power_h = max(3, remaining // 8)
gen_h = max(2, remaining // 10)
extra_panels = power_h + power_h + gen_h + 6 # power + cpu/mem + gen + dividers
log_h_min = max(5, remaining // 5)
curve_h = max(5, remaining - extra_panels - log_h_min)
# Loss Curve + Training Stats divider
put(row, 0, '\u251c\u2500 Loss Curve ' + '\u2500' * max(0, left_w - 13) + '\u252c\u2500 Training Stats ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan)
row += 1
# Loss curve
loss_vals = [l for _, l in S.loss_history]
curve_lines = braille_chart(loss_vals, left_w - 1, curve_h)
for i, cl in enumerate(curve_lines):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, 1, cl, term.green)
put(row + i, mid_x, '\u2502', term.cyan)
put(row + i, w - 1, '\u2502', term.cyan)
# Training stats (right panel)
sr = row
step_str = f'{S.step}' + (f'/{S.total_steps}' if S.total_steps and S.total_steps < 999999 else '')
put(sr, mid_x + 1, f' Step: {step_str} Loss: {S.loss:.4f}' if S.loss else ' Step: --', term.yellow)
sr += 1
put(sr, mid_x + 1, f' Best: {S.best_loss:.4f} ms/step: {S.ms_per_step:.1f}' if S.best_loss < float('inf') else ' Best: --')
sr += 1
ane_tflops = S.flops.get('ane_tflops', 0)
ane_util = S.flops.get('ane_util', 0)
if ane_tflops:
put(sr, mid_x + 1, f' ANE: {ane_tflops:.2f}T Compile: {S.compile_pct:.0f}% Util: {ane_util:.1f}%')
else:
put(sr, mid_x + 1, f' Compile: {S.compile_pct:.0f}%')
sr += 1
ct = S.component_timing
if ct:
put(sr, mid_x + 1, f' ane={ct.get("ane", 0):.1f} io={ct.get("io", 0):.1f} cls={ct.get("cls", 0):.1f} elem={ct.get("elem", 0):.1f}')
sr += 1
put(sr, mid_x + 1, f' rms={ct.get("rms", 0):.1f} cblas_wait={ct.get("cblas_wait", 0):.1f} ms/step')
sr += 1
pw = S.power
if any(pw.values()):
put(sr, mid_x + 1, '\u2500 Power ' + '\u2500' * max(0, right_w - 9), term.cyan)
sr += 1
put(sr, mid_x + 1, f' ANE: {pw["ane"]:.1f}W CPU: {pw["cpu"]:.1f}W GPU: {pw["gpu"]:.1f}W', term.magenta)
sr += 1
if S.batch_num:
put(sr, mid_x + 1, f' Batch: {S.batch_num} Compiles: {S.compiles}')
sr += 1
# Fill vertical borders between loss curve and stats
top_end = row + len(curve_lines)
for r in range(row, max(top_end, sr)):
if r >= top_end:
put(r, 0, '\u2502', term.cyan)
if r >= sr:
put(r, mid_x, '\u2502', term.cyan)
put(r, w - 1, '\u2502', term.cyan)
row = max(top_end, sr)
# Power charts
has_power = len(S.power_history_ane) > 1 or len(S.power_history_cpu) > 1
if has_power:
put(row, 0, '\u251c\u2500 ANE Power (W) ' + '\u2500' * max(0, left_w - 16) + '\u252c\u2500 CPU Power (W) ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan)
row += 1
ane_vals = [v for _, v in S.power_history_ane]
cpu_vals = [v for _, v in S.power_history_cpu]
ane_lines = braille_chart(ane_vals, left_w - 1, power_h, label_fmt='{:.1f}')
cpu_lines = braille_chart(cpu_vals, right_w - 1, power_h, label_fmt='{:.1f}')
max_lines = max(len(ane_lines), len(cpu_lines))
while len(ane_lines) < max_lines:
ane_lines.append(' ' * (left_w - 1))
while len(cpu_lines) < max_lines:
cpu_lines.append(' ' * (right_w - 1))
for i in range(max_lines):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, 1, ane_lines[i], term.red)
put(row + i, mid_x, '\u2502', term.cyan)
put(row + i, mid_x + 1, cpu_lines[i], term.blue)
put(row + i, w - 1, '\u2502', term.cyan)
row += max_lines
# CPU / Memory charts
has_sysmetrics = len(S.cpu_pct_history) > 0
if has_sysmetrics:
put(row, 0, '\u251c\u2500 CPU % ' + '\u2500' * max(0, left_w - 8) + '\u252c\u2500 Memory (MB) ' + '\u2500' * max(0, right_w - 15) + '\u2524', term.cyan)
row += 1
cpu_vals = list(S.cpu_pct_history)
mem_vals = list(S.proc_mem_mb_history) if S.proc_mem_mb_history else list(S.mem_mb_history)
mem_label = 'proc' if S.proc_mem_mb_history else 'sys'
cpu_lines = braille_chart(cpu_vals, left_w - 1, power_h, label_fmt='{:.0f}', y_range=(0, 100))
mem_lines = braille_chart(mem_vals, right_w - 1, power_h, label_fmt='{:.0f}')
max_lines = max(len(cpu_lines), len(mem_lines))
while len(cpu_lines) < max_lines:
cpu_lines.append(' ' * (left_w - 1))
while len(mem_lines) < max_lines:
mem_lines.append(' ' * (right_w - 1))
for i in range(max_lines):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, 1, cpu_lines[i], term.yellow)
put(row + i, mid_x, '\u2502', term.cyan)
put(row + i, mid_x + 1, mem_lines[i], term.magenta)
put(row + i, w - 1, '\u2502', term.cyan)
row += max_lines
# Generated text
with S.gen_lock:
gen_text = S.gen_text
gen_step = S.gen_step
gen_status = S.gen_status
if gen_text or gen_status == 'generating':
status_tag = ' (generating...)' if gen_status == 'generating' else f' (step {gen_step})'
put(row, 0, '\u251c\u2500 Generated Text' + status_tag + ' ' + '\u2500' * max(0, w - 20 - len(status_tag)) + '\u2524', term.cyan)
row += 1
if gen_text:
line_w = w - 3
text = gen_text.replace('\n', ' ')
wrapped = [text[i:i + line_w] for i in range(0, len(text), line_w)]
for i, tl in enumerate(wrapped[:gen_h]):
put(row, 0, '\u2502', term.cyan)
put(row, 2, tl, term.white)
put(row, w - 1, '\u2502', term.cyan)
row += 1
else:
put(row, 0, '\u2502', term.cyan)
put(row, 2, '...')
put(row, w - 1, '\u2502', term.cyan)
row += 1
# Logs
log_h = h - row - 1
scroll_hint = ' (scroll) ' if not S.auto_scroll else ' '
put(row, 0, '\u251c\u2500 Logs' + scroll_hint + '\u2500' * max(0, w - 8 - len(scroll_hint)) + '\u2524', term.cyan)
row += 1
logs = list(S.logs)
if log_h > 0 and logs:
if S.auto_scroll:
start = max(0, len(logs) - log_h)
else:
start = max(0, min(S.log_scroll, len(logs) - log_h))
visible = logs[start:start + log_h]
for i, line in enumerate(visible):
put(row + i, 0, '\u2502', term.cyan)
if RE_STEP.search(line):
put(row + i, 1, line[:w - 2], term.yellow)
elif line.strip().startswith('[batch'):
put(row + i, 1, line[:w - 2], term.blue)
elif 'FAIL' in line or 'error' in line.lower():
put(row + i, 1, line[:w - 2], term.red)
else:
put(row + i, 1, line[:w - 2])
put(row + i, w - 1, '\u2502', term.cyan)
for i in range(len(visible), log_h):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, w - 1, '\u2502', term.cyan)
# Bottom border
put(h - 1, 0, '\u2514' + '\u2500' * (w - 2) + '\u2518', term.cyan)
sys.stdout.write(''.join(buf))
sys.stdout.flush()
def set_nonblock(fd):
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
def spawn_training(resume=False, steps=10000, dynamic=False, scratch=False, lr=None, accum=None):
if dynamic:
cmd = 'cd training_dynamic && make 2>&1 && ./train'
else:
cmd = 'make train_large 2>&1 && ./train_large'
if resume:
cmd += ' --resume'
if scratch and dynamic:
cmd += ' --scratch'
if lr is not None:
cmd += f' --lr {lr}'
if accum is not None:
cmd += f' --accum {accum}'
cmd += f' --steps {steps}'
proc = subprocess.Popen(
['bash', '-c', cmd],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
cwd=os.path.dirname(os.path.abspath(__file__)) or '.')
set_nonblock(proc.stdout.fileno())
return proc
def spawn_powermetrics():
try:
proc = subprocess.Popen(
['sudo', 'powermetrics', '--samplers', 'cpu_power,gpu_power,ane_power', '-i', '1000'],
stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
set_nonblock(proc.stdout.fileno())
return proc
except (FileNotFoundError, PermissionError):
return None
def main():
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
parser.add_argument('--dynamic', action='store_true', help='Use v2 dynamic weight pipeline (training_dynamic/)')
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
parser.add_argument('--lr', type=float, default=None, help='Learning rate')
parser.add_argument('--accum', type=int, default=None, help='Gradient accumulation steps')
parser.add_argument('--infinite', action='store_true', help='Train indefinitely')
parser.add_argument('--no-powermetrics', action='store_true')
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
args = parser.parse_args()
if args.infinite:
args.steps = 999999999
S.total_steps = args.steps
term = Terminal()
procs = []
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
scratch=args.scratch, lr=args.lr, accum=args.accum)
S.train_pid = train_proc.pid
procs.append(train_proc)
if HAS_PSUTIL:
psutil.cpu_percent(interval=None) # prime the counter
sys_t = threading.Thread(target=sysmetrics_thread, daemon=True)
sys_t.start()
pm_proc = None
if not args.no_powermetrics:
pm_proc = spawn_powermetrics()
if pm_proc:
procs.append(pm_proc)
if not args.no_generate:
gen_t = threading.Thread(target=generation_thread, daemon=True)
gen_t.start()
pm_buf = ''
train_buf = ''
def cleanup():
for p in procs:
try:
p.terminate()
except Exception:
pass
signal.signal(signal.SIGINT, lambda *a: cleanup())
signal.signal(signal.SIGTERM, lambda *a: cleanup())
resized = [False]
def on_resize(*a):
resized[0] = True
signal.signal(signal.SIGWINCH, on_resize)
with term.fullscreen(), term.cbreak(), term.hidden_cursor():
draw(term)
last_draw = time.monotonic()
while True:
fds = []
fd_map = {}
if train_proc and train_proc.stdout:
fd = train_proc.stdout.fileno()
fds.append(fd)
fd_map[fd] = 'train'
if pm_proc and pm_proc.stdout:
fd = pm_proc.stdout.fileno()
fds.append(fd)
fd_map[fd] = 'pm'
fds.append(sys.stdin.fileno())
fd_map[sys.stdin.fileno()] = 'stdin'
try:
readable, _, _ = select.select(fds, [], [], 0.25)
except (ValueError, OSError):
continue
need_draw = resized[0]
resized[0] = False
train_finished = False
for fd in readable:
kind = fd_map.get(fd)
if kind == 'train':
try:
data = os.read(fd, 65536)
except BlockingIOError:
continue
except (OSError, ValueError):
data = b''
if not data:
if train_proc.poll() is not None:
try:
rest = train_proc.stdout.read()
if rest:
for line in rest.decode('utf-8', errors='replace').split('\n'):
if line:
parse_line(line)
except Exception:
pass
S.logs.append('[dashboard] Training finished. Press q to exit.')
train_finished = True
continue
train_buf += data.decode('utf-8', errors='replace')
while '\n' in train_buf:
line, train_buf = train_buf.split('\n', 1)
parse_line(line)
need_draw = True
elif kind == 'pm':
try:
data = os.read(fd, 65536).decode('utf-8', errors='replace')
except BlockingIOError:
continue
except (OSError, ValueError):
data = ''
if not data:
continue
pm_buf += data
while '\n\n' in pm_buf or '*** ' in pm_buf:
end = pm_buf.find('\n*** ', 1)
if end < 0:
end = pm_buf.find('\n\n', 1)
if end < 0:
break
chunk = pm_buf[:end]
pm_buf = pm_buf[end:]
parse_powermetrics_text(chunk)
if len(pm_buf) > 16384:
pm_buf = pm_buf[-8192:]
need_draw = True
elif kind == 'stdin':
key = term.inkey(timeout=0)
if not key:
continue
if key == 'q':
cleanup()
return
elif key.name == 'KEY_UP':
S.auto_scroll = False
S.log_scroll = max(0, S.log_scroll - 1)
need_draw = True
elif key.name == 'KEY_DOWN':
S.log_scroll += 1
need_draw = True
elif key == 'p':
S.auto_scroll = not S.auto_scroll
if S.auto_scroll:
S.log_scroll = max(0, len(S.logs) - 10)
need_draw = True
elif key == 'r':
if train_proc:
train_proc.terminate()
train_proc.wait()
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
lr=args.lr, accum=args.accum)
S.train_pid = train_proc.pid
procs = [p for p in procs if p.poll() is None]
procs.append(train_proc)
S.logs.append('[dashboard] Restarted with --resume')
need_draw = True
elif key == 'g':
with S.gen_lock:
S.gen_status = 'generating'
S.gen_step = S.step
def force_gen():
try:
W = load_weights_from_ckpt(CKPT_PATH)
if W:
text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8)
with S.gen_lock:
S.gen_text = text
S.gen_step = S.step
S.gen_status = 'done'
except Exception as e:
with S.gen_lock:
S.gen_text = f'[error: {e}]'
S.gen_status = 'done'
threading.Thread(target=force_gen, daemon=True).start()
need_draw = True
now = time.monotonic()
if not need_draw and now - last_draw > 1.0:
need_draw = True
if need_draw and now - last_draw > 0.066:
draw(term)
last_draw = now
if train_finished:
draw(term)
while True:
key = term.inkey(timeout=1)
if key == 'q':
cleanup()
return
if __name__ == '__main__':
main()