// train_large.m — Train stories110M (12 layers, 768dim, 3072hidden) on ANE // Uses pretokenized TinyStories data with cross-entropy loss // 5 weight-bearing ANE kernels per layer × 12 layers = 60 per compile batch #include "stories_io.h" #include "stories_mil.h" #include "stories_cpu_ops.h" #define CKPT_PATH "ane_stories110M_ckpt.bin" #define MODEL_PATH "../../assets/models/stories110M.bin" #define DATA_PATH "tinystories_data00.bin" // ===== Weight loading from llama2.c format ===== static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { FILE *f = fopen(path, "rb"); if (!f) { printf("Cannot open %s\n", path); return false; } Llama2Config cfg; fread(&cfg, sizeof(cfg), 1, f); printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS); fclose(f); return false; } int V = abs(cfg.vocab_size); bool shared = cfg.vocab_size > 0; // Read in llama2.c order: embed, rms_att[all], wq[all], wk[all], wv[all], wo[all], // rms_ffn[all], w1[all], w2[all], w3[all], rms_final, [wcls] fread(embed, 4, V * DIM, f); // rms_att weights for all layers (contiguous) for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f); // wq for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f); // wk for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f); // wv for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f); // wo for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f); // rms_ffn weights for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f); // w1 for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f); // w2 for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f); // w3 for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f); // rms_final fread(rms_final, 4, DIM, f); // wcls = embed if shared (we just use embed pointer) fclose(f); printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls"); return true; } // ===== Compile one layer's kernels ===== static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) { lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{ @"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)}, @"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)}, @"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)}, @"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)}, @"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)}, @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, }), DIM*SEQ*2, 6*DIM*SEQ*2); lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{ @"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)}, @"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)}, @"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)}, @"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)}, }), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2); lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{ @"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)}, @"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)}, @"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)}, }), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2); lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{ @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, @"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)}, }), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2); lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{ @"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)}, @"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)}, @"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)}, }), 3*DIM*SEQ*2, DIM*SEQ*2); return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd; } // Compile weight-free sdpaBwd2 (only needs once, no weights) static Kern *compile_sdpa_bwd2(void) { return compile_kern_mil_w(gen_sdpa_bwd2(), @{}, (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); } static void free_layer_kernels(LayerKernels *lk) { free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd); free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd); // sdpaBwd2 is shared, freed separately lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL; } // ===== Checkpoint save/load ===== static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, double cc, double ct, double cw, int cs, int cb, int adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "wb"); CkptHdr h = {0}; h.magic = 0x424C5A54; h.version = 2; h.step = step; h.total_steps = total_steps; h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; h.lr = lr; h.loss = loss; h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t; fwrite(&h, sizeof(h), 1, f); // Per-layer weights + adam for (int L = 0; L < NLAYERS; L++) { fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f); fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); // Adam state fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f); fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f); fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f); fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f); fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f); } fwrite(rms_final,4,DIM,f); fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f); fwrite(embed,4,VOCAB*DIM,f); fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f); fclose(f); } static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss, double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "rb"); if (!f) return false; CkptHdr h; fread(&h, sizeof(h), 1, f); if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; } *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; *cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t; for (int L = 0; L < NLAYERS; L++) { fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f); fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f); fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f); fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f); fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f); fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f); } fread(rms_final,4,DIM,f); fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f); fread(embed,4,VOCAB*DIM,f); fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f); fclose(f); return true; } // ===== Main ===== int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); ane_init(); mach_timebase_info(&g_tb); int total_steps = 10000; float lr = 3e-4f; float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; int adam_t = 0, start_step = 0; // Parse args bool do_resume = false; for (int i=1; i