// train_large.m — Train stories110M (12 layers, 768dim, 3072hidden) on ANE // Uses pretokenized TinyStories data with cross-entropy loss // 5 weight-bearing ANE kernels per layer × 12 layers = 60 per compile batch #include "stories_io.h" #include "stories_mil.h" #include "stories_cpu_ops.h" #define CKPT_PATH "ane_stories110M_ckpt.bin" #define MODEL_PATH "../../assets/models/stories110M.bin" #define DATA_PATH "tinystories_data00.bin" // ===== Weight loading from llama2.c format ===== static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { FILE *f = fopen(path, "rb"); if (!f) { printf("Cannot open %s\n", path); return false; } Llama2Config cfg; fread(&cfg, sizeof(cfg), 1, f); printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS); fclose(f); return false; } int V = abs(cfg.vocab_size); bool shared = cfg.vocab_size > 0; // Read in llama2.c order: embed, rms_att[all], wq[all], wk[all], wv[all], wo[all], // rms_ffn[all], w1[all], w2[all], w3[all], rms_final, [wcls] fread(embed, 4, V * DIM, f); // rms_att weights for all layers (contiguous) for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f); // wq for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f); // wk for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f); // wv for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f); // wo for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f); // rms_ffn weights for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f); // w1 for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f); // w2 for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f); // w3 for all layers for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f); // rms_final fread(rms_final, 4, DIM, f); // wcls = embed if shared (we just use embed pointer) fclose(f); printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls"); return true; } // ===== Compile one layer's kernels ===== static bool compile_layer_kernels(LayerKernels *lk) { int fwdAttn_ins[] = { DIM*SEQ*2, DIM*2, WQ_SZ*2, WQ_SZ*2, WQ_SZ*2, WO_SZ*2, SEQ*SEQ*2 }; lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_flex(), @{}, fwdAttn_ins, 7, 6*DIM*SEQ*2); int fwdFFN_ins[] = { DIM*SEQ*2, DIM*2, W1_SZ*2, WO_SZ*2, W3_SZ*2 }; lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_flex(), @{}, fwdFFN_ins, 5, (2*DIM+3*HIDDEN)*SEQ*2); int ffnBwd_ins[] = { (DIM+2*HIDDEN)*SEQ*2, W1_SZ*2, W2_SZ*2, W3_SZ*2 }; lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd_flex(), @{}, ffnBwd_ins, 4, (DIM+2*HIDDEN)*SEQ*2); int sdpaBwd1_ins[] = { 4*DIM*SEQ*2, WO_SZ*2, SEQ*SEQ*2 }; lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_flex(), @{}, sdpaBwd1_ins, 3, (DIM+2*SCORE_CH)*SEQ*2); int qkvBwd_ins[] = { 3*DIM*SEQ*2, WQ_SZ*2, WQ_SZ*2, WQ_SZ*2 }; lk->qkvBwd = compile_kern_mil_w(gen_qkvb_flex(), @{}, qkvBwd_ins, 4, DIM*SEQ*2); return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd; } static void update_ane_weights(LayerKernels *lk, LayerWeights *w) { // fwdAttn: x(0), rw(1), Wq(2), Wk(3), Wv(4), Wo(5), cm(6) io_write_fp16(lk->fwdAttn->inputs[1], w->rms_att, 1, DIM); io_write_fp16(lk->fwdAttn->inputs[2], w->Wq, DIM, DIM); io_write_fp16(lk->fwdAttn->inputs[3], w->Wk, DIM, DIM); io_write_fp16(lk->fwdAttn->inputs[4], w->Wv, DIM, DIM); io_write_fp16(lk->fwdAttn->inputs[5], w->Wo, DIM, DIM); static NSData *m_blob = nil; if(!m_blob) m_blob = get_mask_blob(); IOSurfaceLock(lk->fwdAttn->inputs[6], 0, NULL); memcpy(IOSurfaceGetBaseAddress(lk->fwdAttn->inputs[6]), (uint8_t*)[m_blob bytes]+128, SEQ*SEQ*2); IOSurfaceUnlock(lk->fwdAttn->inputs[6], 0, NULL); // fwdFFN: x(0), rw(1), W1(2), W2(3), W3(4) io_write_fp16(lk->fwdFFN->inputs[1], w->rms_ffn, 1, DIM); io_write_fp16(lk->fwdFFN->inputs[2], w->W1, HIDDEN, DIM); io_write_fp16(lk->fwdFFN->inputs[3], w->W2, DIM, HIDDEN); io_write_fp16(lk->fwdFFN->inputs[4], w->W3, HIDDEN, DIM); // ffnBwd: x(0), W1t(1), W2t(2), W3t(3) io_write_fp16_t(lk->ffnBwd->inputs[1], w->W1, HIDDEN, DIM); io_write_fp16_t(lk->ffnBwd->inputs[2], w->W2, DIM, HIDDEN); io_write_fp16_t(lk->ffnBwd->inputs[3], w->W3, HIDDEN, DIM); // sdpaBwd1: x(0), Wot(1), cm(2) io_write_fp16_t(lk->sdpaBwd1->inputs[1], w->Wo, DIM, DIM); IOSurfaceLock(lk->sdpaBwd1->inputs[2], 0, NULL); memcpy(IOSurfaceGetBaseAddress(lk->sdpaBwd1->inputs[2]), (uint8_t*)[m_blob bytes]+128, SEQ*SEQ*2); IOSurfaceUnlock(lk->sdpaBwd1->inputs[2], 0, NULL); // qkvBwd: x(0), Wqt(1), Wkt(2), Wvt(3) io_write_fp16_t(lk->qkvBwd->inputs[1], w->Wq, DIM, DIM); io_write_fp16_t(lk->qkvBwd->inputs[2], w->Wk, DIM, DIM); io_write_fp16_t(lk->qkvBwd->inputs[3], w->Wv, DIM, DIM); } // Compile weight-free sdpaBwd2 (only needs once, no weights) static Kern *compile_sdpa_bwd2(void) { int bwd2_ins[] = { (2*SCORE_CH+2*DIM)*SEQ*2 }; return compile_kern_mil_w(gen_sdpa_bwd2_flex(), @{}, bwd2_ins, 1, 2*DIM*SEQ*2); } static void free_layer_kernels(LayerKernels *lk) { free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd); free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd); lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL; } // ===== Checkpoint save/load ===== static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, double cc, double ct, double cw, int cs, int cb, int adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "wb"); CkptHdr h = {0}; h.magic = 0x424C5A54; h.version = 2; h.step = step; h.total_steps = total_steps; h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; h.lr = lr; h.loss = loss; h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t; fwrite(&h, sizeof(h), 1, f); // Per-layer weights + adam for (int L = 0; L < NLAYERS; L++) { fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f); fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); // Adam state fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f); fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f); fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f); fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f); fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f); } fwrite(rms_final,4,DIM,f); fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f); fwrite(embed,4,VOCAB*DIM,f); fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f); fclose(f); } static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss, double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "rb"); if (!f) return false; CkptHdr h; fread(&h, sizeof(h), 1, f); if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; } *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; *cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t; for (int L = 0; L < NLAYERS; L++) { fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f); fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f); fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f); fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f); fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f); fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f); } fread(rms_final,4,DIM,f); fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f); fread(embed,4,VOCAB*DIM,f); fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f); fclose(f); return true; } // ===== Main ===== int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); ane_init(); mach_timebase_info(&g_tb); int total_steps = 10000; float lr = 3e-4f; float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; int adam_t = 0, start_step = 0; // Parse args bool do_resume = false; int cli_steps = -1; float cli_lr = -1; for (int i=1; i 0) total_steps = cli_steps; if (cli_lr > 0) lr = cli_lr; // Allocate per-layer state LayerWeights lw[NLAYERS]; LayerAdam la[NLAYERS]; LayerActs acts[NLAYERS]; LayerGrads grads[NLAYERS]; LayerKernels kern[NLAYERS]; for (int L=0; L 0) total_steps = cli_steps; if (cli_lr > 0) lr = cli_lr; } } if (!resuming) { printf("=== ANE Training: Stories110M (12 layers) ===\n"); printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS); if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) { printf("Pretrained load failed, using random init\n"); srand48(42); float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN); for (int L=0; Llayer_in, x_cur, SEQ*DIM*4); // Attention forward: x_cur → o_out,Q,K,V,attn_out,xnorm t0=mach_absolute_time(); dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER); t1=mach_absolute_time(); t_cblas_wait+=tb_ms(t1-t0); t0=t1; io_write_fp16(kern[L].fwdAttn->inputs[0], x_cur, DIM, SEQ); t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1; ane_eval(kern[L].fwdAttn); t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1; io_read_fp16(kern[L].fwdAttn->ioOut, ac->o_out, 0, DIM, SEQ); io_read_fp16(kern[L].fwdAttn->ioOut, ac->attn_out, 4*DIM, DIM, SEQ); io_read_fp16(kern[L].fwdAttn->ioOut, ac->xnorm, 5*DIM, DIM, SEQ); t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1; vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM)); t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1; // FFN forward io_write_fp16(kern[L].fwdFFN->inputs[0], ac->x2, DIM, SEQ); t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1; ane_eval(kern[L].fwdFFN); t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1; io_read_fp16(kern[L].fwdFFN->ioOut, ac->ffn_out, 0, DIM, SEQ); io_read_fp16(kern[L].fwdFFN->ioOut, ac->h1, DIM, HIDDEN, SEQ); io_read_fp16(kern[L].fwdFFN->ioOut, ac->h3, DIM+HIDDEN, HIDDEN, SEQ); io_read_fp16(kern[L].fwdFFN->ioOut, ac->silu_out, DIM+2*HIDDEN, HIDDEN, SEQ); io_read_fp16(kern[L].fwdFFN->ioOut, ac->x2norm, DIM+3*HIDDEN, DIM, SEQ); t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1; vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM)); t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); } // Final RMSNorm (CPU) t0=mach_absolute_time(); rmsnorm(x_final, x_cur, rms_final, DIM, SEQ); t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1; // Classifier: logits = embed^T @ x_final cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, VOCAB, SEQ, DIM, 1.0f, embed, DIM, x_final, SEQ, 0.0f, logits, SEQ); t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1; // Cross-entropy loss float loss = cross_entropy_loss(dlogits, logits, target_tokens, VOCAB, SEQ); last_loss = loss; t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1; // ===== BACKWARD ===== // dlogits already computed by cross_entropy_loss // Classifier backward: dx_final = embed^T @ dlogits, dembed += dlogits @ x_final^T // dx_final[DIM,SEQ] = embed^T[DIM,VOCAB] @ dlogits[VOCAB,SEQ] cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, DIM, SEQ, VOCAB, 1.0f, embed, DIM, dlogits, SEQ, 0.0f, dy, SEQ); // dembed[VOCAB,DIM] += dlogits[VOCAB,SEQ] @ x_final^T[SEQ,DIM] dispatch_group_async(dw_grp, dw_q, ^{ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, VOCAB, DIM, SEQ, 1.0f, dlogits, SEQ, x_final, SEQ, 1.0f, gembed, DIM); }); // Final RMSNorm backward float *dx_rms_final = (float*)calloc(SEQ*DIM, 4); rmsnorm_bwd(dx_rms_final, grms_final, dy, x_cur, rms_final, DIM, SEQ); memcpy(dy, dx_rms_final, SEQ*DIM*4); free(dx_rms_final); // ===== BACKWARD (12 layers, reverse) ===== for (int L=NLAYERS-1; L>=0; L--) { LayerActs *ac = &acts[L]; LayerGrads *gr = &grads[L]; // dy is the gradient at the output of this layer // dffn = dy (residual connection: d(x2 + ffn) = dy for both) memcpy(dffn, dy, SEQ*DIM*4); // FFN backward (ANE) io_write_fp16_at(kern[L].ffnBwd->inputs[0], 0, dffn, DIM, SEQ); io_copy(kern[L].ffnBwd->inputs[0], DIM, kern[L].fwdFFN->ioOut, DIM, 2*HIDDEN, SEQ); ane_eval(kern[L].ffnBwd); io_read_fp16(kern[L].ffnBwd->ioOut, dx_ffn, 0, DIM, SEQ); io_read_fp16(kern[L].ffnBwd->ioOut, dh1, DIM, HIDDEN, SEQ); io_read_fp16(kern[L].ffnBwd->ioOut, dh3, DIM+HIDDEN, HIDDEN, SEQ); // dW FFN async float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4); float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4); float *capt_dh1 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh1, dh1, SEQ*HIDDEN*4); float *capt_dh3 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh3, dh3, SEQ*HIDDEN*4); float *capt_x2n = (float*)malloc(SEQ*DIM*4); memcpy(capt_x2n, ac->x2norm, SEQ*DIM*4); dispatch_group_async(dw_grp, dw_q, ^{ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ, 1.0f, capt_dffn, SEQ, capt_silu, SEQ, 1.0f, gr->W2, HIDDEN); cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ, 1.0f, capt_dh1, SEQ, capt_x2n, SEQ, 1.0f, gr->W1, DIM); cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ, 1.0f, capt_dh3, SEQ, capt_x2n, SEQ, 1.0f, gr->W3, DIM); free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n); }); // RMSNorm2 backward memset(dx2, 0, SEQ*DIM*4); rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ); // Add residual: dx2 += dy (from skip connection) for(int i=0;iattn_out, SEQ*DIM*4); dispatch_group_async(dw_grp, dw_q, ^{ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, 1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM); free(capt_do); free(capt_attn); }); // SDPA backward (ANE) io_copy(kern[L].sdpaBwd1->inputs[0], 0, kern[L].fwdAttn->ioOut, DIM, 3*DIM, SEQ); io_write_fp16_at(kern[L].sdpaBwd1->inputs[0], 3*DIM, dx2, DIM, SEQ); ane_eval(kern[L].sdpaBwd1); io_copy(sdpaBwd2[L]->inputs[0], 0, kern[L].sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ); io_copy(sdpaBwd2[L]->inputs[0], 2*SCORE_CH, kern[L].fwdAttn->ioOut, DIM, 2*DIM, SEQ); ane_eval(sdpaBwd2[L]); io_read_fp16(sdpaBwd2[L]->ioOut, dq, 0, DIM, SEQ); io_read_fp16(sdpaBwd2[L]->ioOut, dk, DIM, DIM, SEQ); io_read_fp16(kern[L].sdpaBwd1->ioOut, dv, 0, DIM, SEQ); // dWq/dWk/dWv async float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4); float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk, SEQ*DIM*4); float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4); float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4); dispatch_group_async(dw_grp, dw_q, ^{ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, 1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM); cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, 1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM); cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, 1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM); free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn); }); // QKV backward (ANE) io_copy(kern[L].qkvBwd->inputs[0], 0, sdpaBwd2[L]->ioOut, 0, 2*DIM, SEQ); io_copy(kern[L].qkvBwd->inputs[0], 2*DIM, kern[L].sdpaBwd1->ioOut, 0, DIM, SEQ); ane_eval(kern[L].qkvBwd); io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ); // RMSNorm1 backward (using saved layer input) float *dx_rms1 = (float*)calloc(SEQ*DIM, 4); rmsnorm_bwd(dx_rms1, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ); // dy for next layer (going backward) = dx_rms1 + dx2 residual // Actually: layer output = layer_input + o_out, and x2 = layer_input + o_out // So dx(layer_input) = dx_attn_rmsnorm + dx2 (residual from attn skip) // Wait, dx2 already includes the attn skip residual gradient. // dy = dx_rms1 (through rmsnorm1) is the gradient to the layer input // But there's also the skip connection: layer_input → x2 directly // So total gradient to layer_input = dx_rms1 + dx2_skip // dx2 was computed as rmsnorm2_bwd + dy(ffn_skip), which already flows to x2 // x2 = layer_input + o_out, so d(layer_input) from x2 path = dx2 // And d(layer_input) from attn path through rmsnorm1 = dx_rms1 // Total: dy_prev = dx_rms1 (attn rmsnorm path) // Wait no - dx2 = d(loss)/d(x2), not d(loss)/d(layer_input) // d(layer_input) = d(loss)/d(x2) * d(x2)/d(layer_input) = dx2 (since x2 = input + o_out, d(x2)/d(input) = 1) // Plus the path through rmsnorm1: dx_rms1 // Hmm but dx2 was already used as input to SDPA backward... let me reconsider. // // Actually the gradient flow is: // dy → split to (dffn, dy_skip) [dy_skip = dy due to residual] // dffn → ffnBwd → dx_ffn // dx_ffn → rmsnorm2_bwd → dx_rms2 // dx2 = dx_rms2 + dy (skip connection from residual x2 → output) // dx2 → sdpaBwd → dx_attn through Wo^T // dx_attn → qkvBwd → dx_qkv // dx_qkv → rmsnorm1_bwd → dx_rms1 // dy_prev_layer = dx_rms1 + dx2 (skip connection input → x2) // // So: dy for previous layer = dx_rms1 + dx2 for(int i=0;iWq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;} for(size_t i=0;iW1[i]*=gsc; for(size_t i=0;iW2[i]*=gsc; for(size_t i=0;iW3[i]*=gsc; for(int i=0;irms_att[i]*=gsc; g->rms_ffn[i]*=gsc;} adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps); } for(int i=0;i