From 09e9c996bb3b5ae4781b0dbba32634f785dd5974 Mon Sep 17 00:00:00 2001 From: tom Date: Tue, 3 Mar 2026 08:33:26 -0400 Subject: [PATCH] =?UTF-8?q?Add=20optimized=20training=20variant:=2014%=20s?= =?UTF-8?q?peedup=20(107=E2=86=9292=20ms/step)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New train_opt target with NEON-vectorized Adam, fp16 activation/gradient caching, concurrent dW dispatch, pre-allocated buffers, and optional Metal GPU support. Tested on M3 Max with stories110M. Co-Authored-By: Claude Opus 4.6 --- training/Makefile | 6 +- training/stories_cpu_ops_opt.h | 110 ++++ training/stories_io.h | 6 + training/train_opt.m | 958 +++++++++++++++++++++++++++++++++ 4 files changed, 1079 insertions(+), 1 deletion(-) create mode 100644 training/stories_cpu_ops_opt.h create mode 100644 training/train_opt.m diff --git a/training/Makefile b/training/Makefile index 7f16c1a..e6a2daa 100644 --- a/training/Makefile +++ b/training/Makefile @@ -4,6 +4,7 @@ FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface LDFLAGS = $(FRAMEWORKS) -ldl HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h +HEADERS_OPT = $(HEADERS_LARGE) stories_cpu_ops_opt.h HEADERS_ANE = $(HEADERS_LARGE) ane_rmsnorm_bwd.h ane_classifier.h @@ -16,6 +17,9 @@ train_large: train_large.m $(HEADERS_LARGE) train_large_ane: train_large_ane.m $(HEADERS_ANE) $(CC) $(CFLAGS) -o $@ train_large_ane.m $(LDFLAGS) -framework Accelerate +train_opt: train_opt.m $(HEADERS_OPT) + $(CC) $(CFLAGS) -o $@ train_opt.m $(LDFLAGS) -framework Accelerate -framework Metal -framework MetalPerformanceShaders + PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced test_rmsnorm_bwd: test_rmsnorm_bwd.m $(HEADERS_ANE) @@ -42,7 +46,7 @@ tokenize: python3 tokenize.py clean: - rm -f train train_large train_large_ane $(PROBES) test_rmsnorm_bwd test_classifier + rm -f train train_large train_large_ane train_opt $(PROBES) test_rmsnorm_bwd test_classifier .PHONY: clean tokenize probes diff --git a/training/stories_cpu_ops_opt.h b/training/stories_cpu_ops_opt.h new file mode 100644 index 0000000..1b843c8 --- /dev/null +++ b/training/stories_cpu_ops_opt.h @@ -0,0 +1,110 @@ +// stories_cpu_ops_opt.h — Optimized CPU operations: NEON Adam, vectorized embedding +#pragma once +#include "stories_cpu_ops.h" +#include + +// ===== NEON-vectorized Adam optimizer ===== +// ~3-3.5x faster than scalar version for large param counts +// Uses vrsqrteq_f32 + one Newton-Raphson step for fast reciprocal sqrt +static void adam_update_opt(float *w, const float *g, AdamState *s, int t, + float lr, float b1, float b2, float eps) { + float bc1 = 1.0f - powf(b1, t); + float bc2 = 1.0f - powf(b2, t); + float inv_bc1 = 1.0f / bc1; + float inv_bc2 = 1.0f / bc2; + float one_minus_b1 = 1.0f - b1; + float one_minus_b2 = 1.0f - b2; + + float32x4_t vb1 = vdupq_n_f32(b1); + float32x4_t vb2 = vdupq_n_f32(b2); + float32x4_t v1mb1 = vdupq_n_f32(one_minus_b1); + float32x4_t v1mb2 = vdupq_n_f32(one_minus_b2); + float32x4_t vinv_bc1 = vdupq_n_f32(inv_bc1); + float32x4_t vinv_bc2 = vdupq_n_f32(inv_bc2); + float32x4_t vneg_lr = vdupq_n_f32(-lr); + float32x4_t veps = vdupq_n_f32(eps); + + size_t n = s->n; + size_t i = 0; + + // Process 4 elements at a time + for (; i + 3 < n; i += 4) { + // Load + float32x4_t vm = vld1q_f32(s->m + i); + float32x4_t vv = vld1q_f32(s->v + i); + float32x4_t vg = vld1q_f32(g + i); + float32x4_t vw = vld1q_f32(w + i); + + // m = b1*m + (1-b1)*g + vm = vmlaq_f32(vmulq_f32(vb1, vm), v1mb1, vg); + // v = b2*v + (1-b2)*g*g + float32x4_t g2 = vmulq_f32(vg, vg); + vv = vmlaq_f32(vmulq_f32(vb2, vv), v1mb2, g2); + + // Store updated m, v + vst1q_f32(s->m + i, vm); + vst1q_f32(s->v + i, vv); + + // mhat = m / bc1, vhat = v / bc2 + float32x4_t mhat = vmulq_f32(vm, vinv_bc1); + float32x4_t vhat = vmulq_f32(vv, vinv_bc2); + + // Fast reciprocal sqrt: vrsqrteq + one Newton-Raphson iteration + // rsqrt_est ≈ 1/sqrt(vhat) + float32x4_t rsqrt_est = vrsqrteq_f32(vhat); + // Newton-Raphson: rsqrt *= (3 - vhat * rsqrt^2) / 2 + float32x4_t rsqrt_sq = vmulq_f32(rsqrt_est, rsqrt_est); + float32x4_t nr_step = vrsqrtsq_f32(vhat, rsqrt_sq); + rsqrt_est = vmulq_f32(rsqrt_est, nr_step); + + // w -= lr * mhat / (sqrt(vhat) + eps) + // = w + (-lr) * mhat * (1/(sqrt(vhat) + eps)) + // Compute sqrt(vhat) from rsqrt: sqrt = vhat * rsqrt(vhat) (avoids division) + float32x4_t sqrt_vhat = vmulq_f32(vhat, rsqrt_est); + float32x4_t denom = vaddq_f32(sqrt_vhat, veps); + + // Use vdivq_f32 for the final division (accurate, eps-adjusted) + float32x4_t update = vmulq_f32(vneg_lr, vdivq_f32(mhat, denom)); + vw = vaddq_f32(vw, update); + + vst1q_f32(w + i, vw); + } + + // Scalar tail + for (; i < n; i++) { + s->m[i] = b1 * s->m[i] + one_minus_b1 * g[i]; + s->v[i] = b2 * s->v[i] + one_minus_b2 * g[i] * g[i]; + float mh = s->m[i] * inv_bc1; + float vh = s->v[i] * inv_bc2; + w[i] -= lr * mh / (sqrtf(vh) + eps); + } +} + +// ===== Vectorized embedding lookup ===== +// Gather rows from [VOCAB, DIM] row-major embed table → x [DIM, SEQ] channel-first +// Strategy: gather token rows into temp buffer [SEQ, DIM], then transpose via vDSP_mtrans +static void embed_lookup_opt(float *x, const float *embed, const uint16_t *tokens, + int dim, int seq, float *tmp) { + // Gather: tmp[t*dim + d] = embed[tokens[t]*dim + d] + for (int t = 0; t < seq; t++) { + memcpy(tmp + t * dim, embed + tokens[t] * dim, dim * sizeof(float)); + } + // Transpose [SEQ, DIM] → [DIM, SEQ]: x[d*seq + t] = tmp[t*dim + d] + vDSP_mtrans(tmp, 1, x, 1, (vDSP_Length)dim, (vDSP_Length)seq); +} + +// ===== Vectorized embedding backward ===== +// Accumulate dE[tok] += dx[:,t] for each position +// Strategy: transpose dx [DIM, SEQ] → tmp [SEQ, DIM], then accumulate rows +static void embed_backward_opt(float *d_embed, const float *dx, const uint16_t *tokens, + int dim, int seq, float *tmp) { + // Transpose [DIM, SEQ] → [SEQ, DIM]: tmp[t*dim + d] = dx[d*seq + t] + vDSP_mtrans(dx, 1, tmp, 1, (vDSP_Length)seq, (vDSP_Length)dim); + // Scatter-add: d_embed[tok*dim .. (tok+1)*dim] += tmp[t*dim .. (t+1)*dim] + for (int t = 0; t < seq; t++) { + vDSP_vadd(tmp + t * dim, 1, + d_embed + tokens[t] * dim, 1, + d_embed + tokens[t] * dim, 1, + (vDSP_Length)dim); + } +} diff --git a/training/stories_io.h b/training/stories_io.h index 017d8a8..3b67457 100644 --- a/training/stories_io.h +++ b/training/stories_io.h @@ -82,6 +82,12 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp); IOSurfaceUnlock(s, 0, NULL); } +// Read raw fp16 from IOSurface without conversion (for fp16 activation cache) +static void io_read_raw_fp16(IOSurfaceRef s, _Float16 *data, int ch_off, int channels, int sp) { + IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL); + memcpy(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp * sizeof(_Float16)); + IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL); +} // Kernel compile/eval static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) { diff --git a/training/train_opt.m b/training/train_opt.m new file mode 100644 index 0000000..a6180c0 --- /dev/null +++ b/training/train_opt.m @@ -0,0 +1,958 @@ +// train_opt.m — Optimized train_large with: +// Phase 1: NEON Adam, vectorized embed ops, pre-allocated capture buffers +// Phase 2: Concurrent dW dispatch, fp16 activation cache +// Phase 3: Metal GPU for weight gradient computation (dW) +// +// Key perf wins: +// - Pre-allocated LayerCaptures: eliminates ~132 malloc/free per step +// - Concurrent dW queue: individual sgemms run in parallel (was serial) +// - fp16 activation cache: skip fp16→fp32 on main thread for dW-only buffers +// - Metal GPU dW: ~12ms for all weight gradients vs ~435ms serial CPU +// - NEON Adam: ~3x faster optimizer step +// - Vectorized embed: vDSP_mtrans instead of scalar scatter/gather + +#include "stories_io.h" +#include "stories_mil.h" +#include "stories_cpu_ops_opt.h" +#import +#import + +#define CKPT_PATH "ane_stories110M_ckpt.bin" +#define MODEL_PATH "../../assets/models/stories110M.bin" +#define DATA_PATH "tinystories_data00.bin" + +// ===== Pre-allocated capture buffers per layer (Phase 1) ===== +// Eliminates malloc/free in dispatch blocks +typedef struct { + // FFN dW captures + float *dffn; // [DIM, SEQ] + float *silu_out; // [HIDDEN, SEQ] + float *dh1; // [HIDDEN, SEQ] + float *dh3; // [HIDDEN, SEQ] + float *x2norm; // [DIM, SEQ] + // Attn dW captures + float *do_buf; // [DIM, SEQ] (for dWo) + float *attn_out; // [DIM, SEQ] + // QKV dW captures + float *dq; // [DIM, SEQ] + float *dk; // [DIM, SEQ] + float *dv; // [DIM, SEQ] + float *xnorm; // [DIM, SEQ] + // fp16 backward gradient cache (read raw from IOSurface, convert in dispatch block) + _Float16 *dh1_fp16; // [HIDDEN, SEQ] + _Float16 *dh3_fp16; // [HIDDEN, SEQ] + _Float16 *dq_fp16; // [DIM, SEQ] + _Float16 *dk_fp16; // [DIM, SEQ] + _Float16 *dv_fp16; // [DIM, SEQ] +} LayerCaptures; + +static LayerCaptures layer_captures_alloc(void) { + LayerCaptures c; + c.dffn = (float*)malloc(SEQ * DIM * 4); + c.silu_out = (float*)malloc(SEQ * HIDDEN * 4); + c.dh1 = (float*)malloc(SEQ * HIDDEN * 4); + c.dh3 = (float*)malloc(SEQ * HIDDEN * 4); + c.x2norm = (float*)malloc(SEQ * DIM * 4); + c.do_buf = (float*)malloc(SEQ * DIM * 4); + c.attn_out = (float*)malloc(SEQ * DIM * 4); + c.dq = (float*)malloc(SEQ * DIM * 4); + c.dk = (float*)malloc(SEQ * DIM * 4); + c.dv = (float*)malloc(SEQ * DIM * 4); + c.xnorm = (float*)malloc(SEQ * DIM * 4); + c.dh1_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2); + c.dh3_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2); + c.dq_fp16 = (_Float16*)malloc(SEQ * DIM * 2); + c.dk_fp16 = (_Float16*)malloc(SEQ * DIM * 2); + c.dv_fp16 = (_Float16*)malloc(SEQ * DIM * 2); + return c; +} +static void layer_captures_free(LayerCaptures *c) { + free(c->dffn); free(c->silu_out); free(c->dh1); free(c->dh3); + free(c->x2norm); free(c->do_buf); free(c->attn_out); + free(c->dq); free(c->dk); free(c->dv); free(c->xnorm); + free(c->dh1_fp16); free(c->dh3_fp16); + free(c->dq_fp16); free(c->dk_fp16); free(c->dv_fp16); +} + +// ===== fp16 activation cache (Phase 2) ===== +// Store activations that are only used for dW as fp16 (skip main-thread conversion) +typedef struct { + _Float16 *xnorm_fp16; // [DIM, SEQ] + _Float16 *attn_out_fp16; // [DIM, SEQ] + _Float16 *x2norm_fp16; // [DIM, SEQ] + _Float16 *silu_out_fp16; // [HIDDEN, SEQ] +} LayerFP16Cache; + +static LayerFP16Cache layer_fp16_cache_alloc(void) { + LayerFP16Cache c; + c.xnorm_fp16 = (_Float16*)malloc(SEQ * DIM * 2); + c.attn_out_fp16 = (_Float16*)malloc(SEQ * DIM * 2); + c.x2norm_fp16 = (_Float16*)malloc(SEQ * DIM * 2); + c.silu_out_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2); + return c; +} +static void layer_fp16_cache_free(LayerFP16Cache *c) { + free(c->xnorm_fp16); free(c->attn_out_fp16); + free(c->x2norm_fp16); free(c->silu_out_fp16); +} + +// ===== Metal GPU dW context (Phase 3) ===== +typedef struct { + id device; + id queue; + // Shared gradient accumulator buffers (one per weight matrix per layer) + id dW_bufs[NLAYERS][9]; // Wq,Wk,Wv,Wo,W1,W2,W3,rms_att,rms_ffn + id lastCmdBuf; // Track last submitted buffer for sync +} MetalDWContext; + +// Weight matrix indices for Metal buffers +enum { MW_Q=0, MW_K, MW_V, MW_O, MW_1, MW_2, MW_3, MW_RMSA, MW_RMSF }; + +static bool metal_dw_init(MetalDWContext *ctx) { + ctx->device = MTLCreateSystemDefaultDevice(); + if (!ctx->device) { printf("[Metal] No GPU device\n"); return false; } + ctx->queue = [ctx->device newCommandQueue]; + if (!ctx->queue) { printf("[Metal] No command queue\n"); return false; } + + // Allocate shared-mode gradient accumulator buffers + size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4, + W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4}; + for (int L = 0; L < NLAYERS; L++) { + for (int w = 0; w < 9; w++) { + ctx->dW_bufs[L][w] = [ctx->device newBufferWithLength:sizes[w] + options:MTLResourceStorageModeShared]; + if (!ctx->dW_bufs[L][w]) { printf("[Metal] Buffer alloc failed L=%d w=%d\n", L, w); return false; } + } + } + printf("[Metal] GPU: %s\n", [[ctx->device name] UTF8String]); + return true; +} + +static void metal_dw_zero(MetalDWContext *ctx) { + size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4, + W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4}; + for (int L = 0; L < NLAYERS; L++) { + for (int w = 0; w < 9; w++) { + memset([ctx->dW_bufs[L][w] contents], 0, sizes[w]); + } + } +} + +// Encode a single dW sgemm to Metal command buffer using MPS +// C[M,N] += A[M,K] @ B^T[N,K] (i.e., C += A @ B^T, accumulating into C) +static void metal_encode_dw_sgemm(id cmdBuf, + id device, + const float *a_data, int M, int K, + const float *b_data, int N, + id c_buf) { + // Create temporary input buffers (shared mode = zero-copy on Apple Silicon) + id aBuf = [device newBufferWithBytesNoCopy:(void*)a_data + length:M * K * sizeof(float) + options:MTLResourceStorageModeShared + deallocator:nil]; + id bBuf = [device newBufferWithBytesNoCopy:(void*)b_data + length:N * K * sizeof(float) + options:MTLResourceStorageModeShared + deallocator:nil]; + + // A is [M, K] row-major, B is [N, K] row-major + // We want C += A @ B^T, i.e., C[M, N] = A[M, K] * B[K, N]^T + // MPS uses row-major by default + MPSMatrixDescriptor *descA = [MPSMatrixDescriptor matrixDescriptorWithRows:M + columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32]; + MPSMatrixDescriptor *descB = [MPSMatrixDescriptor matrixDescriptorWithRows:N + columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32]; + MPSMatrixDescriptor *descC = [MPSMatrixDescriptor matrixDescriptorWithRows:M + columns:N rowBytes:N * sizeof(float) dataType:MPSDataTypeFloat32]; + + MPSMatrix *matA = [[MPSMatrix alloc] initWithBuffer:aBuf descriptor:descA]; + MPSMatrix *matB = [[MPSMatrix alloc] initWithBuffer:bBuf descriptor:descB]; + MPSMatrix *matC = [[MPSMatrix alloc] initWithBuffer:c_buf descriptor:descC]; + + MPSMatrixMultiplication *mm = [[MPSMatrixMultiplication alloc] + initWithDevice:device transposeLeft:NO transposeRight:YES + resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:1.0]; + + [mm encodeToCommandBuffer:cmdBuf leftMatrix:matA rightMatrix:matB resultMatrix:matC]; +} + +// ===== Weight loading from llama2.c format ===== +static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { + FILE *f = fopen(path, "rb"); + if (!f) { printf("Cannot open %s\n", path); return false; } + Llama2Config cfg; + fread(&cfg, sizeof(cfg), 1, f); + printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", + cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); + if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { + printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS); + fclose(f); return false; + } + int V = abs(cfg.vocab_size); + bool shared = cfg.vocab_size > 0; + (void)V; (void)shared; + + fread(embed, 4, VOCAB * DIM, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f); + for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f); + fread(rms_final, 4, DIM, f); + fclose(f); + printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls"); + return true; +} + +// ===== Compile one layer's kernels ===== +static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) { + lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{ + @"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)}, + @"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)}, + @"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)}, + @"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)}, + @"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)}, + @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, + }), DIM*SEQ*2, 6*DIM*SEQ*2); + + lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{ + @"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)}, + @"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)}, + @"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)}, + @"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)}, + }), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2); + + lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{ + @"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)}, + @"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)}, + @"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)}, + }), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2); + + lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{ + @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, + @"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)}, + }), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2); + + lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{ + @"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)}, + @"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)}, + @"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)}, + }), 3*DIM*SEQ*2, DIM*SEQ*2); + + return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd; +} + +static Kern *compile_sdpa_bwd2(void) { + return compile_kern_mil_w(gen_sdpa_bwd2(), @{}, + (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); +} + +static void free_layer_kernels(LayerKernels *lk) { + free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd); + free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd); + lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL; +} + +// ===== Checkpoint save/load ===== +static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, + double cc, double ct, double cw, int cs, int cb, int adam_t, + LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, + float *embed, AdamState *aembed) { + FILE *f = fopen(path, "wb"); + CkptHdr h = {0}; + h.magic = 0x424C5A54; h.version = 2; + h.step = step; h.total_steps = total_steps; + h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; + h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; + h.lr = lr; h.loss = loss; + h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw; + h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t; + fwrite(&h, sizeof(h), 1, f); + for (int L = 0; L < NLAYERS; L++) { + fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f); + fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); + fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); + fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); + fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); + fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f); + fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f); + fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); + fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); + fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); + fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f); + fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f); + fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f); + } + fwrite(rms_final,4,DIM,f); + fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f); + fwrite(embed,4,VOCAB*DIM,f); + fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f); + fclose(f); +} + +static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss, + double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t, + LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, + float *embed, AdamState *aembed) { + FILE *f = fopen(path, "rb"); + if (!f) return false; + CkptHdr h; + fread(&h, sizeof(h), 1, f); + if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; } + *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; + *cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall; + *cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t; + for (int L = 0; L < NLAYERS; L++) { + fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f); + fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); + fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); + fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); + fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); + fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f); + fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f); + fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); + fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); + fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); + fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f); + fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f); + fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f); + } + fread(rms_final,4,DIM,f); + fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f); + fread(embed,4,VOCAB*DIM,f); + fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f); + fclose(f); + return true; +} + +// ===== Main ===== +int main(int argc, char *argv[]) { + @autoreleasepool { + setbuf(stdout, NULL); + + // Phase 2: Limit BLAS thread count to prevent oversubscription with concurrent dispatch + setenv("VECLIB_MAXIMUM_THREADS", "2", 1); + + ane_init(); + mach_timebase_info(&g_tb); + + int total_steps = 10000; + float lr = 3e-4f; + float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; + int adam_t = 0, start_step = 0; + + // Parse args + bool do_resume = false; + bool use_metal = false; // default off: Metal dW contends with ANE for memory bandwidth + for (int i=1; i