Add Qwen3-0.6B GQA support and multi-model build system

Implement Grouped-Query Attention (16q/8kv heads, head_dim=128) for
Qwen3-0.6B (28 layers, 596M params). Model configs moved to
models/*.h headers selected at build time via make MODEL=xxx.

Key changes:
- GQA-aware MIL kernels: sdpaFwd split from woFwd (Q_DIM!=DIM),
  qBwd/kvBwd split from qkvBwd (different IC dimensions)
- K/V tile (KV_HEADS→HEADS) before SDPA backward, reduce after
- 10 kernels total, all model-agnostic via compile-time defines
- Makefile: make MODEL=qwen3_06b (default) or MODEL=stories110m
- Both models verified: Stories110M ~115ms/step, Qwen3 ~412ms/step
This commit is contained in:
maderix 2026-03-06 06:23:15 -08:00
parent c3c5094865
commit 475348ad14
8 changed files with 761 additions and 855 deletions

View File

@ -1,14 +1,22 @@
# ANE Training — Stories110M on Apple Neural Engine
# ANE Training — On-Device Training on Apple Neural Engine
Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly on Apple's Neural Engine using private ANE APIs.
Training transformer models directly on Apple's Neural Engine using private ANE APIs. Supports multiple architectures including GQA (Grouped-Query Attention).
![Dashboard](dashboard.gif)
## Supported Models
| Model | Layers | Heads (Q/KV) | Dim | Hidden | Params | ms/step |
|-------|--------|--------------|-----|--------|--------|---------|
| Stories110M | 12 | 12/12 (MHA) | 768 | 2048 | 109M | ~115 |
| Qwen3-0.6B | 28 | 16/8 (GQA) | 1024 | 3072 | 596M | ~412 |
Model configs live in `training_dynamic/models/*.h`. To add a new model, create a header with the architecture defines (see below).
## Architecture
- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256
- **109.53M params** (84.95M transformer + 24.58M embedding)
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask — decompose into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)
- **GQA support**: K/V heads tiled to match Q heads for SDPA, reduced back after backward pass
## Three Training Pipelines
@ -27,10 +35,10 @@ Offloads classifier forward (32K conv), softmax, final RMSNorm, and RMSNorm back
- Use `--no-ane-extras` to disable and fall back to CPU (for debugging)
### 3. Dynamic Weight Pipeline (`training_dynamic/`)
Weights passed via IOSurface spatial dimension — compile 9 kernels once at startup, no recompilation needed.
Weights passed via IOSurface spatial dimension — compile 10 kernels once at startup, no recompilation needed. Supports multiple models via `make MODEL=xxx`.
- 9 shared kernels across all 12 layers
- **111 ms/step**, 0.4s one-time compile
- 10 shared kernels across all layers (GQA-aware: split sdpaFwd/woFwd, split qBwd/kvBwd)
- **~115 ms/step** (Stories110M) / **~412 ms/step** (Qwen3-0.6B), 0.4s one-time compile
- No exec() restart, no compile limit issues
## Performance Comparison (20 Steps)
@ -56,10 +64,11 @@ Weights passed via IOSurface spatial dimension — compile 9 kernels once at sta
|------|-------------|
| `train_large.m` | Static baseline — 72 kernels, classifier/softmax on CPU |
| `train_large_ane.m` | PR#19 — 86 kernels, classifier/softmax/rmsnorm_bwd on ANE |
| `training_dynamic/train.m` | Dynamic pipeline — 9 kernels, weights via IOSurface |
| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels |
| `training_dynamic/config.h` | Model config (DIM=768, HIDDEN=2048, etc.) |
| `training_dynamic/io.h` | IOSurface I/O + MIL compilation helpers |
| `training_dynamic/train.m` | Dynamic pipeline — 10 kernels, weights via IOSurface |
| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels (GQA-aware) |
| `training_dynamic/config.h` | Derived sizes, structs, alloc helpers (model-agnostic) |
| `training_dynamic/models/*.h` | Per-model configs (stories110m.h, qwen3_06b.h) |
| `training_dynamic/io.h` | IOSurface I/O, weight staging, GQA tile/reduce |
| `training_dynamic/cpu_ops.h` | CPU ops (SiLU backward, cross-entropy, Adam) |
| `stories_config.h` | Static pipeline config, structs, alloc helpers |
| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval |
@ -83,33 +92,35 @@ Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from HuggingFace. Pr
### 2. Build & Train
```bash
# Static baseline (classifier + softmax on CPU)
make train_large
./train_large stories110M.bin 256 100 1e-4
./train_large --model stories110M.bin --steps 100 --lr 1e-4
./train_large --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
make train_large_ane
./train_large_ane stories110M.bin 256 100 1e-4
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras
./train_large_ane --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# Static baseline (classifier + softmax on CPU)
make train_large
./train_large stories110M.bin 256 100 1e-4
./train_large --model stories110M.bin --steps 100 --lr 1e-4
./train_large --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# Dynamic pipeline (no recompilation)
cd training_dynamic && make train
# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
make train_large_ane
./train_large_ane stories110M.bin 256 100 1e-4
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras
./train_large_ane --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# Dynamic pipeline (model selected at build time)
cd training_dynamic
make MODEL=qwen3_06b # default — Qwen3-0.6B (28L, GQA, 596M)
make MODEL=stories110m # Stories110M (12L, MHA, 109M)
./train --scratch # train from random init
./train # resume from checkpoint
./train --resume # resume from checkpoint
./train --steps 200 --lr 1e-4 # custom steps/lr
```
**CLI flags (`train_large` / `train_large_ane`):**
**CLI flags (`train_large` / `train_large_ane`):**
- `--steps N` (default 10000)
- `--lr F` (default 3e-4)
- `--model PATH` — pretrained weights file
- `--data PATH` — tokenized TinyStories `.bin` file (default: `tinystories_data00.bin`)
- `--ckpt PATH` — checkpoint file (preserved across exec() restarts)
- `--resume` — resume from checkpoint
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd
- `--lr F` (default 3e-4)
- `--model PATH` — pretrained weights file
- `--data PATH` — tokenized TinyStories `.bin` file (default: `tinystories_data00.bin`)
- `--ckpt PATH` — checkpoint file (preserved across exec() restarts)
- `--resume` — resume from checkpoint
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd
### 3. Monitor with Dashboard
@ -133,11 +144,42 @@ Avg train: 91.8 ms/step
ANE TFLOPS: 1.15 sustained
```
## Adding a New Model
Create `training_dynamic/models/mymodel.h`:
```c
#pragma once
#define MODEL_NAME "MyModel-1B"
#define DIM 2048 // model hidden dim
#define HIDDEN 5504 // FFN intermediate dim
#define HEADS 32 // number of query heads
#define KV_HEADS 8 // number of KV heads (= HEADS for MHA)
#define HD 64 // head dim (can differ from DIM/HEADS)
#define SEQ 256 // sequence length
#define NLAYERS 22 // number of transformer layers
#define VOCAB 32000 // vocabulary size
#define CKPT_PATH "ane_mymodel_dyn_ckpt.bin"
#define DEFAULT_DATA_PATH "../tinystories_data00.bin"
```
Everything else is derived automatically: `GQA_RATIO`, `Q_DIM`, `KV_DIM`, weight sizes, IOSurface layouts, MIL kernels.
Build with: `make MODEL=mymodel`
**Constraints:**
- `HEADS` must be divisible by `KV_HEADS`
- `HD` is explicit (not necessarily `DIM/HEADS` — Qwen3 uses HD=128 with DIM/HEADS=64)
- For MHA (no GQA), set `KV_HEADS = HEADS`
## Key Techniques
- **NEON vectorized fp16↔fp32**: ARM NEON intrinsics for fast IOSurface data transfer
- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar
- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE
- **Vocab compaction** (dynamic): 32K → 9.2K active tokens, 3.5x reduction in classifier work
- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all 12 layers
- **Vocab compaction** (dynamic): 32K152K → 9.2K active tokens, up to 16.5x reduction in classifier work
- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all layers
- **GQA tile/reduce**: K/V tiled from KV_HEADS→HEADS on CPU before SDPA backward, gradients reduced HEADS→KV_HEADS after
- **exec() restart**: Workaround for ANE ~119 compile limit per process

View File

@ -2,8 +2,16 @@ CC = xcrun clang
CFLAGS = -O2 -DACCELERATE_NEW_LAPACK -framework Foundation -framework IOSurface -framework Accelerate \
-isysroot $(shell xcrun --show-sdk-path) -fobjc-arc
train: train.m config.h io.h cpu_ops.h mil_dynamic.h
$(CC) $(CFLAGS) -o train train.m
# Model selection: make MODEL=qwen3_06b (default)
# Available: stories110m, qwen3_06b
MODEL ?= qwen3_06b
MODEL_HDR = models/$(MODEL).h
train: train.m config.h io.h cpu_ops.h mil_dynamic.h $(MODEL_HDR)
@echo "Building for model: $(MODEL)"
$(CC) $(CFLAGS) -include $(MODEL_HDR) -o train train.m
clean:
rm -f train
.PHONY: clean

View File

@ -1,4 +1,5 @@
// config.h — Stories110M model config, structs, ANE init
// config.h — Model-agnostic structs, derived sizes, ANE init
// Model-specific dims come from models/*.h, selected via -DMODEL_HEADER
#pragma once
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
@ -15,22 +16,21 @@
#include <fcntl.h>
#include <arm_neon.h>
// Stories110M config
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define HD (DIM/HEADS)
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
// Include selected model config
// MODEL_HEADER is set by Makefile via -include models/xxx.h
#ifndef MODEL_NAME
#error "No model selected. Build with: make MODEL=qwen3_06b (or stories110m)"
#endif
// Weight sizes per layer
#define WQ_SZ (DIM*DIM)
#define WO_SZ (DIM*DIM)
// Derived weight sizes per layer (GQA-aware)
#define WQ_SZ (Q_DIM*DIM)
#define WK_SZ (KV_DIM*DIM)
#define WV_SZ (KV_DIM*DIM)
#define WO_SZ (DIM*Q_DIM)
#define W1_SZ (HIDDEN*DIM)
#define W2_SZ (DIM*HIDDEN)
#define W3_SZ (HIDDEN*DIM)
#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
#define LAYER_PARAMS (WQ_SZ + WK_SZ + WV_SZ + WO_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
// Attention score channels for SDPA backward
#define SCORE_CH (HEADS*SEQ)
@ -64,14 +64,14 @@ typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmp
// Per-layer IOSurfaces for pre-staged weights
typedef struct {
IOSurfaceRef sdpaFwd_in, ffnFused_in;
IOSurfaceRef ffnBwdW2t_in, ffnBwdW13t_in, wotBwd_in, qkvBwd_in;
IOSurfaceRef sdpaFwd_in, woFwd_in, ffnFused_in;
IOSurfaceRef ffnBwdW2t_in, ffnBwdW13t_in, wotBwd_in, qBwd_in, kvBwd_in;
} PerLayerSurfaces;
// Per-layer ANE requests (bound to per-layer IOSurfaces)
typedef struct {
void *sdpaFwd, *ffnFused;
void *ffnBwdW2t, *ffnBwdW13t, *wotBwd, *qkvBwd;
void *sdpaFwd, *woFwd, *ffnFused;
void *ffnBwdW2t, *ffnBwdW13t, *wotBwd, *qBwd, *kvBwd;
} PerLayerRequests;
// Checkpoint header
@ -81,14 +81,10 @@ typedef struct {
float lr, loss;
double cum_compile, cum_train, cum_wall;
int cum_steps, cum_batches, adam_t;
int pad[3];
int kv_heads, head_dim, q_dim; // GQA fields
// Note: was int pad[3] in v3, now stores GQA info in v4+
} CkptHdr;
// llama2.c model file header
typedef struct {
int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len;
} Llama2Config;
// Globals
static Class g_D, g_I, g_AR, g_AIO;
static mach_timebase_info_data_t g_tb;
@ -109,8 +105,8 @@ static void adam_free(AdamState *s) { free(s->m); free(s->v); }
static LayerWeights layer_weights_alloc(void) {
LayerWeights w;
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4);
w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WK_SZ*4);
w.Wv=(float*)malloc(WV_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4);
w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4);
return w;
@ -121,7 +117,7 @@ static void layer_weights_free(LayerWeights *w) {
}
static LayerAdam layer_adam_alloc(void) {
LayerAdam a;
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ);
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WK_SZ); a.Wv=adam_alloc(WV_SZ); a.Wo=adam_alloc(WO_SZ);
a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ);
a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM);
return a;
@ -135,8 +131,8 @@ static LayerActs layer_acts_alloc(void) {
LayerActs a;
a.layer_in=(float*)malloc(SEQ*DIM*4);
a.xnorm=(float*)malloc(SEQ*DIM*4);
a.Q=(float*)malloc(SEQ*DIM*4); a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4);
a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
a.Q=(float*)malloc(SEQ*Q_DIM*4); a.K=(float*)malloc(SEQ*KV_DIM*4); a.V=(float*)malloc(SEQ*KV_DIM*4);
a.attn_out=(float*)malloc(SEQ*Q_DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4);
a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4);
a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4);
@ -150,15 +146,15 @@ static void layer_acts_free(LayerActs *a) {
}
static LayerGrads layer_grads_alloc(void) {
LayerGrads g;
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4);
g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WK_SZ,4);
g.Wv=(float*)calloc(WV_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4);
g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4);
return g;
}
static void layer_grads_zero(LayerGrads *g) {
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4);
memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4);
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WK_SZ*4);
memset(g->Wv,0,WV_SZ*4);memset(g->Wo,0,WO_SZ*4);
memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4);
memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4);
}

View File

@ -1,4 +1,5 @@
// io.h — IOSurface helpers, NEON conversion, kernel compile/eval
// Updated for GQA (Qwen3-0.6B): Q_DIM != DIM, separate KV heads
#pragma once
#include "config.h"
@ -75,8 +76,6 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int
}
// fp16 IOSurface I/O (for dynamic matmul kernels with fp16 input/output)
// Layout: [1, IC, 1, SP] where SP = SEQ + OC
// Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC]
static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq,
const float *W, int oc) {
int sp = seq + oc;
@ -145,14 +144,10 @@ static void ane_eval(Kern *k) {
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}
// Evaluate with a per-layer request (different ioIn, same model)
static void ane_eval_req(Kern *k, void *request) {
id mdl = (__bridge id)k->model; id req = (__bridge id)request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}
// Create an ANE request binding a custom ioIn to a kernel's model+ioOut
static void *make_request(Kern *k, IOSurfaceRef ioIn) {
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
@ -162,172 +157,157 @@ static void *make_request(Kern *k, IOSurfaceRef ioIn) {
return (void*)CFBridgingRetain(req);
}
// ===== Per-layer weight staging (write once, reuse across steps) =====
// All surfaces are now fp16 — staging converts fp32 weights to fp16
// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] fp16 — weights at sp[SEQ:]
static void stage_sdpa_fwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk,
const float *Wv, const float *Wo) {
// ===== Per-layer weight staging for GQA =====
// sdpaFwd: [1, DIM, 1, SEQ + Q_DIM + KV_DIM + KV_DIM] fp16 — no Wo (separate kernel)
// Wq: [DIM, Q_DIM], Wk: [DIM, KV_DIM], Wv: [DIM, KV_DIM]
#define SDPA_FWD_SP (SEQ + Q_DIM + KV_DIM + KV_DIM)
static void stage_sdpa_fwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, const float *Wv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + SEQ, Wq + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM);
cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ, Wq + d*Q_DIM, Q_DIM);
cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ+Q_DIM, Wk + d*KV_DIM, KV_DIM);
cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ+Q_DIM+KV_DIM, Wv + d*KV_DIM, KV_DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_sdpa_fwd_acts(IOSurfaceRef s, const float *xnorm) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ);
cvt_f32_f16(buf + d*SDPA_FWD_SP, xnorm + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// woFwd: [1, Q_DIM, 1, SEQ + DIM] fp16 — Wo: [Q_DIM, DIM]
#define WO_FWD_SP (SEQ + DIM)
static void stage_wo_fwd_weights(IOSurfaceRef s, const float *Wo) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*WO_FWD_SP + SEQ, Wo + d*DIM, DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_wo_fwd_acts(IOSurfaceRef s, const float *attn_out) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*WO_FWD_SP, attn_out + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnFused: [1, DIM, 1, 2*SEQ+3*HIDDEN] fp16
#define FFN_FUSED_SP (2*SEQ + 3*HIDDEN)
static void stage_ffn_fused_weights(IOSurfaceRef s,
const float *W1t, const float *W3t, const float *W2_orig) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 3*HIDDEN;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + 2*SEQ, W1t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*sp + 2*SEQ+HIDDEN, W3t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*sp + 2*SEQ+2*HIDDEN, W2_orig + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ, W1t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ+HIDDEN, W3t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ+2*HIDDEN, W2_orig + d*HIDDEN, HIDDEN);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_fused_acts(IOSurfaceRef s, const float *x2norm, const float *x2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 3*HIDDEN;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp, x2norm + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + SEQ, x2 + d*SEQ, SEQ);
cvt_f32_f16(buf + d*FFN_FUSED_SP, x2norm + d*SEQ, SEQ);
cvt_f32_f16(buf + d*FFN_FUSED_SP + SEQ, x2 + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] fp16
static void stage_ffn_w13_weights(IOSurfaceRef s, const float *W1, const float *W3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_w13_acts(IOSurfaceRef s, const float *xnorm) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnW2: [1, HIDDEN, 1, SEQ+DIM] fp16
static void stage_ffn_w2_weights(IOSurfaceRef s, const float *W2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++)
cvt_f32_f16(buf + d*sp + SEQ, W2 + d*DIM, DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_w2_acts(IOSurfaceRef s, const float *gate) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++)
cvt_f32_f16(buf + d*sp, gate + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnBwdW2t: [1, DIM, 1, SEQ+HIDDEN] fp16
#define FFN_BWD_W2T_SP (SEQ + HIDDEN)
static void stage_ffn_bwd_w2t_weights(IOSurfaceRef s, const float *W2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + HIDDEN;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp + SEQ, W2 + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*FFN_BWD_W2T_SP + SEQ, W2 + d*HIDDEN, HIDDEN);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_bwd_w2t_acts(IOSurfaceRef s, const float *dffn) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + HIDDEN;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, dffn + d*SEQ, SEQ);
cvt_f32_f16(buf + d*FFN_BWD_W2T_SP, dffn + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnBwdW13t: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16
#define FFN_BWD_W13T_SP (2*SEQ + 2*DIM)
static void stage_ffn_bwd_w13t_weights(IOSurfaceRef s, const float *W1, const float *W3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
cvt_f32_f16(buf + d*sp + 2*SEQ, W1 + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + 2*SEQ + DIM, W3 + d*DIM, DIM);
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + 2*SEQ, W1 + d*DIM, DIM);
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + 2*SEQ + DIM, W3 + d*DIM, DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_bwd_w13t_acts(IOSurfaceRef s, const float *dh1, const float *dh3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
cvt_f32_f16(buf + d*sp, dh1 + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ);
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP, dh1 + d*SEQ, SEQ);
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + SEQ, dh3 + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// wotBwd: [1, DIM, 1, SEQ+DIM] fp16
// wotBwd: [1, DIM, 1, SEQ+Q_DIM] fp16 — Wo is [DIM, Q_DIM], matmul gives Wo^T @ dy
#define WOT_BWD_SP (SEQ + Q_DIM)
static void stage_wot_bwd_weights(IOSurfaceRef s, const float *Wo) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp + SEQ, Wo + d*DIM, DIM);
cvt_f32_f16(buf + d*WOT_BWD_SP + SEQ, Wo + d*Q_DIM, Q_DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_wot_bwd_acts(IOSurfaceRef s, const float *dx2) {
static void write_wot_bwd_acts(IOSurfaceRef s, const float *dy) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, dx2 + d*SEQ, SEQ);
cvt_f32_f16(buf + d*WOT_BWD_SP, dy + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// qkvBwd: [1, DIM, 1, 3*SEQ+3*DIM] fp16
static void stage_qkv_bwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, const float *Wv) {
// qBwd: [1, Q_DIM, 1, SEQ+DIM] fp16 — Wq is [Q_DIM, DIM], matmul gives Wq^T @ dq
#define Q_BWD_SP (SEQ + DIM)
static void stage_q_bwd_weights(IOSurfaceRef s, const float *Wq) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + 3*SEQ, Wq + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + 3*SEQ + DIM, Wk + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + 3*SEQ + 2*DIM, Wv + d*DIM, DIM);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*Q_BWD_SP + SEQ, Wq + d*DIM, DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_q_bwd_acts(IOSurfaceRef s, const float *dq) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*Q_BWD_SP, dq + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// kvBwd: [1, KV_DIM, 1, 2*SEQ+2*DIM] fp16 — dk @ Wk + dv @ Wv → dx_kv
#define KV_BWD_SP (2*SEQ + 2*DIM)
static void stage_kv_bwd_weights(IOSurfaceRef s, const float *Wk, const float *Wv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < KV_DIM; d++) {
cvt_f32_f16(buf + d*KV_BWD_SP + 2*SEQ, Wk + d*DIM, DIM);
cvt_f32_f16(buf + d*KV_BWD_SP + 2*SEQ + DIM, Wv + d*DIM, DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_qkv_bwd_acts(IOSurfaceRef s, const float *dq, const float *dk, const float *dv) {
static void write_kv_bwd_acts(IOSurfaceRef s, const float *dk, const float *dv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp, dq + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + SEQ, dk + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ);
for (int d = 0; d < KV_DIM; d++) {
cvt_f32_f16(buf + d*KV_BWD_SP, dk + d*SEQ, SEQ);
cvt_f32_f16(buf + d*KV_BWD_SP + SEQ, dv + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
@ -335,11 +315,37 @@ static void write_qkv_bwd_acts(IOSurfaceRef s, const float *dq, const float *dk,
// Free per-layer surfaces and requests
static void free_per_layer(PerLayerSurfaces *pls, PerLayerRequests *plr) {
for (int L = 0; L < NLAYERS; L++) {
CFRelease(pls[L].sdpaFwd_in); CFRelease(pls[L].ffnFused_in);
CFRelease(pls[L].sdpaFwd_in); CFRelease(pls[L].woFwd_in); CFRelease(pls[L].ffnFused_in);
CFRelease(pls[L].ffnBwdW2t_in); CFRelease(pls[L].ffnBwdW13t_in);
CFRelease(pls[L].wotBwd_in); CFRelease(pls[L].qkvBwd_in);
CFRelease(plr[L].sdpaFwd); CFRelease(plr[L].ffnFused);
CFRelease(pls[L].wotBwd_in); CFRelease(pls[L].qBwd_in); CFRelease(pls[L].kvBwd_in);
CFRelease(plr[L].sdpaFwd); CFRelease(plr[L].woFwd); CFRelease(plr[L].ffnFused);
CFRelease(plr[L].ffnBwdW2t); CFRelease(plr[L].ffnBwdW13t);
CFRelease(plr[L].wotBwd); CFRelease(plr[L].qkvBwd);
CFRelease(plr[L].wotBwd); CFRelease(plr[L].qBwd); CFRelease(plr[L].kvBwd);
}
}
// GQA helpers: tile KV from KV_HEADS to HEADS, and reduce HEADS to KV_HEADS
// tile_kv: input [KV_DIM, SEQ], output [Q_DIM, SEQ]
// Each KV head is duplicated GQA_RATIO times
static void gqa_tile_kv(float *out, const float *in, int seq) {
for (int kv = 0; kv < KV_HEADS; kv++) {
for (int r = 0; r < GQA_RATIO; r++) {
int q_head = kv * GQA_RATIO + r;
memcpy(out + q_head * HD * seq, in + kv * HD * seq, HD * seq * sizeof(float));
}
}
}
// reduce_kv: input [Q_DIM, SEQ], output [KV_DIM, SEQ]
// Sum contributions from Q heads sharing each KV head
static void gqa_reduce_kv(float *out, const float *in, int seq) {
memset(out, 0, KV_DIM * seq * sizeof(float));
for (int kv = 0; kv < KV_HEADS; kv++) {
for (int r = 0; r < GQA_RATIO; r++) {
int q_head = kv * GQA_RATIO + r;
const float *src = in + q_head * HD * seq;
float *dst = out + kv * HD * seq;
for (int i = 0; i < HD * seq; i++)
dst[i] += src[i];
}
}
}

View File

@ -1,7 +1,7 @@
// mil_dynamic.h — MIL generators using dynamic matmul (weights via IOSurface)
// Instead of conv(const_weight, x), we use matmul(x, W) where both come from input.
// Input layout: [1, IC, 1, SP] fp32, SP = SEQ + total_weight_cols
// Activations in sp[0:SEQ], weight matrices packed sequentially in sp[SEQ:]
// mil_dynamic.h — MIL generators for Qwen3-0.6B with GQA
// Q_DIM=2048 != DIM=1024, KV_DIM=1024, GQA_RATIO=2
// SDPA split: sdpaFwd (QKV proj + attention, no Wo) + woFwd (Wo matmul)
// Backward: qBwd + kvBwd (split from qkvBwd)
#pragma once
#include "io.h"
@ -11,42 +11,30 @@
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
// Helper: generate a dynamic matmul within a MIL function
// Slices activation [1,ic,1,seq] and weight [1,ic,1,oc] from input, does matmul
// act_sp_off: spatial offset for activations (usually 0)
// w_sp_off: spatial offset for weight block
// Returns variable name of result [1,oc,1,seq] in fp16
static void gen_dyn_matmul(NSMutableString *m, const char *prefix,
int ic, int oc, int seq,
int act_sp_off, int w_sp_off,
const char *input_var) {
// Slice activations
[m appendFormat:@" tensor<int32, [4]> %s_ba = const()[name=string(\"%s_ba\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, act_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sa = const()[name=string(\"%s_sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_act = slice_by_size(x=%s,begin=%s_ba,size=%s_sa)[name=string(\"%s_act\")];\n", ic, seq, prefix, input_var, prefix, prefix, prefix];
// Slice weight
[m appendFormat:@" tensor<int32, [4]> %s_bw = const()[name=string(\"%s_bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, w_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sw = const()[name=string(\"%s_sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_wt = slice_by_size(x=%s,begin=%s_bw,size=%s_sw)[name=string(\"%s_wt\")];\n", ic, oc, prefix, input_var, prefix, prefix, prefix];
// Reshape act: [1,ic,1,seq] → [1,1,ic,seq] → transpose → [1,1,seq,ic]
[m appendFormat:@" tensor<int32, [4]> %s_ra = const()[name=string(\"%s_ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a2 = reshape(shape=%s_ra,x=%s_act)[name=string(\"%s_a2\")];\n", ic, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_pm = const()[name=string(\"%s_pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n", prefix, prefix];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a3 = transpose(perm=%s_pm,x=%s_a2)[name=string(\"%s_a3\")];\n", seq, ic, prefix, prefix, prefix, prefix];
// Reshape weight: [1,ic,1,oc] → [1,1,ic,oc]
[m appendFormat:@" tensor<int32, [4]> %s_rw = const()[name=string(\"%s_rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_W = reshape(shape=%s_rw,x=%s_wt)[name=string(\"%s_W\")];\n", ic, oc, prefix, prefix, prefix, prefix];
// matmul: [1,1,seq,ic] @ [1,1,ic,oc] → [1,1,seq,oc]
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yh = matmul(transpose_x=bF,transpose_y=bF,x=%s_a3,y=%s_W)[name=string(\"%s_yh\")];\n", seq, oc, prefix, prefix, prefix, prefix];
// Transpose back + reshape: [1,1,seq,oc] → [1,1,oc,seq] → [1,oc,1,seq]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yt = transpose(perm=%s_pm,x=%s_yh)[name=string(\"%s_yt\")];\n", oc, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_ro = const()[name=string(\"%s_ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, oc, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_y = reshape(shape=%s_ro,x=%s_yt)[name=string(\"%s_y\")];\n", oc, seq, prefix, prefix, prefix, prefix];
}
// ===== Dynamic matmul kernel: y = x @ W =====
// Input: [1, IC, 1, SEQ+OC] fp16 — act[0:SEQ] + W[SEQ:SEQ+OC]
// Output: [1, OC, 1, SEQ] fp16
// Simple dynamic matmul kernel: y = x @ W, input [1,IC,1,SEQ+OC], output [1,OC,1,SEQ]
static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
@ -57,20 +45,18 @@ static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) {
return m;
}
// ===== SDPA forward (dynamic weights) =====
// Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul
// Input: [1, DIM, 1, SEQ + 4*DIM] fp16
// sp[0:SEQ] = xnorm (rmsnorm output, DIM channels)
// sp[SEQ:SEQ+DIM] = Wq[DIM,DIM]
// sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM]
// sp[SEQ+2D:SEQ+3D] = Wv[DIM,DIM]
// sp[SEQ+3D:SEQ+4D] = Wo[DIM,DIM]
// Output: [1, 6*DIM, 1, SEQ] fp16 = concat(o_out, Q, K, V, attn_out, xnorm_pass)
// NOTE: mask is still a const weight (it doesn't change)
// ===== SDPA forward with GQA (no Wo) =====
// Input: [1, DIM, 1, SEQ + Q_DIM + KV_DIM + KV_DIM] fp16
// sp[0:SEQ] = xnorm [DIM, SEQ]
// sp[SEQ:SEQ+Q_DIM] = Wq [DIM, Q_DIM]
// sp[SEQ+Q_DIM:SEQ+Q_DIM+KVD] = Wk [DIM, KV_DIM]
// sp[SEQ+Q_DIM+KVD:...] = Wv [DIM, KV_DIM]
// Output: [1, Q_DIM+Q_DIM+KV_DIM+KV_DIM+DIM, 1, SEQ] fp16
// = concat(attn_out, Q_rope, K_rope, V, xnorm_pass)
static NSString *gen_sdpa_fwd_dynamic(void) {
float sc = 1.0f/sqrtf((float)HD);
int w_total = 4*DIM; // Wq+Wk+Wv+Wo
int sp_in = SEQ + w_total;
int sp_in = SDPA_FWD_SP;
int out_ch = Q_DIM + Q_DIM + KV_DIM + KV_DIM + DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
@ -80,100 +66,126 @@ static NSString *gen_sdpa_fwd_dynamic(void) {
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=x,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice Wq [1,DIM,1,DIM]
// Slice Wq [1,DIM,1,Q_DIM]
[m appendFormat:@" tensor<int32, [4]> bq = const()[name=string(\"bq\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=x,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> swq = const()[name=string(\"swq\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=x,begin=bq,size=swq)[name=string(\"Wq\")];\n", DIM, Q_DIM];
// Slice Wk
[m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=x,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM];
// Slice Wk [1,DIM,1,KV_DIM]
[m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+Q_DIM];
[m appendFormat:@" tensor<int32, [4]> swk = const()[name=string(\"swk\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=x,begin=bk,size=swk)[name=string(\"Wk\")];\n", DIM, KV_DIM];
// Slice Wv
[m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=x,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM];
// Slice Wv [1,DIM,1,KV_DIM]
[m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+Q_DIM+KV_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=x,begin=bv,size=swk)[name=string(\"Wv\")];\n", DIM, KV_DIM];
// Slice Wo
[m appendFormat:@" tensor<int32, [4]> bo = const()[name=string(\"bo\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wo = slice_by_size(x=x,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM];
// Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D]
// Reshape xnorm for matmul: [1,DIM,1,SEQ] → [1,1,DIM,SEQ] → [1,1,SEQ,DIM]
[m appendFormat:@" tensor<int32, [4]> r2 = const()[name=string(\"r2\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=r2,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
// Reshape weights: [1,D,1,D] → [1,1,D,D]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wq2 = reshape(shape=rw,x=Wq)[name=string(\"Wq2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wk2 = reshape(shape=rw,x=Wk)[name=string(\"Wk2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wv2 = reshape(shape=rw,x=Wv)[name=string(\"Wv2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wo2 = reshape(shape=rw,x=Wo)[name=string(\"Wo2\")];\n", DIM, DIM];
// Reshape weights
[m appendFormat:@" tensor<int32, [4]> rwq = const()[name=string(\"rwq\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, Q_DIM];
[m appendFormat:@" tensor<int32, [4]> rwk = const()[name=string(\"rwk\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wq2 = reshape(shape=rwq,x=Wq)[name=string(\"Wq2\")];\n", DIM, Q_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wk2 = reshape(shape=rwk,x=Wk)[name=string(\"Wk2\")];\n", DIM, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wv2 = reshape(shape=rwk,x=Wv)[name=string(\"Wv2\")];\n", DIM, KV_DIM];
// QKV matmul: [1,1,S,D] @ [1,1,D,D] → [1,1,S,D]
// QKV matmul
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, DIM];
// Q: [1,1,SEQ,DIM] @ [1,1,DIM,Q_DIM] → [1,1,SEQ,Q_DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, Q_DIM];
// K: [1,1,SEQ,DIM] @ [1,1,DIM,KV_DIM] → [1,1,SEQ,KV_DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, KV_DIM];
// V: same as K
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, KV_DIM];
// Transpose back: [1,1,S,D] → [1,1,D,S] → reshape [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> os = const()[name=string(\"os\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = reshape(shape=os,x=qt)[name=string(\"qf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = reshape(shape=os,x=kt)[name=string(\"kf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = reshape(shape=os,x=vt)[name=string(\"vf\")];\n", DIM, SEQ];
// Transpose back: [1,1,SEQ,X] → [1,1,X,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", KV_DIM, SEQ];
// SDPA: reshape to heads, matmul, mask, softmax, matmul
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
// Reshape to [1,X,1,SEQ]
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> kvsh = const()[name=string(\"kvsh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = reshape(shape=qsh,x=qt)[name=string(\"qf\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = reshape(shape=kvsh,x=kt)[name=string(\"kf\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = reshape(shape=kvsh,x=vt)[name=string(\"vf\")];\n", KV_DIM, SEQ];
// Reshape to heads for attention
// Q: [1,Q_DIM,1,SEQ] → [1,HEADS,HD,SEQ] → transpose → [1,HEADS,SEQ,HD]
[m appendFormat:@" tensor<int32, [4]> qhsh = const()[name=string(\"qhsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qhsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD];
// K: [1,KV_DIM,1,SEQ] → [1,KV_HEADS,HD,SEQ] → [1,KV_HEADS,SEQ,HD]
[m appendFormat:@" tensor<int32, [4]> khsh = const()[name=string(\"khsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=khsh,x=kf)[name=string(\"rk\")];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", KV_HEADS, SEQ, HD];
// V: same reshape as K
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=khsh,x=vf)[name=string(\"rv\")];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", KV_HEADS, SEQ, HD];
// RoPE: q_rope = q * cos + rotate_half(q) * sin, same for k
int pairs = SEQ * HD / 2;
// RoPE on Q: [1,HEADS,SEQ,HD]
int pairs_q = SEQ * HD / 2;
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> rope_cos = const()[name=string(\"rc\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/rope_cos.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> rope_sin = const()[name=string(\"rs\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/rope_sin.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD];
[m appendFormat:@" tensor<int32, [4]> rp_sh = const()[name=string(\"rp_sh\"), val=tensor<int32, [4]>([1,%d,%d,2])];\n", HEADS, pairs];
[m appendFormat:@" tensor<int32, [4]> rp_s1 = const()[name=string(\"rp_s1\"), val=tensor<int32, [4]>([1,%d,%d,1])];\n", HEADS, pairs];
[m appendFormat:@" tensor<int32, [4]> rp_sh = const()[name=string(\"rp_sh\"), val=tensor<int32, [4]>([1,%d,%d,2])];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<int32, [4]> rp_s1 = const()[name=string(\"rp_s1\"), val=tensor<int32, [4]>([1,%d,%d,1])];\n", HEADS, pairs_q];
[m appendString:@" tensor<int32, [4]> rp_b0 = const()[name=string(\"rp_b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendString:@" tensor<int32, [4]> rp_b1 = const()[name=string(\"rp_b1\"), val=tensor<int32, [4]>([0,0,0,1])];\n"];
[m appendFormat:@" tensor<int32, [4]> rp_bk = const()[name=string(\"rp_bk\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, HD];
// rotate_half(q): reshape to pairs, swap+negate, reshape back
[m appendString:@" fp16 neg1 = const()[name=string(\"neg1\"), val=fp16(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> q_p = reshape(shape=rp_sh,x=q)[name=string(\"q_p\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_e = slice_by_size(x=q_p,begin=rp_b0,size=rp_s1)[name=string(\"q_e\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_o = slice_by_size(x=q_p,begin=rp_b1,size=rp_s1)[name=string(\"q_o\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nq = mul(x=q_o,y=neg1)[name=string(\"nq\")];\n", HEADS, pairs];
[m appendString:@" int32 rpax = const()[name=string(\"rpax\"), val=int32(3)];\n"];
[m appendString:@" bool rpil = const()[name=string(\"rpil\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> qrp = concat(axis=rpax,interleave=rpil,values=(nq,q_e))[name=string(\"qrp\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q_rot = reshape(shape=rp_bk,x=qrp)[name=string(\"q_rot\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<int32, [4]> rp_bk_q = const()[name=string(\"rp_bk_q\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, HD];
// rotate_half(q)
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> q_p = reshape(shape=rp_sh,x=q)[name=string(\"q_p\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_e = slice_by_size(x=q_p,begin=rp_b0,size=rp_s1)[name=string(\"q_e\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_o = slice_by_size(x=q_p,begin=rp_b1,size=rp_s1)[name=string(\"q_o\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nq = mul(x=q_o,y=neg1)[name=string(\"nq\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> qrp = concat(axis=rpax,interleave=rpil,values=(nq,q_e))[name=string(\"qrp\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q_rot = reshape(shape=rp_bk_q,x=qrp)[name=string(\"q_rot\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qc = mul(x=q,y=rope_cos)[name=string(\"qc\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qrs = mul(x=q_rot,y=rope_sin)[name=string(\"qrs\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q_rope = add(x=qc,y=qrs)[name=string(\"q_rope\")];\n", HEADS, SEQ, HD];
// rotate_half(k)
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> k_p = reshape(shape=rp_sh,x=k)[name=string(\"k_p\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_e = slice_by_size(x=k_p,begin=rp_b0,size=rp_s1)[name=string(\"k_e\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_o = slice_by_size(x=k_p,begin=rp_b1,size=rp_s1)[name=string(\"k_o\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nk = mul(x=k_o,y=neg1)[name=string(\"nk\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> krp = concat(axis=rpax,interleave=rpil,values=(nk,k_e))[name=string(\"krp\")];\n", HEADS, pairs];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rot = reshape(shape=rp_bk,x=krp)[name=string(\"k_rot\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kc = mul(x=k,y=rope_cos)[name=string(\"kc\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krs = mul(x=k_rot,y=rope_sin)[name=string(\"krs\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rope = add(x=kc,y=krs)[name=string(\"k_rope\")];\n", HEADS, SEQ, HD];
// Q_rope @ K_rope^T
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q_rope,y=k_rope)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
// RoPE on K: [1,KV_HEADS,SEQ,HD]
int pairs_k = SEQ * HD / 2;
[m appendFormat:@" tensor<int32, [4]> rp_sh_k = const()[name=string(\"rp_sh_k\"), val=tensor<int32, [4]>([1,%d,%d,2])];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<int32, [4]> rp_s1_k = const()[name=string(\"rp_s1_k\"), val=tensor<int32, [4]>([1,%d,%d,1])];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<int32, [4]> rp_bk_k = const()[name=string(\"rp_bk_k\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> k_p = reshape(shape=rp_sh_k,x=k)[name=string(\"k_p\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_e = slice_by_size(x=k_p,begin=rp_b0,size=rp_s1_k)[name=string(\"k_e\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_o = slice_by_size(x=k_p,begin=rp_b1,size=rp_s1_k)[name=string(\"k_o\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nk = mul(x=k_o,y=neg1)[name=string(\"nk\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> krp = concat(axis=rpax,interleave=rpil,values=(nk,k_e))[name=string(\"krp\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rot = reshape(shape=rp_bk_k,x=krp)[name=string(\"k_rot\")];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kc = mul(x=k,y=rope_cos)[name=string(\"kc\")];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krs = mul(x=k_rot,y=rope_sin)[name=string(\"krs\")];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rope = add(x=kc,y=krs)[name=string(\"k_rope\")];\n", KV_HEADS, SEQ, HD];
// GQA: tile K,V from KV_HEADS to HEADS
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
// For GQA_RATIO=2: concat(k_rope, k_rope) along head dim
NSMutableString *k_vals = [NSMutableString string];
NSMutableString *v_vals = [NSMutableString string];
for (int r = 0; r < GQA_RATIO; r++) {
if (r > 0) { [k_vals appendString:@","]; [v_vals appendString:@","]; }
[k_vals appendString:@"k_rope"]; [v_vals appendString:@"v"];
}
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_tiled = concat(axis=cax,interleave=cid,values=(%@))[name=string(\"ktile\")];\n", HEADS, SEQ, HD, k_vals];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v_tiled = concat(axis=cax,interleave=cid,values=(%@))[name=string(\"vtile\")];\n", HEADS, SEQ, HD, v_vals];
// Q_rope @ K_tiled^T → [1,HEADS,SEQ,SEQ]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q_rope,y=k_tiled)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ];
// Causal mask (still const — doesn't change)
// Causal mask
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ];
@ -181,90 +193,67 @@ static NSString *gen_sdpa_fwd_dynamic(void) {
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ];
// scores @ V
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS, SEQ, HD];
// scores @ V_tiled → [1,HEADS,SEQ,HD]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v_tiled)[name=string(\"mm2\")];\n", HEADS, SEQ, HD];
// Reshape back to [1,DIM,1,SEQ]
// Reshape attn_out to [1,Q_DIM,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=qsh,x=at)[name=string(\"ra\")];\n", Q_DIM, SEQ];
// Wo matmul: af → [1,1,S,D] @ Wo[1,1,D,D] → [1,1,S,D] → [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> af2 = reshape(shape=r2,x=af)[name=string(\"af2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> aft = transpose(perm=pm,x=af2)[name=string(\"aft\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> om = matmul(transpose_x=bF,transpose_y=bF,x=aft,y=Wo2)[name=string(\"om\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ];
// Convert RoPE'd Q,K back to [1,DIM,1,SEQ] for backward pass output
// Convert RoPE'd Q,K back to flat layout for backward
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qrt = transpose(perm=pm,x=q_rope)[name=string(\"qrt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qrf = reshape(shape=os,x=qrt)[name=string(\"qrf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krt = transpose(perm=pm,x=k_rope)[name=string(\"krt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> krf = reshape(shape=os,x=krt)[name=string(\"krf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qrf = reshape(shape=qsh,x=qrt)[name=string(\"qrf\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krt = transpose(perm=pm,x=k_rope)[name=string(\"krt\")];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> krf = reshape(shape=kvsh,x=krt)[name=string(\"krf\")];\n", KV_DIM, SEQ];
// Output: concat(o_out, Q_rope, K_rope, V, attn_out, xnorm) for backward
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qrf,krf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ];
// Output: concat(attn_out[Q_DIM], Q_rope[Q_DIM], K_rope[KV_DIM], V[KV_DIM], xnorm[DIM])
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(af,qrf,krf,vf,xn))[name=string(\"cat\")];\n", out_ch, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// ===== FFN forward (dynamic weights) =====
// RMSNorm on CPU. This kernel: xnorm @ W1 → SiLU, xnorm @ W3 → gate, gate*silu @ W2 → out
// Input: [1, DIM, 1, SEQ + HIDDEN + HIDDEN + DIM] fp32
// sp[0:SEQ] = xnorm [DIM,SEQ]
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN] = W3[DIM,HIDDEN]
// sp[SEQ+2*HIDDEN:SEQ+2*HIDDEN+DIM]= W2[HIDDEN→DIM] — but W2 is [DIM,HIDDEN], we need HIDDEN input channels
// PROBLEM: W2 has shape [DIM,HIDDEN] = HIDDEN input channels, but our kernel has DIM input channels.
// Solution: separate kernels for W1/W3 (DIM→HIDDEN) and W2 (HIDDEN→DIM)
// OR: do W1,W3 in one kernel, SiLU on CPU/ANE, W2 in another kernel.
// Simpler: 3 separate matmul kernels per FFN direction. But that's too many dispatches.
// Better: one kernel for W1+W3 (same input dim), CPU SiLU, one kernel for W2.
// woFwd: attn_out[Q_DIM,SEQ] @ Wo → o_out[DIM,SEQ]
// Simple dyn_matmul: IC=Q_DIM, OC=DIM
static NSString *gen_wo_fwd_dynamic(void) {
return gen_dyn_matmul_mil(Q_DIM, DIM, SEQ);
}
// FFN part 1: xnorm @ W1, xnorm @ W3 (both DIM→HIDDEN)
// Input: [1, DIM, 1, SEQ + 2*HIDDEN] fp32
// sp[0:SEQ] = xnorm
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN]= W3[DIM,HIDDEN]
// Output: [1, 2*HIDDEN, 1, SEQ] fp32 = concat(h1, h3)
static NSString *gen_ffn_w13_dynamic(void) {
int sp_in = SEQ + 2*HIDDEN;
// ===== Fused FFN forward: W1,W3 + SiLU + W2 + residual =====
// Same structure as before, just with Qwen3 DIM=1024, HIDDEN=3072
static NSString *gen_ffn_fused_dynamic(void) {
int sp_in = FFN_FUSED_SP;
int out_ch = DIM + 3*HIDDEN;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice xnorm
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice W1
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> s1 = const()[name=string(\"s1\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=xh,begin=b1,size=s1)[name=string(\"W1\")];\n", DIM, HIDDEN];
// Slice W3
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=xh,begin=b3,size=s1)[name=string(\"W3\")];\n", DIM, HIDDEN];
// Reshape for matmul
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Slice x2norm, x2, W1, W3, W2_orig
[m appendString:@" tensor<int32, [4]> b_xn = const()[name=string(\"b_xn\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> s_ds = const()[name=string(\"s_ds\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2norm = slice_by_size(x=x,begin=b_xn,size=s_ds)[name=string(\"x2norm\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b_x2 = const()[name=string(\"b_x2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2 = slice_by_size(x=x,begin=b_x2,size=s_ds)[name=string(\"x2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b_w1 = const()[name=string(\"b_w1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> s_wh = const()[name=string(\"s_wh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=x,begin=b_w1,size=s_wh)[name=string(\"W1\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> b_w3 = const()[name=string(\"b_w3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=x,begin=b_w3,size=s_wh)[name=string(\"W3\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> b_w2 = const()[name=string(\"b_w2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+2*HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2r = slice_by_size(x=x,begin=b_w2,size=s_wh)[name=string(\"W2r\")];\n", DIM, HIDDEN];
// xnorm matmul
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=x2norm)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN];
// Transpose back
// Reshape back
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> rh = const()[name=string(\"rh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
@ -276,107 +265,24 @@ static NSString *gen_ffn_w13_dynamic(void) {
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ];
// Concat output: (h1, h3, gate)
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(h1,h3,gate))[name=string(\"cat\")];\n", 2*HIDDEN+HIDDEN, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 3*HIDDEN, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
return m;
}
// ===== Fused FFN forward: W1,W3 + SiLU + W2 + residual =====
// RMSNorm stays on CPU (ANE can't handle RMS + 3 matmuls without BNNS fallback)
// Replaces: ffnW13 + CPU gate read + ffnW2 + CPU residual
// Input: [1, DIM, 1, 2*SEQ + 3*HIDDEN] fp16
// sp[0:SEQ] = x2norm (RMSNorm output, from CPU)
// sp[SEQ:2*SEQ] = x2 (residual, for x_next = x2 + ffn_out)
// sp[2*SEQ : 2*SEQ+HIDDEN] = W1t[DIM,HIDDEN]
// sp[2*SEQ+HIDDEN : 2*SEQ+2*HIDDEN] = W3t[DIM,HIDDEN]
// sp[2*SEQ+2*HIDDEN : 2*SEQ+3*HIDDEN] = W2_orig[DIM,HIDDEN] (transposed inside kernel)
// Output: [1, DIM + 3*HIDDEN, 1, SEQ] fp16
// = concat(x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN])
static NSString *gen_ffn_fused_dynamic(void) {
int sp_in = 2*SEQ + 3*HIDDEN;
int out_ch = DIM + 3*HIDDEN;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Slice x2norm [DIM, SEQ] — RMSNorm output (computed on CPU)
[m appendString:@" tensor<int32, [4]> b_xn = const()[name=string(\"b_xn\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> s_ds = const()[name=string(\"s_ds\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2norm = slice_by_size(x=x,begin=b_xn,size=s_ds)[name=string(\"x2norm\")];\n", DIM, SEQ];
// Slice x2 [DIM, SEQ] — for residual: x_next = x2 + ffn_out
[m appendFormat:@" tensor<int32, [4]> b_x2 = const()[name=string(\"b_x2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2 = slice_by_size(x=x,begin=b_x2,size=s_ds)[name=string(\"x2\")];\n", DIM, SEQ];
// Slice W1 [DIM, HIDDEN]
[m appendFormat:@" tensor<int32, [4]> b_w1 = const()[name=string(\"b_w1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> s_wh = const()[name=string(\"s_wh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=x,begin=b_w1,size=s_wh)[name=string(\"W1\")];\n", DIM, HIDDEN];
// Slice W3 [DIM, HIDDEN]
[m appendFormat:@" tensor<int32, [4]> b_w3 = const()[name=string(\"b_w3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=x,begin=b_w3,size=s_wh)[name=string(\"W3\")];\n", DIM, HIDDEN];
// Slice W2_orig [DIM, HIDDEN] (transposed inside kernel)
[m appendFormat:@" tensor<int32, [4]> b_w2 = const()[name=string(\"b_w2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+2*HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2r = slice_by_size(x=x,begin=b_w2,size=s_wh)[name=string(\"W2r\")];\n", DIM, HIDDEN];
// Reshape for matmul: x2norm [1,DIM,1,SEQ] → [1,1,DIM,SEQ] → [1,1,SEQ,DIM]
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=x2norm)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
// Reshape weights: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN];
// h1 = x2norm_t @ W1, h3 = x2norm_t @ W3 [SEQ,DIM] @ [DIM,HIDDEN] → [SEQ,HIDDEN]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN];
// Reshape back: [1,1,SEQ,HIDDEN] → [1,1,HIDDEN,SEQ] → [1,HIDDEN,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> rh = const()[name=string(\"rh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = reshape(shape=rh,x=h1t)[name=string(\"h1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = reshape(shape=rh,x=h3t)[name=string(\"h3\")];\n", HIDDEN, SEQ];
// SiLU + gate: gate = silu(h1) * h3
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ];
// gate @ W2: reshape gate [1,HIDDEN,1,SEQ] → [1,1,HIDDEN,SEQ] → [1,1,SEQ,HIDDEN]
// gate @ W2: W2 is [DIM, HIDDEN] stored as-is, transpose inside kernel
[m appendFormat:@" tensor<int32, [4]> rg = const()[name=string(\"rg\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> g2 = reshape(shape=rg,x=gate)[name=string(\"g2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> gt = transpose(perm=pm,x=g2)[name=string(\"gtt\")];\n", SEQ, HIDDEN];
// W2: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN] → transpose → [1,1,HIDDEN,DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W22 = reshape(shape=rw,x=W2r)[name=string(\"W22\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W2t = transpose(perm=pm,x=W22)[name=string(\"W2t\")];\n", HIDDEN, DIM];
// matmul: [1,1,SEQ,HIDDEN] @ [1,1,HIDDEN,DIM] → [1,1,SEQ,DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> fm = matmul(transpose_x=bF,transpose_y=bF,x=gt,y=W2t)[name=string(\"fm\")];\n", SEQ, DIM];
// Reshape: [1,1,SEQ,DIM] → [1,1,DIM,SEQ] → [1,DIM,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ft = transpose(perm=pm,x=fm)[name=string(\"ft\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> rd2 = const()[name=string(\"rd2\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> ffn_out = reshape(shape=rd2,x=ft)[name=string(\"ffn_out\")];\n", DIM, SEQ];
// Residual: x_next = x2 + alpha * ffn_out (residual scaling)
// Residual: x_next = x2 + alpha * ffn_out
float alpha = 1.0f / sqrtf(2.0f * NLAYERS);
[m appendFormat:@" fp16 res_alpha = const()[name=string(\"res_alpha\"), val=fp16(%g)];\n", alpha];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> ffn_scaled = mul(x=ffn_out,y=res_alpha)[name=string(\"ffn_sc\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x_next = add(x=x2,y=ffn_scaled)[name=string(\"x_next\")];\n", DIM, SEQ];
// Output: concat(x_next, h1, h3, gate) — gate=silu*h3 needed for dW2
// Output: concat(x_next, h1, h3, gate)
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(x_next,h1,h3,gate))[name=string(\"cat\")];\n", out_ch, SEQ];
@ -384,103 +290,37 @@ static NSString *gen_ffn_fused_dynamic(void) {
return m;
}
// FFN part 2: gate @ W2 (HIDDEN→DIM)
// Input: [1, HIDDEN, 1, SEQ + DIM] fp32
// sp[0:SEQ] = gate [HIDDEN,SEQ]
// sp[SEQ:SEQ+DIM] = W2[HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32
static NSString *gen_ffn_w2_dynamic(void) {
int sp_in = SEQ + DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
// ===== Backward kernels =====
[m appendString:@" tensor<int32, [4]> ba = const()[name=string(\"ba\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sa = const()[name=string(\"sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> bw = const()[name=string(\"bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2 = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"W2\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> at = transpose(perm=pm,x=a2)[name=string(\"at\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W22 = reshape(shape=rw,x=W2)[name=string(\"W22\")];\n", HIDDEN, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ym = matmul(transpose_x=bF,transpose_y=bF,x=at,y=W22)[name=string(\"ym\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=ym)[name=string(\"yt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=yr)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== FFN backward (dynamic weights) =====
// Input: [1, DIM+2*HIDDEN, 1, SEQ + HIDDEN + DIM + DIM] fp32
// Actually simpler to split into separate backward kernels like forward.
// FFN backward part 1: dffn @ W2^T → dsilu (HIDDEN), then SiLU derivative
// Input: [1, DIM, 1, SEQ + HIDDEN] fp32
// sp[0:SEQ] = dffn [DIM, SEQ]
// sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN]
// Output: [1, HIDDEN, 1, SEQ] fp16 = dsilu_raw
// ffnBwdW2t: dffn @ W2 → dsilu_raw (IC=DIM, OC=HIDDEN)
static NSString *gen_ffn_bwd_w2t_dynamic(void) {
return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ);
}
// FFN backward part 2: dh1 @ W1^T + dh3 @ W3^T → dx
// We need h1,h3 for SiLU derivative, but those are on CPU.
// Actually the SiLU derivative + gating is element-wise, do on CPU.
// Then: dh1 @ W1^T and dh3 @ W3^T are two separate matmuls (HIDDEN→DIM).
// Combine into one kernel:
// Input: [1, HIDDEN, 1, SEQ + SEQ + DIM + DIM] fp32
// sp[0:SEQ] = dh1 [HIDDEN,SEQ]
// sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ]
// sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM]
// sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp16 = dx1 + dx3
// ffnBwdW13t: dh1 @ W1 + dh3 @ W3 → dx_ffn (IC=HIDDEN, two matmuls added)
static NSString *gen_ffn_bwd_w13t_dynamic(void) {
int sp_in = 2*SEQ + 2*DIM;
int sp_in = FFN_BWD_W13T_SP;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
// Slice dh1 [HIDDEN, SEQ]
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
// Slice dh3
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
// Slice W1^T [HIDDEN, DIM]
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1t = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM];
// Slice W3^T
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3t = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
// dh1 matmul: [S,H] @ [H,D] → [S,D]
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh12 = reshape(shape=ra,x=dh1)[name=string(\"dh12\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh1t = transpose(perm=pm,x=dh12)[name=string(\"dh1t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh32 = reshape(shape=ra,x=dh3)[name=string(\"dh32\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh3t = transpose(perm=pm,x=dh32)[name=string(\"dh3t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W1t2 = reshape(shape=rw,x=W1t)[name=string(\"W1t2\")];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W3t2 = reshape(shape=rw,x=W3t)[name=string(\"W3t2\")];\n", HIDDEN, DIM];
@ -488,10 +328,7 @@ static NSString *gen_ffn_bwd_w13t_dynamic(void) {
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx1m = matmul(transpose_x=bF,transpose_y=bF,x=dh1t,y=W1t2)[name=string(\"dx1m\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx3m = matmul(transpose_x=bF,transpose_y=bF,x=dh3t,y=W3t2)[name=string(\"dx3m\")];\n", SEQ, DIM];
// Add
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxm = add(x=dx1m,y=dx3m)[name=string(\"dxm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
@ -499,40 +336,80 @@ static NSString *gen_ffn_bwd_w13t_dynamic(void) {
return m;
}
// ===== SDPA backward part 1 (dynamic Wo^T) =====
// Same as original gen_sdpa_bwd1 but Wo^T comes from input instead of const
// Input: [1, 4*DIM, 1, SEQ + DIM] fp32 — Q,K,V,dx2 in channels, Wo^T in spatial
// Wait — channels must match for all data. Q,K,V are [DIM,SEQ], dx2 is [DIM,SEQ].
// Total input channels = 4*DIM. But Wo^T is [DIM,DIM] = DIM channels of DIM spatial.
// Problem: can't mix 4*DIM channels for data with DIM channels for Wo^T.
// Solution: Wo^T matmul as separate kernel, then SDPA part purely element-wise on ANE.
// Wo^T matmul: dx2 @ Wo^T → da (DIM→DIM)
// wotBwd: dy @ Wo → da (IC=DIM, OC=Q_DIM)
static NSString *gen_wot_dynamic(void) {
return gen_dyn_matmul_mil(DIM, DIM, SEQ);
return gen_dyn_matmul_mil(DIM, Q_DIM, SEQ);
}
// SDPA backward part 1 (no weights, all data): Q,K,V,da → dV,probs,dp
// Same as original but without Wo^T conv (already done)
// Input: [1, 4*DIM, 1, SEQ] fp16
static NSString *gen_sdpa_bwd1_noweight(void) {
float sc = 1.0f/sqrtf((float)HD);
// qBwd: dq @ Wq → dx_q (IC=Q_DIM, OC=DIM)
static NSString *gen_q_bwd_dynamic(void) {
return gen_dyn_matmul_mil(Q_DIM, DIM, SEQ);
}
// kvBwd: dk @ Wk + dv @ Wv → dx_kv (IC=KV_DIM)
// Input: [1, KV_DIM, 1, 2*SEQ+2*DIM] fp16
// Same pattern as ffnBwdW13t but with KV_DIM channels
static NSString *gen_kv_bwd_dynamic(void) {
int sp_in = KV_BWD_SP;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 4*DIM, SEQ];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", KV_DIM, sp_in];
// Slice Q,K,V,da
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", KV_DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dk\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dv\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"Wkt\")];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"Wvt\")];\n", KV_DIM, DIM];
// Reshape to heads
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dk2 = reshape(shape=ra,x=dk)[name=string(\"dk2\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dv2 = reshape(shape=ra,x=dv)[name=string(\"dv2\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, KV_DIM];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", KV_DIM, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxm = add(x=dxk,y=dxv)[name=string(\"dxm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" } -> (dx);\n}\n"];
return m;
}
// SDPA backward part 1: recompute attention + dV, dp
// Uses tiled K,V at HEADS dimension (CPU pre-tiles)
// Input: [1, 2*Q_DIM+2*Q_DIM, 1, SEQ] fp16 = (Q, K_tiled, V_tiled, da)
// Output: [1, Q_DIM+2*SCORE_CH, 1, SEQ] fp16 = (dV_full, probs, dp)
static NSString *gen_sdpa_bwd1_noweight(void) {
float sc = 1.0f/sqrtf((float)HD);
int in_ch = 4*Q_DIM; // Q + K_tiled + V_tiled + da, all at Q_DIM (HEADS*HD)
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", in_ch, SEQ];
// Slice Q,K_tiled,V_tiled,da — all [Q_DIM, SEQ]
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", Q_DIM, SEQ];
// Reshape to heads [1,HEADS,HD,SEQ] → [1,HEADS,SEQ,HD]
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
@ -544,7 +421,7 @@ static NSString *gen_sdpa_bwd1_noweight(void) {
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dr = reshape(shape=rsh,x=da)[name=string(\"rd\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dat = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS, SEQ, HD];
// Forward attention scores (recompute)
// Recompute attention scores
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
@ -559,49 +436,57 @@ static NSString *gen_sdpa_bwd1_noweight(void) {
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=dat)[name=string(\"dv\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp4 = matmul(transpose_x=bF,transpose_y=bT,x=dat,y=v)[name=string(\"dp\")];\n", HEADS, SEQ, SEQ];
// Reshape dV back
// Reshape dV to [Q_DIM, SEQ] (will be reduced to KV_DIM on CPU)
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", Q_DIM, SEQ];
// Flatten probs and dp for output
// Flatten probs and dp
[m appendFormat:@" tensor<int32, [4]> scs = const()[name=string(\"scs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", Q_DIM+2*SCORE_CH, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 2: same as original (no weights, pure computation)
// SDPA backward part 2: probs, dp, Q, K_tiled → dQ, dK_full
// Input: [1, 2*SCORE_CH + 2*Q_DIM, 1, SEQ]
// Output: [1, 2*Q_DIM, 1, SEQ] = (dQ, dK_full)
static NSString *gen_sdpa_bwd2(void) {
float sc = 1.0f/sqrtf((float)HD);
int bwd2_in = 2*SCORE_CH + 2*DIM;
int bwd2_in = 2*SCORE_CH + 2*Q_DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", bwd2_in, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_sc = const()[name=string(\"szsc\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_d = const()[name=string(\"szd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_q = const()[name=string(\"szq\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_q)[name=string(\"s2\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_q)[name=string(\"s3\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ssh = const()[name=string(\"ssh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
// Softmax backward: ds = (dp - sum(dp*probs)) * probs * scale
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([-1])];\n"];
[m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
@ -610,92 +495,26 @@ static NSString *gen_sdpa_bwd2(void) {
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", Q_DIM, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*Q_DIM, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// QKV backward (dynamic): dq @ Wq^T + dk @ Wk^T + dv @ Wv^T → dx
// Input: [1, DIM, 1, 3*SEQ + 3*DIM] fp32
// sp[0:SEQ] = dq [DIM,SEQ]
// sp[SEQ:2*SEQ] = dk [DIM,SEQ]
// sp[2*SEQ:3*SEQ] = dv [DIM,SEQ]
// sp[3*SEQ:3*SEQ+DIM] = Wq^T [DIM,DIM]
// sp[3*SEQ+DIM:3*SEQ+2D] = Wk^T [DIM,DIM]
// sp[3*SEQ+2D:3*SEQ+3D] = Wv^T [DIM,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dxq + dxk + dxv
static NSString *gen_qkvb_dynamic(void) {
int sp_in = 3*SEQ + 3*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
// Slice dq, dk, dv
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=x,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=x,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=x,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ];
// Slice Wq^T, Wk^T, Wv^T
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wqt = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b4 = const()[name=string(\"b4\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=x,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b5 = const()[name=string(\"b5\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=x,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Reshape and matmul for each
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
// dq @ Wq^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dq2 = reshape(shape=rd,x=dq)[name=string(\"dq2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dqt = transpose(perm=pm,x=dq2)[name=string(\"dqt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wqt2 = reshape(shape=rw,x=Wqt)[name=string(\"Wqt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxq = matmul(transpose_x=bF,transpose_y=bF,x=dqt,y=Wqt2)[name=string(\"dxq\")];\n", SEQ, DIM];
// dk @ Wk^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dk2 = reshape(shape=rd,x=dk)[name=string(\"dk2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM];
// dv @ Wv^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dv2 = reshape(shape=rd,x=dv)[name=string(\"dv2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM];
// Sum: dxq + dxk + dxv
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxall = add(x=dxqk,y=dxv)[name=string(\"aall\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" } -> (dx);\n}\n"];
return m;
}
// Causal mask blob (used by sdpa_fwd and sdpa_bwd1)
// Causal mask blob
static NSData *g_mask_blob = nil;
static NSData *get_mask_blob(void) {
if (!g_mask_blob) {
@ -708,7 +527,7 @@ static NSData *get_mask_blob(void) {
return g_mask_blob;
}
// RoPE cos/sin blobs [1, 1, SEQ, HD] — rotary position encodings
// RoPE cos/sin blobs [1, 1, SEQ, HD]
static NSData *g_rope_cos_blob = nil;
static NSData *g_rope_sin_blob = nil;

View File

@ -0,0 +1,19 @@
// qwen3_06b.h — Qwen3-0.6B (28 layers, GQA 16q/8kv, head_dim=128)
#pragma once
#define MODEL_NAME "Qwen3-0.6B"
#define DIM 1024
#define HIDDEN 3072
#define HEADS 16
#define KV_HEADS 8
#define HD 128 // explicit head_dim (NOT DIM/HEADS)
#define GQA_RATIO (HEADS / KV_HEADS) // = 2
#define Q_DIM (HEADS * HD) // = 2048
#define KV_DIM (KV_HEADS * HD) // = 1024 (= DIM for this model)
#define SEQ 256
#define NLAYERS 28
#define VOCAB 151936
#define CKPT_PATH "ane_qwen3_06b_dyn_ckpt.bin"
#define DEFAULT_DATA_PATH "../tinystories_data00.bin"

View File

@ -0,0 +1,19 @@
// stories110m.h — Stories110M (Llama2-style, 12 layers, MHA)
#pragma once
#define MODEL_NAME "Stories110M"
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define KV_HEADS 12
#define HD (DIM/HEADS) // = 64
#define GQA_RATIO 1 // MHA: no GQA
#define Q_DIM (HEADS * HD) // = 768 = DIM
#define KV_DIM (KV_HEADS * HD) // = 768 = DIM
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin"
#define DEFAULT_DATA_PATH "../tinystories_data00.bin"

View File

@ -1,53 +1,23 @@
// train.m Dynamic weight ANE training for Stories110M
// train.m Dynamic weight ANE training (model-agnostic GQA support)
// Model selected at compile time via: make MODEL=qwen3_06b (or stories110m)
// Compile kernels ONCE at startup, update weights via IOSurface every step.
// No exec() restart needed eliminates 76% compile overhead.
#include "mil_dynamic.h"
#include "cpu_ops.h"
#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin"
#define MODEL_PATH "../../../assets/models/stories110M.bin"
#define DEFAULT_DATA_PATH "../tinystories_data00.bin"
// Dynamic kernel set per layer
typedef struct {
Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface)
Kern *ffnFused; // residual + RMSNorm + W1,W3 + SiLU + W2 + residual (fused)
Kern *ffnBwdW2t; // dffn @ W2^T (dynamic)
Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic)
Kern *wotBwd; // dx2 @ Wo^T (dynamic)
Kern *sdpaBwd1; // Q,K,V,da dV,probs,dp (weight-free, has mask const)
Kern *sdpaBwd2; // probs,dp,Q,K dQ,dK (weight-free)
Kern *qkvBwd; // dq@Wq^T + dk@Wk^T + dv@Wv^T (dynamic)
Kern *sdpaFwd; // QKV matmul + RoPE + GQA tile + SDPA (no Wo)
Kern *woFwd; // attn_out @ Wo^T o_out (Q_DIM DIM)
Kern *ffnFused; // W1,W3 + SiLU + W2 + residual (fused)
Kern *ffnBwdW2t; // dffn @ W2^T dsilu_raw (DIM HIDDEN)
Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T dx_ffn (HIDDEN DIM)
Kern *wotBwd; // dx2 @ Wo da (DIM Q_DIM)
Kern *sdpaBwd1; // Q,K,V,da dV_full,probs,dp (weight-free, has mask)
Kern *sdpaBwd2; // probs,dp,Q,K dQ,dK_full (weight-free)
Kern *qBwd; // dq @ Wq dx_q (Q_DIM DIM)
Kern *kvBwd; // dk@Wk + dv@Wv dx_kv (KV_DIM DIM)
} DynLayerKernels;
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { printf("Cannot open %s\n", path); return false; }
Llama2Config cfg;
fread(&cfg, sizeof(cfg), 1, f);
printf(" Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
printf(" ERROR: Config mismatch!\n"); fclose(f); return false;
}
int V = abs(cfg.vocab_size);
fread(embed, 4, V * DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
fread(rms_final, 4, DIM, f);
fclose(f);
printf(" Loaded pretrained weights\n");
return true;
}
// Transpose W[rows,cols] W^T[cols,rows] stored as [cols channels, rows spatial]
static void transpose_weight(float *dst, const float *src, int rows, int cols) {
for (int r = 0; r < rows; r++)
@ -64,76 +34,72 @@ static bool compile_dynamic_kernels(DynLayerKernels *dk) {
@"@model_path/weights/rope_sin.bin": @{@"offset":@0, @"data":get_rope_sin_blob()}
};
// SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp16 [1, 6*DIM, 1, SEQ] fp16
printf(" Compiling sdpaFwd...\n");
int sdpa_out_ch = Q_DIM + Q_DIM + KV_DIM + KV_DIM + DIM;
// SDPA forward (no Wo): [1, DIM, 1, SDPA_FWD_SP] [1, sdpa_out_ch, 1, SEQ]
printf(" Compiling sdpaFwd (GQA)...\n");
dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), sdpa_fwd_w,
DIM*(SEQ+4*DIM)*2, 6*DIM*SEQ*2);
DIM*SDPA_FWD_SP*2, sdpa_out_ch*SEQ*2);
if (!dk->sdpaFwd) return false;
// Fused FFN: W1,W3 + SiLU + W2 + residual (RMSNorm on CPU)
// Wo forward: [1, Q_DIM, 1, SEQ+DIM] [1, DIM, 1, SEQ]
printf(" Compiling woFwd...\n");
dk->woFwd = compile_kern_mil_w(gen_wo_fwd_dynamic(), @{},
Q_DIM*WO_FWD_SP*2, DIM*SEQ*2);
if (!dk->woFwd) return false;
// Fused FFN: [1, DIM, 1, FFN_FUSED_SP] [1, DIM+3*HIDDEN, 1, SEQ]
printf(" Compiling ffnFused...\n");
int ffn_fused_sp = 2*SEQ + 3*HIDDEN;
int ffn_fused_och = DIM + 3*HIDDEN;
dk->ffnFused = compile_kern_mil_w(gen_ffn_fused_dynamic(), @{},
DIM*ffn_fused_sp*2, ffn_fused_och*SEQ*2);
DIM*FFN_FUSED_SP*2, ffn_fused_och*SEQ*2);
if (!dk->ffnFused) return false;
// FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp16 [1, HIDDEN, 1, SEQ] fp16
// FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] [1, HIDDEN, 1, SEQ]
printf(" Compiling ffnBwdW2t...\n");
dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{},
DIM*(SEQ+HIDDEN)*2, HIDDEN*SEQ*2);
DIM*FFN_BWD_W2T_SP*2, HIDDEN*SEQ*2);
if (!dk->ffnBwdW2t) return false;
// FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16 [1, DIM, 1, SEQ] fp16
// FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] [1, DIM, 1, SEQ]
printf(" Compiling ffnBwdW13t...\n");
dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{},
HIDDEN*(2*SEQ+2*DIM)*2, DIM*SEQ*2);
HIDDEN*FFN_BWD_W13T_SP*2, DIM*SEQ*2);
if (!dk->ffnBwdW13t) return false;
// Wo^T backward: [1, DIM, 1, SEQ+DIM] fp16 [1, DIM, 1, SEQ] fp16
// Wo^T backward: [1, DIM, 1, SEQ+Q_DIM] [1, Q_DIM, 1, SEQ]
printf(" Compiling wotBwd...\n");
dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{},
DIM*(SEQ+DIM)*2, DIM*SEQ*2);
DIM*WOT_BWD_SP*2, Q_DIM*SEQ*2);
if (!dk->wotBwd) return false;
// SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 [1, DIM+2*SCORE_CH, 1, SEQ] fp16
printf(" Compiling sdpaBwd1...\n");
// SDPA bwd1 (weight-free, has mask): [1, 4*Q_DIM, 1, SEQ] [1, Q_DIM+2*SCORE_CH, 1, SEQ]
printf(" Compiling sdpaBwd1 (GQA)...\n");
dk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_noweight(), mask_w,
4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
4*Q_DIM*SEQ*2, (Q_DIM+2*SCORE_CH)*SEQ*2);
if (!dk->sdpaBwd1) return false;
// SDPA bwd2 (no weights): [1, 2*SCORE_CH+2*DIM, 1, SEQ] fp16 [1, 2*DIM, 1, SEQ] fp16
printf(" Compiling sdpaBwd2...\n");
// SDPA bwd2 (weight-free): [1, 2*SCORE_CH+2*Q_DIM, 1, SEQ] [1, 2*Q_DIM, 1, SEQ]
printf(" Compiling sdpaBwd2 (GQA)...\n");
dk->sdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
(2*SCORE_CH+2*Q_DIM)*SEQ*2, 2*Q_DIM*SEQ*2);
if (!dk->sdpaBwd2) return false;
// QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp16 [1, DIM, 1, SEQ] fp16
printf(" Compiling qkvBwd...\n");
dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{},
DIM*(3*SEQ+3*DIM)*2, DIM*SEQ*2);
if (!dk->qkvBwd) return false;
// Q backward: [1, Q_DIM, 1, SEQ+DIM] [1, DIM, 1, SEQ]
printf(" Compiling qBwd...\n");
dk->qBwd = compile_kern_mil_w(gen_q_bwd_dynamic(), @{},
Q_DIM*Q_BWD_SP*2, DIM*SEQ*2);
if (!dk->qBwd) return false;
// KV backward: [1, KV_DIM, 1, 2*SEQ+2*DIM] [1, DIM, 1, SEQ]
printf(" Compiling kvBwd...\n");
dk->kvBwd = compile_kern_mil_w(gen_kv_bwd_dynamic(), @{},
KV_DIM*KV_BWD_SP*2, DIM*SEQ*2);
if (!dk->kvBwd) return false;
return true;
}
// ===== Write dynamic weights into IOSurface =====
// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] xnorm at sp[0:S], Wq/Wk/Wv/Wo at sp[S:]
static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm,
const float *Wq, const float *Wk, const float *Wv, const float *Wo) {
IOSurfaceLock(dk->sdpaFwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->sdpaFwd->ioIn);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL);
}
// ===== Checkpoint =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double ct, double cw, int cs, int adam_t,
@ -141,21 +107,22 @@ static void save_checkpoint(const char *path, int step, int total_steps, float l
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "wb");
CkptHdr h = {0};
h.magic = 0x424C5A54; h.version = 3;
h.magic = 0x424C5A54; h.version = 4;
h.step = step; h.total_steps = total_steps;
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
h.lr = lr; h.loss = loss;
h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.adam_t = adam_t;
h.kv_heads = KV_HEADS; h.head_dim = HD; h.q_dim = Q_DIM;
fwrite(&h, sizeof(h), 1, f);
for (int L = 0; L < NLAYERS; L++) {
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WK_SZ,f);
fwrite(lw[L].Wv,4,WV_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WK_SZ,f); fwrite(la[L].Wk.v,4,WK_SZ,f);
fwrite(la[L].Wv.m,4,WV_SZ,f); fwrite(la[L].Wv.v,4,WV_SZ,f);
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
@ -178,17 +145,17 @@ static bool load_checkpoint(const char *path, int *step, int *total_steps, float
if (!f) return false;
CkptHdr h;
fread(&h, sizeof(h), 1, f);
if (h.magic != 0x424C5A54 || h.version != 3) { fclose(f); return false; }
if (h.magic != 0x424C5A54 || h.version != 4) { fclose(f); return false; }
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
*ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *adam_t = h.adam_t;
for (int L = 0; L < NLAYERS; L++) {
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WK_SZ,f);
fread(lw[L].Wv,4,WV_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WK_SZ,f); fread(la[L].Wk.v,4,WK_SZ,f);
fread(la[L].Wv.m,4,WV_SZ,f); fread(la[L].Wv.v,4,WV_SZ,f);
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
@ -217,9 +184,9 @@ int main(int argc, char *argv[]) {
int accum_steps = 10;
int warmup_steps = 100;
float grad_clip = 1.0f;
float loss_scale = 256.0f; // fp16 loss scaling for ANE backward
float res_alpha = 1.0f / sqrtf(2.0f * NLAYERS); // residual scaling (DeepNet-style)
float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1
float loss_scale = 256.0f;
float res_alpha = 1.0f / sqrtf(2.0f * NLAYERS);
float min_lr_frac = 0.1f;
bool do_resume = false, from_scratch = false;
const char *data_path = DEFAULT_DATA_PATH;
@ -259,29 +226,28 @@ int main(int argc, char *argv[]) {
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
}
if (!resuming) {
printf("=== ANE Dynamic Training: Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
// Param counts for dashboard
double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6;
printf("=== ANE Dynamic Training: %s (%d layers, GQA %d/%d heads) ===\n",
MODEL_NAME, NLAYERS, HEADS, KV_HEADS);
printf("dim=%d q_dim=%d kv_dim=%d hd=%d hidden=%d seq=%d vocab=%d\n",
DIM, Q_DIM, KV_DIM, HD, HIDDEN, SEQ, VOCAB);
double xformer_m = (double)NLAYERS*(WQ_SZ + WK_SZ + WV_SZ + (double)WO_SZ + W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6;
double embed_m = (double)VOCAB*DIM / 1e6;
printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m);
printf("Kernels: 8 compiled (ffnFused replaces ffnW13+ffnW2, RMSNorm on CPU)\n");
printf("Kernels: 10 compiled (sdpaFwd+woFwd, ffnFused, ffnBwdW2t+W13t, wotBwd, sdpaBwd1+2, qBwd+kvBwd)\n");
printf("Accum %d steps, LR=%g\n", accum_steps, max_lr);
// FLOPs estimate: 6*N*B*T for transformer (forward+backward 3x forward)
double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ;
double total_flops = 3.0 * fwd_flops; // fwd + bwd 3x fwd
printf("FLOPs/step: fwd=%.1fM bwd_dx=%.1fM bwd_dW=%.1fM sdpa_bwd=0.0M total=%.1fM\n",
fwd_flops/1e6, fwd_flops/1e6, fwd_flops/1e6, total_flops/1e6);
printf("ANE FLOPs/step: %.1fM\n", total_flops/1e6);
if (from_scratch || !load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
if (from_scratch) printf(" Training from scratch (random init)\n");
else printf(" Pretrained load failed, using random init\n");
double fwd_flops = 2.0*NLAYERS*((double)WQ_SZ + WK_SZ + WV_SZ + WO_SZ + W1_SZ + W2_SZ + W3_SZ) * SEQ;
double total_flops = 3.0 * fwd_flops;
printf("FLOPs/step: fwd=%.1fM total=%.1fM\n", fwd_flops/1e6, total_flops/1e6);
if (from_scratch) {
printf(" Training from scratch (random init)\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); // LLaMA-style output proj scaling
float scale_d=1.0f/sqrtf(DIM), scale_qd=1.0f/sqrtf(Q_DIM), scale_h=1.0f/sqrtf(HIDDEN);
float res_scale = 1.0f/sqrtf(2.0f*NLAYERS);
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*res_scale*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++) lw[L].Wq[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<WK_SZ;i++) lw[L].Wk[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<WV_SZ;i++) lw[L].Wv[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<WO_SZ;i++) lw[L].Wo[i]=scale_qd*res_scale*(2*drand48()-1);
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*res_scale*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
@ -290,22 +256,31 @@ int main(int argc, char *argv[]) {
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
float escale = 0.02f;
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
} else {
printf(" ERROR: Pretrained weight loading not implemented for Qwen3. Use --scratch.\n");
return 1;
}
}
// Precompute transposed weights (for backward pass kernels)
// These get updated after each Adam step
// Precompute transposed weights for forward/backward kernels
// Forward: sdpaFwd needs Wq^T[Q_DIM,DIM], Wk^T[KV_DIM,DIM], Wv^T[KV_DIM,DIM]
// woFwd needs Wo^T[DIM,Q_DIM]
// Backward uses original (non-transposed) weights
float *Wqt_buf[NLAYERS], *Wkt_buf[NLAYERS], *Wvt_buf[NLAYERS], *Wot_buf[NLAYERS];
float *W1t_buf[NLAYERS], *W2t_buf[NLAYERS], *W3t_buf[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
Wqt_buf[L]=(float*)malloc(WQ_SZ*4); Wkt_buf[L]=(float*)malloc(WQ_SZ*4);
Wvt_buf[L]=(float*)malloc(WQ_SZ*4); Wot_buf[L]=(float*)malloc(WO_SZ*4);
Wqt_buf[L]=(float*)malloc(WQ_SZ*4); Wkt_buf[L]=(float*)malloc(WK_SZ*4);
Wvt_buf[L]=(float*)malloc(WV_SZ*4); Wot_buf[L]=(float*)malloc(WO_SZ*4);
W1t_buf[L]=(float*)malloc(W1_SZ*4); W2t_buf[L]=(float*)malloc(W2_SZ*4);
W3t_buf[L]=(float*)malloc(W3_SZ*4);
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
// Wq is [Q_DIM, DIM] Wq^T is [DIM, Q_DIM] (staged as [DIM channels, Q_DIM spatial])
transpose_weight(Wqt_buf[L], lw[L].Wq, Q_DIM, DIM);
// Wk is [KV_DIM, DIM] Wk^T is [DIM, KV_DIM]
transpose_weight(Wkt_buf[L], lw[L].Wk, KV_DIM, DIM);
// Wv is [KV_DIM, DIM] Wv^T is [DIM, KV_DIM]
transpose_weight(Wvt_buf[L], lw[L].Wv, KV_DIM, DIM);
// Wo is [DIM, Q_DIM] Wo^T is [Q_DIM, DIM]
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, Q_DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
@ -321,69 +296,72 @@ int main(int argc, char *argv[]) {
size_t n_tokens = data_len / 2;
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
// Vocab compaction: map 32K sparse vocab ~9K compact
// Vocab compaction
VocabMap vm = vocab_map_build(token_data, n_tokens, VOCAB);
int CV = vm.compact_vocab;
printf("Vocab compaction: %d → %d active tokens (%.1fx reduction)\n", VOCAB, CV, (float)VOCAB/CV);
// Create compact embedding + adam state
float *cembed = vocab_compact_embed(embed, &vm, DIM);
float *gcembed = (float*)calloc((size_t)CV*DIM, 4);
AdamState acembed = adam_alloc((size_t)CV*DIM);
// ===== Compile all kernels ONCE =====
printf("Compiling %d dynamic kernels (one-time)...\n", 8);
printf("Compiling 10 dynamic kernels (one-time)...\n");
uint64_t tc = mach_absolute_time();
DynLayerKernels dk;
if (!compile_dynamic_kernels(&dk)) {
printf("Compilation failed!\n"); return 1;
}
double compile_ms = tb_ms(mach_absolute_time() - tc);
printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n", compile_ms, NLAYERS);
printf("Compiled 10 kernels in %.0fms (shared across all %d layers)\n", compile_ms, NLAYERS);
// Allocate per-layer IOSurfaces + requests (pre-stage weights)
int per_layer_bytes = (DIM*(SEQ+4*DIM) + DIM*(2*SEQ+3*HIDDEN) +
DIM*(SEQ+HIDDEN) + HIDDEN*(2*SEQ+2*DIM) + DIM*(SEQ+DIM) + DIM*(3*SEQ+3*DIM)) * 2;
int total_surf_mb = (int)((long)per_layer_bytes * NLAYERS / (1024*1024));
printf("Allocating per-layer IOSurfaces (%d surfaces, ~%dMB fp16)...\n", NLAYERS*6, total_surf_mb);
// Allocate per-layer IOSurfaces + requests
printf("Allocating per-layer IOSurfaces...\n");
PerLayerSurfaces pls[NLAYERS];
PerLayerRequests plr[NLAYERS];
for (int L = 0; L < NLAYERS; L++) {
pls[L].sdpaFwd_in = make_surface(DIM*(SEQ+4*DIM)*2);
pls[L].ffnFused_in = make_surface(DIM*(2*SEQ+3*HIDDEN)*2);
pls[L].ffnBwdW2t_in = make_surface(DIM*(SEQ+HIDDEN)*2);
pls[L].ffnBwdW13t_in= make_surface(HIDDEN*(2*SEQ+2*DIM)*2);
pls[L].wotBwd_in = make_surface(DIM*(SEQ+DIM)*2);
pls[L].qkvBwd_in = make_surface(DIM*(3*SEQ+3*DIM)*2);
pls[L].sdpaFwd_in = make_surface(DIM*SDPA_FWD_SP*2);
pls[L].woFwd_in = make_surface(Q_DIM*WO_FWD_SP*2);
pls[L].ffnFused_in = make_surface(DIM*FFN_FUSED_SP*2);
pls[L].ffnBwdW2t_in = make_surface(DIM*FFN_BWD_W2T_SP*2);
pls[L].ffnBwdW13t_in = make_surface(HIDDEN*FFN_BWD_W13T_SP*2);
pls[L].wotBwd_in = make_surface(DIM*WOT_BWD_SP*2);
pls[L].qBwd_in = make_surface(Q_DIM*Q_BWD_SP*2);
pls[L].kvBwd_in = make_surface(KV_DIM*KV_BWD_SP*2);
plr[L].sdpaFwd = make_request(dk.sdpaFwd, pls[L].sdpaFwd_in);
plr[L].woFwd = make_request(dk.woFwd, pls[L].woFwd_in);
plr[L].ffnFused = make_request(dk.ffnFused, pls[L].ffnFused_in);
plr[L].ffnBwdW2t = make_request(dk.ffnBwdW2t, pls[L].ffnBwdW2t_in);
plr[L].ffnBwdW13t= make_request(dk.ffnBwdW13t,pls[L].ffnBwdW13t_in);
plr[L].wotBwd = make_request(dk.wotBwd, pls[L].wotBwd_in);
plr[L].qkvBwd = make_request(dk.qkvBwd, pls[L].qkvBwd_in);
plr[L].qBwd = make_request(dk.qBwd, pls[L].qBwd_in);
plr[L].kvBwd = make_request(dk.kvBwd, pls[L].kvBwd_in);
}
// Stage weights into per-layer surfaces
for (int L = 0; L < NLAYERS; L++) {
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]);
stage_wo_fwd_weights(pls[L].woFwd_in, Wot_buf[L]);
stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2);
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, lw[L].W2);
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, lw[L].W1, lw[L].W3);
stage_wot_bwd_weights(pls[L].wotBwd_in, lw[L].Wo);
stage_qkv_bwd_weights(pls[L].qkvBwd_in, lw[L].Wq, lw[L].Wk, lw[L].Wv);
stage_q_bwd_weights(pls[L].qBwd_in, lw[L].Wq);
stage_kv_bwd_weights(pls[L].kvBwd_in, lw[L].Wk, lw[L].Wv);
}
printf("Per-layer weight staging complete\n\n");
// Gradient + work buffers
// Gradient + work buffers (GQA: Q has Q_DIM, K/V have KV_DIM)
float *dy = (float*)malloc(SEQ*DIM*4);
float *dffn = (float*)malloc(SEQ*DIM*4);
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
float *dx2 = (float*)malloc(SEQ*DIM*4);
float *dx_attn = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*DIM*4);
float *dk_buf = (float*)malloc(SEQ*DIM*4);
float *dv = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*Q_DIM*4); // Q_DIM for Q grads
float *dk_buf = (float*)malloc(SEQ*KV_DIM*4); // KV_DIM for K grads
float *dv = (float*)malloc(SEQ*KV_DIM*4); // KV_DIM for V grads
float *da_buf = (float*)malloc(SEQ*Q_DIM*4); // Q_DIM for attn grads
float *x_cur = (float*)malloc(SEQ*DIM*4);
float *x_final = (float*)malloc(SEQ*DIM*4);
float *xnorm_buf = (float*)malloc(SEQ*DIM*4);
@ -395,6 +373,12 @@ int main(int argc, char *argv[]) {
float *dsilu = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp2 = (float*)malloc(SEQ*HIDDEN*4);
// GQA tile/reduce buffers
float *k_tiled = (float*)malloc(SEQ*Q_DIM*4); // KV_DIM Q_DIM
float *v_tiled = (float*)malloc(SEQ*Q_DIM*4);
float *dq_full = (float*)malloc(SEQ*Q_DIM*4); // from sdpaBwd2
float *dk_full = (float*)malloc(SEQ*Q_DIM*4); // from sdpaBwd2 (needs reduce)
float *dv_full = (float*)malloc(SEQ*Q_DIM*4); // from sdpaBwd1 (needs reduce)
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
dispatch_group_t dw_grp = dispatch_group_create();
@ -414,18 +398,15 @@ int main(int argc, char *argv[]) {
uint16_t *input_tokens = token_data + pos;
uint16_t *target_tokens_raw = token_data + pos + 1;
// Map targets to compact vocab IDs
uint16_t ctargets[SEQ];
for (int t = 0; t < SEQ; t++) ctargets[t] = (uint16_t)vm.full_to_compact[target_tokens_raw[t]];
// Embedding lookup (uses full embed for now input tokens are full IDs)
embed_lookup(x_cur, embed, input_tokens, DIM, SEQ);
// Timing accumulators (reset each step)
double t_rms=0, t_ane_fwd=0, t_io_fwd=0, t_cblas_wait=0;
double t_ane_bwd=0, t_io_bwd=0, t_silu=0, t_rms_bwd=0, t_cls=0, t_dw_copy=0;
// ===== FORWARD (12 layers) =====
// ===== FORWARD (28 layers) =====
for (int L=0; L<NLAYERS; L++) {
LayerActs *ac = &acts[L];
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
@ -441,7 +422,7 @@ int main(int argc, char *argv[]) {
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t_cblas_wait += tb_ms(mach_absolute_time() - t0);
// SDPA forward (ANE): xnorm + pre-staged Wq,Wk,Wv,Wo o_out,Q,K,V,attn_out,xnorm
// SDPA forward (ANE): xnorm + Wq,Wk,Wv attn_out[Q_DIM], Q_rope[Q_DIM], K_rope[KV_DIM], V[KV_DIM], xnorm[DIM]
t0 = mach_absolute_time();
write_sdpa_fwd_acts(pls[L].sdpaFwd_in, xnorm_buf);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
@ -449,28 +430,37 @@ int main(int argc, char *argv[]) {
ane_eval_req(dk.sdpaFwd, plr[L].sdpaFwd);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read output: [1, 6*DIM, 1, SEQ] fp16
// Read SDPA output: [1, Q_DIM+Q_DIM+KV_DIM+KV_DIM+DIM, 1, SEQ] fp16
t0 = mach_absolute_time();
IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
_Float16 *fwd_out = (_Float16*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut);
cvt_f16_f32(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ);
cvt_f16_f32(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ);
cvt_f16_f32(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ);
cvt_f16_f32(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ);
cvt_f16_f32(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ);
int off = 0;
cvt_f16_f32(ac->attn_out, fwd_out + off, Q_DIM*SEQ); off += Q_DIM*SEQ;
cvt_f16_f32(ac->Q, fwd_out + off, Q_DIM*SEQ); off += Q_DIM*SEQ;
cvt_f16_f32(ac->K, fwd_out + off, KV_DIM*SEQ); off += KV_DIM*SEQ;
cvt_f16_f32(ac->V, fwd_out + off, KV_DIM*SEQ); off += KV_DIM*SEQ;
// xnorm passthrough (DIM*SEQ) not needed, already saved
IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// CPU: scaled residual + RMSNorm (ANE can't fuse RMS with 3 matmuls)
// Wo forward (ANE): attn_out[Q_DIM] o_out[DIM]
t0 = mach_absolute_time();
write_wo_fwd_acts(pls[L].woFwd_in, ac->attn_out);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval_req(dk.woFwd, plr[L].woFwd);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.woFwd->ioOut, ac->o_out, DIM, SEQ);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// CPU: scaled residual + RMSNorm
t0 = mach_absolute_time();
// x2 = x_cur + alpha * o_out (residual scaling keeps activations bounded)
vDSP_vsma(ac->o_out, 1, &res_alpha, x_cur, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
rmsnorm(ac->x2norm, ac->x2, lw[L].rms_ffn, DIM, SEQ);
t_rms += tb_ms(mach_absolute_time() - t0);
// Fused FFN (ANE): W1,W3 + SiLU + W2 + residual
// Input: x2norm + x2 (acts), W1t + W3t + W2 (pre-staged weights)
// Output: x_next, h1, h3, silu_out
// Fused FFN (ANE)
t0 = mach_absolute_time();
write_ffn_fused_acts(pls[L].ffnFused_in, ac->x2norm, ac->x2);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
@ -478,21 +468,17 @@ int main(int argc, char *argv[]) {
ane_eval_req(dk.ffnFused, plr[L].ffnFused);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read fused output: [1, DIM+3*HIDDEN, 1, SEQ] fp16
// Layout: x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN]
// Read fused output: [1, DIM+3*HIDDEN, 1, SEQ]
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL);
_Float16 *ffn_out = (_Float16*)IOSurfaceGetBaseAddress(dk.ffnFused->ioOut);
int off = 0;
off = 0;
cvt_f16_f32(x_cur, ffn_out + off, DIM*SEQ); off += DIM*SEQ;
cvt_f16_f32(ac->h1, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ;
cvt_f16_f32(ac->h3, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ;
cvt_f16_f32(ac->silu_out,ffn_out + off, HIDDEN*SEQ);
IOSurfaceUnlock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// (act_clip removed was causing gradient explosion without backward,
// vanishing gradients with backward. RMSNorm keeps activations bounded.)
}
// Final RMSNorm + classifier + loss (CPU)
@ -500,7 +486,6 @@ int main(int argc, char *argv[]) {
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t_rms += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
// Classifier: logits[CV, SEQ] = cembed[CV, DIM] @ x_final[DIM, SEQ]
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
CV, SEQ, DIM, 1.0f, cembed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
float loss = cross_entropy_loss(dlogits, logits, ctargets, CV, SEQ);
@ -508,17 +493,15 @@ int main(int argc, char *argv[]) {
last_loss = loss;
// ===== BACKWARD =====
// Loss scaling: scale dlogits to prevent fp16 underflow in ANE backward kernels
// All gradients flow scaled; weight grads divided by loss_scale before Adam
vDSP_vsmul(dlogits, 1, &loss_scale, dlogits, 1, (vDSP_Length)(SEQ*CV));
// Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ]
// Classifier backward
t0 = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, CV, 1.0f, cembed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
t_cls += tb_ms(mach_absolute_time() - t0);
// dEmbed async: gcembed[CV, DIM] += dlogits[CV, SEQ] @ x_final^T[SEQ, DIM]
// dEmbed async
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
CV, DIM, SEQ, 1.0f, dlogits, SEQ, x_final, SEQ, 1.0f, gcembed, DIM);
@ -530,15 +513,15 @@ int main(int argc, char *argv[]) {
memcpy(dy, dx_rms_final, SEQ*DIM*4);
free(dx_rms_final);
// ===== BACKWARD (12 layers, reverse) =====
// ===== BACKWARD (28 layers, reverse) =====
for (int L=NLAYERS-1; L>=0; L--) {
LayerActs *ac = &acts[L];
LayerGrads *gr = &grads[L];
// dffn = alpha * dy (gradient into FFN branch scaled by residual alpha)
// dffn = alpha * dy
vDSP_vsmul(dy, 1, &res_alpha, dffn, 1, (vDSP_Length)(SEQ*DIM));
// FFN backward: dffn @ pre-staged W2^T dsilu_raw
// FFN backward: dffn @ W2^T dsilu_raw
t0 = mach_absolute_time();
write_ffn_bwd_w2t_acts(pls[L].ffnBwdW2t_in, dffn);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
@ -549,34 +532,28 @@ int main(int argc, char *argv[]) {
io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// SiLU derivative (vectorized): dsilu dh1, dh3
// silu(h1) = h1*sig(h1), dsilu_dh1 = sig*(1+h1*(1-sig))
// dh1 = dsilu * h3 * dsilu_dh1, dh3 = dsilu * silu(h1)
// SiLU derivative (vectorized)
t0 = mach_absolute_time();
{
int n = HIDDEN*SEQ;
// sig = 1/(1+exp(-h1))
float minus1 = -1.0f, one = 1.0f;
vDSP_vsmul(ac->h1, 1, &minus1, silu_tmp, 1, (vDSP_Length)n);
vvexpf(silu_tmp, silu_tmp, &n);
vDSP_vsadd(silu_tmp, 1, &one, silu_tmp, 1, (vDSP_Length)n);
vvrecf(silu_tmp, silu_tmp, &n); // silu_tmp = sig
// dh3 = dsilu * h1 * sig (= dsilu * silu(h1))
vvrecf(silu_tmp, silu_tmp, &n); // sig
vDSP_vmul(ac->h1, 1, silu_tmp, 1, dh3, 1, (vDSP_Length)n);
vDSP_vmul(dsilu, 1, dh3, 1, dh3, 1, (vDSP_Length)n);
// dsilu_dh1 = sig*(1+h1*(1-sig)), store in silu_tmp2
vDSP_vsadd(silu_tmp, 1, &minus1, silu_tmp2, 1, (vDSP_Length)n); // sig-1
vDSP_vneg(silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // 1-sig
vDSP_vmul(ac->h1, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // h1*(1-sig)
vDSP_vsadd(silu_tmp2, 1, &one, silu_tmp2, 1, (vDSP_Length)n); // 1+h1*(1-sig)
vDSP_vmul(silu_tmp, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // full dsilu_dh1
// dh1 = dsilu * h3 * dsilu_dh1
vDSP_vsadd(silu_tmp, 1, &minus1, silu_tmp2, 1, (vDSP_Length)n);
vDSP_vneg(silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n);
vDSP_vmul(ac->h1, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n);
vDSP_vsadd(silu_tmp2, 1, &one, silu_tmp2, 1, (vDSP_Length)n);
vDSP_vmul(silu_tmp, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n);
vDSP_vmul(dsilu, 1, ac->h3, 1, dh1, 1, (vDSP_Length)n);
vDSP_vmul(dh1, 1, silu_tmp2, 1, dh1, 1, (vDSP_Length)n);
}
t_silu += tb_ms(mach_absolute_time() - t0);
// dh1@W1^T + dh3@W3^T dx_ffn (ANE, pre-staged weights)
// dh1@W1^T + dh3@W3^T dx_ffn (ANE)
t0 = mach_absolute_time();
write_ffn_bwd_w13t_acts(pls[L].ffnBwdW13t_in, dh1, dh3);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
@ -587,7 +564,7 @@ int main(int argc, char *argv[]) {
io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dW FFN async (cblas)
// dW FFN async
t0 = mach_absolute_time();
float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4);
float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4);
@ -612,8 +589,7 @@ int main(int argc, char *argv[]) {
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
// Wo^T backward (ANE): alpha*dx2 @ pre-staged Wo^T da
// Scale dx2 by alpha for the attention branch (residual scaling backward)
// Wo^T backward (ANE): alpha*dx2 @ Wo da[Q_DIM]
float *dx2_scaled = (float*)malloc(SEQ*DIM*4);
vDSP_vsmul(dx2, 1, &res_alpha, dx2_scaled, 1, (vDSP_Length)(SEQ*DIM));
t0 = mach_absolute_time();
@ -623,105 +599,120 @@ int main(int argc, char *argv[]) {
ane_eval_req(dk.wotBwd, plr[L].wotBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
float *da_buf = (float*)malloc(SEQ*DIM*4);
io_read_dyn(dk.wotBwd->ioOut, da_buf, DIM, SEQ);
io_read_dyn(dk.wotBwd->ioOut, da_buf, Q_DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dWo async (uses alpha-scaled dx2)
// dWo async: gr->Wo[DIM,Q_DIM] += dx2_scaled[DIM,SEQ] @ attn_out^T[SEQ,Q_DIM]
t0 = mach_absolute_time();
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2_scaled, SEQ*DIM*4);
free(dx2_scaled);
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
float *capt_attn = (float*)malloc(SEQ*Q_DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*Q_DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, Q_DIM, SEQ,
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, Q_DIM);
free(capt_do); free(capt_attn);
});
if (L == 0 && step % 10 == 0) {
float damx, dx2mx, dx2mean;
vDSP_maxmgv(da_buf, 1, &damx, (vDSP_Length)(SEQ*DIM));
vDSP_maxmgv(dx2, 1, &dx2mx, (vDSP_Length)(SEQ*DIM));
vDSP_meamgv(dx2, 1, &dx2mean, (vDSP_Length)(SEQ*DIM));
// Count how many dx2 values survive fp16 conversion
int nz = 0;
for (int i=0; i<SEQ*DIM && i<1000; i++) {
_Float16 h = (_Float16)dx2[i];
if (h != 0) nz++;
}
printf(" L0 wot_bwd: |da|=%.2e |dx2| max=%.2e mean=%.2e fp16_nz=%d/1000\n", damx, dx2mx, dx2mean, nz);
}
// SDPA backward part 1 (ANE, fp16): Q,K,V,da dV,probs,dp
// GQA: tile K,V from KV_DIM Q_DIM for SDPA backward
t0 = mach_absolute_time();
io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, DIM, ac->K, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 2*DIM, ac->V, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 3*DIM, da_buf, DIM, SEQ);
free(da_buf);
gqa_tile_kv(k_tiled, ac->K, SEQ);
gqa_tile_kv(v_tiled, ac->V, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// SDPA backward part 1: Q[Q_DIM],K_tiled[Q_DIM],V_tiled[Q_DIM],da[Q_DIM] dV_full[Q_DIM],probs,dp
t0 = mach_absolute_time();
io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, Q_DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, Q_DIM, k_tiled, Q_DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 2*Q_DIM, v_tiled, Q_DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 3*Q_DIM, da_buf, Q_DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd1);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
// SDPA backward part 2: probs,dp,Q,K dQ,dK
// SDPA backward part 2: probs,dp,Q[Q_DIM],K_tiled[Q_DIM] dQ[Q_DIM],dK_full[Q_DIM]
t0 = mach_absolute_time();
io_copy(dk.sdpaBwd2->ioIn, 0, dk.sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH+DIM, ac->K, DIM, SEQ);
io_copy(dk.sdpaBwd2->ioIn, 0, dk.sdpaBwd1->ioOut, Q_DIM, 2*SCORE_CH, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH, ac->Q, Q_DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH+Q_DIM, k_tiled, Q_DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd2);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
// Read SDPA backward outputs
t0 = mach_absolute_time();
io_read_fp16(dk.sdpaBwd2->ioOut, dq, 0, DIM, SEQ);
io_read_fp16(dk.sdpaBwd2->ioOut, dk_buf, DIM, DIM, SEQ);
io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
io_read_fp16(dk.sdpaBwd2->ioOut, dq_full, 0, Q_DIM, SEQ); // dQ at full HEADS
io_read_fp16(dk.sdpaBwd2->ioOut, dk_full, Q_DIM, Q_DIM, SEQ); // dK at full HEADS
io_read_fp16(dk.sdpaBwd1->ioOut, dv_full, 0, Q_DIM, SEQ); // dV at full HEADS
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// RoPE backward: dq, dk are grads w.r.t. Q_rope, K_rope
// Inverse rotation to get grads w.r.t. pre-RoPE Q, K
rope_backward_inplace(dq, SEQ, DIM, HD);
rope_backward_inplace(dk_buf, SEQ, DIM, HD);
// GQA: reduce dK, dV from Q_DIM (HEADS) KV_DIM (KV_HEADS)
gqa_reduce_kv(dk_buf, dk_full, SEQ);
gqa_reduce_kv(dv, dv_full, SEQ);
// dQ stays at Q_DIM no reduction needed
memcpy(dq, dq_full, SEQ*Q_DIM*4);
// RoPE backward on dQ[Q_DIM] and dK[KV_DIM]
rope_backward_inplace(dq, SEQ, Q_DIM, HD);
rope_backward_inplace(dk_buf, SEQ, KV_DIM, HD);
// Debug: check SDPA backward output magnitudes
if (L == 0 && step % 10 == 0) {
float dqmx, dkmx, dvmx;
vDSP_maxmgv(dq, 1, &dqmx, (vDSP_Length)(SEQ*DIM));
vDSP_maxmgv(dk_buf, 1, &dkmx, (vDSP_Length)(SEQ*DIM));
vDSP_maxmgv(dv, 1, &dvmx, (vDSP_Length)(SEQ*DIM));
vDSP_maxmgv(dq, 1, &dqmx, (vDSP_Length)(SEQ*Q_DIM));
vDSP_maxmgv(dk_buf, 1, &dkmx, (vDSP_Length)(SEQ*KV_DIM));
vDSP_maxmgv(dv, 1, &dvmx, (vDSP_Length)(SEQ*KV_DIM));
printf(" L0 sdpa_bwd: |dq|=%.6f |dk|=%.6f |dv|=%.6f\n", dqmx, dkmx, dvmx);
}
// dWq/dWk/dWv async
// dWq[Q_DIM,DIM] += dq[Q_DIM,SEQ] @ xnorm^T[SEQ,DIM]
// dWk[KV_DIM,DIM] += dk[KV_DIM,SEQ] @ xnorm^T[SEQ,DIM]
// dWv[KV_DIM,DIM] += dv[KV_DIM,SEQ] @ xnorm^T[SEQ,DIM]
t0 = mach_absolute_time();
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk_buf, SEQ*DIM*4);
float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4);
float *capt_dq = (float*)malloc(SEQ*Q_DIM*4); memcpy(capt_dq, dq, SEQ*Q_DIM*4);
float *capt_dk = (float*)malloc(SEQ*KV_DIM*4); memcpy(capt_dk, dk_buf, SEQ*KV_DIM*4);
float *capt_dv = (float*)malloc(SEQ*KV_DIM*4); memcpy(capt_dv, dv, SEQ*KV_DIM*4);
float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Q_DIM, DIM, SEQ,
1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, KV_DIM, DIM, SEQ,
1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, KV_DIM, DIM, SEQ,
1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM);
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
});
// QKV backward (ANE): dq,dk,dv @ pre-staged Wq^T,Wk^T,Wv^T dx_attn
// Q backward (ANE): dq[Q_DIM] @ Wq dx_q[DIM]
t0 = mach_absolute_time();
write_qkv_bwd_acts(pls[L].qkvBwd_in, dq, dk_buf, dv);
write_q_bwd_acts(pls[L].qBwd_in, dq);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval_req(dk.qkvBwd, plr[L].qkvBwd);
ane_eval_req(dk.qBwd, plr[L].qBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ);
io_read_dyn(dk.qBwd->ioOut, dx_attn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// KV backward (ANE): dk[KV_DIM]@Wk + dv[KV_DIM]@Wv dx_kv[DIM]
float *dx_kv = (float*)malloc(SEQ*DIM*4);
t0 = mach_absolute_time();
write_kv_bwd_acts(pls[L].kvBwd_in, dk_buf, dv);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval_req(dk.kvBwd, plr[L].kvBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.kvBwd->ioOut, dx_kv, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dx_attn = dx_q + dx_kv
for(int i=0; i<SEQ*DIM; i++) dx_attn[i] += dx_kv[i];
free(dx_kv);
// RMSNorm1 backward
t0 = mach_absolute_time();
float *dx_rms1 = (float*)calloc(SEQ*DIM, 4);
@ -758,17 +749,19 @@ int main(int argc, char *argv[]) {
float gsc = 1.0f / (accum_steps * loss_scale);
adam_t++;
// Scale gradients by 1/(accum_steps * loss_scale)
// Scale gradients
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
for(size_t i=0;i<WQ_SZ;i++) g->Wq[i]*=gsc;
for(size_t i=0;i<WK_SZ;i++) g->Wk[i]*=gsc;
for(size_t i=0;i<WV_SZ;i++) g->Wv[i]*=gsc;
for(size_t i=0;i<WO_SZ;i++) g->Wo[i]*=gsc;
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
}
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
// Merge compact classifier grads into full embed grads
vocab_scatter_grads(gembed, gcembed, &vm, DIM);
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
@ -778,8 +771,8 @@ int main(int argc, char *argv[]) {
LayerGrads *g = &grads[L];
float s;
vDSP_dotpr(g->Wq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WK_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WV_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); grad_norm_sq+=s;
@ -793,13 +786,12 @@ int main(int argc, char *argv[]) {
}
float grad_norm = sqrtf(grad_norm_sq);
if ((step+1) % 10 == 0) {
// Per-component gradient norms for diagnostics
float attn_sq=0, ffn_sq=0, embed_sq=0;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L]; float s;
vDSP_dotpr(g->Wq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WK_SZ); attn_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WV_SZ); attn_sq+=s;
vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); attn_sq+=s;
vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); ffn_sq+=s;
vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); ffn_sq+=s;
@ -818,8 +810,8 @@ int main(int argc, char *argv[]) {
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
vDSP_vsmul(g->Wq,1,&clip_scale,g->Wq,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wk,1,&clip_scale,g->Wk,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wv,1,&clip_scale,g->Wv,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wk,1,&clip_scale,g->Wk,1,(vDSP_Length)WK_SZ);
vDSP_vsmul(g->Wv,1,&clip_scale,g->Wv,1,(vDSP_Length)WV_SZ);
vDSP_vsmul(g->Wo,1,&clip_scale,g->Wo,1,(vDSP_Length)WO_SZ);
vDSP_vsmul(g->W1,1,&clip_scale,g->W1,1,(vDSP_Length)W1_SZ);
vDSP_vsmul(g->W2,1,&clip_scale,g->W2,1,(vDSP_Length)W2_SZ);
@ -854,25 +846,26 @@ int main(int argc, char *argv[]) {
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f);
// Update transposed weight buffers
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
transpose_weight(Wqt_buf[L], lw[L].Wq, Q_DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, KV_DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, KV_DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, Q_DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
// Re-stage weights into per-layer IOSurfaces
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
// Re-stage weights
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]);
stage_wo_fwd_weights(pls[L].woFwd_in, Wot_buf[L]);
stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2);
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, lw[L].W2);
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, lw[L].W1, lw[L].W3);
stage_wot_bwd_weights(pls[L].wotBwd_in, lw[L].Wo);
stage_qkv_bwd_weights(pls[L].qkvBwd_in, lw[L].Wq, lw[L].Wk, lw[L].Wv);
stage_q_bwd_weights(pls[L].qBwd_in, lw[L].Wq);
stage_kv_bwd_weights(pls[L].kvBwd_in, lw[L].Wk, lw[L].Wv);
}
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f);
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
// Re-extract compact embed from updated full embed
free(cembed);
cembed = vocab_compact_embed(embed, &vm, DIM);
@ -908,9 +901,13 @@ int main(int argc, char *argv[]) {
free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]);
}
free_per_layer(pls, plr);
free_kern(dk.sdpaFwd); free_kern(dk.ffnFused);
free_kern(dk.sdpaFwd); free_kern(dk.woFwd); free_kern(dk.ffnFused);
free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd);
free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd);
free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2);
free_kern(dk.qBwd); free_kern(dk.kvBwd);
free(da_buf); free(k_tiled); free(v_tiled);
free(dq_full); free(dk_full); free(dv_full);
free(dq); free(dk_buf); free(dv);
munmap(token_data, data_len); close(data_fd);
}
return 0;