diff --git a/training/dashboard.py b/training/dashboard.py index 55e8bb9..3506901 100644 --- a/training/dashboard.py +++ b/training/dashboard.py @@ -18,16 +18,48 @@ try: except ImportError: HAS_PSUTIL = False -DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12 -HD = DIM // HEADS -CKPT_PATH_STATIC = 'ane_stories110M_ckpt.bin' -CKPT_PATH_DYNAMIC = 'training_dynamic/ane_stories110M_dyn_ckpt.bin' -CKPT_PATH = CKPT_PATH_STATIC # set in main() based on --dynamic +try: + import wandb + HAS_WANDB = True +except ImportError: + HAS_WANDB = False + +# Model configs — set at startup based on --model flag +MODEL_CONFIGS = { + 'stories110m': { + 'dim': 768, 'hidden': 2048, 'heads': 12, 'kv_heads': 12, + 'hd': 64, 'seq': 256, 'vocab': 32000, 'nlayers': 12, + 'ckpt_static': 'ane_stories110M_ckpt.bin', + 'ckpt_dynamic': 'training_dynamic/ane_stories110M_dyn_ckpt.bin', + }, + 'qwen3_06b': { + 'dim': 1024, 'hidden': 3072, 'heads': 16, 'kv_heads': 8, + 'hd': 128, 'seq': 256, 'vocab': 151936, 'nlayers': 28, + 'ckpt_static': None, + 'ckpt_dynamic': 'training_dynamic/ane_qwen3_06b_dyn_ckpt.bin', + }, +} + +# Active model dims — set in main() +DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 12, 64, 256, 32000, 12 +Q_DIM, KV_DIM, GQA_RATIO = DIM, DIM, 1 +CKPT_PATH = 'ane_stories110M_ckpt.bin' TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent / 'assets' / 'models' / 'tokenizer.bin') +def set_model_config(name): + global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS + global Q_DIM, KV_DIM, GQA_RATIO + cfg = MODEL_CONFIGS[name] + DIM, HIDDEN, HEADS, KV_HEADS = cfg['dim'], cfg['hidden'], cfg['heads'], cfg['kv_heads'] + HD, SEQ, VOCAB, NLAYERS = cfg['hd'], cfg['seq'], cfg['vocab'], cfg['nlayers'] + Q_DIM = HEADS * HD + KV_DIM = KV_HEADS * HD + GQA_RATIO = HEADS // KV_HEADS + class State: def __init__(self): + self.active_model = 'stories110m' self.model_config = {} self.params = {} self.kernels = {} @@ -62,6 +94,7 @@ class State: self.train_start = None # wall clock when first step seen self.compile_ms = 0.0 # total compile time + S = State() @@ -71,8 +104,12 @@ class Tokenizer: self.scores = [] with open(path, 'rb') as f: max_len = struct.unpack('i', f.read(4))[0] - for _ in range(VOCAB): - score = struct.unpack('f', f.read(4))[0] + # Read until EOF — works for any vocab size + while True: + data = f.read(4) + if len(data) < 4: + break + score = struct.unpack('f', data)[0] slen = struct.unpack('i', f.read(4))[0] tok = f.read(slen).decode('utf-8', errors='replace') self.vocab.append(tok) @@ -104,33 +141,32 @@ def get_tokenizer(): def load_weights_from_ckpt(path): try: with open(path, 'rb') as f: - # CkptHdr: 96 bytes (verified with sizeof) hdr = f.read(96) if len(hdr) < 96: return None - wq_sz = DIM * DIM - wo_sz = DIM * DIM + wq_sz = Q_DIM * DIM + wk_sz = KV_DIM * DIM + wv_sz = KV_DIM * DIM + wo_sz = DIM * Q_DIM w1_sz = HIDDEN * DIM w2_sz = DIM * HIDDEN w3_sz = HIDDEN * DIM - # Per-layer: weights + adam state (m,v for each) - adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 + + adam_per_layer = (wq_sz*2 + wk_sz*2 + wv_sz*2 + wo_sz*2 + w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2) W = {} for L in range(NLAYERS): - W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() - W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() - W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() - W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() + W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(Q_DIM, DIM).copy() + W[f'Wk{L}'] = np.frombuffer(f.read(wk_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy() + W[f'Wv{L}'] = np.frombuffer(f.read(wv_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy() + W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, Q_DIM).copy() W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy() W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy() W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy() W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() - # Skip adam state for this layer f.seek(adam_per_layer * 4, 1) W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() - f.seek(DIM * 2 * 4, 1) # skip rms_final adam + f.seek(DIM * 2 * 4, 1) W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy() return W except Exception as e: @@ -151,20 +187,21 @@ def generate_text(W, max_tokens=64, temperature=0.8): tokenizer = get_tokenizer() if tokenizer is None: return '[no tokenizer]' + if len(tokenizer.vocab) < VOCAB: + return f'[tokenizer has {len(tokenizer.vocab)} tokens, model needs {VOCAB}]' tokens = [1] text_parts = [] - # Precompute RoPE frequencies freqs = np.zeros((SEQ, HD // 2), dtype=np.float32) for pos in range(SEQ): for i in range(HD // 2): freq = 1.0 / (10000.0 ** (2.0 * i / HD)) freqs[pos, i] = pos * freq - # KV cache: per-layer, per-head arrays - k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)] - v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)] + # KV cache: per-layer, per KV head + k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)] + v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)] res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS) @@ -177,13 +214,12 @@ def generate_text(W, max_tokens=64, temperature=0.8): pos = seq_len - 1 for L in range(NLAYERS): - # RMSNorm + QKV xn = rmsnorm(x, W[f'rms1_{L}']) - q = W[f'Wq{L}'] @ xn - k = W[f'Wk{L}'] @ xn - v = W[f'Wv{L}'] @ xn + q = W[f'Wq{L}'] @ xn # [Q_DIM] + k = W[f'Wk{L}'] @ xn # [KV_DIM] + v = W[f'Wv{L}'] @ xn # [KV_DIM] - # RoPE + # RoPE on Q (HEADS heads) and K (KV_HEADS heads) for h in range(HEADS): for i in range(HD // 2): freq = freqs[pos, i] @@ -191,31 +227,37 @@ def generate_text(W, max_tokens=64, temperature=0.8): qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1] q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v + for h in range(KV_HEADS): + for i in range(HD // 2): + freq = freqs[pos, i] + cos_v, sin_v = math.cos(freq), math.sin(freq) ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1] k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v - # Append to KV cache and compute attention - o = np.zeros(DIM, dtype=np.float32) - for h in range(HEADS): - qh = q[h * HD:(h + 1) * HD] - kh = k[h * HD:(h + 1) * HD].reshape(1, HD) - vh = v[h * HD:(h + 1) * HD].reshape(1, HD) - k_cache[L][h] = np.vstack([k_cache[L][h], kh]) - v_cache[L][h] = np.vstack([v_cache[L][h], vh]) - # scores: (1, HD) @ (HD, seq_len) -> (seq_len,) - scores = k_cache[L][h] @ qh / math.sqrt(HD) - attn = softmax(scores) - o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h] + # Append to KV cache (KV_HEADS entries) + for kv in range(KV_HEADS): + kh = k[kv * HD:(kv + 1) * HD].reshape(1, HD) + vh = v[kv * HD:(kv + 1) * HD].reshape(1, HD) + k_cache[L][kv] = np.vstack([k_cache[L][kv], kh]) + v_cache[L][kv] = np.vstack([v_cache[L][kv], vh]) - # Residual + output projection (scaled residual, matches training) + # GQA attention: each Q head uses its corresponding KV head + o = np.zeros(Q_DIM, dtype=np.float32) + for h in range(HEADS): + kv = h // GQA_RATIO + qh = q[h * HD:(h + 1) * HD] + scores = k_cache[L][kv] @ qh / math.sqrt(HD) + attn = softmax(scores) + o[h * HD:(h + 1) * HD] = attn @ v_cache[L][kv] + + # Residual + output projection x2 = x + res_alpha * (W[f'Wo{L}'] @ o) # FFN x2n = rmsnorm(x2, W[f'rms2_{L}']) h1 = W[f'W1_{L}'] @ x2n h3 = W[f'W3_{L}'] @ x2n - # SiLU h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3 ffn_out = W[f'W2_{L}'] @ h1 @@ -230,8 +272,11 @@ def generate_text(W, max_tokens=64, temperature=0.8): next_tok = int(np.argmax(logits)) else: logits = logits / temperature - probs = softmax(logits) - next_tok = int(np.random.choice(VOCAB, p=probs)) + top_k = 50 + top_idx = np.argpartition(logits, -top_k)[-top_k:] + top_logits = logits[top_idx] + probs = softmax(top_logits) + next_tok = int(top_idx[np.random.choice(len(top_idx), p=probs)]) if next_tok == 2: break @@ -291,6 +336,8 @@ def sysmetrics_thread(): RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)') +RE_CONFIG_GQA = re.compile(r'dim=(\d+) q_dim=(\d+) kv_dim=(\d+) hd=(\d+) hidden=(\d+) seq=(\d+) vocab=(\d+)') +RE_MODEL_NAME = re.compile(r'ANE Dynamic Training: (.+?) \((\d+) layers') RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)') RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing') RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing') @@ -307,10 +354,67 @@ RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)') RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%') RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)') RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms') +RE_CKPT_SAVED = re.compile(r'\[ckpt saved, best_loss=([\d.]+)\]') RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW') RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW') RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW') +USE_WANDB = False + +def wandb_log_step(): + """Log current state to wandb. Called after each step update.""" + if not USE_WANDB: + return + d = {'step': S.step, 'loss': S.loss, 'best_loss': S.best_loss} + if S.ms_per_step > 0: + d['ms_per_step'] = S.ms_per_step + lr = S.training.get('lr') + if lr: + try: + d['lr'] = float(lr) + except ValueError: + pass + ct = S.component_timing + if ct: + for k, v in ct.items(): + if k != '_dynamic': + d[f'timing/{k}'] = v + fl = S.flops + if fl.get('ane_tflops'): + d['perf/ane_tflops'] = fl['ane_tflops'] + if fl.get('ane_util'): + d['perf/ane_util_pct'] = fl['ane_util'] + pw = S.power + if pw['ane'] > 0: + d['power/ane_w'] = pw['ane'] + if pw['cpu'] > 0: + d['power/cpu_w'] = pw['cpu'] + wandb.log(d, step=S.step) + +def _sync_globals_from_parsed(cfg): + """Sync dashboard globals from parsed binary output so text gen uses correct dims.""" + global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS + global Q_DIM, KV_DIM, GQA_RATIO + if 'dim' in cfg: + DIM = cfg['dim'] + if 'hidden' in cfg: + HIDDEN = cfg['hidden'] + if 'heads' in cfg: + HEADS = cfg['heads'] + if 'kv_heads' in cfg: + KV_HEADS = cfg['kv_heads'] + if 'hd' in cfg: + HD = cfg['hd'] + if 'seq' in cfg: + SEQ = cfg['seq'] + if 'vocab' in cfg: + VOCAB = cfg['vocab'] + if 'layers' in cfg: + NLAYERS = cfg['layers'] + Q_DIM = HEADS * HD + KV_DIM = KV_HEADS * HD + GQA_RATIO = HEADS // KV_HEADS if KV_HEADS else 1 + def parse_line(line): S.logs.append(line) # Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...}) @@ -339,6 +443,7 @@ def parse_line(line): ct[k[2:]] = j[k] # strip 't_' prefix if ct: S.component_timing = ct + wandb_log_step() return elif jt == 'batch': S.batch_num = j.get('batch', S.batch_num) @@ -356,9 +461,21 @@ def parse_line(line): return except (json.JSONDecodeError, KeyError): pass + m = RE_MODEL_NAME.search(line) + if m: + S.model_config['name'] = m[1] + S.model_config['layers'] = int(m[2]) + m = RE_CONFIG_GQA.search(line) + if m: + d, qd, kvd, hd, hid, seq, voc = map(int, m.groups()) + S.model_config.update(dim=d, q_dim=qd, kv_dim=kvd, hd=hd, hidden=hid, seq=seq, vocab=voc, + heads=qd//hd, kv_heads=kvd//hd) + _sync_globals_from_parsed(S.model_config) + return m = RE_CONFIG.search(line) if m: S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups()))) + _sync_globals_from_parsed(S.model_config) return m = RE_PARAMS.search(line) if m: @@ -398,6 +515,7 @@ def parse_line(line): S.ms_per_step = dt * 1000 S.loss_history.append((S.step, S.loss)) S.best_loss = min(S.best_loss, S.loss) + wandb_log_step() return m = RE_BATCH.search(line) if m: @@ -434,6 +552,11 @@ def parse_line(line): S.compiles = int(m[1]) S.compile_ms += float(m[2]) return + m = RE_CKPT_SAVED.search(line) + if m: + if USE_WANDB: + wandb.log({'checkpoint/best_loss': float(m[1]), 'checkpoint/saved': True}, step=S.step) + return m = RE_EFFICIENCY.search(line) if m: S.efficiency[m[1].strip()] = m[2].strip() @@ -553,14 +676,17 @@ def draw(term): row = 0 - # Model Config header - hdr = '\u2500 Model Config ' - put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan) + # Model Config header — use parsed name from binary if available, else CLI arg + model_label = S.model_config.get('name', S.active_model) + keys_hint = '[r]estart [g]en [q]uit' + hdr_text = f'\u2500 {model_label} \u2500\u2500 {keys_hint} ' + put(row, 0, '\u250c' + hdr_text + '\u2500' * max(0, w - len(hdr_text) - 2) + '\u2510', term.cyan) row += 1 cfg = S.model_config if cfg: - line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}" + gqa_str = f" kv_heads={cfg.get('kv_heads', '')}" if cfg.get('kv_heads', cfg.get('heads', 0)) != cfg.get('heads', 0) else '' + line1 = f"dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')}{gqa_str} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}" put(row, 0, '\u2502', term.cyan) put(row, 2, line1) put(row, w - 1, '\u2502', term.cyan) @@ -778,9 +904,10 @@ def set_nonblock(fd): fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False, - lr=None, accum=None, no_ane_extras=False, data=None): + lr=None, accum=None, no_ane_extras=False, data=None, model=None): if dynamic: - cmd = 'cd training_dynamic && make 2>&1 && ./train' + model_arg = f' MODEL={model}' if model else '' + cmd = f'cd training_dynamic && make{model_arg} 2>&1 && ./train' elif ane: cmd = 'make train_large_ane 2>&1 && ./train_large_ane' else: @@ -818,9 +945,12 @@ def spawn_powermetrics(): return None def main(): - parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)') + parser = argparse.ArgumentParser(description='ANE Training Dashboard') parser.add_argument('--resume', action='store_true', help='Resume from checkpoint') parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)') + parser.add_argument('--model', type=str, default=None, + choices=list(MODEL_CONFIGS.keys()), + help='Model config (default: stories110m for static, qwen3_06b for dynamic)') parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd') parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)') parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)') @@ -831,14 +961,53 @@ def main(): parser.add_argument('--no-generate', action='store_true', help='Disable text generation') parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)') parser.add_argument('--data', type=str, default=None, help='Path to training data shard (.bin)') + parser.add_argument('--wandb', action='store_true', help='Log to Weights & Biases') + parser.add_argument('--wandb-project', type=str, default='ane-training', help='W&B project name') + parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name') args = parser.parse_args() if args.infinite: args.steps = 999999999 S.total_steps = args.steps - global CKPT_PATH - CKPT_PATH = CKPT_PATH_DYNAMIC if args.dynamic else CKPT_PATH_STATIC + # Select model + if args.model is None: + args.model = 'qwen3_06b' if args.dynamic else 'stories110m' + cfg = MODEL_CONFIGS[args.model] + # Auto-enable dynamic for models without a static pipeline + if cfg['ckpt_static'] is None: + args.dynamic = True + set_model_config(args.model) + S.active_model = args.model + # For dynamic: default to --scratch when --resume not given + if args.dynamic and not args.resume: + args.scratch = True + + global CKPT_PATH, USE_WANDB + CKPT_PATH = cfg['ckpt_dynamic'] if args.dynamic else cfg['ckpt_static'] + + # Weights & Biases + if args.wandb: + if not HAS_WANDB: + print('pip install wandb') + sys.exit(1) + run_name = args.wandb_name or f'{args.model}-{"resume" if args.resume else "scratch"}' + wandb.init( + project=args.wandb_project, + name=run_name, + config={ + 'model': args.model, + 'dim': DIM, 'hidden': HIDDEN, 'heads': HEADS, + 'kv_heads': KV_HEADS, 'hd': HD, 'seq': SEQ, + 'vocab': VOCAB, 'nlayers': NLAYERS, + 'q_dim': Q_DIM, 'kv_dim': KV_DIM, + 'pipeline': 'dynamic' if args.dynamic else 'static', + 'resume': args.resume, + 'lr': args.lr, 'accum': args.accum, + 'steps': args.steps, + }, + ) + USE_WANDB = True term = Terminal() procs = [] @@ -846,7 +1015,7 @@ def main(): train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic, scratch=args.scratch, lr=args.lr, accum=args.accum, ane=args.ane, no_ane_extras=args.no_ane_extras, - data=args.data) + data=args.data, model=args.model) S.train_pid = train_proc.pid procs.append(train_proc) @@ -874,6 +1043,8 @@ def main(): p.terminate() except Exception: pass + if USE_WANDB: + wandb.finish() signal.signal(signal.SIGINT, lambda *a: cleanup()) signal.signal(signal.SIGTERM, lambda *a: cleanup()) @@ -989,11 +1160,11 @@ def main(): train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic, lr=args.lr, accum=args.accum, ane=args.ane, no_ane_extras=args.no_ane_extras, - data=args.data) + data=args.data, model=S.active_model) S.train_pid = train_proc.pid procs = [p for p in procs if p.poll() is None] procs.append(train_proc) - S.logs.append('[dashboard] Restarted with --resume') + S.logs.append(f'[dashboard] Restarted {S.active_model} with --resume') need_draw = True elif key == 'g': with S.gen_lock: diff --git a/training/training_dynamic/train.m b/training/training_dynamic/train.m index 0c9f658..4249a5b 100644 --- a/training/training_dynamic/train.m +++ b/training/training_dynamic/train.m @@ -384,6 +384,7 @@ int main(int argc, char *argv[]) { dispatch_group_t dw_grp = dispatch_group_create(); float last_loss = 999.0f; + float best_loss = resume_loss > 0 ? resume_loss : 999.0f; double total_train_ms = 0; int total_steps_done = 0; uint64_t t_wall_start = mach_absolute_time(); @@ -875,12 +876,14 @@ int main(int argc, char *argv[]) { memset(gembed, 0, (size_t)VOCAB*DIM*4); memset(gcembed, 0, (size_t)CV*DIM*4); - // Checkpoint - if ((step+1) % 100 == 0) { + // Checkpoint — only save on best loss + if ((step+1) % 100 == 0 && last_loss < best_loss) { + best_loss = last_loss; double wall = tb_ms(mach_absolute_time() - t_wall_start); save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss, total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t, lw, la, rms_final, &arms_final, embed, &aembed); + printf(" [ckpt saved, best_loss=%.4f]\n", best_loss); } } }