// train_opt.m — Optimized train_large with: // Phase 1: NEON Adam, vectorized embed ops, pre-allocated capture buffers // Phase 2: Concurrent dW dispatch, fp16 activation cache // Phase 3: Metal GPU for weight gradient computation (dW) // // Key perf wins: // - Pre-allocated LayerCaptures: eliminates ~132 malloc/free per step // - Concurrent dW queue: individual sgemms run in parallel (was serial) // - fp16 activation cache: skip fp16→fp32 on main thread for dW-only buffers // - Metal GPU dW: ~12ms for all weight gradients vs ~435ms serial CPU // - NEON Adam: ~3x faster optimizer step // - Vectorized embed: vDSP_mtrans instead of scalar scatter/gather #include "stories_io.h" #include "stories_mil.h" #include "stories_cpu_ops_opt.h" #import #import #define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin" #define MODEL_PATH_DEFAULT "stories110M.bin" #define DATA_PATH_DEFAULT "tinystories_data00.bin" static const char *get_path(const char *env_var, const char *default_val) { const char *v = getenv(env_var); return (v && v[0]) ? v : default_val; } // ===== Pre-allocated capture buffers per layer (Phase 1) ===== // Eliminates malloc/free in dispatch blocks typedef struct { // FFN dW captures float *dffn; // [DIM, SEQ] float *silu_out; // [HIDDEN, SEQ] float *dh1; // [HIDDEN, SEQ] float *dh3; // [HIDDEN, SEQ] float *x2norm; // [DIM, SEQ] // Attn dW captures float *do_buf; // [DIM, SEQ] (for dWo) float *attn_out; // [DIM, SEQ] // QKV dW captures float *dq; // [DIM, SEQ] float *dk; // [DIM, SEQ] float *dv; // [DIM, SEQ] float *xnorm; // [DIM, SEQ] // fp16 backward gradient cache (read raw from IOSurface, convert in dispatch block) _Float16 *dh1_fp16; // [HIDDEN, SEQ] _Float16 *dh3_fp16; // [HIDDEN, SEQ] _Float16 *dq_fp16; // [DIM, SEQ] _Float16 *dk_fp16; // [DIM, SEQ] _Float16 *dv_fp16; // [DIM, SEQ] } LayerCaptures; static LayerCaptures layer_captures_alloc(void) { LayerCaptures c; c.dffn = (float*)malloc(SEQ * DIM * 4); c.silu_out = (float*)malloc(SEQ * HIDDEN * 4); c.dh1 = (float*)malloc(SEQ * HIDDEN * 4); c.dh3 = (float*)malloc(SEQ * HIDDEN * 4); c.x2norm = (float*)malloc(SEQ * DIM * 4); c.do_buf = (float*)malloc(SEQ * DIM * 4); c.attn_out = (float*)malloc(SEQ * DIM * 4); c.dq = (float*)malloc(SEQ * DIM * 4); c.dk = (float*)malloc(SEQ * DIM * 4); c.dv = (float*)malloc(SEQ * DIM * 4); c.xnorm = (float*)malloc(SEQ * DIM * 4); c.dh1_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2); c.dh3_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2); c.dq_fp16 = (_Float16*)malloc(SEQ * DIM * 2); c.dk_fp16 = (_Float16*)malloc(SEQ * DIM * 2); c.dv_fp16 = (_Float16*)malloc(SEQ * DIM * 2); return c; } static void layer_captures_free(LayerCaptures *c) { free(c->dffn); free(c->silu_out); free(c->dh1); free(c->dh3); free(c->x2norm); free(c->do_buf); free(c->attn_out); free(c->dq); free(c->dk); free(c->dv); free(c->xnorm); free(c->dh1_fp16); free(c->dh3_fp16); free(c->dq_fp16); free(c->dk_fp16); free(c->dv_fp16); } // ===== fp16 activation cache (Phase 2) ===== // Store activations that are only used for dW as fp16 (skip main-thread conversion) typedef struct { _Float16 *xnorm_fp16; // [DIM, SEQ] _Float16 *attn_out_fp16; // [DIM, SEQ] _Float16 *x2norm_fp16; // [DIM, SEQ] _Float16 *silu_out_fp16; // [HIDDEN, SEQ] } LayerFP16Cache; static LayerFP16Cache layer_fp16_cache_alloc(void) { LayerFP16Cache c; c.xnorm_fp16 = (_Float16*)malloc(SEQ * DIM * 2); c.attn_out_fp16 = (_Float16*)malloc(SEQ * DIM * 2); c.x2norm_fp16 = (_Float16*)malloc(SEQ * DIM * 2); c.silu_out_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2); return c; } static void layer_fp16_cache_free(LayerFP16Cache *c) { free(c->xnorm_fp16); free(c->attn_out_fp16); free(c->x2norm_fp16); free(c->silu_out_fp16); } // ===== Metal GPU dW context (Phase 3) ===== typedef struct { id device; id queue; // Shared gradient accumulator buffers (one per weight matrix per layer) id dW_bufs[NLAYERS][9]; // Wq,Wk,Wv,Wo,W1,W2,W3,rms_att,rms_ffn id lastCmdBuf; // Track last submitted buffer for sync } MetalDWContext; // Weight matrix indices for Metal buffers enum { MW_Q=0, MW_K, MW_V, MW_O, MW_1, MW_2, MW_3, MW_RMSA, MW_RMSF }; static bool metal_dw_init(MetalDWContext *ctx) { ctx->device = MTLCreateSystemDefaultDevice(); if (!ctx->device) { printf("[Metal] No GPU device\n"); return false; } ctx->queue = [ctx->device newCommandQueue]; if (!ctx->queue) { printf("[Metal] No command queue\n"); return false; } // Allocate shared-mode gradient accumulator buffers size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4, W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4}; for (int L = 0; L < NLAYERS; L++) { for (int w = 0; w < 9; w++) { ctx->dW_bufs[L][w] = [ctx->device newBufferWithLength:sizes[w] options:MTLResourceStorageModeShared]; if (!ctx->dW_bufs[L][w]) { printf("[Metal] Buffer alloc failed L=%d w=%d\n", L, w); return false; } } } printf("[Metal] GPU: %s\n", [[ctx->device name] UTF8String]); return true; } static void metal_dw_zero(MetalDWContext *ctx) { size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4, W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4}; for (int L = 0; L < NLAYERS; L++) { for (int w = 0; w < 9; w++) { memset([ctx->dW_bufs[L][w] contents], 0, sizes[w]); } } } // Encode a single dW sgemm to Metal command buffer using MPS // C[M,N] += A[M,K] @ B^T[N,K] (i.e., C += A @ B^T, accumulating into C) static void metal_encode_dw_sgemm(id cmdBuf, id device, const float *a_data, int M, int K, const float *b_data, int N, id c_buf) { // Create temporary input buffers (shared mode = zero-copy on Apple Silicon) id aBuf = [device newBufferWithBytesNoCopy:(void*)a_data length:M * K * sizeof(float) options:MTLResourceStorageModeShared deallocator:nil]; id bBuf = [device newBufferWithBytesNoCopy:(void*)b_data length:N * K * sizeof(float) options:MTLResourceStorageModeShared deallocator:nil]; // A is [M, K] row-major, B is [N, K] row-major // We want C += A @ B^T, i.e., C[M, N] = A[M, K] * B[K, N]^T // MPS uses row-major by default MPSMatrixDescriptor *descA = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32]; MPSMatrixDescriptor *descB = [MPSMatrixDescriptor matrixDescriptorWithRows:N columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32]; MPSMatrixDescriptor *descC = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:N rowBytes:N * sizeof(float) dataType:MPSDataTypeFloat32]; MPSMatrix *matA = [[MPSMatrix alloc] initWithBuffer:aBuf descriptor:descA]; MPSMatrix *matB = [[MPSMatrix alloc] initWithBuffer:bBuf descriptor:descB]; MPSMatrix *matC = [[MPSMatrix alloc] initWithBuffer:c_buf descriptor:descC]; MPSMatrixMultiplication *mm = [[MPSMatrixMultiplication alloc] initWithDevice:device transposeLeft:NO transposeRight:YES resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:1.0]; [mm encodeToCommandBuffer:cmdBuf leftMatrix:matA rightMatrix:matB resultMatrix:matC]; } // ===== Weight loading from llama2.c format ===== static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { FILE *f = fopen(path, "rb"); if (!f) { printf("Cannot open %s\n", path); return false; } Llama2Config cfg; fread(&cfg, sizeof(cfg), 1, f); printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS); fclose(f); return false; } int V = abs(cfg.vocab_size); bool shared = cfg.vocab_size > 0; (void)V; (void)shared; fread(embed, 4, VOCAB * DIM, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f); for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f); fread(rms_final, 4, DIM, f); fclose(f); printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls"); return true; } // ===== Compile one layer's kernels ===== static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) { lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{ @"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)}, @"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)}, @"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)}, @"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)}, @"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)}, @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, }), DIM*SEQ*2, 6*DIM*SEQ*2); lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{ @"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)}, @"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)}, @"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)}, @"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)}, }), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2); lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{ @"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)}, @"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)}, @"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)}, }), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2); lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{ @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, @"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)}, }), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2); lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{ @"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)}, @"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)}, @"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)}, }), 3*DIM*SEQ*2, DIM*SEQ*2); return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd; } static Kern *compile_sdpa_bwd2(void) { return compile_kern_mil_w(gen_sdpa_bwd2(), @{}, (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); } static void free_layer_kernels(LayerKernels *lk) { free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd); free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd); lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL; } // ===== Checkpoint save/load ===== static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, double cc, double ct, double cw, int cs, int cb, int adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "wb"); CkptHdr h = {0}; h.magic = 0x424C5A54; h.version = 2; h.step = step; h.total_steps = total_steps; h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; h.lr = lr; h.loss = loss; h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t; fwrite(&h, sizeof(h), 1, f); for (int L = 0; L < NLAYERS; L++) { fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f); fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f); fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f); fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f); fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f); fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f); } fwrite(rms_final,4,DIM,f); fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f); fwrite(embed,4,VOCAB*DIM,f); fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f); fclose(f); } static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss, double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "rb"); if (!f) return false; CkptHdr h; fread(&h, sizeof(h), 1, f); if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; } *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; *cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t; for (int L = 0; L < NLAYERS; L++) { fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f); fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f); fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f); fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f); fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f); fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f); } fread(rms_final,4,DIM,f); fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f); fread(embed,4,VOCAB*DIM,f); fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f); fclose(f); return true; } // ===== Main ===== int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); // Phase 2: Limit BLAS thread count to prevent oversubscription with concurrent dispatch setenv("VECLIB_MAXIMUM_THREADS", "2", 1); ane_init(); mach_timebase_info(&g_tb); int total_steps = 10000; float lr = 3e-4f; float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; int adam_t = 0, start_step = 0; // Parse args const char *model_path = get_path("ANE_MODEL_PATH", MODEL_PATH_DEFAULT); const char *ckpt_path = get_path("ANE_CKPT_PATH", CKPT_PATH_DEFAULT); const char *data_path = get_path("ANE_DATA_PATH", DATA_PATH_DEFAULT); bool do_resume = false; bool use_metal = false; // default off: Metal dW contends with ANE for memory bandwidth int pos = 0; for (int i=1; i