diff --git a/training/README.md b/training/README.md index a3f33eb..91cede5 100644 --- a/training/README.md +++ b/training/README.md @@ -1,14 +1,22 @@ -# ANE Training — Stories110M on Apple Neural Engine +# ANE Training — On-Device Training on Apple Neural Engine -Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly on Apple's Neural Engine using private ANE APIs. +Training transformer models directly on Apple's Neural Engine using private ANE APIs. Supports multiple architectures including GQA (Grouped-Query Attention). ![Dashboard](dashboard.gif) +## Supported Models + +| Model | Layers | Heads (Q/KV) | Dim | Hidden | Params | ms/step | +|-------|--------|--------------|-----|--------|--------|---------| +| Stories110M | 12 | 12/12 (MHA) | 768 | 2048 | 109M | ~115 | +| Qwen3-0.6B | 28 | 16/8 (GQA) | 1024 | 3072 | 596M | ~412 | + +Model configs live in `training_dynamic/models/*.h`. To add a new model, create a header with the architecture defines (see below). + ## Architecture -- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256 -- **109.53M params** (84.95M transformer + 24.58M embedding) - **SDPA causal mask workaround**: ANE hardware ignores attn_mask — decompose into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv) +- **GQA support**: K/V heads tiled to match Q heads for SDPA, reduced back after backward pass ## Three Training Pipelines @@ -27,10 +35,10 @@ Offloads classifier forward (32K conv), softmax, final RMSNorm, and RMSNorm back - Use `--no-ane-extras` to disable and fall back to CPU (for debugging) ### 3. Dynamic Weight Pipeline (`training_dynamic/`) -Weights passed via IOSurface spatial dimension — compile 9 kernels once at startup, no recompilation needed. +Weights passed via IOSurface spatial dimension — compile 10 kernels once at startup, no recompilation needed. Supports multiple models via `make MODEL=xxx`. -- 9 shared kernels across all 12 layers -- **111 ms/step**, 0.4s one-time compile +- 10 shared kernels across all layers (GQA-aware: split sdpaFwd/woFwd, split qBwd/kvBwd) +- **~115 ms/step** (Stories110M) / **~412 ms/step** (Qwen3-0.6B), 0.4s one-time compile - No exec() restart, no compile limit issues ## Performance Comparison (20 Steps) @@ -56,10 +64,11 @@ Weights passed via IOSurface spatial dimension — compile 9 kernels once at sta |------|-------------| | `train_large.m` | Static baseline — 72 kernels, classifier/softmax on CPU | | `train_large_ane.m` | PR#19 — 86 kernels, classifier/softmax/rmsnorm_bwd on ANE | -| `training_dynamic/train.m` | Dynamic pipeline — 9 kernels, weights via IOSurface | -| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels | -| `training_dynamic/config.h` | Model config (DIM=768, HIDDEN=2048, etc.) | -| `training_dynamic/io.h` | IOSurface I/O + MIL compilation helpers | +| `training_dynamic/train.m` | Dynamic pipeline — 10 kernels, weights via IOSurface | +| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels (GQA-aware) | +| `training_dynamic/config.h` | Derived sizes, structs, alloc helpers (model-agnostic) | +| `training_dynamic/models/*.h` | Per-model configs (stories110m.h, qwen3_06b.h) | +| `training_dynamic/io.h` | IOSurface I/O, weight staging, GQA tile/reduce | | `training_dynamic/cpu_ops.h` | CPU ops (SiLU backward, cross-entropy, Adam) | | `stories_config.h` | Static pipeline config, structs, alloc helpers | | `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval | @@ -83,33 +92,35 @@ Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from HuggingFace. Pr ### 2. Build & Train ```bash -# Static baseline (classifier + softmax on CPU) -make train_large -./train_large stories110M.bin 256 100 1e-4 -./train_large --model stories110M.bin --steps 100 --lr 1e-4 -./train_large --data ./tinystories_data00.bin --steps 100 --lr 1e-4 - -# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd -make train_large_ane -./train_large_ane stories110M.bin 256 100 1e-4 -./train_large_ane --no-ane-extras --steps 100 # disable ANE extras -./train_large_ane --data ./tinystories_data00.bin --steps 100 --lr 1e-4 +# Static baseline (classifier + softmax on CPU) +make train_large +./train_large stories110M.bin 256 100 1e-4 +./train_large --model stories110M.bin --steps 100 --lr 1e-4 +./train_large --data ./tinystories_data00.bin --steps 100 --lr 1e-4 -# Dynamic pipeline (no recompilation) -cd training_dynamic && make train +# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd +make train_large_ane +./train_large_ane stories110M.bin 256 100 1e-4 +./train_large_ane --no-ane-extras --steps 100 # disable ANE extras +./train_large_ane --data ./tinystories_data00.bin --steps 100 --lr 1e-4 + +# Dynamic pipeline (model selected at build time) +cd training_dynamic +make MODEL=qwen3_06b # default — Qwen3-0.6B (28L, GQA, 596M) +make MODEL=stories110m # Stories110M (12L, MHA, 109M) ./train --scratch # train from random init -./train # resume from checkpoint +./train --resume # resume from checkpoint ./train --steps 200 --lr 1e-4 # custom steps/lr ``` -**CLI flags (`train_large` / `train_large_ane`):** +**CLI flags (`train_large` / `train_large_ane`):** - `--steps N` (default 10000) -- `--lr F` (default 3e-4) -- `--model PATH` — pretrained weights file -- `--data PATH` — tokenized TinyStories `.bin` file (default: `tinystories_data00.bin`) -- `--ckpt PATH` — checkpoint file (preserved across exec() restarts) -- `--resume` — resume from checkpoint -- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd +- `--lr F` (default 3e-4) +- `--model PATH` — pretrained weights file +- `--data PATH` — tokenized TinyStories `.bin` file (default: `tinystories_data00.bin`) +- `--ckpt PATH` — checkpoint file (preserved across exec() restarts) +- `--resume` — resume from checkpoint +- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd ### 3. Monitor with Dashboard @@ -133,11 +144,42 @@ Avg train: 91.8 ms/step ANE TFLOPS: 1.15 sustained ``` +## Adding a New Model + +Create `training_dynamic/models/mymodel.h`: + +```c +#pragma once +#define MODEL_NAME "MyModel-1B" + +#define DIM 2048 // model hidden dim +#define HIDDEN 5504 // FFN intermediate dim +#define HEADS 32 // number of query heads +#define KV_HEADS 8 // number of KV heads (= HEADS for MHA) +#define HD 64 // head dim (can differ from DIM/HEADS) +#define SEQ 256 // sequence length +#define NLAYERS 22 // number of transformer layers +#define VOCAB 32000 // vocabulary size + +#define CKPT_PATH "ane_mymodel_dyn_ckpt.bin" +#define DEFAULT_DATA_PATH "../tinystories_data00.bin" +``` + +Everything else is derived automatically: `GQA_RATIO`, `Q_DIM`, `KV_DIM`, weight sizes, IOSurface layouts, MIL kernels. + +Build with: `make MODEL=mymodel` + +**Constraints:** +- `HEADS` must be divisible by `KV_HEADS` +- `HD` is explicit (not necessarily `DIM/HEADS` — Qwen3 uses HD=128 with DIM/HEADS=64) +- For MHA (no GQA), set `KV_HEADS = HEADS` + ## Key Techniques - **NEON vectorized fp16↔fp32**: ARM NEON intrinsics for fast IOSurface data transfer - **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar - **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE -- **Vocab compaction** (dynamic): 32K → 9.2K active tokens, 3.5x reduction in classifier work -- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all 12 layers +- **Vocab compaction** (dynamic): 32K–152K → 9.2K active tokens, up to 16.5x reduction in classifier work +- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all layers +- **GQA tile/reduce**: K/V tiled from KV_HEADS→HEADS on CPU before SDPA backward, gradients reduced HEADS→KV_HEADS after - **exec() restart**: Workaround for ANE ~119 compile limit per process diff --git a/training/training_dynamic/Makefile b/training/training_dynamic/Makefile index 1105351..bbcaa80 100644 --- a/training/training_dynamic/Makefile +++ b/training/training_dynamic/Makefile @@ -2,8 +2,16 @@ CC = xcrun clang CFLAGS = -O2 -DACCELERATE_NEW_LAPACK -framework Foundation -framework IOSurface -framework Accelerate \ -isysroot $(shell xcrun --show-sdk-path) -fobjc-arc -train: train.m config.h io.h cpu_ops.h mil_dynamic.h - $(CC) $(CFLAGS) -o train train.m +# Model selection: make MODEL=qwen3_06b (default) +# Available: stories110m, qwen3_06b +MODEL ?= qwen3_06b +MODEL_HDR = models/$(MODEL).h + +train: train.m config.h io.h cpu_ops.h mil_dynamic.h $(MODEL_HDR) + @echo "Building for model: $(MODEL)" + $(CC) $(CFLAGS) -include $(MODEL_HDR) -o train train.m clean: rm -f train + +.PHONY: clean diff --git a/training/training_dynamic/config.h b/training/training_dynamic/config.h index d22b6f1..f4c2fe4 100644 --- a/training/training_dynamic/config.h +++ b/training/training_dynamic/config.h @@ -1,4 +1,5 @@ -// config.h — Stories110M model config, structs, ANE init +// config.h — Model-agnostic structs, derived sizes, ANE init +// Model-specific dims come from models/*.h, selected via -DMODEL_HEADER #pragma once #import #import @@ -15,22 +16,21 @@ #include #include -// Stories110M config -#define DIM 768 -#define HIDDEN 2048 -#define HEADS 12 -#define HD (DIM/HEADS) -#define SEQ 256 -#define NLAYERS 12 -#define VOCAB 32000 +// Include selected model config +// MODEL_HEADER is set by Makefile via -include models/xxx.h +#ifndef MODEL_NAME +#error "No model selected. Build with: make MODEL=qwen3_06b (or stories110m)" +#endif -// Weight sizes per layer -#define WQ_SZ (DIM*DIM) -#define WO_SZ (DIM*DIM) +// Derived weight sizes per layer (GQA-aware) +#define WQ_SZ (Q_DIM*DIM) +#define WK_SZ (KV_DIM*DIM) +#define WV_SZ (KV_DIM*DIM) +#define WO_SZ (DIM*Q_DIM) #define W1_SZ (HIDDEN*DIM) #define W2_SZ (DIM*HIDDEN) #define W3_SZ (HIDDEN*DIM) -#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM) +#define LAYER_PARAMS (WQ_SZ + WK_SZ + WV_SZ + WO_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM) // Attention score channels for SDPA backward #define SCORE_CH (HEADS*SEQ) @@ -64,14 +64,14 @@ typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmp // Per-layer IOSurfaces for pre-staged weights typedef struct { - IOSurfaceRef sdpaFwd_in, ffnFused_in; - IOSurfaceRef ffnBwdW2t_in, ffnBwdW13t_in, wotBwd_in, qkvBwd_in; + IOSurfaceRef sdpaFwd_in, woFwd_in, ffnFused_in; + IOSurfaceRef ffnBwdW2t_in, ffnBwdW13t_in, wotBwd_in, qBwd_in, kvBwd_in; } PerLayerSurfaces; // Per-layer ANE requests (bound to per-layer IOSurfaces) typedef struct { - void *sdpaFwd, *ffnFused; - void *ffnBwdW2t, *ffnBwdW13t, *wotBwd, *qkvBwd; + void *sdpaFwd, *woFwd, *ffnFused; + void *ffnBwdW2t, *ffnBwdW13t, *wotBwd, *qBwd, *kvBwd; } PerLayerRequests; // Checkpoint header @@ -81,14 +81,10 @@ typedef struct { float lr, loss; double cum_compile, cum_train, cum_wall; int cum_steps, cum_batches, adam_t; - int pad[3]; + int kv_heads, head_dim, q_dim; // GQA fields + // Note: was int pad[3] in v3, now stores GQA info in v4+ } CkptHdr; -// llama2.c model file header -typedef struct { - int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len; -} Llama2Config; - // Globals static Class g_D, g_I, g_AR, g_AIO; static mach_timebase_info_data_t g_tb; @@ -109,8 +105,8 @@ static void adam_free(AdamState *s) { free(s->m); free(s->v); } static LayerWeights layer_weights_alloc(void) { LayerWeights w; - w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4); - w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4); + w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WK_SZ*4); + w.Wv=(float*)malloc(WV_SZ*4); w.Wo=(float*)malloc(WO_SZ*4); w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4); w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4); return w; @@ -121,7 +117,7 @@ static void layer_weights_free(LayerWeights *w) { } static LayerAdam layer_adam_alloc(void) { LayerAdam a; - a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ); + a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WK_SZ); a.Wv=adam_alloc(WV_SZ); a.Wo=adam_alloc(WO_SZ); a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ); a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM); return a; @@ -135,8 +131,8 @@ static LayerActs layer_acts_alloc(void) { LayerActs a; a.layer_in=(float*)malloc(SEQ*DIM*4); a.xnorm=(float*)malloc(SEQ*DIM*4); - a.Q=(float*)malloc(SEQ*DIM*4); a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4); - a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4); + a.Q=(float*)malloc(SEQ*Q_DIM*4); a.K=(float*)malloc(SEQ*KV_DIM*4); a.V=(float*)malloc(SEQ*KV_DIM*4); + a.attn_out=(float*)malloc(SEQ*Q_DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4); a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4); a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4); a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4); @@ -150,15 +146,15 @@ static void layer_acts_free(LayerActs *a) { } static LayerGrads layer_grads_alloc(void) { LayerGrads g; - g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4); - g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4); + g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WK_SZ,4); + g.Wv=(float*)calloc(WV_SZ,4); g.Wo=(float*)calloc(WO_SZ,4); g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4); g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4); return g; } static void layer_grads_zero(LayerGrads *g) { - memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4); - memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4); + memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WK_SZ*4); + memset(g->Wv,0,WV_SZ*4);memset(g->Wo,0,WO_SZ*4); memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4); memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4); } diff --git a/training/training_dynamic/io.h b/training/training_dynamic/io.h index 776e4b7..02d6920 100644 --- a/training/training_dynamic/io.h +++ b/training/training_dynamic/io.h @@ -1,4 +1,5 @@ // io.h — IOSurface helpers, NEON conversion, kernel compile/eval +// Updated for GQA (Qwen3-0.6B): Q_DIM != DIM, separate KV heads #pragma once #include "config.h" @@ -75,8 +76,6 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int } // fp16 IOSurface I/O (for dynamic matmul kernels with fp16 input/output) -// Layout: [1, IC, 1, SP] where SP = SEQ + OC -// Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC] static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq, const float *W, int oc) { int sp = seq + oc; @@ -145,14 +144,10 @@ static void ane_eval(Kern *k) { id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil; ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); } - -// Evaluate with a per-layer request (different ioIn, same model) static void ane_eval_req(Kern *k, void *request) { id mdl = (__bridge id)k->model; id req = (__bridge id)request; NSError *e = nil; ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); } - -// Create an ANE request binding a custom ioIn to a kernel's model+ioOut static void *make_request(Kern *k, IOSurfaceRef ioIn) { id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), ioIn); id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut); @@ -162,172 +157,157 @@ static void *make_request(Kern *k, IOSurfaceRef ioIn) { return (void*)CFBridgingRetain(req); } -// ===== Per-layer weight staging (write once, reuse across steps) ===== -// All surfaces are now fp16 — staging converts fp32 weights to fp16 - -// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] fp16 — weights at sp[SEQ:] -static void stage_sdpa_fwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, - const float *Wv, const float *Wo) { +// ===== Per-layer weight staging for GQA ===== +// sdpaFwd: [1, DIM, 1, SEQ + Q_DIM + KV_DIM + KV_DIM] fp16 — no Wo (separate kernel) +// Wq: [DIM, Q_DIM], Wk: [DIM, KV_DIM], Wv: [DIM, KV_DIM] +#define SDPA_FWD_SP (SEQ + Q_DIM + KV_DIM + KV_DIM) +static void stage_sdpa_fwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, const float *Wv) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + 4*DIM; for (int d = 0; d < DIM; d++) { - cvt_f32_f16(buf + d*sp + SEQ, Wq + d*DIM, DIM); - cvt_f32_f16(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM); - cvt_f32_f16(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM); - cvt_f32_f16(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM); + cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ, Wq + d*Q_DIM, Q_DIM); + cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ+Q_DIM, Wk + d*KV_DIM, KV_DIM); + cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ+Q_DIM+KV_DIM, Wv + d*KV_DIM, KV_DIM); } IOSurfaceUnlock(s, 0, NULL); } static void write_sdpa_fwd_acts(IOSurfaceRef s, const float *xnorm) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + 4*DIM; for (int d = 0; d < DIM; d++) - cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ); + cvt_f32_f16(buf + d*SDPA_FWD_SP, xnorm + d*SEQ, SEQ); + IOSurfaceUnlock(s, 0, NULL); +} + +// woFwd: [1, Q_DIM, 1, SEQ + DIM] fp16 — Wo: [Q_DIM, DIM] +#define WO_FWD_SP (SEQ + DIM) +static void stage_wo_fwd_weights(IOSurfaceRef s, const float *Wo) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + for (int d = 0; d < Q_DIM; d++) + cvt_f32_f16(buf + d*WO_FWD_SP + SEQ, Wo + d*DIM, DIM); + IOSurfaceUnlock(s, 0, NULL); +} +static void write_wo_fwd_acts(IOSurfaceRef s, const float *attn_out) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + for (int d = 0; d < Q_DIM; d++) + cvt_f32_f16(buf + d*WO_FWD_SP, attn_out + d*SEQ, SEQ); IOSurfaceUnlock(s, 0, NULL); } // ffnFused: [1, DIM, 1, 2*SEQ+3*HIDDEN] fp16 +#define FFN_FUSED_SP (2*SEQ + 3*HIDDEN) static void stage_ffn_fused_weights(IOSurfaceRef s, const float *W1t, const float *W3t, const float *W2_orig) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = 2*SEQ + 3*HIDDEN; for (int d = 0; d < DIM; d++) { - cvt_f32_f16(buf + d*sp + 2*SEQ, W1t + d*HIDDEN, HIDDEN); - cvt_f32_f16(buf + d*sp + 2*SEQ+HIDDEN, W3t + d*HIDDEN, HIDDEN); - cvt_f32_f16(buf + d*sp + 2*SEQ+2*HIDDEN, W2_orig + d*HIDDEN, HIDDEN); + cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ, W1t + d*HIDDEN, HIDDEN); + cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ+HIDDEN, W3t + d*HIDDEN, HIDDEN); + cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ+2*HIDDEN, W2_orig + d*HIDDEN, HIDDEN); } IOSurfaceUnlock(s, 0, NULL); } static void write_ffn_fused_acts(IOSurfaceRef s, const float *x2norm, const float *x2) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = 2*SEQ + 3*HIDDEN; for (int d = 0; d < DIM; d++) { - cvt_f32_f16(buf + d*sp, x2norm + d*SEQ, SEQ); - cvt_f32_f16(buf + d*sp + SEQ, x2 + d*SEQ, SEQ); + cvt_f32_f16(buf + d*FFN_FUSED_SP, x2norm + d*SEQ, SEQ); + cvt_f32_f16(buf + d*FFN_FUSED_SP + SEQ, x2 + d*SEQ, SEQ); } IOSurfaceUnlock(s, 0, NULL); } -// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] fp16 -static void stage_ffn_w13_weights(IOSurfaceRef s, const float *W1, const float *W3) { - IOSurfaceLock(s, 0, NULL); - _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + 2*HIDDEN; - for (int d = 0; d < DIM; d++) { - cvt_f32_f16(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN); - cvt_f32_f16(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN); - } - IOSurfaceUnlock(s, 0, NULL); -} -static void write_ffn_w13_acts(IOSurfaceRef s, const float *xnorm) { - IOSurfaceLock(s, 0, NULL); - _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + 2*HIDDEN; - for (int d = 0; d < DIM; d++) - cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ); - IOSurfaceUnlock(s, 0, NULL); -} - -// ffnW2: [1, HIDDEN, 1, SEQ+DIM] fp16 -static void stage_ffn_w2_weights(IOSurfaceRef s, const float *W2) { - IOSurfaceLock(s, 0, NULL); - _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + DIM; - for (int d = 0; d < HIDDEN; d++) - cvt_f32_f16(buf + d*sp + SEQ, W2 + d*DIM, DIM); - IOSurfaceUnlock(s, 0, NULL); -} -static void write_ffn_w2_acts(IOSurfaceRef s, const float *gate) { - IOSurfaceLock(s, 0, NULL); - _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + DIM; - for (int d = 0; d < HIDDEN; d++) - cvt_f32_f16(buf + d*sp, gate + d*SEQ, SEQ); - IOSurfaceUnlock(s, 0, NULL); -} - // ffnBwdW2t: [1, DIM, 1, SEQ+HIDDEN] fp16 +#define FFN_BWD_W2T_SP (SEQ + HIDDEN) static void stage_ffn_bwd_w2t_weights(IOSurfaceRef s, const float *W2) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + HIDDEN; for (int d = 0; d < DIM; d++) - cvt_f32_f16(buf + d*sp + SEQ, W2 + d*HIDDEN, HIDDEN); + cvt_f32_f16(buf + d*FFN_BWD_W2T_SP + SEQ, W2 + d*HIDDEN, HIDDEN); IOSurfaceUnlock(s, 0, NULL); } static void write_ffn_bwd_w2t_acts(IOSurfaceRef s, const float *dffn) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + HIDDEN; for (int d = 0; d < DIM; d++) - cvt_f32_f16(buf + d*sp, dffn + d*SEQ, SEQ); + cvt_f32_f16(buf + d*FFN_BWD_W2T_SP, dffn + d*SEQ, SEQ); IOSurfaceUnlock(s, 0, NULL); } // ffnBwdW13t: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16 +#define FFN_BWD_W13T_SP (2*SEQ + 2*DIM) static void stage_ffn_bwd_w13t_weights(IOSurfaceRef s, const float *W1, const float *W3) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = 2*SEQ + 2*DIM; for (int d = 0; d < HIDDEN; d++) { - cvt_f32_f16(buf + d*sp + 2*SEQ, W1 + d*DIM, DIM); - cvt_f32_f16(buf + d*sp + 2*SEQ + DIM, W3 + d*DIM, DIM); + cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + 2*SEQ, W1 + d*DIM, DIM); + cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + 2*SEQ + DIM, W3 + d*DIM, DIM); } IOSurfaceUnlock(s, 0, NULL); } static void write_ffn_bwd_w13t_acts(IOSurfaceRef s, const float *dh1, const float *dh3) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = 2*SEQ + 2*DIM; for (int d = 0; d < HIDDEN; d++) { - cvt_f32_f16(buf + d*sp, dh1 + d*SEQ, SEQ); - cvt_f32_f16(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ); + cvt_f32_f16(buf + d*FFN_BWD_W13T_SP, dh1 + d*SEQ, SEQ); + cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + SEQ, dh3 + d*SEQ, SEQ); } IOSurfaceUnlock(s, 0, NULL); } -// wotBwd: [1, DIM, 1, SEQ+DIM] fp16 +// wotBwd: [1, DIM, 1, SEQ+Q_DIM] fp16 — Wo is [DIM, Q_DIM], matmul gives Wo^T @ dy +#define WOT_BWD_SP (SEQ + Q_DIM) static void stage_wot_bwd_weights(IOSurfaceRef s, const float *Wo) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + DIM; for (int d = 0; d < DIM; d++) - cvt_f32_f16(buf + d*sp + SEQ, Wo + d*DIM, DIM); + cvt_f32_f16(buf + d*WOT_BWD_SP + SEQ, Wo + d*Q_DIM, Q_DIM); IOSurfaceUnlock(s, 0, NULL); } -static void write_wot_bwd_acts(IOSurfaceRef s, const float *dx2) { +static void write_wot_bwd_acts(IOSurfaceRef s, const float *dy) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = SEQ + DIM; for (int d = 0; d < DIM; d++) - cvt_f32_f16(buf + d*sp, dx2 + d*SEQ, SEQ); + cvt_f32_f16(buf + d*WOT_BWD_SP, dy + d*SEQ, SEQ); IOSurfaceUnlock(s, 0, NULL); } -// qkvBwd: [1, DIM, 1, 3*SEQ+3*DIM] fp16 -static void stage_qkv_bwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, const float *Wv) { +// qBwd: [1, Q_DIM, 1, SEQ+DIM] fp16 — Wq is [Q_DIM, DIM], matmul gives Wq^T @ dq +#define Q_BWD_SP (SEQ + DIM) +static void stage_q_bwd_weights(IOSurfaceRef s, const float *Wq) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = 3*SEQ + 3*DIM; - for (int d = 0; d < DIM; d++) { - cvt_f32_f16(buf + d*sp + 3*SEQ, Wq + d*DIM, DIM); - cvt_f32_f16(buf + d*sp + 3*SEQ + DIM, Wk + d*DIM, DIM); - cvt_f32_f16(buf + d*sp + 3*SEQ + 2*DIM, Wv + d*DIM, DIM); + for (int d = 0; d < Q_DIM; d++) + cvt_f32_f16(buf + d*Q_BWD_SP + SEQ, Wq + d*DIM, DIM); + IOSurfaceUnlock(s, 0, NULL); +} +static void write_q_bwd_acts(IOSurfaceRef s, const float *dq) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + for (int d = 0; d < Q_DIM; d++) + cvt_f32_f16(buf + d*Q_BWD_SP, dq + d*SEQ, SEQ); + IOSurfaceUnlock(s, 0, NULL); +} + +// kvBwd: [1, KV_DIM, 1, 2*SEQ+2*DIM] fp16 — dk @ Wk + dv @ Wv → dx_kv +#define KV_BWD_SP (2*SEQ + 2*DIM) +static void stage_kv_bwd_weights(IOSurfaceRef s, const float *Wk, const float *Wv) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + for (int d = 0; d < KV_DIM; d++) { + cvt_f32_f16(buf + d*KV_BWD_SP + 2*SEQ, Wk + d*DIM, DIM); + cvt_f32_f16(buf + d*KV_BWD_SP + 2*SEQ + DIM, Wv + d*DIM, DIM); } IOSurfaceUnlock(s, 0, NULL); } -static void write_qkv_bwd_acts(IOSurfaceRef s, const float *dq, const float *dk, const float *dv) { +static void write_kv_bwd_acts(IOSurfaceRef s, const float *dk, const float *dv) { IOSurfaceLock(s, 0, NULL); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); - int sp = 3*SEQ + 3*DIM; - for (int d = 0; d < DIM; d++) { - cvt_f32_f16(buf + d*sp, dq + d*SEQ, SEQ); - cvt_f32_f16(buf + d*sp + SEQ, dk + d*SEQ, SEQ); - cvt_f32_f16(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ); + for (int d = 0; d < KV_DIM; d++) { + cvt_f32_f16(buf + d*KV_BWD_SP, dk + d*SEQ, SEQ); + cvt_f32_f16(buf + d*KV_BWD_SP + SEQ, dv + d*SEQ, SEQ); } IOSurfaceUnlock(s, 0, NULL); } @@ -335,11 +315,37 @@ static void write_qkv_bwd_acts(IOSurfaceRef s, const float *dq, const float *dk, // Free per-layer surfaces and requests static void free_per_layer(PerLayerSurfaces *pls, PerLayerRequests *plr) { for (int L = 0; L < NLAYERS; L++) { - CFRelease(pls[L].sdpaFwd_in); CFRelease(pls[L].ffnFused_in); + CFRelease(pls[L].sdpaFwd_in); CFRelease(pls[L].woFwd_in); CFRelease(pls[L].ffnFused_in); CFRelease(pls[L].ffnBwdW2t_in); CFRelease(pls[L].ffnBwdW13t_in); - CFRelease(pls[L].wotBwd_in); CFRelease(pls[L].qkvBwd_in); - CFRelease(plr[L].sdpaFwd); CFRelease(plr[L].ffnFused); + CFRelease(pls[L].wotBwd_in); CFRelease(pls[L].qBwd_in); CFRelease(pls[L].kvBwd_in); + CFRelease(plr[L].sdpaFwd); CFRelease(plr[L].woFwd); CFRelease(plr[L].ffnFused); CFRelease(plr[L].ffnBwdW2t); CFRelease(plr[L].ffnBwdW13t); - CFRelease(plr[L].wotBwd); CFRelease(plr[L].qkvBwd); + CFRelease(plr[L].wotBwd); CFRelease(plr[L].qBwd); CFRelease(plr[L].kvBwd); + } +} + +// GQA helpers: tile KV from KV_HEADS to HEADS, and reduce HEADS to KV_HEADS +// tile_kv: input [KV_DIM, SEQ], output [Q_DIM, SEQ] +// Each KV head is duplicated GQA_RATIO times +static void gqa_tile_kv(float *out, const float *in, int seq) { + for (int kv = 0; kv < KV_HEADS; kv++) { + for (int r = 0; r < GQA_RATIO; r++) { + int q_head = kv * GQA_RATIO + r; + memcpy(out + q_head * HD * seq, in + kv * HD * seq, HD * seq * sizeof(float)); + } + } +} +// reduce_kv: input [Q_DIM, SEQ], output [KV_DIM, SEQ] +// Sum contributions from Q heads sharing each KV head +static void gqa_reduce_kv(float *out, const float *in, int seq) { + memset(out, 0, KV_DIM * seq * sizeof(float)); + for (int kv = 0; kv < KV_HEADS; kv++) { + for (int r = 0; r < GQA_RATIO; r++) { + int q_head = kv * GQA_RATIO + r; + const float *src = in + q_head * HD * seq; + float *dst = out + kv * HD * seq; + for (int i = 0; i < HD * seq; i++) + dst[i] += src[i]; + } } } diff --git a/training/training_dynamic/mil_dynamic.h b/training/training_dynamic/mil_dynamic.h index 4cbb418..854fa8e 100644 --- a/training/training_dynamic/mil_dynamic.h +++ b/training/training_dynamic/mil_dynamic.h @@ -1,7 +1,7 @@ -// mil_dynamic.h — MIL generators using dynamic matmul (weights via IOSurface) -// Instead of conv(const_weight, x), we use matmul(x, W) where both come from input. -// Input layout: [1, IC, 1, SP] fp32, SP = SEQ + total_weight_cols -// Activations in sp[0:SEQ], weight matrices packed sequentially in sp[SEQ:] +// mil_dynamic.h — MIL generators for Qwen3-0.6B with GQA +// Q_DIM=2048 != DIM=1024, KV_DIM=1024, GQA_RATIO=2 +// SDPA split: sdpaFwd (QKV proj + attention, no Wo) + woFwd (Wo matmul) +// Backward: qBwd + kvBwd (split from qkvBwd) #pragma once #include "io.h" @@ -11,42 +11,30 @@ "{\"coremltools-version\", \"9.0\"}})]\n{\n" // Helper: generate a dynamic matmul within a MIL function -// Slices activation [1,ic,1,seq] and weight [1,ic,1,oc] from input, does matmul -// act_sp_off: spatial offset for activations (usually 0) -// w_sp_off: spatial offset for weight block -// Returns variable name of result [1,oc,1,seq] in fp16 static void gen_dyn_matmul(NSMutableString *m, const char *prefix, int ic, int oc, int seq, int act_sp_off, int w_sp_off, const char *input_var) { - // Slice activations [m appendFormat:@" tensor %s_ba = const()[name=string(\"%s_ba\"), val=tensor([0,0,0,%d])];\n", prefix, prefix, act_sp_off]; [m appendFormat:@" tensor %s_sa = const()[name=string(\"%s_sa\"), val=tensor([1,%d,1,%d])];\n", prefix, prefix, ic, seq]; [m appendFormat:@" tensor %s_act = slice_by_size(x=%s,begin=%s_ba,size=%s_sa)[name=string(\"%s_act\")];\n", ic, seq, prefix, input_var, prefix, prefix, prefix]; - // Slice weight [m appendFormat:@" tensor %s_bw = const()[name=string(\"%s_bw\"), val=tensor([0,0,0,%d])];\n", prefix, prefix, w_sp_off]; [m appendFormat:@" tensor %s_sw = const()[name=string(\"%s_sw\"), val=tensor([1,%d,1,%d])];\n", prefix, prefix, ic, oc]; [m appendFormat:@" tensor %s_wt = slice_by_size(x=%s,begin=%s_bw,size=%s_sw)[name=string(\"%s_wt\")];\n", ic, oc, prefix, input_var, prefix, prefix, prefix]; - // Reshape act: [1,ic,1,seq] → [1,1,ic,seq] → transpose → [1,1,seq,ic] [m appendFormat:@" tensor %s_ra = const()[name=string(\"%s_ra\"), val=tensor([1,1,%d,%d])];\n", prefix, prefix, ic, seq]; [m appendFormat:@" tensor %s_a2 = reshape(shape=%s_ra,x=%s_act)[name=string(\"%s_a2\")];\n", ic, seq, prefix, prefix, prefix, prefix]; [m appendFormat:@" tensor %s_pm = const()[name=string(\"%s_pm\"), val=tensor([0,1,3,2])];\n", prefix, prefix]; [m appendFormat:@" tensor %s_a3 = transpose(perm=%s_pm,x=%s_a2)[name=string(\"%s_a3\")];\n", seq, ic, prefix, prefix, prefix, prefix]; - // Reshape weight: [1,ic,1,oc] → [1,1,ic,oc] [m appendFormat:@" tensor %s_rw = const()[name=string(\"%s_rw\"), val=tensor([1,1,%d,%d])];\n", prefix, prefix, ic, oc]; [m appendFormat:@" tensor %s_W = reshape(shape=%s_rw,x=%s_wt)[name=string(\"%s_W\")];\n", ic, oc, prefix, prefix, prefix, prefix]; - // matmul: [1,1,seq,ic] @ [1,1,ic,oc] → [1,1,seq,oc] [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendFormat:@" tensor %s_yh = matmul(transpose_x=bF,transpose_y=bF,x=%s_a3,y=%s_W)[name=string(\"%s_yh\")];\n", seq, oc, prefix, prefix, prefix, prefix]; - // Transpose back + reshape: [1,1,seq,oc] → [1,1,oc,seq] → [1,oc,1,seq] [m appendFormat:@" tensor %s_yt = transpose(perm=%s_pm,x=%s_yh)[name=string(\"%s_yt\")];\n", oc, seq, prefix, prefix, prefix, prefix]; [m appendFormat:@" tensor %s_ro = const()[name=string(\"%s_ro\"), val=tensor([1,%d,1,%d])];\n", prefix, prefix, oc, seq]; [m appendFormat:@" tensor %s_y = reshape(shape=%s_ro,x=%s_yt)[name=string(\"%s_y\")];\n", oc, seq, prefix, prefix, prefix, prefix]; } -// ===== Dynamic matmul kernel: y = x @ W ===== -// Input: [1, IC, 1, SEQ+OC] fp16 — act[0:SEQ] + W[SEQ:SEQ+OC] -// Output: [1, OC, 1, SEQ] fp16 +// Simple dynamic matmul kernel: y = x @ W, input [1,IC,1,SEQ+OC], output [1,OC,1,SEQ] static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; @@ -57,20 +45,18 @@ static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) { return m; } -// ===== SDPA forward (dynamic weights) ===== -// Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul -// Input: [1, DIM, 1, SEQ + 4*DIM] fp16 -// sp[0:SEQ] = xnorm (rmsnorm output, DIM channels) -// sp[SEQ:SEQ+DIM] = Wq[DIM,DIM] -// sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM] -// sp[SEQ+2D:SEQ+3D] = Wv[DIM,DIM] -// sp[SEQ+3D:SEQ+4D] = Wo[DIM,DIM] -// Output: [1, 6*DIM, 1, SEQ] fp16 = concat(o_out, Q, K, V, attn_out, xnorm_pass) -// NOTE: mask is still a const weight (it doesn't change) +// ===== SDPA forward with GQA (no Wo) ===== +// Input: [1, DIM, 1, SEQ + Q_DIM + KV_DIM + KV_DIM] fp16 +// sp[0:SEQ] = xnorm [DIM, SEQ] +// sp[SEQ:SEQ+Q_DIM] = Wq [DIM, Q_DIM] +// sp[SEQ+Q_DIM:SEQ+Q_DIM+KVD] = Wk [DIM, KV_DIM] +// sp[SEQ+Q_DIM+KVD:...] = Wv [DIM, KV_DIM] +// Output: [1, Q_DIM+Q_DIM+KV_DIM+KV_DIM+DIM, 1, SEQ] fp16 +// = concat(attn_out, Q_rope, K_rope, V, xnorm_pass) static NSString *gen_sdpa_fwd_dynamic(void) { float sc = 1.0f/sqrtf((float)HD); - int w_total = 4*DIM; // Wq+Wk+Wv+Wo - int sp_in = SEQ + w_total; + int sp_in = SDPA_FWD_SP; + int out_ch = Q_DIM + Q_DIM + KV_DIM + KV_DIM + DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; @@ -80,100 +66,126 @@ static NSString *gen_sdpa_fwd_dynamic(void) { [m appendFormat:@" tensor sx = const()[name=string(\"sx\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor xn = slice_by_size(x=x,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ]; - // Slice Wq [1,DIM,1,DIM] + // Slice Wq [1,DIM,1,Q_DIM] [m appendFormat:@" tensor bq = const()[name=string(\"bq\"), val=tensor([0,0,0,%d])];\n", SEQ]; - [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", DIM, DIM]; - [m appendFormat:@" tensor Wq = slice_by_size(x=x,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM]; + [m appendFormat:@" tensor swq = const()[name=string(\"swq\"), val=tensor([1,%d,1,%d])];\n", DIM, Q_DIM]; + [m appendFormat:@" tensor Wq = slice_by_size(x=x,begin=bq,size=swq)[name=string(\"Wq\")];\n", DIM, Q_DIM]; - // Slice Wk - [m appendFormat:@" tensor bk = const()[name=string(\"bk\"), val=tensor([0,0,0,%d])];\n", SEQ+DIM]; - [m appendFormat:@" tensor Wk = slice_by_size(x=x,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM]; + // Slice Wk [1,DIM,1,KV_DIM] + [m appendFormat:@" tensor bk = const()[name=string(\"bk\"), val=tensor([0,0,0,%d])];\n", SEQ+Q_DIM]; + [m appendFormat:@" tensor swk = const()[name=string(\"swk\"), val=tensor([1,%d,1,%d])];\n", DIM, KV_DIM]; + [m appendFormat:@" tensor Wk = slice_by_size(x=x,begin=bk,size=swk)[name=string(\"Wk\")];\n", DIM, KV_DIM]; - // Slice Wv - [m appendFormat:@" tensor bv = const()[name=string(\"bv\"), val=tensor([0,0,0,%d])];\n", SEQ+2*DIM]; - [m appendFormat:@" tensor Wv = slice_by_size(x=x,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM]; + // Slice Wv [1,DIM,1,KV_DIM] + [m appendFormat:@" tensor bv = const()[name=string(\"bv\"), val=tensor([0,0,0,%d])];\n", SEQ+Q_DIM+KV_DIM]; + [m appendFormat:@" tensor Wv = slice_by_size(x=x,begin=bv,size=swk)[name=string(\"Wv\")];\n", DIM, KV_DIM]; - // Slice Wo - [m appendFormat:@" tensor bo = const()[name=string(\"bo\"), val=tensor([0,0,0,%d])];\n", SEQ+3*DIM]; - [m appendFormat:@" tensor Wo = slice_by_size(x=x,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM]; - - // Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D] + // Reshape xnorm for matmul: [1,DIM,1,SEQ] → [1,1,DIM,SEQ] → [1,1,SEQ,DIM] [m appendFormat:@" tensor r2 = const()[name=string(\"r2\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor xn2 = reshape(shape=r2,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ]; [m appendFormat:@" tensor xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM]; - // Reshape weights: [1,D,1,D] → [1,1,D,D] - [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, DIM]; - [m appendFormat:@" tensor Wq2 = reshape(shape=rw,x=Wq)[name=string(\"Wq2\")];\n", DIM, DIM]; - [m appendFormat:@" tensor Wk2 = reshape(shape=rw,x=Wk)[name=string(\"Wk2\")];\n", DIM, DIM]; - [m appendFormat:@" tensor Wv2 = reshape(shape=rw,x=Wv)[name=string(\"Wv2\")];\n", DIM, DIM]; - [m appendFormat:@" tensor Wo2 = reshape(shape=rw,x=Wo)[name=string(\"Wo2\")];\n", DIM, DIM]; + // Reshape weights + [m appendFormat:@" tensor rwq = const()[name=string(\"rwq\"), val=tensor([1,1,%d,%d])];\n", DIM, Q_DIM]; + [m appendFormat:@" tensor rwk = const()[name=string(\"rwk\"), val=tensor([1,1,%d,%d])];\n", DIM, KV_DIM]; + [m appendFormat:@" tensor Wq2 = reshape(shape=rwq,x=Wq)[name=string(\"Wq2\")];\n", DIM, Q_DIM]; + [m appendFormat:@" tensor Wk2 = reshape(shape=rwk,x=Wk)[name=string(\"Wk2\")];\n", DIM, KV_DIM]; + [m appendFormat:@" tensor Wv2 = reshape(shape=rwk,x=Wv)[name=string(\"Wv2\")];\n", DIM, KV_DIM]; - // QKV matmul: [1,1,S,D] @ [1,1,D,D] → [1,1,S,D] + // QKV matmul [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; - [m appendFormat:@" tensor qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, DIM]; + // Q: [1,1,SEQ,DIM] @ [1,1,DIM,Q_DIM] → [1,1,SEQ,Q_DIM] + [m appendFormat:@" tensor qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, Q_DIM]; + // K: [1,1,SEQ,DIM] @ [1,1,DIM,KV_DIM] → [1,1,SEQ,KV_DIM] + [m appendFormat:@" tensor km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, KV_DIM]; + // V: same as K + [m appendFormat:@" tensor vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, KV_DIM]; - // Transpose back: [1,1,S,D] → [1,1,D,S] → reshape [1,D,1,S] - [m appendFormat:@" tensor qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor os = const()[name=string(\"os\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor qf = reshape(shape=os,x=qt)[name=string(\"qf\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor kf = reshape(shape=os,x=kt)[name=string(\"kf\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor vf = reshape(shape=os,x=vt)[name=string(\"vf\")];\n", DIM, SEQ]; + // Transpose back: [1,1,SEQ,X] → [1,1,X,SEQ] + [m appendFormat:@" tensor qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", KV_DIM, SEQ]; - // SDPA: reshape to heads, matmul, mask, softmax, matmul - [m appendFormat:@" tensor qsh = const()[name=string(\"qsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ]; + // Reshape to [1,X,1,SEQ] + [m appendFormat:@" tensor qsh = const()[name=string(\"qsh\"), val=tensor([1,%d,1,%d])];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor kvsh = const()[name=string(\"kvsh\"), val=tensor([1,%d,1,%d])];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor qf = reshape(shape=qsh,x=qt)[name=string(\"qf\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor kf = reshape(shape=kvsh,x=kt)[name=string(\"kf\")];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor vf = reshape(shape=kvsh,x=vt)[name=string(\"vf\")];\n", KV_DIM, SEQ]; + + // Reshape to heads for attention + // Q: [1,Q_DIM,1,SEQ] → [1,HEADS,HD,SEQ] → transpose → [1,HEADS,SEQ,HD] + [m appendFormat:@" tensor qhsh = const()[name=string(\"qhsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor q4 = reshape(shape=qhsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS, SEQ, HD]; - [m appendFormat:@" tensor k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS, SEQ, HD]; - [m appendFormat:@" tensor v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD]; + // K: [1,KV_DIM,1,SEQ] → [1,KV_HEADS,HD,SEQ] → [1,KV_HEADS,SEQ,HD] + [m appendFormat:@" tensor khsh = const()[name=string(\"khsh\"), val=tensor([1,%d,%d,%d])];\n", KV_HEADS, HD, SEQ]; + [m appendFormat:@" tensor k4 = reshape(shape=khsh,x=kf)[name=string(\"rk\")];\n", KV_HEADS, HD, SEQ]; + [m appendFormat:@" tensor k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", KV_HEADS, SEQ, HD]; + // V: same reshape as K + [m appendFormat:@" tensor v4 = reshape(shape=khsh,x=vf)[name=string(\"rv\")];\n", KV_HEADS, HD, SEQ]; + [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", KV_HEADS, SEQ, HD]; - // RoPE: q_rope = q * cos + rotate_half(q) * sin, same for k - int pairs = SEQ * HD / 2; + // RoPE on Q: [1,HEADS,SEQ,HD] + int pairs_q = SEQ * HD / 2; [m appendFormat:@" tensor rope_cos = const()[name=string(\"rc\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rope_cos.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD]; [m appendFormat:@" tensor rope_sin = const()[name=string(\"rs\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rope_sin.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD]; - [m appendFormat:@" tensor rp_sh = const()[name=string(\"rp_sh\"), val=tensor([1,%d,%d,2])];\n", HEADS, pairs]; - [m appendFormat:@" tensor rp_s1 = const()[name=string(\"rp_s1\"), val=tensor([1,%d,%d,1])];\n", HEADS, pairs]; + [m appendFormat:@" tensor rp_sh = const()[name=string(\"rp_sh\"), val=tensor([1,%d,%d,2])];\n", HEADS, pairs_q]; + [m appendFormat:@" tensor rp_s1 = const()[name=string(\"rp_s1\"), val=tensor([1,%d,%d,1])];\n", HEADS, pairs_q]; [m appendString:@" tensor rp_b0 = const()[name=string(\"rp_b0\"), val=tensor([0,0,0,0])];\n"]; [m appendString:@" tensor rp_b1 = const()[name=string(\"rp_b1\"), val=tensor([0,0,0,1])];\n"]; - [m appendFormat:@" tensor rp_bk = const()[name=string(\"rp_bk\"), val=tensor([1,%d,%d,%d])];\n", HEADS, SEQ, HD]; - // rotate_half(q): reshape to pairs, swap+negate, reshape back [m appendString:@" fp16 neg1 = const()[name=string(\"neg1\"), val=fp16(-1)];\n"]; - [m appendFormat:@" tensor q_p = reshape(shape=rp_sh,x=q)[name=string(\"q_p\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor q_e = slice_by_size(x=q_p,begin=rp_b0,size=rp_s1)[name=string(\"q_e\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor q_o = slice_by_size(x=q_p,begin=rp_b1,size=rp_s1)[name=string(\"q_o\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor nq = mul(x=q_o,y=neg1)[name=string(\"nq\")];\n", HEADS, pairs]; [m appendString:@" int32 rpax = const()[name=string(\"rpax\"), val=int32(3)];\n"]; [m appendString:@" bool rpil = const()[name=string(\"rpil\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor qrp = concat(axis=rpax,interleave=rpil,values=(nq,q_e))[name=string(\"qrp\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor q_rot = reshape(shape=rp_bk,x=qrp)[name=string(\"q_rot\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor rp_bk_q = const()[name=string(\"rp_bk_q\"), val=tensor([1,%d,%d,%d])];\n", HEADS, SEQ, HD]; + + // rotate_half(q) + [m appendFormat:@" tensor q_p = reshape(shape=rp_sh,x=q)[name=string(\"q_p\")];\n", HEADS, pairs_q]; + [m appendFormat:@" tensor q_e = slice_by_size(x=q_p,begin=rp_b0,size=rp_s1)[name=string(\"q_e\")];\n", HEADS, pairs_q]; + [m appendFormat:@" tensor q_o = slice_by_size(x=q_p,begin=rp_b1,size=rp_s1)[name=string(\"q_o\")];\n", HEADS, pairs_q]; + [m appendFormat:@" tensor nq = mul(x=q_o,y=neg1)[name=string(\"nq\")];\n", HEADS, pairs_q]; + [m appendFormat:@" tensor qrp = concat(axis=rpax,interleave=rpil,values=(nq,q_e))[name=string(\"qrp\")];\n", HEADS, pairs_q]; + [m appendFormat:@" tensor q_rot = reshape(shape=rp_bk_q,x=qrp)[name=string(\"q_rot\")];\n", HEADS, SEQ, HD]; [m appendFormat:@" tensor qc = mul(x=q,y=rope_cos)[name=string(\"qc\")];\n", HEADS, SEQ, HD]; [m appendFormat:@" tensor qrs = mul(x=q_rot,y=rope_sin)[name=string(\"qrs\")];\n", HEADS, SEQ, HD]; [m appendFormat:@" tensor q_rope = add(x=qc,y=qrs)[name=string(\"q_rope\")];\n", HEADS, SEQ, HD]; - // rotate_half(k) - [m appendFormat:@" tensor k_p = reshape(shape=rp_sh,x=k)[name=string(\"k_p\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor k_e = slice_by_size(x=k_p,begin=rp_b0,size=rp_s1)[name=string(\"k_e\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor k_o = slice_by_size(x=k_p,begin=rp_b1,size=rp_s1)[name=string(\"k_o\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor nk = mul(x=k_o,y=neg1)[name=string(\"nk\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor krp = concat(axis=rpax,interleave=rpil,values=(nk,k_e))[name=string(\"krp\")];\n", HEADS, pairs]; - [m appendFormat:@" tensor k_rot = reshape(shape=rp_bk,x=krp)[name=string(\"k_rot\")];\n", HEADS, SEQ, HD]; - [m appendFormat:@" tensor kc = mul(x=k,y=rope_cos)[name=string(\"kc\")];\n", HEADS, SEQ, HD]; - [m appendFormat:@" tensor krs = mul(x=k_rot,y=rope_sin)[name=string(\"krs\")];\n", HEADS, SEQ, HD]; - [m appendFormat:@" tensor k_rope = add(x=kc,y=krs)[name=string(\"k_rope\")];\n", HEADS, SEQ, HD]; - // Q_rope @ K_rope^T - [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q_rope,y=k_rope)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ]; + // RoPE on K: [1,KV_HEADS,SEQ,HD] + int pairs_k = SEQ * HD / 2; + [m appendFormat:@" tensor rp_sh_k = const()[name=string(\"rp_sh_k\"), val=tensor([1,%d,%d,2])];\n", KV_HEADS, pairs_k]; + [m appendFormat:@" tensor rp_s1_k = const()[name=string(\"rp_s1_k\"), val=tensor([1,%d,%d,1])];\n", KV_HEADS, pairs_k]; + [m appendFormat:@" tensor rp_bk_k = const()[name=string(\"rp_bk_k\"), val=tensor([1,%d,%d,%d])];\n", KV_HEADS, SEQ, HD]; + [m appendFormat:@" tensor k_p = reshape(shape=rp_sh_k,x=k)[name=string(\"k_p\")];\n", KV_HEADS, pairs_k]; + [m appendFormat:@" tensor k_e = slice_by_size(x=k_p,begin=rp_b0,size=rp_s1_k)[name=string(\"k_e\")];\n", KV_HEADS, pairs_k]; + [m appendFormat:@" tensor k_o = slice_by_size(x=k_p,begin=rp_b1,size=rp_s1_k)[name=string(\"k_o\")];\n", KV_HEADS, pairs_k]; + [m appendFormat:@" tensor nk = mul(x=k_o,y=neg1)[name=string(\"nk\")];\n", KV_HEADS, pairs_k]; + [m appendFormat:@" tensor krp = concat(axis=rpax,interleave=rpil,values=(nk,k_e))[name=string(\"krp\")];\n", KV_HEADS, pairs_k]; + [m appendFormat:@" tensor k_rot = reshape(shape=rp_bk_k,x=krp)[name=string(\"k_rot\")];\n", KV_HEADS, SEQ, HD]; + [m appendFormat:@" tensor kc = mul(x=k,y=rope_cos)[name=string(\"kc\")];\n", KV_HEADS, SEQ, HD]; + [m appendFormat:@" tensor krs = mul(x=k_rot,y=rope_sin)[name=string(\"krs\")];\n", KV_HEADS, SEQ, HD]; + [m appendFormat:@" tensor k_rope = add(x=kc,y=krs)[name=string(\"k_rope\")];\n", KV_HEADS, SEQ, HD]; + + // GQA: tile K,V from KV_HEADS to HEADS + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + // For GQA_RATIO=2: concat(k_rope, k_rope) along head dim + NSMutableString *k_vals = [NSMutableString string]; + NSMutableString *v_vals = [NSMutableString string]; + for (int r = 0; r < GQA_RATIO; r++) { + if (r > 0) { [k_vals appendString:@","]; [v_vals appendString:@","]; } + [k_vals appendString:@"k_rope"]; [v_vals appendString:@"v"]; + } + [m appendFormat:@" tensor k_tiled = concat(axis=cax,interleave=cid,values=(%@))[name=string(\"ktile\")];\n", HEADS, SEQ, HD, k_vals]; + [m appendFormat:@" tensor v_tiled = concat(axis=cax,interleave=cid,values=(%@))[name=string(\"vtile\")];\n", HEADS, SEQ, HD, v_vals]; + + // Q_rope @ K_tiled^T → [1,HEADS,SEQ,SEQ] + [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q_rope,y=k_tiled)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ]; - // Causal mask (still const — doesn't change) + // Causal mask [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ]; [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ]; @@ -181,90 +193,67 @@ static NSString *gen_sdpa_fwd_dynamic(void) { [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; [m appendFormat:@" tensor aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ]; - // scores @ V - [m appendFormat:@" tensor a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS, SEQ, HD]; + // scores @ V_tiled → [1,HEADS,SEQ,HD] + [m appendFormat:@" tensor a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v_tiled)[name=string(\"mm2\")];\n", HEADS, SEQ, HD]; - // Reshape back to [1,DIM,1,SEQ] + // Reshape attn_out to [1,Q_DIM,1,SEQ] [m appendFormat:@" tensor at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor af = reshape(shape=qsh,x=at)[name=string(\"ra\")];\n", Q_DIM, SEQ]; - // Wo matmul: af → [1,1,S,D] @ Wo[1,1,D,D] → [1,1,S,D] → [1,D,1,S] - [m appendFormat:@" tensor af2 = reshape(shape=r2,x=af)[name=string(\"af2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor aft = transpose(perm=pm,x=af2)[name=string(\"aft\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor om = matmul(transpose_x=bF,transpose_y=bF,x=aft,y=Wo2)[name=string(\"om\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ]; - - // Convert RoPE'd Q,K back to [1,DIM,1,SEQ] for backward pass output + // Convert RoPE'd Q,K back to flat layout for backward [m appendFormat:@" tensor qrt = transpose(perm=pm,x=q_rope)[name=string(\"qrt\")];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor qrf = reshape(shape=os,x=qrt)[name=string(\"qrf\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor krt = transpose(perm=pm,x=k_rope)[name=string(\"krt\")];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor krf = reshape(shape=os,x=krt)[name=string(\"krf\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor qrf = reshape(shape=qsh,x=qrt)[name=string(\"qrf\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor krt = transpose(perm=pm,x=k_rope)[name=string(\"krt\")];\n", KV_HEADS, HD, SEQ]; + [m appendFormat:@" tensor krf = reshape(shape=kvsh,x=krt)[name=string(\"krf\")];\n", KV_DIM, SEQ]; - // Output: concat(o_out, Q_rope, K_rope, V, attn_out, xnorm) for backward - [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; - [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qrf,krf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ]; + // Output: concat(attn_out[Q_DIM], Q_rope[Q_DIM], K_rope[KV_DIM], V[KV_DIM], xnorm[DIM]) + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(af,qrf,krf,vf,xn))[name=string(\"cat\")];\n", out_ch, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } -// ===== FFN forward (dynamic weights) ===== -// RMSNorm on CPU. This kernel: xnorm @ W1 → SiLU, xnorm @ W3 → gate, gate*silu @ W2 → out -// Input: [1, DIM, 1, SEQ + HIDDEN + HIDDEN + DIM] fp32 -// sp[0:SEQ] = xnorm [DIM,SEQ] -// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN] -// sp[SEQ+HIDDEN:SEQ+2*HIDDEN] = W3[DIM,HIDDEN] -// sp[SEQ+2*HIDDEN:SEQ+2*HIDDEN+DIM]= W2[HIDDEN→DIM] — but W2 is [DIM,HIDDEN], we need HIDDEN input channels -// PROBLEM: W2 has shape [DIM,HIDDEN] = HIDDEN input channels, but our kernel has DIM input channels. -// Solution: separate kernels for W1/W3 (DIM→HIDDEN) and W2 (HIDDEN→DIM) -// OR: do W1,W3 in one kernel, SiLU on CPU/ANE, W2 in another kernel. -// Simpler: 3 separate matmul kernels per FFN direction. But that's too many dispatches. -// Better: one kernel for W1+W3 (same input dim), CPU SiLU, one kernel for W2. +// woFwd: attn_out[Q_DIM,SEQ] @ Wo → o_out[DIM,SEQ] +// Simple dyn_matmul: IC=Q_DIM, OC=DIM +static NSString *gen_wo_fwd_dynamic(void) { + return gen_dyn_matmul_mil(Q_DIM, DIM, SEQ); +} -// FFN part 1: xnorm @ W1, xnorm @ W3 (both DIM→HIDDEN) -// Input: [1, DIM, 1, SEQ + 2*HIDDEN] fp32 -// sp[0:SEQ] = xnorm -// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN] -// sp[SEQ+HIDDEN:SEQ+2*HIDDEN]= W3[DIM,HIDDEN] -// Output: [1, 2*HIDDEN, 1, SEQ] fp32 = concat(h1, h3) -static NSString *gen_ffn_w13_dynamic(void) { - int sp_in = SEQ + 2*HIDDEN; +// ===== Fused FFN forward: W1,W3 + SiLU + W2 + residual ===== +// Same structure as before, just with Qwen3 DIM=1024, HIDDEN=3072 +static NSString *gen_ffn_fused_dynamic(void) { + int sp_in = FFN_FUSED_SP; + int out_ch = DIM + 3*HIDDEN; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; - [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; - [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in]; - - // Slice xnorm - [m appendString:@" tensor bx = const()[name=string(\"bx\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor sx = const()[name=string(\"sx\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ]; - - // Slice W1 - [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; - [m appendFormat:@" tensor s1 = const()[name=string(\"s1\"), val=tensor([1,%d,1,%d])];\n", DIM, HIDDEN]; - [m appendFormat:@" tensor W1 = slice_by_size(x=xh,begin=b1,size=s1)[name=string(\"W1\")];\n", DIM, HIDDEN]; - - // Slice W3 - [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", SEQ+HIDDEN]; - [m appendFormat:@" tensor W3 = slice_by_size(x=xh,begin=b3,size=s1)[name=string(\"W3\")];\n", DIM, HIDDEN]; - - // Reshape for matmul + [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - [m appendFormat:@" tensor rd = const()[name=string(\"rd\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor xn2 = reshape(shape=rd,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM]; + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + // Slice x2norm, x2, W1, W3, W2_orig + [m appendString:@" tensor b_xn = const()[name=string(\"b_xn\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor s_ds = const()[name=string(\"s_ds\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor x2norm = slice_by_size(x=x,begin=b_xn,size=s_ds)[name=string(\"x2norm\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b_x2 = const()[name=string(\"b_x2\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor x2 = slice_by_size(x=x,begin=b_x2,size=s_ds)[name=string(\"x2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b_w1 = const()[name=string(\"b_w1\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; + [m appendFormat:@" tensor s_wh = const()[name=string(\"s_wh\"), val=tensor([1,%d,1,%d])];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W1 = slice_by_size(x=x,begin=b_w1,size=s_wh)[name=string(\"W1\")];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor b_w3 = const()[name=string(\"b_w3\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+HIDDEN]; + [m appendFormat:@" tensor W3 = slice_by_size(x=x,begin=b_w3,size=s_wh)[name=string(\"W3\")];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor b_w2 = const()[name=string(\"b_w2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+2*HIDDEN]; + [m appendFormat:@" tensor W2r = slice_by_size(x=x,begin=b_w2,size=s_wh)[name=string(\"W2r\")];\n", DIM, HIDDEN]; + + // xnorm matmul + [m appendFormat:@" tensor rd = const()[name=string(\"rd\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor xn2 = reshape(shape=rd,x=x2norm)[name=string(\"xn2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM]; [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, HIDDEN]; [m appendFormat:@" tensor W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN]; [m appendFormat:@" tensor W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN]; - - [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendFormat:@" tensor h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN]; [m appendFormat:@" tensor h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN]; - // Transpose back + // Reshape back [m appendFormat:@" tensor h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor rh = const()[name=string(\"rh\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; @@ -276,107 +265,24 @@ static NSString *gen_ffn_w13_dynamic(void) { [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ]; - // Concat output: (h1, h3, gate) - [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; - [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(h1,h3,gate))[name=string(\"cat\")];\n", 2*HIDDEN+HIDDEN, SEQ]; - [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; - [m appendFormat:@" tensor out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 3*HIDDEN, SEQ]; - [m appendString:@" } -> (out32);\n}\n"]; - return m; -} - -// ===== Fused FFN forward: W1,W3 + SiLU + W2 + residual ===== -// RMSNorm stays on CPU (ANE can't handle RMS + 3 matmuls without BNNS fallback) -// Replaces: ffnW13 + CPU gate read + ffnW2 + CPU residual -// Input: [1, DIM, 1, 2*SEQ + 3*HIDDEN] fp16 -// sp[0:SEQ] = x2norm (RMSNorm output, from CPU) -// sp[SEQ:2*SEQ] = x2 (residual, for x_next = x2 + ffn_out) -// sp[2*SEQ : 2*SEQ+HIDDEN] = W1t[DIM,HIDDEN] -// sp[2*SEQ+HIDDEN : 2*SEQ+2*HIDDEN] = W3t[DIM,HIDDEN] -// sp[2*SEQ+2*HIDDEN : 2*SEQ+3*HIDDEN] = W2_orig[DIM,HIDDEN] (transposed inside kernel) -// Output: [1, DIM + 3*HIDDEN, 1, SEQ] fp16 -// = concat(x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN]) -static NSString *gen_ffn_fused_dynamic(void) { - int sp_in = 2*SEQ + 3*HIDDEN; - int out_ch = DIM + 3*HIDDEN; - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; - [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; - - // Slice x2norm [DIM, SEQ] — RMSNorm output (computed on CPU) - [m appendString:@" tensor b_xn = const()[name=string(\"b_xn\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor s_ds = const()[name=string(\"s_ds\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor x2norm = slice_by_size(x=x,begin=b_xn,size=s_ds)[name=string(\"x2norm\")];\n", DIM, SEQ]; - - // Slice x2 [DIM, SEQ] — for residual: x_next = x2 + ffn_out - [m appendFormat:@" tensor b_x2 = const()[name=string(\"b_x2\"), val=tensor([0,0,0,%d])];\n", SEQ]; - [m appendFormat:@" tensor x2 = slice_by_size(x=x,begin=b_x2,size=s_ds)[name=string(\"x2\")];\n", DIM, SEQ]; - - // Slice W1 [DIM, HIDDEN] - [m appendFormat:@" tensor b_w1 = const()[name=string(\"b_w1\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; - [m appendFormat:@" tensor s_wh = const()[name=string(\"s_wh\"), val=tensor([1,%d,1,%d])];\n", DIM, HIDDEN]; - [m appendFormat:@" tensor W1 = slice_by_size(x=x,begin=b_w1,size=s_wh)[name=string(\"W1\")];\n", DIM, HIDDEN]; - - // Slice W3 [DIM, HIDDEN] - [m appendFormat:@" tensor b_w3 = const()[name=string(\"b_w3\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+HIDDEN]; - [m appendFormat:@" tensor W3 = slice_by_size(x=x,begin=b_w3,size=s_wh)[name=string(\"W3\")];\n", DIM, HIDDEN]; - - // Slice W2_orig [DIM, HIDDEN] (transposed inside kernel) - [m appendFormat:@" tensor b_w2 = const()[name=string(\"b_w2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+2*HIDDEN]; - [m appendFormat:@" tensor W2r = slice_by_size(x=x,begin=b_w2,size=s_wh)[name=string(\"W2r\")];\n", DIM, HIDDEN]; - - // Reshape for matmul: x2norm [1,DIM,1,SEQ] → [1,1,DIM,SEQ] → [1,1,SEQ,DIM] - [m appendFormat:@" tensor rd = const()[name=string(\"rd\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor xn2 = reshape(shape=rd,x=x2norm)[name=string(\"xn2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM]; - - // Reshape weights: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN] - [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, HIDDEN]; - [m appendFormat:@" tensor W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN]; - [m appendFormat:@" tensor W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN]; - - // h1 = x2norm_t @ W1, h3 = x2norm_t @ W3 [SEQ,DIM] @ [DIM,HIDDEN] → [SEQ,HIDDEN] - [m appendFormat:@" tensor h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN]; - [m appendFormat:@" tensor h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN]; - - // Reshape back: [1,1,SEQ,HIDDEN] → [1,1,HIDDEN,SEQ] → [1,HIDDEN,1,SEQ] - [m appendFormat:@" tensor h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor rh = const()[name=string(\"rh\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor h1 = reshape(shape=rh,x=h1t)[name=string(\"h1\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor h3 = reshape(shape=rh,x=h3t)[name=string(\"h3\")];\n", HIDDEN, SEQ]; - - // SiLU + gate: gate = silu(h1) * h3 - [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ]; - - // gate @ W2: reshape gate [1,HIDDEN,1,SEQ] → [1,1,HIDDEN,SEQ] → [1,1,SEQ,HIDDEN] + // gate @ W2: W2 is [DIM, HIDDEN] stored as-is, transpose inside kernel [m appendFormat:@" tensor rg = const()[name=string(\"rg\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor g2 = reshape(shape=rg,x=gate)[name=string(\"g2\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor gt = transpose(perm=pm,x=g2)[name=string(\"gtt\")];\n", SEQ, HIDDEN]; - - // W2: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN] → transpose → [1,1,HIDDEN,DIM] [m appendFormat:@" tensor W22 = reshape(shape=rw,x=W2r)[name=string(\"W22\")];\n", DIM, HIDDEN]; [m appendFormat:@" tensor W2t = transpose(perm=pm,x=W22)[name=string(\"W2t\")];\n", HIDDEN, DIM]; - - // matmul: [1,1,SEQ,HIDDEN] @ [1,1,HIDDEN,DIM] → [1,1,SEQ,DIM] [m appendFormat:@" tensor fm = matmul(transpose_x=bF,transpose_y=bF,x=gt,y=W2t)[name=string(\"fm\")];\n", SEQ, DIM]; - // Reshape: [1,1,SEQ,DIM] → [1,1,DIM,SEQ] → [1,DIM,1,SEQ] [m appendFormat:@" tensor ft = transpose(perm=pm,x=fm)[name=string(\"ft\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rd2 = const()[name=string(\"rd2\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor ffn_out = reshape(shape=rd2,x=ft)[name=string(\"ffn_out\")];\n", DIM, SEQ]; - // Residual: x_next = x2 + alpha * ffn_out (residual scaling) + // Residual: x_next = x2 + alpha * ffn_out float alpha = 1.0f / sqrtf(2.0f * NLAYERS); [m appendFormat:@" fp16 res_alpha = const()[name=string(\"res_alpha\"), val=fp16(%g)];\n", alpha]; [m appendFormat:@" tensor ffn_scaled = mul(x=ffn_out,y=res_alpha)[name=string(\"ffn_sc\")];\n", DIM, SEQ]; [m appendFormat:@" tensor x_next = add(x=x2,y=ffn_scaled)[name=string(\"x_next\")];\n", DIM, SEQ]; - // Output: concat(x_next, h1, h3, gate) — gate=silu*h3 needed for dW2 + // Output: concat(x_next, h1, h3, gate) [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(x_next,h1,h3,gate))[name=string(\"cat\")];\n", out_ch, SEQ]; @@ -384,103 +290,37 @@ static NSString *gen_ffn_fused_dynamic(void) { return m; } -// FFN part 2: gate @ W2 (HIDDEN→DIM) -// Input: [1, HIDDEN, 1, SEQ + DIM] fp32 -// sp[0:SEQ] = gate [HIDDEN,SEQ] -// sp[SEQ:SEQ+DIM] = W2[HIDDEN,DIM] -// Output: [1, DIM, 1, SEQ] fp32 -static NSString *gen_ffn_w2_dynamic(void) { - int sp_in = SEQ + DIM; - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", HIDDEN, sp_in]; - [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; - [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in]; +// ===== Backward kernels ===== - [m appendString:@" tensor ba = const()[name=string(\"ba\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor sa = const()[name=string(\"sa\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", HIDDEN, SEQ]; - - [m appendFormat:@" tensor bw = const()[name=string(\"bw\"), val=tensor([0,0,0,%d])];\n", SEQ]; - [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, DIM]; - [m appendFormat:@" tensor W2 = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"W2\")];\n", HIDDEN, DIM]; - - [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - [m appendFormat:@" tensor ra = const()[name=string(\"ra\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor at = transpose(perm=pm,x=a2)[name=string(\"at\")];\n", SEQ, HIDDEN]; - - [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, DIM]; - [m appendFormat:@" tensor W22 = reshape(shape=rw,x=W2)[name=string(\"W22\")];\n", HIDDEN, DIM]; - - [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor ym = matmul(transpose_x=bF,transpose_y=bF,x=at,y=W22)[name=string(\"ym\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor yt = transpose(perm=pm,x=ym)[name=string(\"yt\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", DIM, SEQ]; - [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; - [m appendFormat:@" tensor y = cast(dtype=to32,x=yr)[name=string(\"cout\")];\n", DIM, SEQ]; - [m appendString:@" } -> (y);\n}\n"]; - return m; -} - -// ===== FFN backward (dynamic weights) ===== -// Input: [1, DIM+2*HIDDEN, 1, SEQ + HIDDEN + DIM + DIM] fp32 -// Actually simpler to split into separate backward kernels like forward. - -// FFN backward part 1: dffn @ W2^T → dsilu (HIDDEN), then SiLU derivative -// Input: [1, DIM, 1, SEQ + HIDDEN] fp32 -// sp[0:SEQ] = dffn [DIM, SEQ] -// sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN] -// Output: [1, HIDDEN, 1, SEQ] fp16 = dsilu_raw +// ffnBwdW2t: dffn @ W2 → dsilu_raw (IC=DIM, OC=HIDDEN) static NSString *gen_ffn_bwd_w2t_dynamic(void) { return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ); } -// FFN backward part 2: dh1 @ W1^T + dh3 @ W3^T → dx -// We need h1,h3 for SiLU derivative, but those are on CPU. -// Actually the SiLU derivative + gating is element-wise, do on CPU. -// Then: dh1 @ W1^T and dh3 @ W3^T are two separate matmuls (HIDDEN→DIM). -// Combine into one kernel: -// Input: [1, HIDDEN, 1, SEQ + SEQ + DIM + DIM] fp32 -// sp[0:SEQ] = dh1 [HIDDEN,SEQ] -// sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ] -// sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM] -// sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM] -// Output: [1, DIM, 1, SEQ] fp16 = dx1 + dx3 +// ffnBwdW13t: dh1 @ W1 + dh3 @ W3 → dx_ffn (IC=HIDDEN, two matmuls added) static NSString *gen_ffn_bwd_w13t_dynamic(void) { - int sp_in = 2*SEQ + 2*DIM; + int sp_in = FFN_BWD_W13T_SP; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", HIDDEN, sp_in]; - // Slice dh1 [HIDDEN, SEQ] - [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor sh = const()[name=string(\"sh\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor dh1 = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; - - // Slice dh3 [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; [m appendFormat:@" tensor dh3 = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; - - // Slice W1^T [HIDDEN, DIM] [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, DIM]; [m appendFormat:@" tensor W1t = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM]; - - // Slice W3^T [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+DIM]; [m appendFormat:@" tensor W3t = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - - // dh1 matmul: [S,H] @ [H,D] → [S,D] [m appendFormat:@" tensor ra = const()[name=string(\"ra\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh12 = reshape(shape=ra,x=dh1)[name=string(\"dh12\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh1t = transpose(perm=pm,x=dh12)[name=string(\"dh1t\")];\n", SEQ, HIDDEN]; [m appendFormat:@" tensor dh32 = reshape(shape=ra,x=dh3)[name=string(\"dh32\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh3t = transpose(perm=pm,x=dh32)[name=string(\"dh3t\")];\n", SEQ, HIDDEN]; - [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, DIM]; [m appendFormat:@" tensor W1t2 = reshape(shape=rw,x=W1t)[name=string(\"W1t2\")];\n", HIDDEN, DIM]; [m appendFormat:@" tensor W3t2 = reshape(shape=rw,x=W3t)[name=string(\"W3t2\")];\n", HIDDEN, DIM]; @@ -488,10 +328,7 @@ static NSString *gen_ffn_bwd_w13t_dynamic(void) { [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendFormat:@" tensor dx1m = matmul(transpose_x=bF,transpose_y=bF,x=dh1t,y=W1t2)[name=string(\"dx1m\")];\n", SEQ, DIM]; [m appendFormat:@" tensor dx3m = matmul(transpose_x=bF,transpose_y=bF,x=dh3t,y=W3t2)[name=string(\"dx3m\")];\n", SEQ, DIM]; - - // Add [m appendFormat:@" tensor dxm = add(x=dx1m,y=dx3m)[name=string(\"dxm\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ]; [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; @@ -499,40 +336,80 @@ static NSString *gen_ffn_bwd_w13t_dynamic(void) { return m; } -// ===== SDPA backward part 1 (dynamic Wo^T) ===== -// Same as original gen_sdpa_bwd1 but Wo^T comes from input instead of const -// Input: [1, 4*DIM, 1, SEQ + DIM] fp32 — Q,K,V,dx2 in channels, Wo^T in spatial -// Wait — channels must match for all data. Q,K,V are [DIM,SEQ], dx2 is [DIM,SEQ]. -// Total input channels = 4*DIM. But Wo^T is [DIM,DIM] = DIM channels of DIM spatial. -// Problem: can't mix 4*DIM channels for data with DIM channels for Wo^T. -// Solution: Wo^T matmul as separate kernel, then SDPA part purely element-wise on ANE. - -// Wo^T matmul: dx2 @ Wo^T → da (DIM→DIM) +// wotBwd: dy @ Wo → da (IC=DIM, OC=Q_DIM) static NSString *gen_wot_dynamic(void) { - return gen_dyn_matmul_mil(DIM, DIM, SEQ); + return gen_dyn_matmul_mil(DIM, Q_DIM, SEQ); } -// SDPA backward part 1 (no weights, all data): Q,K,V,da → dV,probs,dp -// Same as original but without Wo^T conv (already done) -// Input: [1, 4*DIM, 1, SEQ] fp16 -static NSString *gen_sdpa_bwd1_noweight(void) { - float sc = 1.0f/sqrtf((float)HD); +// qBwd: dq @ Wq → dx_q (IC=Q_DIM, OC=DIM) +static NSString *gen_q_bwd_dynamic(void) { + return gen_dyn_matmul_mil(Q_DIM, DIM, SEQ); +} + +// kvBwd: dk @ Wk + dv @ Wv → dx_kv (IC=KV_DIM) +// Input: [1, KV_DIM, 1, 2*SEQ+2*DIM] fp16 +// Same pattern as ffnBwdW13t but with KV_DIM channels +static NSString *gen_kv_bwd_dynamic(void) { + int sp_in = KV_BWD_SP; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", 4*DIM, SEQ]; + [m appendFormat:@" func main(tensor x) {\n", KV_DIM, sp_in]; - // Slice Q,K,V,da - [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor sh = const()[name=string(\"sh\"), val=tensor([1,%d,1,%d])];\n", KV_DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; - [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; - [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*DIM]; - [m appendFormat:@" tensor da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dk\")];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dv\")];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; + [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", KV_DIM, DIM]; + [m appendFormat:@" tensor Wkt = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"Wkt\")];\n", KV_DIM, DIM]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+DIM]; + [m appendFormat:@" tensor Wvt = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"Wvt\")];\n", KV_DIM, DIM]; - // Reshape to heads + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor ra = const()[name=string(\"ra\"), val=tensor([1,1,%d,%d])];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor dk2 = reshape(shape=ra,x=dk)[name=string(\"dk2\")];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, KV_DIM]; + [m appendFormat:@" tensor dv2 = reshape(shape=ra,x=dv)[name=string(\"dv2\")];\n", KV_DIM, SEQ]; + [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, KV_DIM]; + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", KV_DIM, DIM]; + [m appendFormat:@" tensor Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", KV_DIM, DIM]; + [m appendFormat:@" tensor Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", KV_DIM, DIM]; + + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor dxm = add(x=dxk,y=dxv)[name=string(\"dxm\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; + [m appendString:@" } -> (dx);\n}\n"]; + return m; +} + +// SDPA backward part 1: recompute attention + dV, dp +// Uses tiled K,V at HEADS dimension (CPU pre-tiles) +// Input: [1, 2*Q_DIM+2*Q_DIM, 1, SEQ] fp16 = (Q, K_tiled, V_tiled, da) +// Output: [1, Q_DIM+2*SCORE_CH, 1, SEQ] fp16 = (dV_full, probs, dp) +static NSString *gen_sdpa_bwd1_noweight(void) { + float sc = 1.0f/sqrtf((float)HD); + int in_ch = 4*Q_DIM; // Q + K_tiled + V_tiled + da, all at Q_DIM (HEADS*HD) + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", in_ch, SEQ]; + + // Slice Q,K_tiled,V_tiled,da — all [Q_DIM, SEQ] + [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", Q_DIM, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", Q_DIM]; + [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*Q_DIM]; + [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*Q_DIM]; + [m appendFormat:@" tensor da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", Q_DIM, SEQ]; + + // Reshape to heads [1,HEADS,HD,SEQ] → [1,HEADS,SEQ,HD] [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, HD, SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ]; @@ -544,7 +421,7 @@ static NSString *gen_sdpa_bwd1_noweight(void) { [m appendFormat:@" tensor dr = reshape(shape=rsh,x=da)[name=string(\"rd\")];\n", HEADS, HD, SEQ]; [m appendFormat:@" tensor dat = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS, SEQ, HD]; - // Forward attention scores (recompute) + // Recompute attention scores [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ]; @@ -559,49 +436,57 @@ static NSString *gen_sdpa_bwd1_noweight(void) { [m appendFormat:@" tensor dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=dat)[name=string(\"dv\")];\n", HEADS, SEQ, HD]; [m appendFormat:@" tensor dp4 = matmul(transpose_x=bF,transpose_y=bT,x=dat,y=v)[name=string(\"dp\")];\n", HEADS, SEQ, SEQ]; - // Reshape dV back + // Reshape dV to [Q_DIM, SEQ] (will be reduced to KV_DIM on CPU) [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor dvs = const()[name=string(\"dvs\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dvs = const()[name=string(\"dvs\"), val=tensor([1,%d,1,%d])];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", Q_DIM, SEQ]; - // Flatten probs and dp for output + // Flatten probs and dp [m appendFormat:@" tensor scs = const()[name=string(\"scs\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; [m appendFormat:@" tensor pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH, SEQ]; [m appendFormat:@" tensor dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH, SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH, SEQ]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", Q_DIM+2*SCORE_CH, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } -// SDPA backward part 2: same as original (no weights, pure computation) +// SDPA backward part 2: probs, dp, Q, K_tiled → dQ, dK_full +// Input: [1, 2*SCORE_CH + 2*Q_DIM, 1, SEQ] +// Output: [1, 2*Q_DIM, 1, SEQ] = (dQ, dK_full) static NSString *gen_sdpa_bwd2(void) { float sc = 1.0f/sqrtf((float)HD); - int bwd2_in = 2*SCORE_CH + 2*DIM; + int bwd2_in = 2*SCORE_CH + 2*Q_DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", bwd2_in, SEQ]; + [m appendFormat:@" tensor sz_sc = const()[name=string(\"szsc\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH, SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", SCORE_CH]; [m appendFormat:@" tensor dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH, SEQ]; - [m appendFormat:@" tensor sz_d = const()[name=string(\"szd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + + [m appendFormat:@" tensor sz_q = const()[name=string(\"szq\"), val=tensor([1,%d,1,%d])];\n", Q_DIM, SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH]; - [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+DIM]; - [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_q)[name=string(\"s2\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+Q_DIM]; + [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_q)[name=string(\"s3\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor ssh = const()[name=string(\"ssh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, SEQ, SEQ]; [m appendFormat:@" tensor probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS, SEQ, SEQ]; [m appendFormat:@" tensor dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, HD, SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD]; [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD]; + + // Softmax backward: ds = (dp - sum(dp*probs)) * probs * scale [m appendFormat:@" tensor pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS, SEQ, SEQ]; [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([-1])];\n"]; [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; @@ -610,92 +495,26 @@ static NSString *gen_sdpa_bwd2(void) { [m appendFormat:@" tensor ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS, SEQ, SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS, SEQ, SEQ]; + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; [m appendFormat:@" tensor dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS, SEQ, HD]; [m appendFormat:@" tensor dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS, HD, SEQ]; [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS, HD, SEQ]; - [m appendFormat:@" tensor fs = const()[name=string(\"fs\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor fs = const()[name=string(\"fs\"), val=tensor([1,%d,1,%d])];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", Q_DIM, SEQ]; + [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", Q_DIM, SEQ]; + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM, SEQ]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*Q_DIM, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } -// QKV backward (dynamic): dq @ Wq^T + dk @ Wk^T + dv @ Wv^T → dx -// Input: [1, DIM, 1, 3*SEQ + 3*DIM] fp32 -// sp[0:SEQ] = dq [DIM,SEQ] -// sp[SEQ:2*SEQ] = dk [DIM,SEQ] -// sp[2*SEQ:3*SEQ] = dv [DIM,SEQ] -// sp[3*SEQ:3*SEQ+DIM] = Wq^T [DIM,DIM] -// sp[3*SEQ+DIM:3*SEQ+2D] = Wk^T [DIM,DIM] -// sp[3*SEQ+2D:3*SEQ+3D] = Wv^T [DIM,DIM] -// Output: [1, DIM, 1, SEQ] fp32 = dxq + dxk + dxv -static NSString *gen_qkvb_dynamic(void) { - int sp_in = 3*SEQ + 3*DIM; - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; - - // Slice dq, dk, dv - [m appendFormat:@" tensor sd = const()[name=string(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor dq = slice_by_size(x=x,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; - [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; - [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ]; - - // Slice Wq^T, Wk^T, Wv^T - [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", DIM, DIM]; - [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", 3*SEQ]; - [m appendFormat:@" tensor Wqt = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM]; - [m appendFormat:@" tensor b4 = const()[name=string(\"b4\"), val=tensor([0,0,0,%d])];\n", 3*SEQ+DIM]; - [m appendFormat:@" tensor Wkt = slice_by_size(x=x,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM]; - [m appendFormat:@" tensor b5 = const()[name=string(\"b5\"), val=tensor([0,0,0,%d])];\n", 3*SEQ+2*DIM]; - [m appendFormat:@" tensor Wvt = slice_by_size(x=x,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM]; - - [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; - - // Reshape and matmul for each - [m appendFormat:@" tensor rd = const()[name=string(\"rd\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, DIM]; - - // dq @ Wq^T - [m appendFormat:@" tensor dq2 = reshape(shape=rd,x=dq)[name=string(\"dq2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq2)[name=string(\"dqt\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor Wqt2 = reshape(shape=rw,x=Wqt)[name=string(\"Wqt2\")];\n", DIM, DIM]; - [m appendFormat:@" tensor dxq = matmul(transpose_x=bF,transpose_y=bF,x=dqt,y=Wqt2)[name=string(\"dxq\")];\n", SEQ, DIM]; - - // dk @ Wk^T - [m appendFormat:@" tensor dk2 = reshape(shape=rd,x=dk)[name=string(\"dk2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", DIM, DIM]; - [m appendFormat:@" tensor dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM]; - - // dv @ Wv^T - [m appendFormat:@" tensor dv2 = reshape(shape=rd,x=dv)[name=string(\"dv2\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", DIM, DIM]; - [m appendFormat:@" tensor dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM]; - - // Sum: dxq + dxk + dxv - [m appendFormat:@" tensor dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", SEQ, DIM]; - [m appendFormat:@" tensor dxall = add(x=dxqk,y=dxv)[name=string(\"aall\")];\n", SEQ, DIM]; - - [m appendFormat:@" tensor dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; - [m appendString:@" } -> (dx);\n}\n"]; - return m; -} - -// Causal mask blob (used by sdpa_fwd and sdpa_bwd1) +// Causal mask blob static NSData *g_mask_blob = nil; static NSData *get_mask_blob(void) { if (!g_mask_blob) { @@ -708,7 +527,7 @@ static NSData *get_mask_blob(void) { return g_mask_blob; } -// RoPE cos/sin blobs [1, 1, SEQ, HD] — rotary position encodings +// RoPE cos/sin blobs [1, 1, SEQ, HD] static NSData *g_rope_cos_blob = nil; static NSData *g_rope_sin_blob = nil; diff --git a/training/training_dynamic/models/qwen3_06b.h b/training/training_dynamic/models/qwen3_06b.h new file mode 100644 index 0000000..72b6137 --- /dev/null +++ b/training/training_dynamic/models/qwen3_06b.h @@ -0,0 +1,19 @@ +// qwen3_06b.h — Qwen3-0.6B (28 layers, GQA 16q/8kv, head_dim=128) +#pragma once + +#define MODEL_NAME "Qwen3-0.6B" + +#define DIM 1024 +#define HIDDEN 3072 +#define HEADS 16 +#define KV_HEADS 8 +#define HD 128 // explicit head_dim (NOT DIM/HEADS) +#define GQA_RATIO (HEADS / KV_HEADS) // = 2 +#define Q_DIM (HEADS * HD) // = 2048 +#define KV_DIM (KV_HEADS * HD) // = 1024 (= DIM for this model) +#define SEQ 256 +#define NLAYERS 28 +#define VOCAB 151936 + +#define CKPT_PATH "ane_qwen3_06b_dyn_ckpt.bin" +#define DEFAULT_DATA_PATH "../tinystories_data00.bin" diff --git a/training/training_dynamic/models/stories110m.h b/training/training_dynamic/models/stories110m.h new file mode 100644 index 0000000..578fc36 --- /dev/null +++ b/training/training_dynamic/models/stories110m.h @@ -0,0 +1,19 @@ +// stories110m.h — Stories110M (Llama2-style, 12 layers, MHA) +#pragma once + +#define MODEL_NAME "Stories110M" + +#define DIM 768 +#define HIDDEN 2048 +#define HEADS 12 +#define KV_HEADS 12 +#define HD (DIM/HEADS) // = 64 +#define GQA_RATIO 1 // MHA: no GQA +#define Q_DIM (HEADS * HD) // = 768 = DIM +#define KV_DIM (KV_HEADS * HD) // = 768 = DIM +#define SEQ 256 +#define NLAYERS 12 +#define VOCAB 32000 + +#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin" +#define DEFAULT_DATA_PATH "../tinystories_data00.bin" diff --git a/training/training_dynamic/train.m b/training/training_dynamic/train.m index 685e075..0c9f658 100644 --- a/training/training_dynamic/train.m +++ b/training/training_dynamic/train.m @@ -1,53 +1,23 @@ -// train.m — Dynamic weight ANE training for Stories110M +// train.m — Dynamic weight ANE training (model-agnostic GQA support) +// Model selected at compile time via: make MODEL=qwen3_06b (or stories110m) // Compile kernels ONCE at startup, update weights via IOSurface every step. -// No exec() restart needed — eliminates 76% compile overhead. #include "mil_dynamic.h" #include "cpu_ops.h" -#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin" -#define MODEL_PATH "../../../assets/models/stories110M.bin" -#define DEFAULT_DATA_PATH "../tinystories_data00.bin" - // Dynamic kernel set per layer typedef struct { - Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface) - Kern *ffnFused; // residual + RMSNorm + W1,W3 + SiLU + W2 + residual (fused) - Kern *ffnBwdW2t; // dffn @ W2^T (dynamic) - Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic) - Kern *wotBwd; // dx2 @ Wo^T (dynamic) - Kern *sdpaBwd1; // Q,K,V,da → dV,probs,dp (weight-free, has mask const) - Kern *sdpaBwd2; // probs,dp,Q,K → dQ,dK (weight-free) - Kern *qkvBwd; // dq@Wq^T + dk@Wk^T + dv@Wv^T (dynamic) + Kern *sdpaFwd; // QKV matmul + RoPE + GQA tile + SDPA (no Wo) + Kern *woFwd; // attn_out @ Wo^T → o_out (Q_DIM → DIM) + Kern *ffnFused; // W1,W3 + SiLU + W2 + residual (fused) + Kern *ffnBwdW2t; // dffn @ W2^T → dsilu_raw (DIM → HIDDEN) + Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T → dx_ffn (HIDDEN → DIM) + Kern *wotBwd; // dx2 @ Wo → da (DIM → Q_DIM) + Kern *sdpaBwd1; // Q,K,V,da → dV_full,probs,dp (weight-free, has mask) + Kern *sdpaBwd2; // probs,dp,Q,K → dQ,dK_full (weight-free) + Kern *qBwd; // dq @ Wq → dx_q (Q_DIM → DIM) + Kern *kvBwd; // dk@Wk + dv@Wv → dx_kv (KV_DIM → DIM) } DynLayerKernels; -// ===== Weight loading from llama2.c format ===== -static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { - FILE *f = fopen(path, "rb"); - if (!f) { printf("Cannot open %s\n", path); return false; } - Llama2Config cfg; - fread(&cfg, sizeof(cfg), 1, f); - printf(" Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", - cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); - if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { - printf(" ERROR: Config mismatch!\n"); fclose(f); return false; - } - int V = abs(cfg.vocab_size); - fread(embed, 4, V * DIM, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f); - for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f); - fread(rms_final, 4, DIM, f); - fclose(f); - printf(" Loaded pretrained weights\n"); - return true; -} - // Transpose W[rows,cols] → W^T[cols,rows] stored as [cols channels, rows spatial] static void transpose_weight(float *dst, const float *src, int rows, int cols) { for (int r = 0; r < rows; r++) @@ -64,76 +34,72 @@ static bool compile_dynamic_kernels(DynLayerKernels *dk) { @"@model_path/weights/rope_sin.bin": @{@"offset":@0, @"data":get_rope_sin_blob()} }; - // SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp16 → [1, 6*DIM, 1, SEQ] fp16 - printf(" Compiling sdpaFwd...\n"); + int sdpa_out_ch = Q_DIM + Q_DIM + KV_DIM + KV_DIM + DIM; + + // SDPA forward (no Wo): [1, DIM, 1, SDPA_FWD_SP] → [1, sdpa_out_ch, 1, SEQ] + printf(" Compiling sdpaFwd (GQA)...\n"); dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), sdpa_fwd_w, - DIM*(SEQ+4*DIM)*2, 6*DIM*SEQ*2); + DIM*SDPA_FWD_SP*2, sdpa_out_ch*SEQ*2); if (!dk->sdpaFwd) return false; - // Fused FFN: W1,W3 + SiLU + W2 + residual (RMSNorm on CPU) + // Wo forward: [1, Q_DIM, 1, SEQ+DIM] → [1, DIM, 1, SEQ] + printf(" Compiling woFwd...\n"); + dk->woFwd = compile_kern_mil_w(gen_wo_fwd_dynamic(), @{}, + Q_DIM*WO_FWD_SP*2, DIM*SEQ*2); + if (!dk->woFwd) return false; + + // Fused FFN: [1, DIM, 1, FFN_FUSED_SP] → [1, DIM+3*HIDDEN, 1, SEQ] printf(" Compiling ffnFused...\n"); - int ffn_fused_sp = 2*SEQ + 3*HIDDEN; int ffn_fused_och = DIM + 3*HIDDEN; dk->ffnFused = compile_kern_mil_w(gen_ffn_fused_dynamic(), @{}, - DIM*ffn_fused_sp*2, ffn_fused_och*SEQ*2); + DIM*FFN_FUSED_SP*2, ffn_fused_och*SEQ*2); if (!dk->ffnFused) return false; - // FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp16 → [1, HIDDEN, 1, SEQ] fp16 + // FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] → [1, HIDDEN, 1, SEQ] printf(" Compiling ffnBwdW2t...\n"); dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{}, - DIM*(SEQ+HIDDEN)*2, HIDDEN*SEQ*2); + DIM*FFN_BWD_W2T_SP*2, HIDDEN*SEQ*2); if (!dk->ffnBwdW2t) return false; - // FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16 → [1, DIM, 1, SEQ] fp16 + // FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] → [1, DIM, 1, SEQ] printf(" Compiling ffnBwdW13t...\n"); dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{}, - HIDDEN*(2*SEQ+2*DIM)*2, DIM*SEQ*2); + HIDDEN*FFN_BWD_W13T_SP*2, DIM*SEQ*2); if (!dk->ffnBwdW13t) return false; - // Wo^T backward: [1, DIM, 1, SEQ+DIM] fp16 → [1, DIM, 1, SEQ] fp16 + // Wo^T backward: [1, DIM, 1, SEQ+Q_DIM] → [1, Q_DIM, 1, SEQ] printf(" Compiling wotBwd...\n"); dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{}, - DIM*(SEQ+DIM)*2, DIM*SEQ*2); + DIM*WOT_BWD_SP*2, Q_DIM*SEQ*2); if (!dk->wotBwd) return false; - // SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 → [1, DIM+2*SCORE_CH, 1, SEQ] fp16 - printf(" Compiling sdpaBwd1...\n"); + // SDPA bwd1 (weight-free, has mask): [1, 4*Q_DIM, 1, SEQ] → [1, Q_DIM+2*SCORE_CH, 1, SEQ] + printf(" Compiling sdpaBwd1 (GQA)...\n"); dk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_noweight(), mask_w, - 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2); + 4*Q_DIM*SEQ*2, (Q_DIM+2*SCORE_CH)*SEQ*2); if (!dk->sdpaBwd1) return false; - // SDPA bwd2 (no weights): [1, 2*SCORE_CH+2*DIM, 1, SEQ] fp16 → [1, 2*DIM, 1, SEQ] fp16 - printf(" Compiling sdpaBwd2...\n"); + // SDPA bwd2 (weight-free): [1, 2*SCORE_CH+2*Q_DIM, 1, SEQ] → [1, 2*Q_DIM, 1, SEQ] + printf(" Compiling sdpaBwd2 (GQA)...\n"); dk->sdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{}, - (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); + (2*SCORE_CH+2*Q_DIM)*SEQ*2, 2*Q_DIM*SEQ*2); if (!dk->sdpaBwd2) return false; - // QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp16 → [1, DIM, 1, SEQ] fp16 - printf(" Compiling qkvBwd...\n"); - dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{}, - DIM*(3*SEQ+3*DIM)*2, DIM*SEQ*2); - if (!dk->qkvBwd) return false; + // Q backward: [1, Q_DIM, 1, SEQ+DIM] → [1, DIM, 1, SEQ] + printf(" Compiling qBwd...\n"); + dk->qBwd = compile_kern_mil_w(gen_q_bwd_dynamic(), @{}, + Q_DIM*Q_BWD_SP*2, DIM*SEQ*2); + if (!dk->qBwd) return false; + + // KV backward: [1, KV_DIM, 1, 2*SEQ+2*DIM] → [1, DIM, 1, SEQ] + printf(" Compiling kvBwd...\n"); + dk->kvBwd = compile_kern_mil_w(gen_kv_bwd_dynamic(), @{}, + KV_DIM*KV_BWD_SP*2, DIM*SEQ*2); + if (!dk->kvBwd) return false; return true; } -// ===== Write dynamic weights into IOSurface ===== -// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] — xnorm at sp[0:S], Wq/Wk/Wv/Wo at sp[S:] -static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm, - const float *Wq, const float *Wk, const float *Wv, const float *Wo) { - IOSurfaceLock(dk->sdpaFwd->ioIn, 0, NULL); - float *buf = (float*)IOSurfaceGetBaseAddress(dk->sdpaFwd->ioIn); - int sp = SEQ + 4*DIM; - for (int d = 0; d < DIM; d++) { - memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4); - memcpy(buf + d*sp + SEQ, Wq + d*DIM, DIM*4); - memcpy(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM*4); - memcpy(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM*4); - memcpy(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM*4); - } - IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL); -} - // ===== Checkpoint ===== static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, double ct, double cw, int cs, int adam_t, @@ -141,21 +107,22 @@ static void save_checkpoint(const char *path, int step, int total_steps, float l float *embed, AdamState *aembed) { FILE *f = fopen(path, "wb"); CkptHdr h = {0}; - h.magic = 0x424C5A54; h.version = 3; + h.magic = 0x424C5A54; h.version = 4; h.step = step; h.total_steps = total_steps; h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM; h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ; h.lr = lr; h.loss = loss; h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.adam_t = adam_t; + h.kv_heads = KV_HEADS; h.head_dim = HD; h.q_dim = Q_DIM; fwrite(&h, sizeof(h), 1, f); for (int L = 0; L < NLAYERS; L++) { - fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f); - fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); + fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WK_SZ,f); + fwrite(lw[L].Wv,4,WV_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f); fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f); fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f); fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f); - fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f); - fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f); + fwrite(la[L].Wk.m,4,WK_SZ,f); fwrite(la[L].Wk.v,4,WK_SZ,f); + fwrite(la[L].Wv.m,4,WV_SZ,f); fwrite(la[L].Wv.v,4,WV_SZ,f); fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f); fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f); fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f); @@ -178,17 +145,17 @@ static bool load_checkpoint(const char *path, int *step, int *total_steps, float if (!f) return false; CkptHdr h; fread(&h, sizeof(h), 1, f); - if (h.magic != 0x424C5A54 || h.version != 3) { fclose(f); return false; } + if (h.magic != 0x424C5A54 || h.version != 4) { fclose(f); return false; } *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; *ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *adam_t = h.adam_t; for (int L = 0; L < NLAYERS; L++) { - fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f); - fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); + fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WK_SZ,f); + fread(lw[L].Wv,4,WV_SZ,f); fread(lw[L].Wo,4,WO_SZ,f); fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f); fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f); fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f); - fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f); - fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f); + fread(la[L].Wk.m,4,WK_SZ,f); fread(la[L].Wk.v,4,WK_SZ,f); + fread(la[L].Wv.m,4,WV_SZ,f); fread(la[L].Wv.v,4,WV_SZ,f); fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f); fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f); fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f); @@ -217,9 +184,9 @@ int main(int argc, char *argv[]) { int accum_steps = 10; int warmup_steps = 100; float grad_clip = 1.0f; - float loss_scale = 256.0f; // fp16 loss scaling for ANE backward - float res_alpha = 1.0f / sqrtf(2.0f * NLAYERS); // residual scaling (DeepNet-style) - float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1 + float loss_scale = 256.0f; + float res_alpha = 1.0f / sqrtf(2.0f * NLAYERS); + float min_lr_frac = 0.1f; bool do_resume = false, from_scratch = false; const char *data_path = DEFAULT_DATA_PATH; @@ -259,29 +226,28 @@ int main(int argc, char *argv[]) { if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss); } if (!resuming) { - printf("=== ANE Dynamic Training: Stories110M (12 layers) ===\n"); - printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS); - // Param counts for dashboard - double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6; + printf("=== ANE Dynamic Training: %s (%d layers, GQA %d/%d heads) ===\n", + MODEL_NAME, NLAYERS, HEADS, KV_HEADS); + printf("dim=%d q_dim=%d kv_dim=%d hd=%d hidden=%d seq=%d vocab=%d\n", + DIM, Q_DIM, KV_DIM, HD, HIDDEN, SEQ, VOCAB); + double xformer_m = (double)NLAYERS*(WQ_SZ + WK_SZ + WV_SZ + (double)WO_SZ + W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6; double embed_m = (double)VOCAB*DIM / 1e6; printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m); - printf("Kernels: 8 compiled (ffnFused replaces ffnW13+ffnW2, RMSNorm on CPU)\n"); + printf("Kernels: 10 compiled (sdpaFwd+woFwd, ffnFused, ffnBwdW2t+W13t, wotBwd, sdpaBwd1+2, qBwd+kvBwd)\n"); printf("Accum %d steps, LR=%g\n", accum_steps, max_lr); - // FLOPs estimate: 6*N*B*T for transformer (forward+backward ≈ 3x forward) - double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ; - double total_flops = 3.0 * fwd_flops; // fwd + bwd ≈ 3x fwd - printf("FLOPs/step: fwd=%.1fM bwd_dx=%.1fM bwd_dW=%.1fM sdpa_bwd=0.0M total=%.1fM\n", - fwd_flops/1e6, fwd_flops/1e6, fwd_flops/1e6, total_flops/1e6); - printf("ANE FLOPs/step: %.1fM\n", total_flops/1e6); - if (from_scratch || !load_pretrained(lw, rms_final, embed, MODEL_PATH)) { - if (from_scratch) printf(" Training from scratch (random init)\n"); - else printf(" Pretrained load failed, using random init\n"); + double fwd_flops = 2.0*NLAYERS*((double)WQ_SZ + WK_SZ + WV_SZ + WO_SZ + W1_SZ + W2_SZ + W3_SZ) * SEQ; + double total_flops = 3.0 * fwd_flops; + printf("FLOPs/step: fwd=%.1fM total=%.1fM\n", fwd_flops/1e6, total_flops/1e6); + if (from_scratch) { + printf(" Training from scratch (random init)\n"); srand48(42); - float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN); - float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); // LLaMA-style output proj scaling + float scale_d=1.0f/sqrtf(DIM), scale_qd=1.0f/sqrtf(Q_DIM), scale_h=1.0f/sqrtf(HIDDEN); + float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); for (int L=0; Llayer_in, x_cur, SEQ*DIM*4); @@ -441,7 +422,7 @@ int main(int argc, char *argv[]) { dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER); t_cblas_wait += tb_ms(mach_absolute_time() - t0); - // SDPA forward (ANE): xnorm + pre-staged Wq,Wk,Wv,Wo → o_out,Q,K,V,attn_out,xnorm + // SDPA forward (ANE): xnorm + Wq,Wk,Wv → attn_out[Q_DIM], Q_rope[Q_DIM], K_rope[KV_DIM], V[KV_DIM], xnorm[DIM] t0 = mach_absolute_time(); write_sdpa_fwd_acts(pls[L].sdpaFwd_in, xnorm_buf); t_io_fwd += tb_ms(mach_absolute_time() - t0); @@ -449,28 +430,37 @@ int main(int argc, char *argv[]) { ane_eval_req(dk.sdpaFwd, plr[L].sdpaFwd); t_ane_fwd += tb_ms(mach_absolute_time() - t0); - // Read output: [1, 6*DIM, 1, SEQ] fp16 + // Read SDPA output: [1, Q_DIM+Q_DIM+KV_DIM+KV_DIM+DIM, 1, SEQ] fp16 t0 = mach_absolute_time(); IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL); _Float16 *fwd_out = (_Float16*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut); - cvt_f16_f32(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ); - cvt_f16_f32(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ); - cvt_f16_f32(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ); - cvt_f16_f32(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ); - cvt_f16_f32(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ); + int off = 0; + cvt_f16_f32(ac->attn_out, fwd_out + off, Q_DIM*SEQ); off += Q_DIM*SEQ; + cvt_f16_f32(ac->Q, fwd_out + off, Q_DIM*SEQ); off += Q_DIM*SEQ; + cvt_f16_f32(ac->K, fwd_out + off, KV_DIM*SEQ); off += KV_DIM*SEQ; + cvt_f16_f32(ac->V, fwd_out + off, KV_DIM*SEQ); off += KV_DIM*SEQ; + // xnorm passthrough (DIM*SEQ) — not needed, already saved IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL); t_io_fwd += tb_ms(mach_absolute_time() - t0); - // CPU: scaled residual + RMSNorm (ANE can't fuse RMS with 3 matmuls) + // Wo forward (ANE): attn_out[Q_DIM] → o_out[DIM] + t0 = mach_absolute_time(); + write_wo_fwd_acts(pls[L].woFwd_in, ac->attn_out); + t_io_fwd += tb_ms(mach_absolute_time() - t0); + t0 = mach_absolute_time(); + ane_eval_req(dk.woFwd, plr[L].woFwd); + t_ane_fwd += tb_ms(mach_absolute_time() - t0); + t0 = mach_absolute_time(); + io_read_dyn(dk.woFwd->ioOut, ac->o_out, DIM, SEQ); + t_io_fwd += tb_ms(mach_absolute_time() - t0); + + // CPU: scaled residual + RMSNorm t0 = mach_absolute_time(); - // x2 = x_cur + alpha * o_out (residual scaling keeps activations bounded) vDSP_vsma(ac->o_out, 1, &res_alpha, x_cur, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM)); rmsnorm(ac->x2norm, ac->x2, lw[L].rms_ffn, DIM, SEQ); t_rms += tb_ms(mach_absolute_time() - t0); - // Fused FFN (ANE): W1,W3 + SiLU + W2 + residual - // Input: x2norm + x2 (acts), W1t + W3t + W2 (pre-staged weights) - // Output: x_next, h1, h3, silu_out + // Fused FFN (ANE) t0 = mach_absolute_time(); write_ffn_fused_acts(pls[L].ffnFused_in, ac->x2norm, ac->x2); t_io_fwd += tb_ms(mach_absolute_time() - t0); @@ -478,21 +468,17 @@ int main(int argc, char *argv[]) { ane_eval_req(dk.ffnFused, plr[L].ffnFused); t_ane_fwd += tb_ms(mach_absolute_time() - t0); - // Read fused output: [1, DIM+3*HIDDEN, 1, SEQ] fp16 - // Layout: x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN] + // Read fused output: [1, DIM+3*HIDDEN, 1, SEQ] t0 = mach_absolute_time(); IOSurfaceLock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL); _Float16 *ffn_out = (_Float16*)IOSurfaceGetBaseAddress(dk.ffnFused->ioOut); - int off = 0; + off = 0; cvt_f16_f32(x_cur, ffn_out + off, DIM*SEQ); off += DIM*SEQ; cvt_f16_f32(ac->h1, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ; cvt_f16_f32(ac->h3, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ; cvt_f16_f32(ac->silu_out,ffn_out + off, HIDDEN*SEQ); IOSurfaceUnlock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL); t_io_fwd += tb_ms(mach_absolute_time() - t0); - - // (act_clip removed — was causing gradient explosion without backward, - // vanishing gradients with backward. RMSNorm keeps activations bounded.) } // Final RMSNorm + classifier + loss (CPU) @@ -500,7 +486,6 @@ int main(int argc, char *argv[]) { rmsnorm(x_final, x_cur, rms_final, DIM, SEQ); t_rms += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - // Classifier: logits[CV, SEQ] = cembed[CV, DIM] @ x_final[DIM, SEQ] cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, CV, SEQ, DIM, 1.0f, cembed, DIM, x_final, SEQ, 0.0f, logits, SEQ); float loss = cross_entropy_loss(dlogits, logits, ctargets, CV, SEQ); @@ -508,17 +493,15 @@ int main(int argc, char *argv[]) { last_loss = loss; // ===== BACKWARD ===== - // Loss scaling: scale dlogits to prevent fp16 underflow in ANE backward kernels - // All gradients flow scaled; weight grads divided by loss_scale before Adam vDSP_vsmul(dlogits, 1, &loss_scale, dlogits, 1, (vDSP_Length)(SEQ*CV)); - // Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ] + // Classifier backward t0 = mach_absolute_time(); cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, DIM, SEQ, CV, 1.0f, cembed, DIM, dlogits, SEQ, 0.0f, dy, SEQ); t_cls += tb_ms(mach_absolute_time() - t0); - // dEmbed async: gcembed[CV, DIM] += dlogits[CV, SEQ] @ x_final^T[SEQ, DIM] + // dEmbed async dispatch_group_async(dw_grp, dw_q, ^{ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, CV, DIM, SEQ, 1.0f, dlogits, SEQ, x_final, SEQ, 1.0f, gcembed, DIM); @@ -530,15 +513,15 @@ int main(int argc, char *argv[]) { memcpy(dy, dx_rms_final, SEQ*DIM*4); free(dx_rms_final); - // ===== BACKWARD (12 layers, reverse) ===== + // ===== BACKWARD (28 layers, reverse) ===== for (int L=NLAYERS-1; L>=0; L--) { LayerActs *ac = &acts[L]; LayerGrads *gr = &grads[L]; - // dffn = alpha * dy (gradient into FFN branch scaled by residual alpha) + // dffn = alpha * dy vDSP_vsmul(dy, 1, &res_alpha, dffn, 1, (vDSP_Length)(SEQ*DIM)); - // FFN backward: dffn @ pre-staged W2^T → dsilu_raw + // FFN backward: dffn @ W2^T → dsilu_raw t0 = mach_absolute_time(); write_ffn_bwd_w2t_acts(pls[L].ffnBwdW2t_in, dffn); t_io_bwd += tb_ms(mach_absolute_time() - t0); @@ -549,34 +532,28 @@ int main(int argc, char *argv[]) { io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); - // SiLU derivative (vectorized): dsilu → dh1, dh3 - // silu(h1) = h1*sig(h1), dsilu_dh1 = sig*(1+h1*(1-sig)) - // dh1 = dsilu * h3 * dsilu_dh1, dh3 = dsilu * silu(h1) + // SiLU derivative (vectorized) t0 = mach_absolute_time(); { int n = HIDDEN*SEQ; - // sig = 1/(1+exp(-h1)) float minus1 = -1.0f, one = 1.0f; vDSP_vsmul(ac->h1, 1, &minus1, silu_tmp, 1, (vDSP_Length)n); vvexpf(silu_tmp, silu_tmp, &n); vDSP_vsadd(silu_tmp, 1, &one, silu_tmp, 1, (vDSP_Length)n); - vvrecf(silu_tmp, silu_tmp, &n); // silu_tmp = sig - // dh3 = dsilu * h1 * sig (= dsilu * silu(h1)) + vvrecf(silu_tmp, silu_tmp, &n); // sig vDSP_vmul(ac->h1, 1, silu_tmp, 1, dh3, 1, (vDSP_Length)n); vDSP_vmul(dsilu, 1, dh3, 1, dh3, 1, (vDSP_Length)n); - // dsilu_dh1 = sig*(1+h1*(1-sig)), store in silu_tmp2 - vDSP_vsadd(silu_tmp, 1, &minus1, silu_tmp2, 1, (vDSP_Length)n); // sig-1 - vDSP_vneg(silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // 1-sig - vDSP_vmul(ac->h1, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // h1*(1-sig) - vDSP_vsadd(silu_tmp2, 1, &one, silu_tmp2, 1, (vDSP_Length)n); // 1+h1*(1-sig) - vDSP_vmul(silu_tmp, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // full dsilu_dh1 - // dh1 = dsilu * h3 * dsilu_dh1 + vDSP_vsadd(silu_tmp, 1, &minus1, silu_tmp2, 1, (vDSP_Length)n); + vDSP_vneg(silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); + vDSP_vmul(ac->h1, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); + vDSP_vsadd(silu_tmp2, 1, &one, silu_tmp2, 1, (vDSP_Length)n); + vDSP_vmul(silu_tmp, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); vDSP_vmul(dsilu, 1, ac->h3, 1, dh1, 1, (vDSP_Length)n); vDSP_vmul(dh1, 1, silu_tmp2, 1, dh1, 1, (vDSP_Length)n); } t_silu += tb_ms(mach_absolute_time() - t0); - // dh1@W1^T + dh3@W3^T → dx_ffn (ANE, pre-staged weights) + // dh1@W1^T + dh3@W3^T → dx_ffn (ANE) t0 = mach_absolute_time(); write_ffn_bwd_w13t_acts(pls[L].ffnBwdW13t_in, dh1, dh3); t_io_bwd += tb_ms(mach_absolute_time() - t0); @@ -587,7 +564,7 @@ int main(int argc, char *argv[]) { io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); - // dW FFN async (cblas) + // dW FFN async t0 = mach_absolute_time(); float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4); float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4); @@ -612,8 +589,7 @@ int main(int argc, char *argv[]) { for(int i=0;iioOut, da_buf, DIM, SEQ); + io_read_dyn(dk.wotBwd->ioOut, da_buf, Q_DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); - // dWo async (uses alpha-scaled dx2) + // dWo async: gr->Wo[DIM,Q_DIM] += dx2_scaled[DIM,SEQ] @ attn_out^T[SEQ,Q_DIM] t0 = mach_absolute_time(); float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2_scaled, SEQ*DIM*4); free(dx2_scaled); - float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4); + float *capt_attn = (float*)malloc(SEQ*Q_DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*Q_DIM*4); t_dw_copy += tb_ms(mach_absolute_time() - t0); dispatch_group_async(dw_grp, dw_q, ^{ - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, - 1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM); + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, Q_DIM, SEQ, + 1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, Q_DIM); free(capt_do); free(capt_attn); }); - if (L == 0 && step % 10 == 0) { - float damx, dx2mx, dx2mean; - vDSP_maxmgv(da_buf, 1, &damx, (vDSP_Length)(SEQ*DIM)); - vDSP_maxmgv(dx2, 1, &dx2mx, (vDSP_Length)(SEQ*DIM)); - vDSP_meamgv(dx2, 1, &dx2mean, (vDSP_Length)(SEQ*DIM)); - // Count how many dx2 values survive fp16 conversion - int nz = 0; - for (int i=0; iioIn, 0, ac->Q, DIM, SEQ); - io_write_fp16_at(dk.sdpaBwd1->ioIn, DIM, ac->K, DIM, SEQ); - io_write_fp16_at(dk.sdpaBwd1->ioIn, 2*DIM, ac->V, DIM, SEQ); - io_write_fp16_at(dk.sdpaBwd1->ioIn, 3*DIM, da_buf, DIM, SEQ); - free(da_buf); + gqa_tile_kv(k_tiled, ac->K, SEQ); + gqa_tile_kv(v_tiled, ac->V, SEQ); + t_io_bwd += tb_ms(mach_absolute_time() - t0); + + // SDPA backward part 1: Q[Q_DIM],K_tiled[Q_DIM],V_tiled[Q_DIM],da[Q_DIM] → dV_full[Q_DIM],probs,dp + t0 = mach_absolute_time(); + io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, Q_DIM, SEQ); + io_write_fp16_at(dk.sdpaBwd1->ioIn, Q_DIM, k_tiled, Q_DIM, SEQ); + io_write_fp16_at(dk.sdpaBwd1->ioIn, 2*Q_DIM, v_tiled, Q_DIM, SEQ); + io_write_fp16_at(dk.sdpaBwd1->ioIn, 3*Q_DIM, da_buf, Q_DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); ane_eval(dk.sdpaBwd1); t_ane_bwd += tb_ms(mach_absolute_time() - t0); - // SDPA backward part 2: probs,dp,Q,K → dQ,dK + // SDPA backward part 2: probs,dp,Q[Q_DIM],K_tiled[Q_DIM] → dQ[Q_DIM],dK_full[Q_DIM] t0 = mach_absolute_time(); - io_copy(dk.sdpaBwd2->ioIn, 0, dk.sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ); - io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH, ac->Q, DIM, SEQ); - io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH+DIM, ac->K, DIM, SEQ); + io_copy(dk.sdpaBwd2->ioIn, 0, dk.sdpaBwd1->ioOut, Q_DIM, 2*SCORE_CH, SEQ); + io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH, ac->Q, Q_DIM, SEQ); + io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH+Q_DIM, k_tiled, Q_DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); ane_eval(dk.sdpaBwd2); t_ane_bwd += tb_ms(mach_absolute_time() - t0); + // Read SDPA backward outputs t0 = mach_absolute_time(); - io_read_fp16(dk.sdpaBwd2->ioOut, dq, 0, DIM, SEQ); - io_read_fp16(dk.sdpaBwd2->ioOut, dk_buf, DIM, DIM, SEQ); - io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ); + io_read_fp16(dk.sdpaBwd2->ioOut, dq_full, 0, Q_DIM, SEQ); // dQ at full HEADS + io_read_fp16(dk.sdpaBwd2->ioOut, dk_full, Q_DIM, Q_DIM, SEQ); // dK at full HEADS + io_read_fp16(dk.sdpaBwd1->ioOut, dv_full, 0, Q_DIM, SEQ); // dV at full HEADS t_io_bwd += tb_ms(mach_absolute_time() - t0); - // RoPE backward: dq, dk are grads w.r.t. Q_rope, K_rope - // Inverse rotation to get grads w.r.t. pre-RoPE Q, K - rope_backward_inplace(dq, SEQ, DIM, HD); - rope_backward_inplace(dk_buf, SEQ, DIM, HD); + // GQA: reduce dK, dV from Q_DIM (HEADS) → KV_DIM (KV_HEADS) + gqa_reduce_kv(dk_buf, dk_full, SEQ); + gqa_reduce_kv(dv, dv_full, SEQ); + // dQ stays at Q_DIM — no reduction needed + memcpy(dq, dq_full, SEQ*Q_DIM*4); + + // RoPE backward on dQ[Q_DIM] and dK[KV_DIM] + rope_backward_inplace(dq, SEQ, Q_DIM, HD); + rope_backward_inplace(dk_buf, SEQ, KV_DIM, HD); - // Debug: check SDPA backward output magnitudes if (L == 0 && step % 10 == 0) { float dqmx, dkmx, dvmx; - vDSP_maxmgv(dq, 1, &dqmx, (vDSP_Length)(SEQ*DIM)); - vDSP_maxmgv(dk_buf, 1, &dkmx, (vDSP_Length)(SEQ*DIM)); - vDSP_maxmgv(dv, 1, &dvmx, (vDSP_Length)(SEQ*DIM)); + vDSP_maxmgv(dq, 1, &dqmx, (vDSP_Length)(SEQ*Q_DIM)); + vDSP_maxmgv(dk_buf, 1, &dkmx, (vDSP_Length)(SEQ*KV_DIM)); + vDSP_maxmgv(dv, 1, &dvmx, (vDSP_Length)(SEQ*KV_DIM)); printf(" L0 sdpa_bwd: |dq|=%.6f |dk|=%.6f |dv|=%.6f\n", dqmx, dkmx, dvmx); } // dWq/dWk/dWv async + // dWq[Q_DIM,DIM] += dq[Q_DIM,SEQ] @ xnorm^T[SEQ,DIM] + // dWk[KV_DIM,DIM] += dk[KV_DIM,SEQ] @ xnorm^T[SEQ,DIM] + // dWv[KV_DIM,DIM] += dv[KV_DIM,SEQ] @ xnorm^T[SEQ,DIM] t0 = mach_absolute_time(); - float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4); - float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk_buf, SEQ*DIM*4); - float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4); + float *capt_dq = (float*)malloc(SEQ*Q_DIM*4); memcpy(capt_dq, dq, SEQ*Q_DIM*4); + float *capt_dk = (float*)malloc(SEQ*KV_DIM*4); memcpy(capt_dk, dk_buf, SEQ*KV_DIM*4); + float *capt_dv = (float*)malloc(SEQ*KV_DIM*4); memcpy(capt_dv, dv, SEQ*KV_DIM*4); float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4); t_dw_copy += tb_ms(mach_absolute_time() - t0); dispatch_group_async(dw_grp, dw_q, ^{ - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Q_DIM, DIM, SEQ, 1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, KV_DIM, DIM, SEQ, 1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM); - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ, + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, KV_DIM, DIM, SEQ, 1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM); free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn); }); - // QKV backward (ANE): dq,dk,dv @ pre-staged Wq^T,Wk^T,Wv^T → dx_attn + // Q backward (ANE): dq[Q_DIM] @ Wq → dx_q[DIM] t0 = mach_absolute_time(); - write_qkv_bwd_acts(pls[L].qkvBwd_in, dq, dk_buf, dv); + write_q_bwd_acts(pls[L].qBwd_in, dq); t_io_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - ane_eval_req(dk.qkvBwd, plr[L].qkvBwd); + ane_eval_req(dk.qBwd, plr[L].qBwd); t_ane_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ); + io_read_dyn(dk.qBwd->ioOut, dx_attn, DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); + // KV backward (ANE): dk[KV_DIM]@Wk + dv[KV_DIM]@Wv → dx_kv[DIM] + float *dx_kv = (float*)malloc(SEQ*DIM*4); + t0 = mach_absolute_time(); + write_kv_bwd_acts(pls[L].kvBwd_in, dk_buf, dv); + t_io_bwd += tb_ms(mach_absolute_time() - t0); + t0 = mach_absolute_time(); + ane_eval_req(dk.kvBwd, plr[L].kvBwd); + t_ane_bwd += tb_ms(mach_absolute_time() - t0); + t0 = mach_absolute_time(); + io_read_dyn(dk.kvBwd->ioOut, dx_kv, DIM, SEQ); + t_io_bwd += tb_ms(mach_absolute_time() - t0); + + // dx_attn = dx_q + dx_kv + for(int i=0; iWq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;} + for(size_t i=0;iWq[i]*=gsc; + for(size_t i=0;iWk[i]*=gsc; + for(size_t i=0;iWv[i]*=gsc; + for(size_t i=0;iWo[i]*=gsc; for(size_t i=0;iW1[i]*=gsc; for(size_t i=0;iW2[i]*=gsc; for(size_t i=0;iW3[i]*=gsc; for(int i=0;irms_att[i]*=gsc; g->rms_ffn[i]*=gsc;} } for(int i=0;iWq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s; - vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s; - vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s; + vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WK_SZ); grad_norm_sq+=s; + vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WV_SZ); grad_norm_sq+=s; vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); grad_norm_sq+=s; vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); grad_norm_sq+=s; vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); grad_norm_sq+=s; @@ -793,13 +786,12 @@ int main(int argc, char *argv[]) { } float grad_norm = sqrtf(grad_norm_sq); if ((step+1) % 10 == 0) { - // Per-component gradient norms for diagnostics float attn_sq=0, ffn_sq=0, embed_sq=0; for (int L=0; LWq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s; - vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s; - vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s; + vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WK_SZ); attn_sq+=s; + vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WV_SZ); attn_sq+=s; vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); attn_sq+=s; vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); ffn_sq+=s; vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); ffn_sq+=s; @@ -818,8 +810,8 @@ int main(int argc, char *argv[]) { for (int L=0; LWq,1,&clip_scale,g->Wq,1,(vDSP_Length)WQ_SZ); - vDSP_vsmul(g->Wk,1,&clip_scale,g->Wk,1,(vDSP_Length)WQ_SZ); - vDSP_vsmul(g->Wv,1,&clip_scale,g->Wv,1,(vDSP_Length)WQ_SZ); + vDSP_vsmul(g->Wk,1,&clip_scale,g->Wk,1,(vDSP_Length)WK_SZ); + vDSP_vsmul(g->Wv,1,&clip_scale,g->Wv,1,(vDSP_Length)WV_SZ); vDSP_vsmul(g->Wo,1,&clip_scale,g->Wo,1,(vDSP_Length)WO_SZ); vDSP_vsmul(g->W1,1,&clip_scale,g->W1,1,(vDSP_Length)W1_SZ); vDSP_vsmul(g->W2,1,&clip_scale,g->W2,1,(vDSP_Length)W2_SZ); @@ -854,25 +846,26 @@ int main(int argc, char *argv[]) { adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f); // Update transposed weight buffers - transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM); - transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM); - transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM); - transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM); + transpose_weight(Wqt_buf[L], lw[L].Wq, Q_DIM, DIM); + transpose_weight(Wkt_buf[L], lw[L].Wk, KV_DIM, DIM); + transpose_weight(Wvt_buf[L], lw[L].Wv, KV_DIM, DIM); + transpose_weight(Wot_buf[L], lw[L].Wo, DIM, Q_DIM); transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM); transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN); transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM); - // Re-stage weights into per-layer IOSurfaces - stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]); + // Re-stage weights + stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]); + stage_wo_fwd_weights(pls[L].woFwd_in, Wot_buf[L]); stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2); stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, lw[L].W2); stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, lw[L].W1, lw[L].W3); stage_wot_bwd_weights(pls[L].wotBwd_in, lw[L].Wo); - stage_qkv_bwd_weights(pls[L].qkvBwd_in, lw[L].Wq, lw[L].Wk, lw[L].Wv); + stage_q_bwd_weights(pls[L].qBwd_in, lw[L].Wq); + stage_kv_bwd_weights(pls[L].kvBwd_in, lw[L].Wk, lw[L].Wv); } adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f); adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); - // Re-extract compact embed from updated full embed free(cembed); cembed = vocab_compact_embed(embed, &vm, DIM); @@ -908,9 +901,13 @@ int main(int argc, char *argv[]) { free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]); } free_per_layer(pls, plr); - free_kern(dk.sdpaFwd); free_kern(dk.ffnFused); + free_kern(dk.sdpaFwd); free_kern(dk.woFwd); free_kern(dk.ffnFused); free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd); - free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd); + free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); + free_kern(dk.qBwd); free_kern(dk.kvBwd); + free(da_buf); free(k_tiled); free(v_tiled); + free(dq_full); free(dk_full); free(dv_full); + free(dq); free(dk_buf); free(dv); munmap(token_data, data_len); close(data_fd); } return 0;