// train.m — Dynamic weight ANE training for Stories110M // Compile kernels ONCE at startup, update weights via IOSurface every step. // No exec() restart needed — eliminates 76% compile overhead. #include "mil_dynamic.h" #include "cpu_ops.h" #define CKPT_PATH "ane_stories110M_dyn_ckpt.bin" #define MODEL_PATH "../../../assets/models/stories110M.bin" #define DATA_PATH "../tinystories_data00.bin" // Dynamic kernel set per layer typedef struct { Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface) Kern *ffnW13; // W1,W3 matmul (dynamic) Kern *ffnW2; // W2 matmul (dynamic) Kern *ffnBwdW2t; // dffn @ W2^T (dynamic) Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic) Kern *wotBwd; // dx2 @ Wo^T (dynamic) Kern *sdpaBwd1; // Q,K,V,da → dV,probs,dp (weight-free, has mask const) Kern *sdpaBwd2; // probs,dp,Q,K → dQ,dK (weight-free) Kern *qkvBwd; // dq@Wq^T + dk@Wk^T + dv@Wv^T (dynamic) } DynLayerKernels; // ===== Weight loading from llama2.c format ===== static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { FILE *f = fopen(path, "rb"); if (!f) { printf("Cannot open %s\n", path); return false; } Llama2Config cfg; fread(&cfg, sizeof(cfg), 1, f); printf(" Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { printf(" ERROR: Config mismatch!\n"); fclose(f); return false; } int V = abs(cfg.vocab_size); fread(embed, 4, V * DIM, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f); fread(rms_final, 4, DIM, f); fclose(f); printf(" Loaded pretrained weights\n"); return true; } // Transpose W[rows,cols] → W^T[cols,rows] stored as [cols channels, rows spatial] static void transpose_weight(float *dst, const float *src, int rows, int cols) { for (int r = 0; r < rows; r++) for (int c = 0; c < cols; c++) dst[c * rows + r] = src[r * cols + c]; } // ===== Compile all dynamic kernels (ONCE) ===== static bool compile_dynamic_kernels(DynLayerKernels *dk) { NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}}; // SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp32 → [1, 6*DIM, 1, SEQ] fp32 printf(" Compiling sdpaFwd...\n"); dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w, DIM*(SEQ+4*DIM)*4, 6*DIM*SEQ*4); if (!dk->sdpaFwd) return false; // FFN W1+W3: [1, DIM, 1, SEQ+2*HIDDEN] fp32 → [1, 3*HIDDEN, 1, SEQ] fp32 printf(" Compiling ffnW13...\n"); dk->ffnW13 = compile_kern_mil_w(gen_ffn_w13_dynamic(), @{}, DIM*(SEQ+2*HIDDEN)*4, 3*HIDDEN*SEQ*4); if (!dk->ffnW13) return false; // FFN W2: [1, HIDDEN, 1, SEQ+DIM] fp32 → [1, DIM, 1, SEQ] fp32 printf(" Compiling ffnW2...\n"); dk->ffnW2 = compile_kern_mil_w(gen_ffn_w2_dynamic(), @{}, HIDDEN*(SEQ+DIM)*4, DIM*SEQ*4); if (!dk->ffnW2) return false; // FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp32 → [1, HIDDEN, 1, SEQ] fp32 printf(" Compiling ffnBwdW2t...\n"); dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{}, DIM*(SEQ+HIDDEN)*4, HIDDEN*SEQ*4); if (!dk->ffnBwdW2t) return false; // FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp32 → [1, DIM, 1, SEQ] fp32 printf(" Compiling ffnBwdW13t...\n"); dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{}, HIDDEN*(2*SEQ+2*DIM)*4, DIM*SEQ*4); if (!dk->ffnBwdW13t) return false; // Wo^T backward: [1, DIM, 1, SEQ+DIM] fp32 → [1, DIM, 1, SEQ] fp32 printf(" Compiling wotBwd...\n"); dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{}, DIM*(SEQ+DIM)*4, DIM*SEQ*4); if (!dk->wotBwd) return false; // SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 → [1, DIM+2*SCORE_CH, 1, SEQ] fp16 printf(" Compiling sdpaBwd1...\n"); dk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_noweight(), mask_w, 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2); if (!dk->sdpaBwd1) return false; // SDPA bwd2 (no weights): [1, 2*SCORE_CH+2*DIM, 1, SEQ] fp16 → [1, 2*DIM, 1, SEQ] fp16 printf(" Compiling sdpaBwd2...\n"); dk->sdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{}, (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); if (!dk->sdpaBwd2) return false; // QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp32 → [1, DIM, 1, SEQ] fp32 printf(" Compiling qkvBwd...\n"); dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{}, DIM*(3*SEQ+3*DIM)*4, DIM*SEQ*4); if (!dk->qkvBwd) return false; return true; } // ===== Write dynamic weights into IOSurface ===== // sdpaFwd: [1, DIM, 1, SEQ+4*DIM] — xnorm at sp[0:S], Wq/Wk/Wv/Wo at sp[S:] static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm, const float *Wq, const float *Wk, const float *Wv, const float *Wo) { IOSurfaceLock(dk->sdpaFwd->ioIn, 0, NULL); float *buf = (float*)IOSurfaceGetBaseAddress(dk->sdpaFwd->ioIn); int sp = SEQ + 4*DIM; for (int d = 0; d < DIM; d++) { memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4); memcpy(buf + d*sp + SEQ, Wq + d*DIM, DIM*4); memcpy(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM*4); memcpy(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM*4); memcpy(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM*4); } IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL); } // ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] — xnorm at sp[0:S], W1,W3 at sp[S:] static void write_ffn_w13_input(DynLayerKernels *dk, const float *xnorm, const float *W1, const float *W3) { IOSurfaceLock(dk->ffnW13->ioIn, 0, NULL); float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW13->ioIn); int sp = SEQ + 2*HIDDEN; for (int d = 0; d < DIM; d++) { memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4); memcpy(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN*4); memcpy(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN*4); } IOSurfaceUnlock(dk->ffnW13->ioIn, 0, NULL); } // ffnW2: [1, HIDDEN, 1, SEQ+DIM] — gate at sp[0:S], W2 at sp[S:] static void write_ffn_w2_input(DynLayerKernels *dk, const float *gate, const float *W2) { IOSurfaceLock(dk->ffnW2->ioIn, 0, NULL); float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW2->ioIn); int sp = SEQ + DIM; for (int d = 0; d < HIDDEN; d++) { memcpy(buf + d*sp, gate + d*SEQ, SEQ*4); memcpy(buf + d*sp + SEQ, W2 + d*DIM, DIM*4); } IOSurfaceUnlock(dk->ffnW2->ioIn, 0, NULL); } // ===== Checkpoint ===== static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, double ct, double cw, int cs, int adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "wb"); CkptHdr h = {0}; h.magic = 0x424C5A54; h.version = 3; h.step = step; h.total_steps = total_steps; h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; h.lr = lr; h.loss = loss; h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.adam_t = adam_t; fwrite(&h, sizeof(h), 1, f); for (int L = 0; L < NLAYERS; L++) { fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f); fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f); fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f); fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f); fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f); fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f); } fwrite(rms_final,4,DIM,f); fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f); fwrite(embed,4,VOCAB*DIM,f); fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f); fclose(f); } static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss, double *ct, double *cw, int *cs, int *adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "rb"); if (!f) return false; CkptHdr h; fread(&h, sizeof(h), 1, f); if (h.magic != 0x424C5A54 || h.version != 3) { fclose(f); return false; } *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; *ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *adam_t = h.adam_t; for (int L = 0; L < NLAYERS; L++) { fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f); fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f); fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f); fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f); fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f); fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f); } fread(rms_final,4,DIM,f); fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f); fread(embed,4,VOCAB*DIM,f); fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f); fclose(f); return true; } int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); ane_init(); mach_timebase_info(&g_tb); int total_steps = 10000; float max_lr = 3e-4f; float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; int adam_t = 0, start_step = 0; int accum_steps = 10; int warmup_steps = 100; float grad_clip = 1.0f; float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1 bool do_resume = false, from_scratch = false; for (int i=1; i