// train.m — Dynamic weight ANE training (model-agnostic GQA support) // Model selected at compile time via: make MODEL=qwen3_06b (or stories110m) // Compile kernels ONCE at startup, update weights via IOSurface every step. #include "mil_dynamic.h" #include "cpu_ops.h" // Dynamic kernel set per layer typedef struct { Kern *sdpaFwd; // QKV matmul + RoPE + GQA tile + SDPA (no Wo) Kern *woFwd; // attn_out @ Wo^T → o_out (Q_DIM → DIM) Kern *ffnFused; // W1,W3 + SiLU + W2 + residual (fused) Kern *ffnBwdW2t; // dffn @ W2^T → dsilu_raw (DIM → HIDDEN) Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T → dx_ffn (HIDDEN → DIM) Kern *wotBwd; // dx2 @ Wo → da (DIM → Q_DIM) Kern *sdpaBwd1; // Q,K,V,da → dV_full,probs,dp (weight-free, has mask) Kern *sdpaBwd2; // probs,dp,Q,K → dQ,dK_full (weight-free) Kern *qBwd; // dq @ Wq → dx_q (Q_DIM → DIM) Kern *kvBwd; // dk@Wk + dv@Wv → dx_kv (KV_DIM → DIM) } DynLayerKernels; // Transpose W[rows,cols] → W^T[cols,rows] stored as [cols channels, rows spatial] static void transpose_weight(float *dst, const float *src, int rows, int cols) { for (int r = 0; r < rows; r++) for (int c = 0; c < cols; c++) dst[c * rows + r] = src[r * cols + c]; } // ===== Compile all dynamic kernels (ONCE) ===== static bool compile_dynamic_kernels(DynLayerKernels *dk) { NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}}; NSDictionary *sdpa_fwd_w = @{ @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, @"@model_path/weights/rope_cos.bin": @{@"offset":@0, @"data":get_rope_cos_blob()}, @"@model_path/weights/rope_sin.bin": @{@"offset":@0, @"data":get_rope_sin_blob()} }; int sdpa_out_ch = Q_DIM + Q_DIM + KV_DIM + KV_DIM + DIM; // SDPA forward (no Wo): [1, DIM, 1, SDPA_FWD_SP] → [1, sdpa_out_ch, 1, SEQ] printf(" Compiling sdpaFwd (GQA)...\n"); dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), sdpa_fwd_w, DIM*SDPA_FWD_SP*2, sdpa_out_ch*SEQ*2); if (!dk->sdpaFwd) return false; // Wo forward: [1, Q_DIM, 1, SEQ+DIM] → [1, DIM, 1, SEQ] printf(" Compiling woFwd...\n"); dk->woFwd = compile_kern_mil_w(gen_wo_fwd_dynamic(), @{}, Q_DIM*WO_FWD_SP*2, DIM*SEQ*2); if (!dk->woFwd) return false; // Fused FFN: [1, DIM, 1, FFN_FUSED_SP] → [1, DIM+3*HIDDEN, 1, SEQ] printf(" Compiling ffnFused...\n"); int ffn_fused_och = DIM + 3*HIDDEN; dk->ffnFused = compile_kern_mil_w(gen_ffn_fused_dynamic(), @{}, DIM*FFN_FUSED_SP*2, ffn_fused_och*SEQ*2); if (!dk->ffnFused) return false; // FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] → [1, HIDDEN, 1, SEQ] printf(" Compiling ffnBwdW2t...\n"); dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{}, DIM*FFN_BWD_W2T_SP*2, HIDDEN*SEQ*2); if (!dk->ffnBwdW2t) return false; // FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] → [1, DIM, 1, SEQ] printf(" Compiling ffnBwdW13t...\n"); dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{}, HIDDEN*FFN_BWD_W13T_SP*2, DIM*SEQ*2); if (!dk->ffnBwdW13t) return false; // Wo^T backward: [1, DIM, 1, SEQ+Q_DIM] → [1, Q_DIM, 1, SEQ] printf(" Compiling wotBwd...\n"); dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{}, DIM*WOT_BWD_SP*2, Q_DIM*SEQ*2); if (!dk->wotBwd) return false; // SDPA bwd1 (weight-free, has mask): [1, 4*Q_DIM, 1, SEQ] → [1, Q_DIM+2*SCORE_CH, 1, SEQ] printf(" Compiling sdpaBwd1 (GQA)...\n"); dk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_noweight(), mask_w, 4*Q_DIM*SEQ*2, (Q_DIM+2*SCORE_CH)*SEQ*2); if (!dk->sdpaBwd1) return false; // SDPA bwd2 (weight-free): [1, 2*SCORE_CH+2*Q_DIM, 1, SEQ] → [1, 2*Q_DIM, 1, SEQ] printf(" Compiling sdpaBwd2 (GQA)...\n"); dk->sdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{}, (2*SCORE_CH+2*Q_DIM)*SEQ*2, 2*Q_DIM*SEQ*2); if (!dk->sdpaBwd2) return false; // Q backward: [1, Q_DIM, 1, SEQ+DIM] → [1, DIM, 1, SEQ] printf(" Compiling qBwd...\n"); dk->qBwd = compile_kern_mil_w(gen_q_bwd_dynamic(), @{}, Q_DIM*Q_BWD_SP*2, DIM*SEQ*2); if (!dk->qBwd) return false; // KV backward: [1, KV_DIM, 1, 2*SEQ+2*DIM] → [1, DIM, 1, SEQ] printf(" Compiling kvBwd...\n"); dk->kvBwd = compile_kern_mil_w(gen_kv_bwd_dynamic(), @{}, KV_DIM*KV_BWD_SP*2, DIM*SEQ*2); if (!dk->kvBwd) return false; return true; } // ===== Checkpoint ===== static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, double ct, double cw, int cs, int adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "wb"); CkptHdr h = {0}; h.magic = 0x424C5A54; h.version = 4; h.step = step; h.total_steps = total_steps; h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; h.lr = lr; h.loss = loss; h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.adam_t = adam_t; h.kv_heads = KV_HEADS; h.head_dim = HD; h.q_dim = Q_DIM; fwrite(&h, sizeof(h), 1, f); for (int L = 0; L < NLAYERS; L++) { fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WK_SZ,f); fwrite(lw[L].Wv,4,WV_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); fwrite(la[L].Wk.m,4,WK_SZ,f); fwrite(la[L].Wk.v,4,WK_SZ,f); fwrite(la[L].Wv.m,4,WV_SZ,f); fwrite(la[L].Wv.v,4,WV_SZ,f); fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f); fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f); fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f); } fwrite(rms_final,4,DIM,f); fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f); fwrite(embed,4,VOCAB*DIM,f); fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f); fclose(f); } static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss, double *ct, double *cw, int *cs, int *adam_t, LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "rb"); if (!f) return false; CkptHdr h; fread(&h, sizeof(h), 1, f); if (h.magic != 0x424C5A54 || h.version != 4) { fclose(f); return false; } *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; *ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *adam_t = h.adam_t; for (int L = 0; L < NLAYERS; L++) { fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WK_SZ,f); fread(lw[L].Wv,4,WV_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); fread(la[L].Wk.m,4,WK_SZ,f); fread(la[L].Wk.v,4,WK_SZ,f); fread(la[L].Wv.m,4,WV_SZ,f); fread(la[L].Wv.v,4,WV_SZ,f); fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f); fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f); fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f); } fread(rms_final,4,DIM,f); fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f); fread(embed,4,VOCAB*DIM,f); fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f); fclose(f); return true; } int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); ane_init(); mach_timebase_info(&g_tb); int total_steps = 10000; float max_lr = 3e-4f; float adam_b1=0.9f, adam_b2=0.95f, adam_eps=1e-8f, wd=0.1f; int adam_t = 0, start_step = 0; int accum_steps = 10; int warmup_steps = 100; float grad_clip = 1.0f; float loss_scale = 256.0f; float res_alpha = 1.0f / sqrtf(2.0f * NLAYERS); float min_lr_frac = 0.1f; bool do_resume = false, from_scratch = false; const char *data_path = DEFAULT_DATA_PATH; for (int i=1; i