ANE/training/training_dynamic/train.m

877 lines
47 KiB
Mathematica

// train.m Dynamic weight ANE training for Stories110M
// Compile kernels ONCE at startup, update weights via IOSurface every step.
// No exec() restart needed eliminates 76% compile overhead.
#include "mil_dynamic.h"
#include "cpu_ops.h"
#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin"
#define MODEL_PATH "../../../assets/models/stories110M.bin"
#define DATA_PATH "../tinystories_data00.bin"
// Dynamic kernel set per layer
typedef struct {
Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface)
Kern *ffnW13; // W1,W3 matmul (dynamic)
Kern *ffnW2; // W2 matmul (dynamic)
Kern *ffnBwdW2t; // dffn @ W2^T (dynamic)
Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic)
Kern *wotBwd; // dx2 @ Wo^T (dynamic)
Kern *sdpaBwd1; // Q,K,V,da dV,probs,dp (weight-free, has mask const)
Kern *sdpaBwd2; // probs,dp,Q,K dQ,dK (weight-free)
Kern *qkvBwd; // dq@Wq^T + dk@Wk^T + dv@Wv^T (dynamic)
} DynLayerKernels;
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { printf("Cannot open %s\n", path); return false; }
Llama2Config cfg;
fread(&cfg, sizeof(cfg), 1, f);
printf(" Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
printf(" ERROR: Config mismatch!\n"); fclose(f); return false;
}
int V = abs(cfg.vocab_size);
fread(embed, 4, V * DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
fread(rms_final, 4, DIM, f);
fclose(f);
printf(" Loaded pretrained weights\n");
return true;
}
// Transpose W[rows,cols] W^T[cols,rows] stored as [cols channels, rows spatial]
static void transpose_weight(float *dst, const float *src, int rows, int cols) {
for (int r = 0; r < rows; r++)
for (int c = 0; c < cols; c++)
dst[c * rows + r] = src[r * cols + c];
}
// ===== Compile all dynamic kernels (ONCE) =====
static bool compile_dynamic_kernels(DynLayerKernels *dk) {
NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}};
// SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp32 [1, 6*DIM, 1, SEQ] fp32
printf(" Compiling sdpaFwd...\n");
dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w,
DIM*(SEQ+4*DIM)*4, 6*DIM*SEQ*4);
if (!dk->sdpaFwd) return false;
// FFN W1+W3: [1, DIM, 1, SEQ+2*HIDDEN] fp32 [1, 3*HIDDEN, 1, SEQ] fp32
printf(" Compiling ffnW13...\n");
dk->ffnW13 = compile_kern_mil_w(gen_ffn_w13_dynamic(), @{},
DIM*(SEQ+2*HIDDEN)*4, 3*HIDDEN*SEQ*4);
if (!dk->ffnW13) return false;
// FFN W2: [1, HIDDEN, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling ffnW2...\n");
dk->ffnW2 = compile_kern_mil_w(gen_ffn_w2_dynamic(), @{},
HIDDEN*(SEQ+DIM)*4, DIM*SEQ*4);
if (!dk->ffnW2) return false;
// FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp32 [1, HIDDEN, 1, SEQ] fp32
printf(" Compiling ffnBwdW2t...\n");
dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{},
DIM*(SEQ+HIDDEN)*4, HIDDEN*SEQ*4);
if (!dk->ffnBwdW2t) return false;
// FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling ffnBwdW13t...\n");
dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{},
HIDDEN*(2*SEQ+2*DIM)*4, DIM*SEQ*4);
if (!dk->ffnBwdW13t) return false;
// Wo^T backward: [1, DIM, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling wotBwd...\n");
dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{},
DIM*(SEQ+DIM)*4, DIM*SEQ*4);
if (!dk->wotBwd) return false;
// SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 [1, DIM+2*SCORE_CH, 1, SEQ] fp16
printf(" Compiling sdpaBwd1...\n");
dk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_noweight(), mask_w,
4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
if (!dk->sdpaBwd1) return false;
// SDPA bwd2 (no weights): [1, 2*SCORE_CH+2*DIM, 1, SEQ] fp16 [1, 2*DIM, 1, SEQ] fp16
printf(" Compiling sdpaBwd2...\n");
dk->sdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
if (!dk->sdpaBwd2) return false;
// QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling qkvBwd...\n");
dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{},
DIM*(3*SEQ+3*DIM)*4, DIM*SEQ*4);
if (!dk->qkvBwd) return false;
return true;
}
// ===== Write dynamic weights into IOSurface =====
// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] xnorm at sp[0:S], Wq/Wk/Wv/Wo at sp[S:]
static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm,
const float *Wq, const float *Wk, const float *Wv, const float *Wo) {
IOSurfaceLock(dk->sdpaFwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->sdpaFwd->ioIn);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL);
}
// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] xnorm at sp[0:S], W1,W3 at sp[S:]
static void write_ffn_w13_input(DynLayerKernels *dk, const float *xnorm,
const float *W1, const float *W3) {
IOSurfaceLock(dk->ffnW13->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW13->ioIn);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN*4);
memcpy(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN*4);
}
IOSurfaceUnlock(dk->ffnW13->ioIn, 0, NULL);
}
// ffnW2: [1, HIDDEN, 1, SEQ+DIM] gate at sp[0:S], W2 at sp[S:]
static void write_ffn_w2_input(DynLayerKernels *dk, const float *gate, const float *W2) {
IOSurfaceLock(dk->ffnW2->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW2->ioIn);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, gate + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W2 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->ffnW2->ioIn, 0, NULL);
}
// ===== Checkpoint =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double ct, double cw, int cs, int adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "wb");
CkptHdr h = {0};
h.magic = 0x424C5A54; h.version = 3;
h.step = step; h.total_steps = total_steps;
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
h.lr = lr; h.loss = loss;
h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.adam_t = adam_t;
fwrite(&h, sizeof(h), 1, f);
for (int L = 0; L < NLAYERS; L++) {
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
}
fwrite(rms_final,4,DIM,f);
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
fwrite(embed,4,VOCAB*DIM,f);
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
fclose(f);
}
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
double *ct, double *cw, int *cs, int *adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "rb");
if (!f) return false;
CkptHdr h;
fread(&h, sizeof(h), 1, f);
if (h.magic != 0x424C5A54 || h.version != 3) { fclose(f); return false; }
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
*ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *adam_t = h.adam_t;
for (int L = 0; L < NLAYERS; L++) {
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
}
fread(rms_final,4,DIM,f);
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
fread(embed,4,VOCAB*DIM,f);
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
fclose(f);
return true;
}
int main(int argc, char *argv[]) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
int total_steps = 10000;
float max_lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
int accum_steps = 10;
int warmup_steps = 100;
float grad_clip = 1.0f;
float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1
bool do_resume = false, from_scratch = false;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--scratch") == 0) from_scratch = true;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) max_lr = atof(argv[++i]);
else if (strcmp(argv[i], "--accum") == 0 && i+1<argc) accum_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--warmup") == 0 && i+1<argc) warmup_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--clip") == 0 && i+1<argc) grad_clip = atof(argv[++i]);
}
float lr = max_lr;
// Allocate per-layer state
LayerWeights lw[NLAYERS]; LayerAdam la[NLAYERS];
LayerActs acts[NLAYERS]; LayerGrads grads[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
lw[L] = layer_weights_alloc(); la[L] = layer_adam_alloc();
acts[L] = layer_acts_alloc(); grads[L] = layer_grads_alloc();
}
float *rms_final = (float*)malloc(DIM*4);
float *embed = (float*)malloc(VOCAB*DIM*4);
float *grms_final = (float*)calloc(DIM, 4);
float *gembed = (float*)calloc(VOCAB*DIM, 4);
AdamState arms_final = adam_alloc(DIM);
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
double cum_train=0, cum_wall=0; int cum_steps=0;
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
&cum_train, &cum_wall, &cum_steps, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
}
if (!resuming) {
printf("=== ANE Dynamic Training: Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
// Param counts for dashboard
double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6;
double embed_m = (double)VOCAB*DIM / 1e6;
printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m);
printf("Kernels: 9 compiled, 9 weight-bearing\n");
printf("Accum %d steps, LR=%g\n", accum_steps, max_lr);
// FLOPs estimate: 6*N*B*T for transformer (forward+backward 3x forward)
double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ;
double total_flops = 3.0 * fwd_flops; // fwd + bwd 3x fwd
printf("FLOPs/step: fwd=%.1fM bwd_dx=%.1fM bwd_dW=%.1fM sdpa_bwd=0.0M total=%.1fM\n",
fwd_flops/1e6, fwd_flops/1e6, fwd_flops/1e6, total_flops/1e6);
printf("ANE FLOPs/step: %.1fM\n", total_flops/1e6);
if (from_scratch || !load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
if (from_scratch) printf(" Training from scratch (random init)\n");
else printf(" Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
}
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
float escale = 0.02f;
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
}
}
// Precompute transposed weights (for backward pass kernels)
// These get updated after each Adam step
float *Wqt_buf[NLAYERS], *Wkt_buf[NLAYERS], *Wvt_buf[NLAYERS], *Wot_buf[NLAYERS];
float *W1t_buf[NLAYERS], *W2t_buf[NLAYERS], *W3t_buf[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
Wqt_buf[L]=(float*)malloc(WQ_SZ*4); Wkt_buf[L]=(float*)malloc(WQ_SZ*4);
Wvt_buf[L]=(float*)malloc(WQ_SZ*4); Wot_buf[L]=(float*)malloc(WO_SZ*4);
W1t_buf[L]=(float*)malloc(W1_SZ*4); W2t_buf[L]=(float*)malloc(W2_SZ*4);
W3t_buf[L]=(float*)malloc(W3_SZ*4);
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
}
// mmap token data
int data_fd = open(DATA_PATH, O_RDONLY);
if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; }
struct stat st; fstat(data_fd, &st);
size_t data_len = st.st_size;
uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
size_t n_tokens = data_len / 2;
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
// Vocab compaction: map 32K sparse vocab ~9K compact
VocabMap vm = vocab_map_build(token_data, n_tokens, VOCAB);
int CV = vm.compact_vocab;
printf("Vocab compaction: %d → %d active tokens (%.1fx reduction)\n", VOCAB, CV, (float)VOCAB/CV);
// Create compact embedding + adam state
float *cembed = vocab_compact_embed(embed, &vm, DIM);
float *gcembed = (float*)calloc((size_t)CV*DIM, 4);
AdamState acembed = adam_alloc((size_t)CV*DIM);
// ===== Compile all kernels ONCE =====
printf("Compiling %d dynamic kernels (one-time)...\n", 9);
uint64_t tc = mach_absolute_time();
DynLayerKernels dk;
if (!compile_dynamic_kernels(&dk)) {
printf("Compilation failed!\n"); return 1;
}
double compile_ms = tb_ms(mach_absolute_time() - tc);
printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n\n", compile_ms, NLAYERS);
// Gradient + work buffers
float *dy = (float*)malloc(SEQ*DIM*4);
float *dffn = (float*)malloc(SEQ*DIM*4);
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
float *dx2 = (float*)malloc(SEQ*DIM*4);
float *dx_attn = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*DIM*4);
float *dk_buf = (float*)malloc(SEQ*DIM*4);
float *dv = (float*)malloc(SEQ*DIM*4);
float *x_cur = (float*)malloc(SEQ*DIM*4);
float *x_final = (float*)malloc(SEQ*DIM*4);
float *xnorm_buf = (float*)malloc(SEQ*DIM*4);
float *logits = (float*)malloc(SEQ*CV*4);
float *dlogits = (float*)malloc(SEQ*CV*4);
float *gate_buf = (float*)malloc(SEQ*HIDDEN*4);
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
float *dsilu = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp2 = (float*)malloc(SEQ*HIDDEN*4);
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f;
double total_train_ms = 0;
int total_steps_done = 0;
uint64_t t_wall_start = mach_absolute_time();
srand48(42 + start_step);
for (int step = start_step; step < total_steps; step++) {
uint64_t t0, t1, t_step = mach_absolute_time();
// Sample data
size_t max_pos = n_tokens - SEQ - 1;
size_t pos = (size_t)(drand48() * max_pos);
uint16_t *input_tokens = token_data + pos;
uint16_t *target_tokens_raw = token_data + pos + 1;
// Map targets to compact vocab IDs
uint16_t ctargets[SEQ];
for (int t = 0; t < SEQ; t++) ctargets[t] = (uint16_t)vm.full_to_compact[target_tokens_raw[t]];
// Embedding lookup (uses full embed for now input tokens are full IDs)
embed_lookup(x_cur, embed, input_tokens, DIM, SEQ);
// Timing accumulators (reset each step)
double t_rms=0, t_ane_fwd=0, t_io_fwd=0, t_cblas_wait=0;
double t_ane_bwd=0, t_io_bwd=0, t_silu=0, t_rms_bwd=0, t_cls=0, t_dw_copy=0;
// ===== FORWARD (12 layers) =====
for (int L=0; L<NLAYERS; L++) {
LayerActs *ac = &acts[L];
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
// RMSNorm1 (CPU)
t0 = mach_absolute_time();
rmsnorm(xnorm_buf, x_cur, lw[L].rms_att, DIM, SEQ);
memcpy(ac->xnorm, xnorm_buf, SEQ*DIM*4);
t_rms += tb_ms(mach_absolute_time() - t0);
// Wait for any pending dW cblas
t0 = mach_absolute_time();
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t_cblas_wait += tb_ms(mach_absolute_time() - t0);
// SDPA forward (ANE): xnorm + Wq,Wk,Wv,Wo o_out,Q,K,V,attn_out,xnorm
t0 = mach_absolute_time();
write_sdpa_fwd_input(&dk, xnorm_buf, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaFwd);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read output: [1, 6*DIM, 1, SEQ] fp32
t0 = mach_absolute_time();
IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
float *fwd_out = (float*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut);
memcpy(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ*4);
IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x2 = x_cur + o_out
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
// RMSNorm2 (CPU)
t0 = mach_absolute_time();
rmsnorm(xnorm_buf, ac->x2, lw[L].rms_ffn, DIM, SEQ);
memcpy(ac->x2norm, xnorm_buf, SEQ*DIM*4);
t_rms += tb_ms(mach_absolute_time() - t0);
// FFN W1+W3 (ANE): xnorm h1, h3, gate
t0 = mach_absolute_time();
write_ffn_w13_input(&dk, xnorm_buf, W1t_buf[L], W3t_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnW13);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read h1, h3, gate from output [1, 3*HIDDEN, 1, SEQ]
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL);
float *ffn13_out = (float*)IOSurfaceGetBaseAddress(dk.ffnW13->ioOut);
memcpy(ac->h1, ffn13_out, HIDDEN*SEQ*4);
memcpy(ac->h3, ffn13_out + HIDDEN*SEQ, HIDDEN*SEQ*4);
memcpy(gate_buf, ffn13_out + 2*HIDDEN*SEQ, HIDDEN*SEQ*4);
memcpy(ac->silu_out, gate_buf, HIDDEN*SEQ*4);
IOSurfaceUnlock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// FFN W2 (ANE): gate @ W2 ffn_out
t0 = mach_absolute_time();
write_ffn_w2_input(&dk, gate_buf, W2t_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnW2);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
memcpy(ac->ffn_out, (float*)IOSurfaceGetBaseAddress(dk.ffnW2->ioOut), DIM*SEQ*4);
IOSurfaceUnlock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x_cur = x2 + ffn_out
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
}
// Final RMSNorm + classifier + loss (CPU)
t0 = mach_absolute_time();
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t_rms += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
// Classifier: logits[CV, SEQ] = cembed[CV, DIM] @ x_final[DIM, SEQ]
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
CV, SEQ, DIM, 1.0f, cembed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
float loss = cross_entropy_loss(dlogits, logits, ctargets, CV, SEQ);
t_cls += tb_ms(mach_absolute_time() - t0);
last_loss = loss;
// ===== BACKWARD =====
// Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ]
t0 = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, CV, 1.0f, cembed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
t_cls += tb_ms(mach_absolute_time() - t0);
// dEmbed async: gcembed[CV, DIM] += dlogits[CV, SEQ] @ x_final^T[SEQ, DIM]
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
CV, DIM, SEQ, 1.0f, dlogits, SEQ, x_final, SEQ, 1.0f, gcembed, DIM);
});
// Final RMSNorm backward
float *dx_rms_final = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms_final, grms_final, dy, x_cur, rms_final, DIM, SEQ);
memcpy(dy, dx_rms_final, SEQ*DIM*4);
free(dx_rms_final);
// ===== BACKWARD (12 layers, reverse) =====
for (int L=NLAYERS-1; L>=0; L--) {
LayerActs *ac = &acts[L];
LayerGrads *gr = &grads[L];
memcpy(dffn, dy, SEQ*DIM*4);
// FFN backward: dffn @ W2^T dsilu_raw
t0 = mach_absolute_time();
io_write_dyn(dk.ffnBwdW2t->ioIn, dffn, DIM, SEQ, lw[L].W2, HIDDEN);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW2t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// SiLU derivative (vectorized): dsilu dh1, dh3
// silu(h1) = h1*sig(h1), dsilu_dh1 = sig*(1+h1*(1-sig))
// dh1 = dsilu * h3 * dsilu_dh1, dh3 = dsilu * silu(h1)
t0 = mach_absolute_time();
{
int n = HIDDEN*SEQ;
// sig = 1/(1+exp(-h1))
float minus1 = -1.0f, one = 1.0f;
vDSP_vsmul(ac->h1, 1, &minus1, silu_tmp, 1, (vDSP_Length)n);
vvexpf(silu_tmp, silu_tmp, &n);
vDSP_vsadd(silu_tmp, 1, &one, silu_tmp, 1, (vDSP_Length)n);
vvrecf(silu_tmp, silu_tmp, &n); // silu_tmp = sig
// dh3 = dsilu * h1 * sig (= dsilu * silu(h1))
vDSP_vmul(ac->h1, 1, silu_tmp, 1, dh3, 1, (vDSP_Length)n);
vDSP_vmul(dsilu, 1, dh3, 1, dh3, 1, (vDSP_Length)n);
// dsilu_dh1 = sig*(1+h1*(1-sig)), store in silu_tmp2
vDSP_vsadd(silu_tmp, 1, &minus1, silu_tmp2, 1, (vDSP_Length)n); // sig-1
vDSP_vneg(silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // 1-sig
vDSP_vmul(ac->h1, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // h1*(1-sig)
vDSP_vsadd(silu_tmp2, 1, &one, silu_tmp2, 1, (vDSP_Length)n); // 1+h1*(1-sig)
vDSP_vmul(silu_tmp, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // full dsilu_dh1
// dh1 = dsilu * h3 * dsilu_dh1
vDSP_vmul(dsilu, 1, ac->h3, 1, dh1, 1, (vDSP_Length)n);
vDSP_vmul(dh1, 1, silu_tmp2, 1, dh1, 1, (vDSP_Length)n);
}
t_silu += tb_ms(mach_absolute_time() - t0);
// dh1@W1^T + dh3@W3^T dx_ffn (ANE)
t0 = mach_absolute_time();
{
IOSurfaceLock(dk.ffnBwdW13t->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.ffnBwdW13t->ioIn);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, dh1 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, lw[L].W1 + d*DIM, DIM*4);
memcpy(buf + d*sp + 2*SEQ + DIM, lw[L].W3 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.ffnBwdW13t->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW13t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dW FFN async (cblas)
t0 = mach_absolute_time();
float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4);
float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4);
float *capt_dh1 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh1, dh1, SEQ*HIDDEN*4);
float *capt_dh3 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh3, dh3, SEQ*HIDDEN*4);
float *capt_x2n = (float*)malloc(SEQ*DIM*4); memcpy(capt_x2n, ac->x2norm, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
1.0f, capt_dffn, SEQ, capt_silu, SEQ, 1.0f, gr->W2, HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh1, SEQ, capt_x2n, SEQ, 1.0f, gr->W1, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh3, SEQ, capt_x2n, SEQ, 1.0f, gr->W3, DIM);
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
});
// RMSNorm2 backward
t0 = mach_absolute_time();
memset(dx2, 0, SEQ*DIM*4);
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
// Wo^T backward (ANE): dx2 @ Wo^T da
t0 = mach_absolute_time();
io_write_dyn(dk.wotBwd->ioIn, dx2, DIM, SEQ, lw[L].Wo, DIM);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.wotBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
float *da_buf = (float*)malloc(SEQ*DIM*4);
io_read_dyn(dk.wotBwd->ioOut, da_buf, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dWo async
t0 = mach_absolute_time();
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2, SEQ*DIM*4);
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM);
free(capt_do); free(capt_attn);
});
// SDPA backward part 1 (ANE, fp16): Q,K,V,da dV,probs,dp
t0 = mach_absolute_time();
io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, DIM, ac->K, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 2*DIM, ac->V, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 3*DIM, da_buf, DIM, SEQ);
free(da_buf);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd1);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
// SDPA backward part 2: probs,dp,Q,K dQ,dK
t0 = mach_absolute_time();
io_copy(dk.sdpaBwd2->ioIn, 0, dk.sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH+DIM, ac->K, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd2);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_fp16(dk.sdpaBwd2->ioOut, dq, 0, DIM, SEQ);
io_read_fp16(dk.sdpaBwd2->ioOut, dk_buf, DIM, DIM, SEQ);
io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dWq/dWk/dWv async
t0 = mach_absolute_time();
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk_buf, SEQ*DIM*4);
float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4);
float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM);
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
});
// QKV backward (ANE): dq,dk,dv @ Wq^T,Wk^T,Wv^T dx_attn
t0 = mach_absolute_time();
{
IOSurfaceLock(dk.qkvBwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.qkvBwd->ioIn);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, dq + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dk_buf + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 3*SEQ, lw[L].Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+DIM, lw[L].Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+2*DIM, lw[L].Wv + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.qkvBwd->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.qkvBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// RMSNorm1 backward
t0 = mach_absolute_time();
float *dx_rms1 = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms1, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms1[i] + dx2[i];
free(dx_rms1);
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
}
// Embedding backward
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
embed_backward(gembed, dy, input_tokens, DIM, SEQ);
double step_ms = tb_ms(mach_absolute_time() - t_step);
total_train_ms += step_ms;
total_steps_done++;
if (step % 10 == 0 || step == start_step) {
printf(" timing: ane_fwd=%.1f io_fwd=%.1f rms=%.1f ane_bwd=%.1f io_bwd=%.1f silu=%.1f rms_bwd=%.1f cls=%.1f cblas_wait=%.1f dw_copy=%.1f\n",
t_ane_fwd, t_io_fwd, t_rms, t_ane_bwd, t_io_bwd, t_silu, t_rms_bwd, t_cls, t_cblas_wait, t_dw_copy);
float xmx, xmn;
vDSP_maxv(x_cur,1,&xmx,(vDSP_Length)(SEQ*DIM));
vDSP_minv(x_cur,1,&xmn,(vDSP_Length)(SEQ*DIM));
float dmx, dmn;
vDSP_maxv(dy,1,&dmx,(vDSP_Length)(SEQ*DIM));
vDSP_minv(dy,1,&dmn,(vDSP_Length)(SEQ*DIM));
printf("step %-4d loss=%.4f lr=%.2e %.1fms/step x[%.2f,%.2f] dy[%.3e,%.3e]\n",
step, loss, lr, step_ms, xmn, xmx, dmn, dmx);
}
// Adam update every accum_steps
if ((step+1) % accum_steps == 0 || step == total_steps-1) {
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
float gsc = 1.0f / accum_steps;
adam_t++;
// Scale gradients by 1/accum_steps
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
}
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
// Merge compact classifier grads into full embed grads
vocab_scatter_grads(gembed, gcembed, &vm, DIM);
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
// Global gradient norm
float grad_norm_sq = 0;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
float s;
vDSP_dotpr(g->Wq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W3,1,g->W3,1,&s,(vDSP_Length)W3_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->rms_att,1,g->rms_att,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
vDSP_dotpr(g->rms_ffn,1,g->rms_ffn,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
}
{ float s;
vDSP_dotpr(grms_final,1,grms_final,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); grad_norm_sq+=s;
}
float grad_norm = sqrtf(grad_norm_sq);
if ((step+1) % 10 == 0) printf(" grad_norm=%.4f\n", grad_norm);
// Gradient clipping
if (grad_clip > 0 && grad_norm > grad_clip) {
float clip_scale = grad_clip / grad_norm;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
vDSP_vsmul(g->Wq,1,&clip_scale,g->Wq,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wk,1,&clip_scale,g->Wk,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wv,1,&clip_scale,g->Wv,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wo,1,&clip_scale,g->Wo,1,(vDSP_Length)WO_SZ);
vDSP_vsmul(g->W1,1,&clip_scale,g->W1,1,(vDSP_Length)W1_SZ);
vDSP_vsmul(g->W2,1,&clip_scale,g->W2,1,(vDSP_Length)W2_SZ);
vDSP_vsmul(g->W3,1,&clip_scale,g->W3,1,(vDSP_Length)W3_SZ);
vDSP_vsmul(g->rms_att,1,&clip_scale,g->rms_att,1,(vDSP_Length)DIM);
vDSP_vsmul(g->rms_ffn,1,&clip_scale,g->rms_ffn,1,(vDSP_Length)DIM);
}
vDSP_vsmul(grms_final,1,&clip_scale,grms_final,1,(vDSP_Length)DIM);
vDSP_vsmul(gembed,1,&clip_scale,gembed,1,(vDSP_Length)(VOCAB*DIM));
}
// Cosine LR schedule with warmup
if (step < warmup_steps) {
lr = max_lr * ((float)(step + 1)) / warmup_steps;
} else {
float decay_ratio = (float)(step - warmup_steps) / (float)(total_steps - warmup_steps);
float min_lr = max_lr * min_lr_frac;
lr = min_lr + 0.5f * (1.0f + cosf(M_PI * decay_ratio)) * (max_lr - min_lr);
}
// Adam update
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Update transposed weight buffers
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
}
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Re-extract compact embed from updated full embed
free(cembed);
cembed = vocab_compact_embed(embed, &vm, DIM);
// Zero grads
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
memset(grms_final, 0, DIM*4);
memset(gembed, 0, (size_t)VOCAB*DIM*4);
memset(gcembed, 0, (size_t)CV*DIM*4);
// Checkpoint
if ((step+1) % 100 == 0) {
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss,
total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
}
}
}
// Report
double wall = tb_ms(mach_absolute_time() - t_wall_start);
printf("\n=== Efficiency Report ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Compile: %.0fms (one-time, %.1f%%)\n", compile_ms, 100*compile_ms/(wall+cum_wall));
printf("Train time: %.0fms (%.1fms/step)\n", total_train_ms, total_train_ms/total_steps_done);
printf("Wall time: %.1fs\n", (wall+cum_wall)/1000);
// Cleanup
for (int L=0; L<NLAYERS; L++) {
layer_weights_free(&lw[L]); layer_adam_free(&la[L]);
layer_acts_free(&acts[L]); layer_grads_free(&grads[L]);
free(Wqt_buf[L]); free(Wkt_buf[L]); free(Wvt_buf[L]); free(Wot_buf[L]);
free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]);
}
free_kern(dk.sdpaFwd); free_kern(dk.ffnW13); free_kern(dk.ffnW2);
free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd);
free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd);
munmap(token_data, data_len); close(data_fd);
}
return 0;
}