Merge branch 'main' into docs/fix-readme-outdated-info

This commit is contained in:
Ömer 2026-03-15 20:38:37 +03:00 committed by GitHub
commit 2206c55bd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1739 additions and 842 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
assets/models/*.bin filter=lfs diff=lfs merge=lfs -text
assets/data/*.bin filter=lfs diff=lfs merge=lfs -text

114
README.md
View File

@ -36,6 +36,7 @@ Some coverage of this project has overstated its implications. To be clear:
The honest results — including all limitations — are documented in the accompanying articles:
- [Part 1: Reverse Engineering](https://maderix.substack.com/p/inside-the-m4-apple-neural-engine)
- [Part 2: Benchmarks](https://maderix.substack.com/p/inside-the-m4-apple-neural-engine-615)
- [Part 3: Training](https://maderix.substack.com/p/inside-the-m4-apple-neural-engine-c8b)
### On Maintenance
@ -55,29 +56,46 @@ This is MIT licensed for a reason. Everyone now has access to AI-assisted develo
## What This Is
A from-scratch implementation of transformer training (forward + backward pass) running on the ANE in Apple Silicon. The ANE is a 15.8 TFLOPS (M4) inference accelerator that Apple does not expose for training. This project reverse-engineers the `_ANEClient` / `_ANECompiler` private APIs and the MIL (Model Intermediate Language) format to run custom compute graphs — including backpropagation — directly on ANE hardware.
A from-scratch implementation of transformer training (forward + backward pass) running on the ANE in Apple Silicon. The ANE is a 15.8 TFLOPS FP16 (M4) inference accelerator that Apple does not expose for training. This project reverse-engineers the `_ANEClient` / `_ANECompiler` private APIs and the MIL (Model Intermediate Language) format to run custom compute graphs — including backpropagation — directly on ANE hardware.
**Current results:**
| Model | Params | ms/step | Pipeline |
|-------|--------|---------|----------|
| Stories110M (12L, dim=768, MHA 12/12) | 109M | **91 ms** | Dynamic (no recompile) |
| Qwen3-0.6B (28L, dim=1024, GQA 16/8) | 596M | **412 ms** | Dynamic (no recompile) |
**Current results — Stories110M (12-layer, dim=768, seq=256, 109M params):**
- Static pipeline: **91 ms/step** (M3 Ultra), **106 ms/step** (M4)
- Dynamic pipeline: **110 ms/step**, no recompilation
- 72 ANE kernels per step (static), 9 shared kernels (dynamic)
- All forward and backward dx passes on ANE, dW gradients on CPU (Accelerate cblas)
- Adam optimizer, gradient accumulation, checkpoint/resume via exec() restart
- GQA (Grouped-Query Attention) support with per-head tiling/reduction
- GPU↔ANE zero-copy pipeline via shared IOSurface (GPU prefill → ANE decode)
**INT8 W8A8 quantization — 1.88x throughput (M4, H16G):**
| Config | FP16 | INT8 W8A8 | Speedup |
|--------|------|-----------|---------|
| 128x conv 512ch 64x64 | 18.6 TOPS, 14.8ms | 35.1 TOPS, 7.8ms | **1.88x** |
| 64x conv 512ch 64x64 | 18.4 TOPS, 7.5ms | 34.1 TOPS, 4.0ms | **1.85x** |
INT8 activations halve L2 SRAM bandwidth between tiles via MIL `quantize`/`dequantize` ops. Weights use `constexpr_affine_dequantize` (int8 stored, fp16 at compile time).
## Architecture
The training loop uses 6 ANE kernels per step:
The dynamic pipeline uses shared ANE kernels with weights packed into spatial dimensions (no recompilation when weights change):
| Kernel | Function | Weights |
|--------|----------|---------|
| `kFwdAttn` | RMSNorm + QKV projection + SDPA + output projection | Wq, Wk, Wv, Wo, rms1, mask |
| `kFwdFFN` | RMSNorm + SwiGLU FFN (W1, W3, SiLU, W2) | W1, W2, W3, rms2 |
| `kFFNBwd` | FFN backward (W2^T + SiLU_bwd + W1^T + W3^T) | W2^T, W1^T, W3^T |
| `kSdpaBwd1` | Wo^T + SDPA backward part 1 (dV, probs, dp) | Wo^T, mask |
| `kSdpaBwd2` | SDPA backward part 2 (softmax grad, dQ, dK) | — |
| `kQKVb` | QKV backward (Wq^T + Wk^T + Wv^T → dx) | Wq^T, Wk^T, Wv^T |
**MHA models (Stories110M) — 6 kernels per layer:**
CPU handles: RMSNorm backward, residual connections, loss computation, dW gradient accumulation (cblas_sgemm), Adam optimizer updates.
| Kernel | Function |
|--------|----------|
| `sdpaFwd` | QKV projection + SDPA + output projection |
| `ffnFused` | SwiGLU FFN (W1, W3, SiLU, W2) |
| `ffnBwdW2t` / `ffnBwdW13t` | FFN backward (split for memory) |
| `sdpaBwd1` / `sdpaBwd2` | SDPA backward |
**GQA models (Qwen3-0.6B) — 10 kernels per layer:**
Adds separate `woFwd`, `qBwd`, `kvBwd` kernels for grouped-query attention (Q_DIM ≠ DIM).
CPU handles: RMSNorm forward/backward, residual connections (DeepNet α scaling), loss computation, dW gradient accumulation (cblas_sgemm), Adam optimizer updates.
Key optimizations:
- **Channel-first CPU layout** — matches ANE IOSurface `[1,C,1,S]` format, eliminates all transpose overhead
@ -149,13 +167,24 @@ See [training/README.md](training/README.md) for detailed training instructions.
Requires macOS 15+ on Apple Silicon (tested on M4).
```bash
# Build the main training program
xcrun clang -O2 -framework Foundation -framework IOSurface \
-framework CoreML -framework Accelerate -ldl -lobjc \
-o train_large training/train_large.m
# Dynamic pipeline (recommended) — model selected at build time
cd training/training_dynamic
make MODEL=stories110m # Stories110M (12L, MHA, 109M params)
make MODEL=qwen3_06b # Qwen3-0.6B (28L, GQA, 596M params)
./train --scratch # train from random init
./train --resume # resume from checkpoint
# Run
./train_large
# Static pipeline (legacy — recompiles weights each step)
cd training && make train_large
./train_large ane_stories110M_ckpt.bin 256 100 1e-4
# INT8 benchmark
xcrun clang -O2 -fobjc-arc -framework Foundation -framework IOSurface -ldl \
-o ane_int8_bench ane_int8_bench.m
./ane_int8_bench
# Bridge library (C-callable ANE API)
cd bridge && make
```
No external dependencies. Uses only system frameworks + private ANE APIs resolved at runtime via `objc_msgSend`.
@ -164,28 +193,40 @@ No external dependencies. Uses only system frameworks + private ANE APIs resolve
1. **MIL generation** — Objective-C code constructs MIL program text at runtime, specifying convolutions (for linear layers), matmul (for attention), softmax, element-wise ops
2. **In-memory compilation**`_ANEInMemoryModelDescriptor` compiles MIL text + weight blobs directly to ANE programs, no disk mlmodelc needed
3. **IOSurface I/O** — Input/output tensors passed via IOSurface shared memory in `[1, channels, 1, spatial]` format (fp16)
4. **Weight embedding** — Weights baked into ANE programs as BLOBFILE constants; recompiled each batch when weights change
3. **IOSurface I/O** — Input/output tensors passed via IOSurface shared memory in `[1, channels, 1, spatial]` format (fp16 or fp32; fp16 direct I/O is ~37% faster)
4. **Dynamic weights** — Activations and weights packed into a single spatial input dimension, sliced apart inside the MIL kernel. Weights change without recompilation.
5. **Gradient flow** — Forward taps expose intermediates needed for backward; backward kernels compute dx (input gradients) on ANE; dW (weight gradients) computed on CPU via cblas
6. **INT8 quantization**`constexpr_affine_dequantize` for int8 weights, `quantize`/`dequantize` between layers for int8 activation caching in L2 SRAM (1.88x throughput)
## Limitations
- **SDPA causal masking** — ANE hardware ignores `attn_mask` in SDPA ops; causal attention is decomposed into separate Q@K^T (ANE) → mask+softmax (ANE via add+softmax) → scores@V (ANE)
- **SDPA causal masking** — ANE hardware ignores `attn_mask` in SDPA ops; causal attention is decomposed into separate Q@K^T (ANE) → mask+softmax (CPU) → scores@V (ANE)
- **~119 compile limit** — ANE compiler leaks resources; worked around via `exec()` restart with checkpoint
- **Compile overhead** — Static pipeline recompiles 60+ kernels every 10 steps (~3.7s); dynamic pipeline avoids this
- **Low utilization** — Training sustains ~1-2 TFLOPS out of 15.8+ peak due to CPU fallbacks and I/O overhead
- **FP16 gradient underflow** — backward matmuls underflow in fp16; fixed with global loss scaling (`256 * NLAYERS`)
- **Single-input constraint** — multi-input ANE requests cause 0x1d error; inputs packed into spatial dimension instead
## Performance History
## Performance
| Optimization | ms/step | ANE util |
|---|---|---|
| Baseline (vDSP transpose) | 33.5 | 3.1% |
| Channel-first layout | 20.3 | 5.2% |
| vDSP vectorized RMSNorm | 14.2 | 7.4% |
| GCD async cblas overlap | 11.4 | 9.2% |
| ANE RMSNorm fusion | 11.4 | 9.2% |
| Wo^T fusion (7→6 kernels) | 11.4 | 9.2% |
| Deferred cblas wait | **9.3** | **11.2%** |
**Training throughput (M4):**
| Model | Params | ms/step | Layers | Kernels/layer |
|-------|--------|---------|--------|---------------|
| Stories110M | 109M | 91 ms | 12 | 6 (MHA) |
| Qwen3-0.6B | 596M | 412 ms | 28 | 10 (GQA) |
**ANE peak throughput (M4, H16G):**
| Precision | Peak TOPS | Config |
|-----------|-----------|--------|
| FP16 | 18.6 | 128x conv 512ch 64x64 |
| INT8 W8A8 | 35.1 | 128x conv 512ch 64x64 |
**GPU↔ANE inference pipeline (M4, seq=256):**
| Model | GPU Prefill | ANE Decode | Total |
|-------|------------|------------|-------|
| Stories110M | 6.7ms | 1.9ms | 8.8ms |
| Qwen3-0.6B | 9.7ms | 2.3ms | 12.0ms |
## Disclaimer
@ -199,3 +240,4 @@ MIT — see [LICENSE](LICENSE)
*Built by a human + Claude, one weekend at a time.*

268
ane_int8_bench.m Normal file
View File

@ -0,0 +1,268 @@
// ane_int8_bench.m INT8 W8A8 benchmark on ANE via _ANEInMemoryModel
// Build: xcrun clang -O2 -fobjc-arc -framework Foundation -framework IOSurface -ldl -o ane_int8_bench ane_int8_bench.m
// Usage: ./ane_int8_bench
//
// Tests FP16 vs W8A8 (int8 weights + int8 activation caching) throughput.
// Key MIL ops: constexpr_affine_dequantize, quantize, dequantize
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <mach/mach_time.h>
#import <IOSurface/IOSurface.h>
static mach_timebase_info_data_t g_tb;
static double ticksToMs(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; }
// Weight blob for int8 weights (1 byte per element)
static NSData *buildWeightBlobInt8(int ch, int depth) {
NSUInteger wsize = ch * ch * 1;
NSUInteger chunkSize = 64 + wsize;
NSUInteger total = 64 + chunkSize * depth;
uint8_t *buf = calloc(total, 1);
buf[0] = 0x01; buf[4] = 0x02;
for (int i = 0; i < depth; i++) {
uint8_t *chunk = buf + 64 + i * chunkSize;
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE;
chunk[4]=0x01; chunk[10]=0x08;
int8_t *data = (int8_t*)(chunk + 64);
for (NSUInteger j = 0; j < wsize; j++) data[j] = (int8_t)(arc4random() % 256 - 128);
}
return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES];
}
// Weight blob for fp16 weights (2 bytes per element)
static NSData *buildWeightBlobFP16(int ch, int depth) {
NSUInteger wsize = ch * ch * 2;
NSUInteger chunkSize = 64 + wsize;
NSUInteger total = 64 + chunkSize * depth;
uint8_t *buf = calloc(total, 1);
buf[0] = 0x01; buf[4] = 0x02;
for (int i = 0; i < depth; i++) {
uint8_t *chunk = buf + 64 + i * chunkSize;
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE;
chunk[4]=0x01; chunk[10]=0x10;
_Float16 *data = (_Float16*)(chunk + 64);
for (NSUInteger j = 0; j < (NSUInteger)(ch*ch); j++) data[j] = (_Float16)(((float)(arc4random()%1000) - 500.0f) * 0.001f);
}
return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES];
}
// Generate W8A8 INT8 MIL: conv with int8 weights + quantize/dequantize between layers
static NSString *genMILInt8(int ch, int sp, int depth) {
NSMutableString *m = [NSMutableString string];
[m appendString:@"program(1.3)\n[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, {\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, {\"coremltools-version\", \"9.0\"}})]\n{\n"];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, %d, %d]> x) {\n", ch, sp, sp];
// Conv constants
[m appendString:@" string c_pad_type = const()[name = string(\"c_pad_type\"), val = string(\"valid\")];\n"
@" tensor<int32, [2]> c_strides = const()[name = string(\"c_strides\"), val = tensor<int32, [2]>([1, 1])];\n"
@" tensor<int32, [4]> c_pad = const()[name = string(\"c_pad\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
@" tensor<int32, [2]> c_dilations = const()[name = string(\"c_dilations\"), val = tensor<int32, [2]>([1, 1])];\n"
@" int32 c_groups = const()[name = string(\"c_groups\"), val = int32(1)];\n"];
// Quantize/dequantize scale
[m appendString:@" fp16 q_scale = const()[name = string(\"q_scale\"), val = fp16(0x1p-3)];\n"
@" string q_dtype = const()[name = string(\"q_dtype\"), val = string(\"int8\")];\n"
@" fp16 dq_scale = const()[name = string(\"dq_scale\"), val = fp16(0x1p-3)];\n"];
NSUInteger cs = 64 + ch * ch * 1; // int8 chunk size
NSString *prev = @"x";
for (int i = 0; i < depth; i++) {
// constexpr_affine_dequantize: int8 weights fp16 at compile time
[m appendFormat:
@" tensor<fp16, [%d, %d, 1, 1]> W%d = constexpr_affine_dequantize()"
@"[axis = int32(0), name = string(\"W%d\"), "
@"quantized_data = tensor<int8, [%d, %d, 1, 1]>"
@"(BLOBFILE(path = string(\"@model_path/weights/weight.bin\"), offset = uint64(%lu))), "
@"scale = fp16(0x1p-3), zero_point = int8(0)];\n",
ch, ch, i, i, ch, ch, (unsigned long)(64 + i * cs)];
// conv
NSString *conv_out = [NSString stringWithFormat:@"c%d", i];
[m appendFormat:@" tensor<fp16, [1, %d, %d, %d]> %@ = conv(dilations = c_dilations, groups = c_groups, pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = W%d, x = %@)[name = string(\"%@\")];\n",
ch, sp, sp, conv_out, i, prev, conv_out];
if (i < depth - 1) {
// quantize: fp16 int8
NSString *q_out = [NSString stringWithFormat:@"q%d", i];
[m appendFormat:@" tensor<int8, [1, %d, %d, %d]> %@ = quantize(input = %@, output_dtype = q_dtype, scale = q_scale)[name = string(\"%@\")];\n",
ch, sp, sp, q_out, conv_out, q_out];
// dequantize: int8 fp16
NSString *dq_out = [NSString stringWithFormat:@"dq%d", i];
[m appendFormat:@" tensor<fp16, [1, %d, %d, %d]> %@ = dequantize(input = %@, scale = dq_scale)[name = string(\"%@\")];\n",
ch, sp, sp, dq_out, q_out, dq_out];
prev = dq_out;
} else {
prev = conv_out;
}
}
[m appendFormat:@" } -> (%@);\n}\n", prev];
return m;
}
// Generate FP16 baseline MIL: pure fp16 conv chain
static NSString *genMILFP16(int ch, int sp, int depth) {
NSMutableString *m = [NSMutableString string];
[m appendString:@"program(1.3)\n[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, {\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, {\"coremltools-version\", \"9.0\"}})]\n{\n"];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, %d, %d]> x) {\n", ch, sp, sp];
[m appendString:@" string c_pad_type = const()[name = string(\"c_pad_type\"), val = string(\"valid\")];\n"
@" tensor<int32, [2]> c_strides = const()[name = string(\"c_strides\"), val = tensor<int32, [2]>([1, 1])];\n"
@" tensor<int32, [4]> c_pad = const()[name = string(\"c_pad\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
@" tensor<int32, [2]> c_dilations = const()[name = string(\"c_dilations\"), val = tensor<int32, [2]>([1, 1])];\n"
@" int32 c_groups = const()[name = string(\"c_groups\"), val = int32(1)];\n"];
NSUInteger cs = 64 + ch * ch * 2; // fp16 chunk size
NSString *prev = @"x";
for (int i = 0; i < depth; i++) {
// fp16 weights from blob
[m appendFormat:
@" tensor<fp16, [%d, %d, 1, 1]> W%d = const()"
@"[name = string(\"W%d\"), "
@"val = tensor<fp16, [%d, %d, 1, 1]>"
@"(BLOBFILE(path = string(\"@model_path/weights/weight.bin\"), offset = uint64(%lu)))];\n",
ch, ch, i, i, ch, ch, (unsigned long)(64 + i * cs)];
NSString *conv_out = [NSString stringWithFormat:@"c%d", i];
[m appendFormat:@" tensor<fp16, [1, %d, %d, %d]> %@ = conv(dilations = c_dilations, groups = c_groups, pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = W%d, x = %@)[name = string(\"%@\")];\n",
ch, sp, sp, conv_out, i, prev, conv_out];
prev = conv_out;
}
[m appendFormat:@" } -> (%@);\n}\n", prev];
return m;
}
static double benchModel(NSString *milStr, NSData *wb, int ch, int sp, const char *label) {
@autoreleasepool {
NSError *e = nil;
NSData *milData = [milStr dataUsingEncoding:NSUTF8StringEncoding];
Class D = NSClassFromString(@"_ANEInMemoryModelDescriptor");
Class I = NSClassFromString(@"_ANEInMemoryModel");
Class AR = NSClassFromString(@"_ANERequest");
Class AIO = NSClassFromString(@"_ANEIOSurfaceObject");
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(D,
@selector(modelWithMILText:weights:optionsPlist:), milData,
@{@"@model_path/weights/weight.bin": @{@"offset": @0, @"data": wb}}, nil);
if (!desc) { printf(" %s: desc FAIL\n", label); return -1; }
id mdl = ((id(*)(Class,SEL,id))objc_msgSend)(I, @selector(inMemoryModelWithDescriptor:), desc);
if (!mdl) { printf(" %s: mdl FAIL\n", label); return -2; }
id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier));
NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx];
NSFileManager *fm = [NSFileManager defaultManager];
[fm createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"]
withIntermediateDirectories:YES attributes:nil error:nil];
[milData writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES];
[wb writeToFile:[td stringByAppendingPathComponent:@"weights/weight.bin"] atomically:YES];
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(
mdl, @selector(compileWithQoS:options:error:), 0, @{}, &e)) {
printf(" %s: compile FAIL: %s\n", label, e ? [[e description] UTF8String] : "?");
[fm removeItemAtPath:td error:nil];
return -3;
}
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(
mdl, @selector(loadWithQoS:options:error:), 0, @{}, &e)) {
printf(" %s: load FAIL\n", label);
[fm removeItemAtPath:td error:nil];
return -4;
}
NSUInteger bytes = (NSUInteger)ch * sp * sp * 2; // fp16 I/O
IOSurfaceRef ioI = IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth: @(bytes), (id)kIOSurfaceHeight: @1,
(id)kIOSurfaceBytesPerElement: @1, (id)kIOSurfaceBytesPerRow: @(bytes),
(id)kIOSurfaceAllocSize: @(bytes), (id)kIOSurfacePixelFormat: @0});
IOSurfaceRef ioO = IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth: @(bytes), (id)kIOSurfaceHeight: @1,
(id)kIOSurfaceBytesPerElement: @1, (id)kIOSurfaceBytesPerRow: @(bytes),
(id)kIOSurfaceAllocSize: @(bytes), (id)kIOSurfacePixelFormat: @0});
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(AIO, @selector(objectWithIOSurface:), ioI);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(AIO, @selector(objectWithIOSurface:), ioO);
id req = ((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0);
// Warmup
for (int i = 0; i < 10; i++)
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
mdl, @selector(evaluateWithQoS:options:request:error:), 0, @{}, req, &e);
int iters = 50;
uint64_t t0 = mach_absolute_time();
for (int i = 0; i < iters; i++)
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
mdl, @selector(evaluateWithQoS:options:request:error:), 0, @{}, req, &e);
double ms = ticksToMs(mach_absolute_time() - t0) / iters;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(
mdl, @selector(unloadWithQoS:error:), 0, &e);
CFRelease(ioI); CFRelease(ioO);
[fm removeItemAtPath:td error:nil];
return ms;
}
}
int main(void) {
mach_timebase_info(&g_tb);
dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW);
// Query HW info
Class DI = NSClassFromString(@"_ANEDeviceInfo");
const char *ane_type = "unknown";
if (DI) {
id subType = ((id(*)(Class,SEL))objc_msgSend)(DI, @selector(aneSubType));
if (subType) ane_type = [[subType description] UTF8String];
}
printf("=== ANE INT8 W8A8 Benchmark (M4, %s) ===\n\n", ane_type);
printf("%-30s %7s %7s %9s %7s %7s\n", "Config", "W(MB)", "GOP", "ms/eval", "TOPS", "Ratio");
printf("--------------------------------------------------------------------------------\n");
typedef struct { int ch; int sp; int depth; } Config;
Config configs[] = {
{512, 64, 128},
{512, 64, 64},
{256, 64, 256},
{256, 64, 128},
{384, 64, 128},
};
int ncfg = sizeof(configs) / sizeof(configs[0]);
for (int ci = 0; ci < ncfg; ci++) {
int ch = configs[ci].ch, sp = configs[ci].sp, depth = configs[ci].depth;
double gop = 2.0 * ch * ch * sp * sp * depth / 1e9;
// FP16
double w_fp16 = (double)ch * ch * 2 * depth / 1024 / 1024;
NSString *milFP16 = genMILFP16(ch, sp, depth);
NSData *wbFP16 = buildWeightBlobFP16(ch, depth);
char lbl[64];
snprintf(lbl, 64, "FP16 %dx conv %dch", depth, ch);
double ms_fp16 = benchModel(milFP16, wbFP16, ch, sp, lbl);
// INT8 W8A8
double w_int8 = (double)ch * ch * 1 * depth / 1024 / 1024;
NSString *milInt8 = genMILInt8(ch, sp, depth);
NSData *wbInt8 = buildWeightBlobInt8(ch, depth);
snprintf(lbl, 64, "W8A8 %dx conv %dch", depth, ch);
double ms_int8 = benchModel(milInt8, wbInt8, ch, sp, lbl);
if (ms_fp16 > 0 && ms_int8 > 0) {
double tops_fp16 = gop / ms_fp16;
double tops_int8 = gop / ms_int8;
double ratio = ms_fp16 / ms_int8;
printf("FP16 %-25s %6.1f %6.2f %7.3f ms %6.2f\n",
[NSString stringWithFormat:@"%dx conv %dch %dx%d", depth, ch, sp, sp].UTF8String,
w_fp16, gop, ms_fp16, tops_fp16);
printf("W8A8 %-25s %6.1f %6.2f %7.3f ms %6.2f %.2fx\n",
[NSString stringWithFormat:@"%dx conv %dch %dx%d", depth, ch, sp, sp].UTF8String,
w_int8, gop, ms_int8, tops_int8, ratio);
printf("\n");
} else {
printf(" %dx conv %dch: FP16=%.1f INT8=%.1f (FAIL)\n", depth, ch, ms_fp16, ms_int8);
}
}
printf("=== Done ===\n");
return 0;
}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50a52ef822ee9e83de5ce9d0be0a025a773d019437f58b5ff9dcafb063ece361
size 433869

View File

@ -77,6 +77,19 @@ uint8_t *ane_bridge_build_weight_blob(const float *src, int rows, int cols,
uint8_t *ane_bridge_build_weight_blob_transposed(const float *src, int rows, int cols,
size_t *out_len);
// Build an int8 weight blob in ANE format (64-byte header + int8 data per chunk)
// src: int8 weights [rows x cols], scale: dequantization scale, zero_point: int8 zero
// For use with constexpr_affine_dequantize in MIL
// Returns allocated buffer and sets out_len. Caller must free().
uint8_t *ane_bridge_build_weight_blob_int8(const int8_t *src, int rows, int cols,
size_t *out_len);
// Quantize float32 weights to int8 and build ANE blob in one step
// Computes per-channel (axis=0) scale = max(abs(row)) / 127
// Returns allocated buffer, sets out_len and out_scale. Caller must free().
uint8_t *ane_bridge_build_weight_blob_quantized(const float *src, int rows, int cols,
float *out_scale, size_t *out_len);
// Free a blob allocated by ane_bridge_build_weight_blob*
void ane_bridge_free_blob(void *ptr);

View File

@ -93,7 +93,7 @@ ANEKernelHandle *ane_bridge_compile_multi_weights(
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(
g_ANEDesc, @selector(modelWithMILText:weights:optionsPlist:),
milData, wdict.count > 0 ? wdict : nil, nil);
milData, wdict.count > 0 ? wdict : @{}, nil);
if (!desc) {
fprintf(stderr, "ane_bridge: modelWithMILText failed\n");
return NULL;
@ -326,3 +326,46 @@ uint8_t *ane_bridge_build_weight_blob_transposed(const float *src, int rows, int
*out_len = total;
return buf;
}
uint8_t *ane_bridge_build_weight_blob_int8(const int8_t *src, int rows, int cols,
size_t *out_len) {
int wsize = rows * cols; // 1 byte per int8 element
int total = 64 + wsize; // 64-byte header + data
uint8_t *buf = (uint8_t *)calloc(total, 1);
// ANE int8 blob header
buf[0] = 0xEF; buf[1] = 0xBE; buf[2] = 0xAD; buf[3] = 0xDE;
buf[4] = 0x01;
buf[10] = 0x08; // 8-bit element marker
memcpy(buf + 64, src, wsize);
*out_len = total;
return buf;
}
uint8_t *ane_bridge_build_weight_blob_quantized(const float *src, int rows, int cols,
float *out_scale, size_t *out_len) {
// Find global max abs for symmetric quantization
float max_abs = 0.0f;
for (int i = 0; i < rows * cols; i++) {
float a = src[i] < 0 ? -src[i] : src[i];
if (a > max_abs) max_abs = a;
}
float scale = max_abs / 127.0f;
if (scale == 0.0f) scale = 1.0f;
// Quantize to int8
int wsize = rows * cols;
int8_t *qdata = (int8_t *)malloc(wsize);
for (int i = 0; i < wsize; i++) {
float v = src[i] / scale;
if (v > 127.0f) v = 127.0f;
if (v < -128.0f) v = -128.0f;
qdata[i] = (int8_t)(v + (v >= 0 ? 0.5f : -0.5f));
}
uint8_t *blob = ane_bridge_build_weight_blob_int8(qdata, rows, cols, out_len);
free(qdata);
*out_scale = scale;
return blob;
}

View File

@ -1,5 +1,5 @@
CC = xcrun clang
CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc
CFLAGS = -O2 -Wall -DACCELERATE_NEW_LAPACK -fobjc-arc
FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface
LDFLAGS = $(FRAMEWORKS) -ldl

View File

@ -1,14 +1,22 @@
# ANE Training — Stories110M on Apple Neural Engine
# ANE Training — On-Device Training on Apple Neural Engine
Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly on Apple's Neural Engine using private ANE APIs.
Training transformer models directly on Apple's Neural Engine using private ANE APIs. Supports multiple architectures including GQA (Grouped-Query Attention).
![Dashboard](dashboard.gif)
## Supported Models
| Model | Layers | Heads (Q/KV) | Dim | Hidden | Params | ms/step |
|-------|--------|--------------|-----|--------|--------|---------|
| Stories110M | 12 | 12/12 (MHA) | 768 | 2048 | 109M | ~115 |
| Qwen3-0.6B | 28 | 16/8 (GQA) | 1024 | 3072 | 596M | ~412 |
Model configs live in `training_dynamic/models/*.h`. To add a new model, create a header with the architecture defines (see below).
## Architecture
- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256
- **109.53M params** (84.95M transformer + 24.58M embedding)
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask — decompose into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)
- **GQA support**: K/V heads tiled to match Q heads for SDPA, reduced back after backward pass
## Three Training Pipelines
@ -27,10 +35,10 @@ Offloads classifier forward (32K conv), softmax, final RMSNorm, and RMSNorm back
- Use `--no-ane-extras` to disable and fall back to CPU (for debugging)
### 3. Dynamic Weight Pipeline (`training_dynamic/`)
Weights passed via IOSurface spatial dimension — compile 9 kernels once at startup, no recompilation needed.
Weights passed via IOSurface spatial dimension — compile 10 kernels once at startup, no recompilation needed. Supports multiple models via `make MODEL=xxx`.
- 9 shared kernels across all 12 layers
- **111 ms/step**, 0.4s one-time compile
- 10 shared kernels across all layers (GQA-aware: split sdpaFwd/woFwd, split qBwd/kvBwd)
- **~115 ms/step** (Stories110M) / **~412 ms/step** (Qwen3-0.6B), 0.4s one-time compile
- No exec() restart, no compile limit issues
## Performance Comparison (20 Steps)
@ -56,10 +64,11 @@ Weights passed via IOSurface spatial dimension — compile 9 kernels once at sta
|------|-------------|
| `train_large.m` | Static baseline — 72 kernels, classifier/softmax on CPU |
| `train_large_ane.m` | PR#19 — 86 kernels, classifier/softmax/rmsnorm_bwd on ANE |
| `training_dynamic/train.m` | Dynamic pipeline — 9 kernels, weights via IOSurface |
| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels |
| `training_dynamic/config.h` | Model config (DIM=768, HIDDEN=2048, etc.) |
| `training_dynamic/io.h` | IOSurface I/O + MIL compilation helpers |
| `training_dynamic/train.m` | Dynamic pipeline — 10 kernels, weights via IOSurface |
| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels (GQA-aware) |
| `training_dynamic/config.h` | Derived sizes, structs, alloc helpers (model-agnostic) |
| `training_dynamic/models/*.h` | Per-model configs (stories110m.h, qwen3_06b.h) |
| `training_dynamic/io.h` | IOSurface I/O, weight staging, GQA tile/reduce |
| `training_dynamic/cpu_ops.h` | CPU ops (SiLU backward, cross-entropy, Adam) |
| `stories_config.h` | Static pipeline config, structs, alloc helpers |
| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval |
@ -83,33 +92,35 @@ Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from HuggingFace. Pr
### 2. Build & Train
```bash
# Static baseline (classifier + softmax on CPU)
make train_large
./train_large stories110M.bin 256 100 1e-4
./train_large --model stories110M.bin --steps 100 --lr 1e-4
./train_large --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
make train_large_ane
./train_large_ane stories110M.bin 256 100 1e-4
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras
./train_large_ane --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# Static baseline (classifier + softmax on CPU)
make train_large
./train_large stories110M.bin 256 100 1e-4
./train_large --model stories110M.bin --steps 100 --lr 1e-4
./train_large --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# Dynamic pipeline (no recompilation)
cd training_dynamic && make train
# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
make train_large_ane
./train_large_ane stories110M.bin 256 100 1e-4
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras
./train_large_ane --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# Dynamic pipeline (model selected at build time)
cd training_dynamic
make MODEL=qwen3_06b # default — Qwen3-0.6B (28L, GQA, 596M)
make MODEL=stories110m # Stories110M (12L, MHA, 109M)
./train --scratch # train from random init
./train # resume from checkpoint
./train --resume # resume from checkpoint
./train --steps 200 --lr 1e-4 # custom steps/lr
```
**CLI flags (`train_large` / `train_large_ane`):**
**CLI flags (`train_large` / `train_large_ane`):**
- `--steps N` (default 10000)
- `--lr F` (default 3e-4)
- `--model PATH` — pretrained weights file
- `--data PATH` — tokenized TinyStories `.bin` file (default: `tinystories_data00.bin`)
- `--ckpt PATH` — checkpoint file (preserved across exec() restarts)
- `--resume` — resume from checkpoint
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd
- `--lr F` (default 3e-4)
- `--model PATH` — pretrained weights file
- `--data PATH` — tokenized TinyStories `.bin` file (default: `tinystories_data00.bin`)
- `--ckpt PATH` — checkpoint file (preserved across exec() restarts)
- `--resume` — resume from checkpoint
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd
### 3. Monitor with Dashboard
@ -133,11 +144,42 @@ Avg train: 91.8 ms/step
ANE TFLOPS: 1.15 sustained
```
## Adding a New Model
Create `training_dynamic/models/mymodel.h`:
```c
#pragma once
#define MODEL_NAME "MyModel-1B"
#define DIM 2048 // model hidden dim
#define HIDDEN 5504 // FFN intermediate dim
#define HEADS 32 // number of query heads
#define KV_HEADS 8 // number of KV heads (= HEADS for MHA)
#define HD 64 // head dim (can differ from DIM/HEADS)
#define SEQ 256 // sequence length
#define NLAYERS 22 // number of transformer layers
#define VOCAB 32000 // vocabulary size
#define CKPT_PATH "ane_mymodel_dyn_ckpt.bin"
#define DEFAULT_DATA_PATH "../tinystories_data00.bin"
```
Everything else is derived automatically: `GQA_RATIO`, `Q_DIM`, `KV_DIM`, weight sizes, IOSurface layouts, MIL kernels.
Build with: `make MODEL=mymodel`
**Constraints:**
- `HEADS` must be divisible by `KV_HEADS`
- `HD` is explicit (not necessarily `DIM/HEADS` — Qwen3 uses HD=128 with DIM/HEADS=64)
- For MHA (no GQA), set `KV_HEADS = HEADS`
## Key Techniques
- **NEON vectorized fp16↔fp32**: ARM NEON intrinsics for fast IOSurface data transfer
- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar
- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE
- **Vocab compaction** (dynamic): 32K → 9.2K active tokens, 3.5x reduction in classifier work
- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all 12 layers
- **Vocab compaction** (dynamic): 32K152K → 9.2K active tokens, up to 16.5x reduction in classifier work
- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all layers
- **GQA tile/reduce**: K/V tiled from KV_HEADS→HEADS on CPU before SDPA backward, gradients reduced HEADS→KV_HEADS after
- **exec() restart**: Workaround for ANE ~119 compile limit per process

View File

@ -18,16 +18,48 @@ try:
except ImportError:
HAS_PSUTIL = False
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12
HD = DIM // HEADS
CKPT_PATH_STATIC = 'ane_stories110M_ckpt.bin'
CKPT_PATH_DYNAMIC = 'training_dynamic/ane_stories110M_dyn_ckpt.bin'
CKPT_PATH = CKPT_PATH_STATIC # set in main() based on --dynamic
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent.parent / 'assets' / 'models' / 'tokenizer.bin')
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
# Model configs — set at startup based on --model flag
MODEL_CONFIGS = {
'stories110m': {
'dim': 768, 'hidden': 2048, 'heads': 12, 'kv_heads': 12,
'hd': 64, 'seq': 256, 'vocab': 32000, 'nlayers': 12,
'ckpt_static': 'ane_stories110M_ckpt.bin',
'ckpt_dynamic': 'training_dynamic/ane_stories110M_dyn_ckpt.bin',
},
'qwen3_06b': {
'dim': 1024, 'hidden': 3072, 'heads': 16, 'kv_heads': 8,
'hd': 128, 'seq': 256, 'vocab': 151936, 'nlayers': 28,
'ckpt_static': None,
'ckpt_dynamic': 'training_dynamic/ane_qwen3_06b_dyn_ckpt.bin',
},
}
# Active model dims — set in main()
DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 12, 64, 256, 32000, 12
Q_DIM, KV_DIM, GQA_RATIO = DIM, DIM, 1
CKPT_PATH = 'ane_stories110M_ckpt.bin'
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent / 'assets' / 'models' / 'tokenizer.bin')
def set_model_config(name):
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
global Q_DIM, KV_DIM, GQA_RATIO
cfg = MODEL_CONFIGS[name]
DIM, HIDDEN, HEADS, KV_HEADS = cfg['dim'], cfg['hidden'], cfg['heads'], cfg['kv_heads']
HD, SEQ, VOCAB, NLAYERS = cfg['hd'], cfg['seq'], cfg['vocab'], cfg['nlayers']
Q_DIM = HEADS * HD
KV_DIM = KV_HEADS * HD
GQA_RATIO = HEADS // KV_HEADS
class State:
def __init__(self):
self.active_model = 'stories110m'
self.model_config = {}
self.params = {}
self.kernels = {}
@ -62,6 +94,7 @@ class State:
self.train_start = None # wall clock when first step seen
self.compile_ms = 0.0 # total compile time
S = State()
@ -71,8 +104,12 @@ class Tokenizer:
self.scores = []
with open(path, 'rb') as f:
max_len = struct.unpack('i', f.read(4))[0]
for _ in range(VOCAB):
score = struct.unpack('f', f.read(4))[0]
# Read until EOF — works for any vocab size
while True:
data = f.read(4)
if len(data) < 4:
break
score = struct.unpack('f', data)[0]
slen = struct.unpack('i', f.read(4))[0]
tok = f.read(slen).decode('utf-8', errors='replace')
self.vocab.append(tok)
@ -104,33 +141,32 @@ def get_tokenizer():
def load_weights_from_ckpt(path):
try:
with open(path, 'rb') as f:
# CkptHdr: 96 bytes (verified with sizeof)
hdr = f.read(96)
if len(hdr) < 96:
return None
wq_sz = DIM * DIM
wo_sz = DIM * DIM
wq_sz = Q_DIM * DIM
wk_sz = KV_DIM * DIM
wv_sz = KV_DIM * DIM
wo_sz = DIM * Q_DIM
w1_sz = HIDDEN * DIM
w2_sz = DIM * HIDDEN
w3_sz = HIDDEN * DIM
# Per-layer: weights + adam state (m,v for each)
adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 +
adam_per_layer = (wq_sz*2 + wk_sz*2 + wv_sz*2 + wo_sz*2 +
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
W = {}
for L in range(NLAYERS):
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(Q_DIM, DIM).copy()
W[f'Wk{L}'] = np.frombuffer(f.read(wk_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
W[f'Wv{L}'] = np.frombuffer(f.read(wv_sz * 4), dtype=np.float32).reshape(KV_DIM, DIM).copy()
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, Q_DIM).copy()
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
# Skip adam state for this layer
f.seek(adam_per_layer * 4, 1)
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
f.seek(DIM * 2 * 4, 1) # skip rms_final adam
f.seek(DIM * 2 * 4, 1)
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
return W
except Exception as e:
@ -151,33 +187,39 @@ def generate_text(W, max_tokens=64, temperature=0.8):
tokenizer = get_tokenizer()
if tokenizer is None:
return '[no tokenizer]'
if len(tokenizer.vocab) < VOCAB:
return f'[tokenizer has {len(tokenizer.vocab)} tokens, model needs {VOCAB}]'
tokens = [1]
text_parts = []
# Precompute RoPE frequencies
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
for pos in range(SEQ):
for i in range(HD // 2):
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
freqs[pos, i] = pos * freq
# KV cache: per-layer, per KV head
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(KV_HEADS)] for _ in range(NLAYERS)]
res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS)
for step in range(max_tokens):
seq_len = len(tokens)
if seq_len > SEQ:
break
x = W['embed'][tokens[-1]].copy()
pos = seq_len - 1
for L in range(NLAYERS):
# RMSNorm + QKV
xn = rmsnorm(x, W[f'rms1_{L}'])
q = W[f'Wq{L}'] @ xn
k = W[f'Wk{L}'] @ xn
v = W[f'Wv{L}'] @ xn
q = W[f'Wq{L}'] @ xn # [Q_DIM]
k = W[f'Wk{L}'] @ xn # [KV_DIM]
v = W[f'Wv{L}'] @ xn # [KV_DIM]
# RoPE
pos = seq_len - 1
# RoPE on Q (HEADS heads) and K (KV_HEADS heads)
for h in range(HEADS):
for i in range(HD // 2):
freq = freqs[pos, i]
@ -185,31 +227,41 @@ def generate_text(W, max_tokens=64, temperature=0.8):
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
for h in range(KV_HEADS):
for i in range(HD // 2):
freq = freqs[pos, i]
cos_v, sin_v = math.cos(freq), math.sin(freq)
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
# Attention (single token)
o = np.zeros(DIM, dtype=np.float32)
# Append to KV cache (KV_HEADS entries)
for kv in range(KV_HEADS):
kh = k[kv * HD:(kv + 1) * HD].reshape(1, HD)
vh = v[kv * HD:(kv + 1) * HD].reshape(1, HD)
k_cache[L][kv] = np.vstack([k_cache[L][kv], kh])
v_cache[L][kv] = np.vstack([v_cache[L][kv], vh])
# GQA attention: each Q head uses its corresponding KV head
o = np.zeros(Q_DIM, dtype=np.float32)
for h in range(HEADS):
kv = h // GQA_RATIO
qh = q[h * HD:(h + 1) * HD]
kh = k[h * HD:(h + 1) * HD]
vh = v[h * HD:(h + 1) * HD]
score = np.dot(qh, kh) / math.sqrt(HD)
o[h * HD:(h + 1) * HD] = vh
scores = k_cache[L][kv] @ qh / math.sqrt(HD)
attn = softmax(scores)
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][kv]
# Residual + output projection
x2 = x + W[f'Wo{L}'] @ o
x2 = x + res_alpha * (W[f'Wo{L}'] @ o)
# FFN
x2n = rmsnorm(x2, W[f'rms2_{L}'])
h1 = W[f'W1_{L}'] @ x2n
h3 = W[f'W3_{L}'] @ x2n
# SiLU
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
ffn_out = W[f'W2_{L}'] @ h1
x = x2 + ffn_out
x = x2 + res_alpha * ffn_out
x = rmsnorm(x, W['rms_final'])
@ -220,8 +272,11 @@ def generate_text(W, max_tokens=64, temperature=0.8):
next_tok = int(np.argmax(logits))
else:
logits = logits / temperature
probs = softmax(logits)
next_tok = int(np.random.choice(VOCAB, p=probs))
top_k = 50
top_idx = np.argpartition(logits, -top_k)[-top_k:]
top_logits = logits[top_idx]
probs = softmax(top_logits)
next_tok = int(top_idx[np.random.choice(len(top_idx), p=probs)])
if next_tok == 2:
break
@ -281,6 +336,8 @@ def sysmetrics_thread():
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
RE_CONFIG_GQA = re.compile(r'dim=(\d+) q_dim=(\d+) kv_dim=(\d+) hd=(\d+) hidden=(\d+) seq=(\d+) vocab=(\d+)')
RE_MODEL_NAME = re.compile(r'ANE Dynamic Training: (.+?) \((\d+) layers')
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing')
@ -297,10 +354,67 @@ RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms')
RE_CKPT_SAVED = re.compile(r'\[ckpt saved, best_loss=([\d.]+)\]')
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
USE_WANDB = False
def wandb_log_step():
"""Log current state to wandb. Called after each step update."""
if not USE_WANDB:
return
d = {'step': S.step, 'loss': S.loss, 'best_loss': S.best_loss}
if S.ms_per_step > 0:
d['ms_per_step'] = S.ms_per_step
lr = S.training.get('lr')
if lr:
try:
d['lr'] = float(lr)
except ValueError:
pass
ct = S.component_timing
if ct:
for k, v in ct.items():
if k != '_dynamic':
d[f'timing/{k}'] = v
fl = S.flops
if fl.get('ane_tflops'):
d['perf/ane_tflops'] = fl['ane_tflops']
if fl.get('ane_util'):
d['perf/ane_util_pct'] = fl['ane_util']
pw = S.power
if pw['ane'] > 0:
d['power/ane_w'] = pw['ane']
if pw['cpu'] > 0:
d['power/cpu_w'] = pw['cpu']
wandb.log(d, step=S.step)
def _sync_globals_from_parsed(cfg):
"""Sync dashboard globals from parsed binary output so text gen uses correct dims."""
global DIM, HIDDEN, HEADS, KV_HEADS, HD, SEQ, VOCAB, NLAYERS
global Q_DIM, KV_DIM, GQA_RATIO
if 'dim' in cfg:
DIM = cfg['dim']
if 'hidden' in cfg:
HIDDEN = cfg['hidden']
if 'heads' in cfg:
HEADS = cfg['heads']
if 'kv_heads' in cfg:
KV_HEADS = cfg['kv_heads']
if 'hd' in cfg:
HD = cfg['hd']
if 'seq' in cfg:
SEQ = cfg['seq']
if 'vocab' in cfg:
VOCAB = cfg['vocab']
if 'layers' in cfg:
NLAYERS = cfg['layers']
Q_DIM = HEADS * HD
KV_DIM = KV_HEADS * HD
GQA_RATIO = HEADS // KV_HEADS if KV_HEADS else 1
def parse_line(line):
S.logs.append(line)
# Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...})
@ -329,6 +443,7 @@ def parse_line(line):
ct[k[2:]] = j[k] # strip 't_' prefix
if ct:
S.component_timing = ct
wandb_log_step()
return
elif jt == 'batch':
S.batch_num = j.get('batch', S.batch_num)
@ -346,9 +461,21 @@ def parse_line(line):
return
except (json.JSONDecodeError, KeyError):
pass
m = RE_MODEL_NAME.search(line)
if m:
S.model_config['name'] = m[1]
S.model_config['layers'] = int(m[2])
m = RE_CONFIG_GQA.search(line)
if m:
d, qd, kvd, hd, hid, seq, voc = map(int, m.groups())
S.model_config.update(dim=d, q_dim=qd, kv_dim=kvd, hd=hd, hidden=hid, seq=seq, vocab=voc,
heads=qd//hd, kv_heads=kvd//hd)
_sync_globals_from_parsed(S.model_config)
return
m = RE_CONFIG.search(line)
if m:
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
_sync_globals_from_parsed(S.model_config)
return
m = RE_PARAMS.search(line)
if m:
@ -388,6 +515,7 @@ def parse_line(line):
S.ms_per_step = dt * 1000
S.loss_history.append((S.step, S.loss))
S.best_loss = min(S.best_loss, S.loss)
wandb_log_step()
return
m = RE_BATCH.search(line)
if m:
@ -424,6 +552,11 @@ def parse_line(line):
S.compiles = int(m[1])
S.compile_ms += float(m[2])
return
m = RE_CKPT_SAVED.search(line)
if m:
if USE_WANDB:
wandb.log({'checkpoint/best_loss': float(m[1]), 'checkpoint/saved': True}, step=S.step)
return
m = RE_EFFICIENCY.search(line)
if m:
S.efficiency[m[1].strip()] = m[2].strip()
@ -523,15 +656,19 @@ def draw(term):
buf = []
def put(y, x, text, style=''):
def put(y, x, text, style='', clear_eol=False):
if 0 <= y < h and x < w:
text = text[:w - x]
suffix = term.clear_eol if clear_eol else ''
if style:
buf.append(term.move(y, x) + style + text + term.normal)
buf.append(term.move(y, x) + style + text + term.normal + suffix)
return
buf.append(term.move(y, x) + text)
buf.append(term.move(y, x) + text + suffix)
buf.append(term.home + term.clear)
buf.append(term.home)
# Clear each line individually (avoids full-screen flash from term.clear)
for y in range(h):
buf.append(term.move(y, 0) + term.clear_eol)
mid_x = w // 2
right_w = w - mid_x - 1
@ -539,14 +676,17 @@ def draw(term):
row = 0
# Model Config header
hdr = '\u2500 Model Config '
put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan)
# Model Config header — use parsed name from binary if available, else CLI arg
model_label = S.model_config.get('name', S.active_model)
keys_hint = '[r]estart [g]en [q]uit'
hdr_text = f'\u2500 {model_label} \u2500\u2500 {keys_hint} '
put(row, 0, '\u250c' + hdr_text + '\u2500' * max(0, w - len(hdr_text) - 2) + '\u2510', term.cyan)
row += 1
cfg = S.model_config
if cfg:
line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
gqa_str = f" kv_heads={cfg.get('kv_heads', '')}" if cfg.get('kv_heads', cfg.get('heads', 0)) != cfg.get('heads', 0) else ''
line1 = f"dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')}{gqa_str} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
put(row, 0, '\u2502', term.cyan)
put(row, 2, line1)
put(row, w - 1, '\u2502', term.cyan)
@ -764,9 +904,10 @@ def set_nonblock(fd):
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False,
lr=None, accum=None, no_ane_extras=False):
lr=None, accum=None, no_ane_extras=False, data=None, model=None):
if dynamic:
cmd = 'cd training_dynamic && make 2>&1 && ./train'
model_arg = f' MODEL={model}' if model else ''
cmd = f'cd training_dynamic && make{model_arg} 2>&1 && ./train'
elif ane:
cmd = 'make train_large_ane 2>&1 && ./train_large_ane'
else:
@ -781,6 +922,8 @@ def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=
cmd += f' --accum {accum}'
if no_ane_extras and ane:
cmd += ' --no-ane-extras'
if data is not None:
cmd += f' --data {data}'
cmd += f' --steps {steps}'
proc = subprocess.Popen(
['bash', '-c', cmd],
@ -802,9 +945,12 @@ def spawn_powermetrics():
return None
def main():
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
parser = argparse.ArgumentParser(description='ANE Training Dashboard')
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)')
parser.add_argument('--model', type=str, default=None,
choices=list(MODEL_CONFIGS.keys()),
help='Model config (default: stories110m for static, qwen3_06b for dynamic)')
parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd')
parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)')
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
@ -814,21 +960,62 @@ def main():
parser.add_argument('--no-powermetrics', action='store_true')
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
parser.add_argument('--data', type=str, default=None, help='Path to training data shard (.bin)')
parser.add_argument('--wandb', action='store_true', help='Log to Weights & Biases')
parser.add_argument('--wandb-project', type=str, default='ane-training', help='W&B project name')
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name')
args = parser.parse_args()
if args.infinite:
args.steps = 999999999
S.total_steps = args.steps
global CKPT_PATH
CKPT_PATH = CKPT_PATH_DYNAMIC if args.dynamic else CKPT_PATH_STATIC
# Select model
if args.model is None:
args.model = 'qwen3_06b' if args.dynamic else 'stories110m'
cfg = MODEL_CONFIGS[args.model]
# Auto-enable dynamic for models without a static pipeline
if cfg['ckpt_static'] is None:
args.dynamic = True
set_model_config(args.model)
S.active_model = args.model
# For dynamic: default to --scratch when --resume not given
if args.dynamic and not args.resume:
args.scratch = True
global CKPT_PATH, USE_WANDB
CKPT_PATH = cfg['ckpt_dynamic'] if args.dynamic else cfg['ckpt_static']
# Weights & Biases
if args.wandb:
if not HAS_WANDB:
print('pip install wandb')
sys.exit(1)
run_name = args.wandb_name or f'{args.model}-{"resume" if args.resume else "scratch"}'
wandb.init(
project=args.wandb_project,
name=run_name,
config={
'model': args.model,
'dim': DIM, 'hidden': HIDDEN, 'heads': HEADS,
'kv_heads': KV_HEADS, 'hd': HD, 'seq': SEQ,
'vocab': VOCAB, 'nlayers': NLAYERS,
'q_dim': Q_DIM, 'kv_dim': KV_DIM,
'pipeline': 'dynamic' if args.dynamic else 'static',
'resume': args.resume,
'lr': args.lr, 'accum': args.accum,
'steps': args.steps,
},
)
USE_WANDB = True
term = Terminal()
procs = []
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
scratch=args.scratch, lr=args.lr, accum=args.accum,
ane=args.ane, no_ane_extras=args.no_ane_extras)
ane=args.ane, no_ane_extras=args.no_ane_extras,
data=args.data, model=args.model)
S.train_pid = train_proc.pid
procs.append(train_proc)
@ -856,6 +1043,8 @@ def main():
p.terminate()
except Exception:
pass
if USE_WANDB:
wandb.finish()
signal.signal(signal.SIGINT, lambda *a: cleanup())
signal.signal(signal.SIGTERM, lambda *a: cleanup())
@ -970,11 +1159,12 @@ def main():
train_proc.wait()
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
lr=args.lr, accum=args.accum,
ane=args.ane, no_ane_extras=args.no_ane_extras)
ane=args.ane, no_ane_extras=args.no_ane_extras,
data=args.data, model=S.active_model)
S.train_pid = train_proc.pid
procs = [p for p in procs if p.poll() is None]
procs.append(train_proc)
S.logs.append('[dashboard] Restarted with --resume')
S.logs.append(f'[dashboard] Restarted {S.active_model} with --resume')
need_draw = True
elif key == 'g':
with S.gen_lock:

View File

@ -254,15 +254,16 @@ int main(int argc, char *argv[]) {
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
if (ane_extras) printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n");
else printf("ANE extras DISABLED (classifier/softmax/rmsnorm_bwd on CPU)\n");
if (!load_pretrained(lw, rms_final, embed, model_path)) {
printf("Pretrained load failed, using random init\n");
{
printf(" Training from scratch (random init)\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); // LLaMA-style output proj scaling
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*res_scale*(2*drand48()-1);}
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*res_scale*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
}

View File

@ -1,9 +1,17 @@
CC = xcrun clang
CFLAGS = -O2 -framework Foundation -framework IOSurface -framework Accelerate \
CFLAGS = -O2 -DACCELERATE_NEW_LAPACK -framework Foundation -framework IOSurface -framework Accelerate \
-isysroot $(shell xcrun --show-sdk-path) -fobjc-arc
train: train.m config.h io.h cpu_ops.h mil_dynamic.h
$(CC) $(CFLAGS) -o train train.m
# Model selection: make MODEL=qwen3_06b (default)
# Available: stories110m, qwen3_06b
MODEL ?= qwen3_06b
MODEL_HDR = models/$(MODEL).h
train: train.m config.h io.h cpu_ops.h mil_dynamic.h $(MODEL_HDR)
@echo "Building for model: $(MODEL)"
$(CC) $(CFLAGS) -include $(MODEL_HDR) -o train train.m
clean:
rm -f train
.PHONY: clean

View File

@ -1,4 +1,5 @@
// config.h — Stories110M model config, structs, ANE init
// config.h — Model-agnostic structs, derived sizes, ANE init
// Model-specific dims come from models/*.h, selected via -DMODEL_HEADER
#pragma once
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
@ -15,22 +16,21 @@
#include <fcntl.h>
#include <arm_neon.h>
// Stories110M config
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define HD (DIM/HEADS)
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
// Include selected model config
// MODEL_HEADER is set by Makefile via -include models/xxx.h
#ifndef MODEL_NAME
#error "No model selected. Build with: make MODEL=qwen3_06b (or stories110m)"
#endif
// Weight sizes per layer
#define WQ_SZ (DIM*DIM)
#define WO_SZ (DIM*DIM)
// Derived weight sizes per layer (GQA-aware)
#define WQ_SZ (Q_DIM*DIM)
#define WK_SZ (KV_DIM*DIM)
#define WV_SZ (KV_DIM*DIM)
#define WO_SZ (DIM*Q_DIM)
#define W1_SZ (HIDDEN*DIM)
#define W2_SZ (DIM*HIDDEN)
#define W3_SZ (HIDDEN*DIM)
#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
#define LAYER_PARAMS (WQ_SZ + WK_SZ + WV_SZ + WO_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
// Attention score channels for SDPA backward
#define SCORE_CH (HEADS*SEQ)
@ -62,6 +62,18 @@ typedef struct {
// ANE kernel handle
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
// Per-layer IOSurfaces for pre-staged weights
typedef struct {
IOSurfaceRef sdpaFwd_in, woFwd_in, ffnFused_in;
IOSurfaceRef ffnBwdW2t_in, ffnBwdW13t_in, wotBwd_in, qBwd_in, kvBwd_in;
} PerLayerSurfaces;
// Per-layer ANE requests (bound to per-layer IOSurfaces)
typedef struct {
void *sdpaFwd, *woFwd, *ffnFused;
void *ffnBwdW2t, *ffnBwdW13t, *wotBwd, *qBwd, *kvBwd;
} PerLayerRequests;
// Checkpoint header
typedef struct {
int magic, version, step, total_steps;
@ -69,14 +81,10 @@ typedef struct {
float lr, loss;
double cum_compile, cum_train, cum_wall;
int cum_steps, cum_batches, adam_t;
int pad[3];
int kv_heads, head_dim, q_dim; // GQA fields
// Note: was int pad[3] in v3, now stores GQA info in v4+
} CkptHdr;
// llama2.c model file header
typedef struct {
int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len;
} Llama2Config;
// Globals
static Class g_D, g_I, g_AR, g_AIO;
static mach_timebase_info_data_t g_tb;
@ -97,8 +105,8 @@ static void adam_free(AdamState *s) { free(s->m); free(s->v); }
static LayerWeights layer_weights_alloc(void) {
LayerWeights w;
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4);
w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WK_SZ*4);
w.Wv=(float*)malloc(WV_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4);
w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4);
return w;
@ -109,7 +117,7 @@ static void layer_weights_free(LayerWeights *w) {
}
static LayerAdam layer_adam_alloc(void) {
LayerAdam a;
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ);
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WK_SZ); a.Wv=adam_alloc(WV_SZ); a.Wo=adam_alloc(WO_SZ);
a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ);
a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM);
return a;
@ -123,8 +131,8 @@ static LayerActs layer_acts_alloc(void) {
LayerActs a;
a.layer_in=(float*)malloc(SEQ*DIM*4);
a.xnorm=(float*)malloc(SEQ*DIM*4);
a.Q=(float*)malloc(SEQ*DIM*4); a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4);
a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
a.Q=(float*)malloc(SEQ*Q_DIM*4); a.K=(float*)malloc(SEQ*KV_DIM*4); a.V=(float*)malloc(SEQ*KV_DIM*4);
a.attn_out=(float*)malloc(SEQ*Q_DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4);
a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4);
a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4);
@ -138,15 +146,15 @@ static void layer_acts_free(LayerActs *a) {
}
static LayerGrads layer_grads_alloc(void) {
LayerGrads g;
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4);
g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WK_SZ,4);
g.Wv=(float*)calloc(WV_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4);
g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4);
return g;
}
static void layer_grads_zero(LayerGrads *g) {
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4);
memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4);
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WK_SZ*4);
memset(g->Wv,0,WV_SZ*4);memset(g->Wo,0,WO_SZ*4);
memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4);
memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4);
}

View File

@ -53,13 +53,13 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c
free(ss); free(rrms); free(dot);
}
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps, float wd) {
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
for (size_t i=0; i<s->n; i++) {
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
w[i] -= lr * (mh / (sqrtf(vh) + eps) + wd * w[i]);
}
}
@ -162,3 +162,23 @@ static void embed_backward(float *d_embed, const float *dx, const uint16_t *toke
d_embed[tok*dim + d] += dx[d*seq + t];
}
}
// RoPE backward (in-place): inverse rotation on dQ/dK gradients
// Data layout: [DIM, SEQ] channel-first, DIM = nheads * hd
static void rope_backward_inplace(float *dx, int seq, int dim, int hd) {
int nheads = dim / hd;
for (int h = 0; h < nheads; h++) {
for (int i = 0; i < hd/2; i++) {
float freq = 1.0f / powf(10000.0f, 2.0f * i / (float)hd);
for (int p = 0; p < seq; p++) {
float theta = p * freq;
float cos_t = cosf(theta), sin_t = sinf(theta);
int idx0 = (h * hd + 2 * i) * seq + p;
int idx1 = (h * hd + 2 * i + 1) * seq + p;
float v0 = dx[idx0], v1 = dx[idx1];
dx[idx0] = v0 * cos_t + v1 * sin_t;
dx[idx1] = -v0 * sin_t + v1 * cos_t;
}
}
}
}

View File

@ -1,4 +1,5 @@
// io.h — IOSurface helpers, NEON conversion, kernel compile/eval
// Updated for GQA (Qwen3-0.6B): Q_DIM != DIM, separate KV heads
#pragma once
#include "config.h"
@ -74,17 +75,15 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int
IOSurfaceUnlock(s, 0, NULL);
}
// fp32 IOSurface I/O (for dynamic matmul kernels that use fp32 input/output)
// Layout: [1, IC, 1, SP] where SP = SEQ + OC
// Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC]
// fp16 IOSurface I/O (for dynamic matmul kernels with fp16 input/output)
static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq,
const float *W, int oc) {
int sp = seq + oc;
IOSurfaceLock(s, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(s);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < ic; d++) {
memcpy(buf + d*sp, act + d*seq, seq*4);
memcpy(buf + d*sp + seq, W + d*oc, oc*4);
cvt_f32_f16(buf + d*sp, act + d*seq, seq);
cvt_f32_f16(buf + d*sp + seq, W + d*oc, oc);
}
IOSurfaceUnlock(s, 0, NULL);
}
@ -92,7 +91,7 @@ static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq,
// Read output from dynamic matmul kernel: [1, OC, 1, SEQ]
static void io_read_dyn(IOSurfaceRef s, float *out, int oc, int seq) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
memcpy(out, (float*)IOSurfaceGetBaseAddress(s), oc * seq * 4);
cvt_f16_f32(out, (_Float16*)IOSurfaceGetBaseAddress(s), oc * seq);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
@ -145,3 +144,208 @@ static void ane_eval(Kern *k) {
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}
static void ane_eval_req(Kern *k, void *request) {
id mdl = (__bridge id)k->model; id req = (__bridge id)request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}
static void *make_request(Kern *k, IOSurfaceRef ioIn) {
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
id req = ((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0);
return (void*)CFBridgingRetain(req);
}
// ===== Per-layer weight staging for GQA =====
// sdpaFwd: [1, DIM, 1, SEQ + Q_DIM + KV_DIM + KV_DIM] fp16 — no Wo (separate kernel)
// Wq: [DIM, Q_DIM], Wk: [DIM, KV_DIM], Wv: [DIM, KV_DIM]
#define SDPA_FWD_SP (SEQ + Q_DIM + KV_DIM + KV_DIM)
static void stage_sdpa_fwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, const float *Wv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ, Wq + d*Q_DIM, Q_DIM);
cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ+Q_DIM, Wk + d*KV_DIM, KV_DIM);
cvt_f32_f16(buf + d*SDPA_FWD_SP + SEQ+Q_DIM+KV_DIM, Wv + d*KV_DIM, KV_DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_sdpa_fwd_acts(IOSurfaceRef s, const float *xnorm) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*SDPA_FWD_SP, xnorm + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// woFwd: [1, Q_DIM, 1, SEQ + DIM] fp16 — Wo: [Q_DIM, DIM]
#define WO_FWD_SP (SEQ + DIM)
static void stage_wo_fwd_weights(IOSurfaceRef s, const float *Wo) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*WO_FWD_SP + SEQ, Wo + d*DIM, DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_wo_fwd_acts(IOSurfaceRef s, const float *attn_out) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*WO_FWD_SP, attn_out + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnFused: [1, DIM, 1, 2*SEQ+3*HIDDEN] fp16
#define FFN_FUSED_SP (2*SEQ + 3*HIDDEN)
static void stage_ffn_fused_weights(IOSurfaceRef s,
const float *W1t, const float *W3t, const float *W2_orig) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ, W1t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ+HIDDEN, W3t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*FFN_FUSED_SP + 2*SEQ+2*HIDDEN, W2_orig + d*HIDDEN, HIDDEN);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_fused_acts(IOSurfaceRef s, const float *x2norm, const float *x2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*FFN_FUSED_SP, x2norm + d*SEQ, SEQ);
cvt_f32_f16(buf + d*FFN_FUSED_SP + SEQ, x2 + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// ffnBwdW2t: [1, DIM, 1, SEQ+HIDDEN] fp16
#define FFN_BWD_W2T_SP (SEQ + HIDDEN)
static void stage_ffn_bwd_w2t_weights(IOSurfaceRef s, const float *W2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*FFN_BWD_W2T_SP + SEQ, W2 + d*HIDDEN, HIDDEN);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_bwd_w2t_acts(IOSurfaceRef s, const float *dffn) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*FFN_BWD_W2T_SP, dffn + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnBwdW13t: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16
#define FFN_BWD_W13T_SP (2*SEQ + 2*DIM)
static void stage_ffn_bwd_w13t_weights(IOSurfaceRef s, const float *W1, const float *W3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < HIDDEN; d++) {
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + 2*SEQ, W1 + d*DIM, DIM);
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + 2*SEQ + DIM, W3 + d*DIM, DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_bwd_w13t_acts(IOSurfaceRef s, const float *dh1, const float *dh3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < HIDDEN; d++) {
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP, dh1 + d*SEQ, SEQ);
cvt_f32_f16(buf + d*FFN_BWD_W13T_SP + SEQ, dh3 + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// wotBwd: [1, DIM, 1, SEQ+Q_DIM] fp16 — Wo is [DIM, Q_DIM], matmul gives Wo^T @ dy
#define WOT_BWD_SP (SEQ + Q_DIM)
static void stage_wot_bwd_weights(IOSurfaceRef s, const float *Wo) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*WOT_BWD_SP + SEQ, Wo + d*Q_DIM, Q_DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_wot_bwd_acts(IOSurfaceRef s, const float *dy) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*WOT_BWD_SP, dy + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// qBwd: [1, Q_DIM, 1, SEQ+DIM] fp16 — Wq is [Q_DIM, DIM], matmul gives Wq^T @ dq
#define Q_BWD_SP (SEQ + DIM)
static void stage_q_bwd_weights(IOSurfaceRef s, const float *Wq) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*Q_BWD_SP + SEQ, Wq + d*DIM, DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_q_bwd_acts(IOSurfaceRef s, const float *dq) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < Q_DIM; d++)
cvt_f32_f16(buf + d*Q_BWD_SP, dq + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// kvBwd: [1, KV_DIM, 1, 2*SEQ+2*DIM] fp16 — dk @ Wk + dv @ Wv → dx_kv
#define KV_BWD_SP (2*SEQ + 2*DIM)
static void stage_kv_bwd_weights(IOSurfaceRef s, const float *Wk, const float *Wv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < KV_DIM; d++) {
cvt_f32_f16(buf + d*KV_BWD_SP + 2*SEQ, Wk + d*DIM, DIM);
cvt_f32_f16(buf + d*KV_BWD_SP + 2*SEQ + DIM, Wv + d*DIM, DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_kv_bwd_acts(IOSurfaceRef s, const float *dk, const float *dv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < KV_DIM; d++) {
cvt_f32_f16(buf + d*KV_BWD_SP, dk + d*SEQ, SEQ);
cvt_f32_f16(buf + d*KV_BWD_SP + SEQ, dv + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// Free per-layer surfaces and requests
static void free_per_layer(PerLayerSurfaces *pls, PerLayerRequests *plr) {
for (int L = 0; L < NLAYERS; L++) {
CFRelease(pls[L].sdpaFwd_in); CFRelease(pls[L].woFwd_in); CFRelease(pls[L].ffnFused_in);
CFRelease(pls[L].ffnBwdW2t_in); CFRelease(pls[L].ffnBwdW13t_in);
CFRelease(pls[L].wotBwd_in); CFRelease(pls[L].qBwd_in); CFRelease(pls[L].kvBwd_in);
CFRelease(plr[L].sdpaFwd); CFRelease(plr[L].woFwd); CFRelease(plr[L].ffnFused);
CFRelease(plr[L].ffnBwdW2t); CFRelease(plr[L].ffnBwdW13t);
CFRelease(plr[L].wotBwd); CFRelease(plr[L].qBwd); CFRelease(plr[L].kvBwd);
}
}
// GQA helpers: tile KV from KV_HEADS to HEADS, and reduce HEADS to KV_HEADS
// tile_kv: input [KV_DIM, SEQ], output [Q_DIM, SEQ]
// Each KV head is duplicated GQA_RATIO times
static void gqa_tile_kv(float *out, const float *in, int seq) {
for (int kv = 0; kv < KV_HEADS; kv++) {
for (int r = 0; r < GQA_RATIO; r++) {
int q_head = kv * GQA_RATIO + r;
memcpy(out + q_head * HD * seq, in + kv * HD * seq, HD * seq * sizeof(float));
}
}
}
// reduce_kv: input [Q_DIM, SEQ], output [KV_DIM, SEQ]
// Sum contributions from Q heads sharing each KV head
static void gqa_reduce_kv(float *out, const float *in, int seq) {
memset(out, 0, KV_DIM * seq * sizeof(float));
for (int kv = 0; kv < KV_HEADS; kv++) {
for (int r = 0; r < GQA_RATIO; r++) {
int q_head = kv * GQA_RATIO + r;
const float *src = in + q_head * HD * seq;
float *dst = out + kv * HD * seq;
for (int i = 0; i < HD * seq; i++)
dst[i] += src[i];
}
}
}

View File

@ -1,7 +1,7 @@
// mil_dynamic.h — MIL generators using dynamic matmul (weights via IOSurface)
// Instead of conv(const_weight, x), we use matmul(x, W) where both come from input.
// Input layout: [1, IC, 1, SP] fp32, SP = SEQ + total_weight_cols
// Activations in sp[0:SEQ], weight matrices packed sequentially in sp[SEQ:]
// mil_dynamic.h — MIL generators for Qwen3-0.6B with GQA
// Q_DIM=2048 != DIM=1024, KV_DIM=1024, GQA_RATIO=2
// SDPA split: sdpaFwd (QKV proj + attention, no Wo) + woFwd (Wo matmul)
// Backward: qBwd + kvBwd (split from qkvBwd)
#pragma once
#include "io.h"
@ -11,143 +11,181 @@
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
// Helper: generate a dynamic matmul within a MIL function
// Slices activation [1,ic,1,seq] and weight [1,ic,1,oc] from input, does matmul
// act_sp_off: spatial offset for activations (usually 0)
// w_sp_off: spatial offset for weight block
// Returns variable name of result [1,oc,1,seq] in fp16
static void gen_dyn_matmul(NSMutableString *m, const char *prefix,
int ic, int oc, int seq,
int act_sp_off, int w_sp_off,
const char *input_var) {
// Slice activations
[m appendFormat:@" tensor<int32, [4]> %s_ba = const()[name=string(\"%s_ba\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, act_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sa = const()[name=string(\"%s_sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_act = slice_by_size(x=%s,begin=%s_ba,size=%s_sa)[name=string(\"%s_act\")];\n", ic, seq, prefix, input_var, prefix, prefix, prefix];
// Slice weight
[m appendFormat:@" tensor<int32, [4]> %s_bw = const()[name=string(\"%s_bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, w_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sw = const()[name=string(\"%s_sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_wt = slice_by_size(x=%s,begin=%s_bw,size=%s_sw)[name=string(\"%s_wt\")];\n", ic, oc, prefix, input_var, prefix, prefix, prefix];
// Reshape act: [1,ic,1,seq] → [1,1,ic,seq] → transpose → [1,1,seq,ic]
[m appendFormat:@" tensor<int32, [4]> %s_ra = const()[name=string(\"%s_ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a2 = reshape(shape=%s_ra,x=%s_act)[name=string(\"%s_a2\")];\n", ic, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_pm = const()[name=string(\"%s_pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n", prefix, prefix];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a3 = transpose(perm=%s_pm,x=%s_a2)[name=string(\"%s_a3\")];\n", seq, ic, prefix, prefix, prefix, prefix];
// Reshape weight: [1,ic,1,oc] → [1,1,ic,oc]
[m appendFormat:@" tensor<int32, [4]> %s_rw = const()[name=string(\"%s_rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_W = reshape(shape=%s_rw,x=%s_wt)[name=string(\"%s_W\")];\n", ic, oc, prefix, prefix, prefix, prefix];
// matmul: [1,1,seq,ic] @ [1,1,ic,oc] → [1,1,seq,oc]
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yh = matmul(transpose_x=bF,transpose_y=bF,x=%s_a3,y=%s_W)[name=string(\"%s_yh\")];\n", seq, oc, prefix, prefix, prefix, prefix];
// Transpose back + reshape: [1,1,seq,oc] → [1,1,oc,seq] → [1,oc,1,seq]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yt = transpose(perm=%s_pm,x=%s_yh)[name=string(\"%s_yt\")];\n", oc, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_ro = const()[name=string(\"%s_ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, oc, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_y = reshape(shape=%s_ro,x=%s_yt)[name=string(\"%s_y\")];\n", oc, seq, prefix, prefix, prefix, prefix];
}
// ===== Dynamic matmul kernel: y = x @ W =====
// Input: [1, IC, 1, SEQ+OC] fp32 — act[0:SEQ] + W[SEQ:SEQ+OC]
// Output: [1, OC, 1, SEQ] fp32
// Simple dynamic matmul kernel: y = x @ W, input [1,IC,1,SEQ+OC], output [1,OC,1,SEQ]
static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
int sp = seq + oc;
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", ic, sp];
gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "xh");
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=mm_y)[name=string(\"cout\")];\n", oc, seq];
[m appendString:@" } -> (y);\n}\n"];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", ic, sp];
gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "x");
[m appendString:@" } -> (mm_y);\n}\n"];
return m;
}
// ===== SDPA forward (dynamic weights) =====
// Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul
// Input: [1, DIM, 1, SEQ + 4*DIM] fp32
// sp[0:SEQ] = xnorm (rmsnorm output, DIM channels)
// sp[SEQ:SEQ+DIM] = Wq[DIM,DIM]
// sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM]
// sp[SEQ+2D:SEQ+3D] = Wv[DIM,DIM]
// sp[SEQ+3D:SEQ+4D] = Wo[DIM,DIM]
// Output: [1, 6*DIM, 1, SEQ] fp16 = concat(o_out, Q, K, V, attn_out, xnorm_pass)
// NOTE: mask is still a const weight (it doesn't change)
// ===== SDPA forward with GQA (no Wo) =====
// Input: [1, DIM, 1, SEQ + Q_DIM + KV_DIM + KV_DIM] fp16
// sp[0:SEQ] = xnorm [DIM, SEQ]
// sp[SEQ:SEQ+Q_DIM] = Wq [DIM, Q_DIM]
// sp[SEQ+Q_DIM:SEQ+Q_DIM+KVD] = Wk [DIM, KV_DIM]
// sp[SEQ+Q_DIM+KVD:...] = Wv [DIM, KV_DIM]
// Output: [1, Q_DIM+Q_DIM+KV_DIM+KV_DIM+DIM, 1, SEQ] fp16
// = concat(attn_out, Q_rope, K_rope, V, xnorm_pass)
static NSString *gen_sdpa_fwd_dynamic(void) {
float sc = 1.0f/sqrtf((float)HD);
int w_total = 4*DIM; // Wq+Wk+Wv+Wo
int sp_in = SEQ + w_total;
int sp_in = SDPA_FWD_SP;
int out_ch = Q_DIM + Q_DIM + KV_DIM + KV_DIM + DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
// Cast to fp16
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
// Slice xnorm [1,DIM,1,SEQ]
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=x,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice Wq [1,DIM,1,DIM]
// Slice Wq [1,DIM,1,Q_DIM]
[m appendFormat:@" tensor<int32, [4]> bq = const()[name=string(\"bq\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=xh,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> swq = const()[name=string(\"swq\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=x,begin=bq,size=swq)[name=string(\"Wq\")];\n", DIM, Q_DIM];
// Slice Wk
[m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=xh,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM];
// Slice Wk [1,DIM,1,KV_DIM]
[m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+Q_DIM];
[m appendFormat:@" tensor<int32, [4]> swk = const()[name=string(\"swk\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=x,begin=bk,size=swk)[name=string(\"Wk\")];\n", DIM, KV_DIM];
// Slice Wv
[m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=xh,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM];
// Slice Wv [1,DIM,1,KV_DIM]
[m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+Q_DIM+KV_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=x,begin=bv,size=swk)[name=string(\"Wv\")];\n", DIM, KV_DIM];
// Slice Wo
[m appendFormat:@" tensor<int32, [4]> bo = const()[name=string(\"bo\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wo = slice_by_size(x=xh,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM];
// Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D]
// Reshape xnorm for matmul: [1,DIM,1,SEQ] → [1,1,DIM,SEQ] → [1,1,SEQ,DIM]
[m appendFormat:@" tensor<int32, [4]> r2 = const()[name=string(\"r2\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=r2,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
// Reshape weights: [1,D,1,D] → [1,1,D,D]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wq2 = reshape(shape=rw,x=Wq)[name=string(\"Wq2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wk2 = reshape(shape=rw,x=Wk)[name=string(\"Wk2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wv2 = reshape(shape=rw,x=Wv)[name=string(\"Wv2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wo2 = reshape(shape=rw,x=Wo)[name=string(\"Wo2\")];\n", DIM, DIM];
// Reshape weights
[m appendFormat:@" tensor<int32, [4]> rwq = const()[name=string(\"rwq\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, Q_DIM];
[m appendFormat:@" tensor<int32, [4]> rwk = const()[name=string(\"rwk\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wq2 = reshape(shape=rwq,x=Wq)[name=string(\"Wq2\")];\n", DIM, Q_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wk2 = reshape(shape=rwk,x=Wk)[name=string(\"Wk2\")];\n", DIM, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wv2 = reshape(shape=rwk,x=Wv)[name=string(\"Wv2\")];\n", DIM, KV_DIM];
// QKV matmul: [1,1,S,D] @ [1,1,D,D] → [1,1,S,D]
// QKV matmul
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, DIM];
// Q: [1,1,SEQ,DIM] @ [1,1,DIM,Q_DIM] → [1,1,SEQ,Q_DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, Q_DIM];
// K: [1,1,SEQ,DIM] @ [1,1,DIM,KV_DIM] → [1,1,SEQ,KV_DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, KV_DIM];
// V: same as K
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, KV_DIM];
// Transpose back: [1,1,S,D] → [1,1,D,S] → reshape [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> os = const()[name=string(\"os\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = reshape(shape=os,x=qt)[name=string(\"qf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = reshape(shape=os,x=kt)[name=string(\"kf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = reshape(shape=os,x=vt)[name=string(\"vf\")];\n", DIM, SEQ];
// Transpose back: [1,1,SEQ,X] → [1,1,X,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", KV_DIM, SEQ];
// SDPA: reshape to heads, matmul, mask, softmax, matmul
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
// Reshape to [1,X,1,SEQ]
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> kvsh = const()[name=string(\"kvsh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = reshape(shape=qsh,x=qt)[name=string(\"qf\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = reshape(shape=kvsh,x=kt)[name=string(\"kf\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = reshape(shape=kvsh,x=vt)[name=string(\"vf\")];\n", KV_DIM, SEQ];
// Reshape to heads for attention
// Q: [1,Q_DIM,1,SEQ] → [1,HEADS,HD,SEQ] → transpose → [1,HEADS,SEQ,HD]
[m appendFormat:@" tensor<int32, [4]> qhsh = const()[name=string(\"qhsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qhsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD];
// K: [1,KV_DIM,1,SEQ] → [1,KV_HEADS,HD,SEQ] → [1,KV_HEADS,SEQ,HD]
[m appendFormat:@" tensor<int32, [4]> khsh = const()[name=string(\"khsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=khsh,x=kf)[name=string(\"rk\")];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", KV_HEADS, SEQ, HD];
// V: same reshape as K
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=khsh,x=vf)[name=string(\"rv\")];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", KV_HEADS, SEQ, HD];
// Q @ K^T
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
// RoPE on Q: [1,HEADS,SEQ,HD]
int pairs_q = SEQ * HD / 2;
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> rope_cos = const()[name=string(\"rc\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/rope_cos.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> rope_sin = const()[name=string(\"rs\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/rope_sin.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD];
[m appendFormat:@" tensor<int32, [4]> rp_sh = const()[name=string(\"rp_sh\"), val=tensor<int32, [4]>([1,%d,%d,2])];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<int32, [4]> rp_s1 = const()[name=string(\"rp_s1\"), val=tensor<int32, [4]>([1,%d,%d,1])];\n", HEADS, pairs_q];
[m appendString:@" tensor<int32, [4]> rp_b0 = const()[name=string(\"rp_b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendString:@" tensor<int32, [4]> rp_b1 = const()[name=string(\"rp_b1\"), val=tensor<int32, [4]>([0,0,0,1])];\n"];
[m appendString:@" fp16 neg1 = const()[name=string(\"neg1\"), val=fp16(-1)];\n"];
[m appendString:@" int32 rpax = const()[name=string(\"rpax\"), val=int32(3)];\n"];
[m appendString:@" bool rpil = const()[name=string(\"rpil\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<int32, [4]> rp_bk_q = const()[name=string(\"rp_bk_q\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, HD];
// rotate_half(q)
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> q_p = reshape(shape=rp_sh,x=q)[name=string(\"q_p\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_e = slice_by_size(x=q_p,begin=rp_b0,size=rp_s1)[name=string(\"q_e\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_o = slice_by_size(x=q_p,begin=rp_b1,size=rp_s1)[name=string(\"q_o\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nq = mul(x=q_o,y=neg1)[name=string(\"nq\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> qrp = concat(axis=rpax,interleave=rpil,values=(nq,q_e))[name=string(\"qrp\")];\n", HEADS, pairs_q];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q_rot = reshape(shape=rp_bk_q,x=qrp)[name=string(\"q_rot\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qc = mul(x=q,y=rope_cos)[name=string(\"qc\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qrs = mul(x=q_rot,y=rope_sin)[name=string(\"qrs\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q_rope = add(x=qc,y=qrs)[name=string(\"q_rope\")];\n", HEADS, SEQ, HD];
// RoPE on K: [1,KV_HEADS,SEQ,HD]
int pairs_k = SEQ * HD / 2;
[m appendFormat:@" tensor<int32, [4]> rp_sh_k = const()[name=string(\"rp_sh_k\"), val=tensor<int32, [4]>([1,%d,%d,2])];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<int32, [4]> rp_s1_k = const()[name=string(\"rp_s1_k\"), val=tensor<int32, [4]>([1,%d,%d,1])];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<int32, [4]> rp_bk_k = const()[name=string(\"rp_bk_k\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> k_p = reshape(shape=rp_sh_k,x=k)[name=string(\"k_p\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_e = slice_by_size(x=k_p,begin=rp_b0,size=rp_s1_k)[name=string(\"k_e\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_o = slice_by_size(x=k_p,begin=rp_b1,size=rp_s1_k)[name=string(\"k_o\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nk = mul(x=k_o,y=neg1)[name=string(\"nk\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> krp = concat(axis=rpax,interleave=rpil,values=(nk,k_e))[name=string(\"krp\")];\n", KV_HEADS, pairs_k];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rot = reshape(shape=rp_bk_k,x=krp)[name=string(\"k_rot\")];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kc = mul(x=k,y=rope_cos)[name=string(\"kc\")];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krs = mul(x=k_rot,y=rope_sin)[name=string(\"krs\")];\n", KV_HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rope = add(x=kc,y=krs)[name=string(\"k_rope\")];\n", KV_HEADS, SEQ, HD];
// GQA: tile K,V from KV_HEADS to HEADS
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
// For GQA_RATIO=2: concat(k_rope, k_rope) along head dim
NSMutableString *k_vals = [NSMutableString string];
NSMutableString *v_vals = [NSMutableString string];
for (int r = 0; r < GQA_RATIO; r++) {
if (r > 0) { [k_vals appendString:@","]; [v_vals appendString:@","]; }
[k_vals appendString:@"k_rope"]; [v_vals appendString:@"v"];
}
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_tiled = concat(axis=cax,interleave=cid,values=(%@))[name=string(\"ktile\")];\n", HEADS, SEQ, HD, k_vals];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v_tiled = concat(axis=cax,interleave=cid,values=(%@))[name=string(\"vtile\")];\n", HEADS, SEQ, HD, v_vals];
// Q_rope @ K_tiled^T → [1,HEADS,SEQ,SEQ]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q_rope,y=k_tiled)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ];
// Causal mask (still const — doesn't change)
// Causal mask
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ];
@ -155,87 +193,67 @@ static NSString *gen_sdpa_fwd_dynamic(void) {
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ];
// scores @ V
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS, SEQ, HD];
// scores @ V_tiled → [1,HEADS,SEQ,HD]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v_tiled)[name=string(\"mm2\")];\n", HEADS, SEQ, HD];
// Reshape back to [1,DIM,1,SEQ]
// Reshape attn_out to [1,Q_DIM,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=qsh,x=at)[name=string(\"ra\")];\n", Q_DIM, SEQ];
// Wo matmul: af → [1,1,S,D] @ Wo[1,1,D,D] → [1,1,S,D] → [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> af2 = reshape(shape=r2,x=af)[name=string(\"af2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> aft = transpose(perm=pm,x=af2)[name=string(\"aft\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> om = matmul(transpose_x=bF,transpose_y=bF,x=aft,y=Wo2)[name=string(\"om\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ];
// Convert RoPE'd Q,K back to flat layout for backward
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qrt = transpose(perm=pm,x=q_rope)[name=string(\"qrt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qrf = reshape(shape=qsh,x=qrt)[name=string(\"qrf\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krt = transpose(perm=pm,x=k_rope)[name=string(\"krt\")];\n", KV_HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> krf = reshape(shape=kvsh,x=krt)[name=string(\"krf\")];\n", KV_DIM, SEQ];
// Output: concat(o_out, qf, kf, vf, af, xn) — same as original for backward compatibility
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ];
// Cast to fp32
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 6*DIM, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
// Output: concat(attn_out[Q_DIM], Q_rope[Q_DIM], K_rope[KV_DIM], V[KV_DIM], xnorm[DIM])
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(af,qrf,krf,vf,xn))[name=string(\"cat\")];\n", out_ch, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// ===== FFN forward (dynamic weights) =====
// RMSNorm on CPU. This kernel: xnorm @ W1 → SiLU, xnorm @ W3 → gate, gate*silu @ W2 → out
// Input: [1, DIM, 1, SEQ + HIDDEN + HIDDEN + DIM] fp32
// sp[0:SEQ] = xnorm [DIM,SEQ]
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN] = W3[DIM,HIDDEN]
// sp[SEQ+2*HIDDEN:SEQ+2*HIDDEN+DIM]= W2[HIDDEN→DIM] — but W2 is [DIM,HIDDEN], we need HIDDEN input channels
// PROBLEM: W2 has shape [DIM,HIDDEN] = HIDDEN input channels, but our kernel has DIM input channels.
// Solution: separate kernels for W1/W3 (DIM→HIDDEN) and W2 (HIDDEN→DIM)
// OR: do W1,W3 in one kernel, SiLU on CPU/ANE, W2 in another kernel.
// Simpler: 3 separate matmul kernels per FFN direction. But that's too many dispatches.
// Better: one kernel for W1+W3 (same input dim), CPU SiLU, one kernel for W2.
// woFwd: attn_out[Q_DIM,SEQ] @ Wo → o_out[DIM,SEQ]
// Simple dyn_matmul: IC=Q_DIM, OC=DIM
static NSString *gen_wo_fwd_dynamic(void) {
return gen_dyn_matmul_mil(Q_DIM, DIM, SEQ);
}
// FFN part 1: xnorm @ W1, xnorm @ W3 (both DIM→HIDDEN)
// Input: [1, DIM, 1, SEQ + 2*HIDDEN] fp32
// sp[0:SEQ] = xnorm
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN]= W3[DIM,HIDDEN]
// Output: [1, 2*HIDDEN, 1, SEQ] fp32 = concat(h1, h3)
static NSString *gen_ffn_w13_dynamic(void) {
int sp_in = SEQ + 2*HIDDEN;
// ===== Fused FFN forward: W1,W3 + SiLU + W2 + residual =====
// Same structure as before, just with Qwen3 DIM=1024, HIDDEN=3072
static NSString *gen_ffn_fused_dynamic(void) {
int sp_in = FFN_FUSED_SP;
int out_ch = DIM + 3*HIDDEN;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice xnorm
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice W1
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> s1 = const()[name=string(\"s1\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=xh,begin=b1,size=s1)[name=string(\"W1\")];\n", DIM, HIDDEN];
// Slice W3
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=xh,begin=b3,size=s1)[name=string(\"W3\")];\n", DIM, HIDDEN];
// Reshape for matmul
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Slice x2norm, x2, W1, W3, W2_orig
[m appendString:@" tensor<int32, [4]> b_xn = const()[name=string(\"b_xn\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> s_ds = const()[name=string(\"s_ds\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2norm = slice_by_size(x=x,begin=b_xn,size=s_ds)[name=string(\"x2norm\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b_x2 = const()[name=string(\"b_x2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2 = slice_by_size(x=x,begin=b_x2,size=s_ds)[name=string(\"x2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b_w1 = const()[name=string(\"b_w1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> s_wh = const()[name=string(\"s_wh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=x,begin=b_w1,size=s_wh)[name=string(\"W1\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> b_w3 = const()[name=string(\"b_w3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=x,begin=b_w3,size=s_wh)[name=string(\"W3\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> b_w2 = const()[name=string(\"b_w2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+2*HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2r = slice_by_size(x=x,begin=b_w2,size=s_wh)[name=string(\"W2r\")];\n", DIM, HIDDEN];
// xnorm matmul
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=x2norm)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN];
// Transpose back
// Reshape back
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> rh = const()[name=string(\"rh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
@ -247,115 +265,62 @@ static NSString *gen_ffn_w13_dynamic(void) {
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ];
// Concat output: (h1, h3, gate)
// gate @ W2: W2 is [DIM, HIDDEN] stored as-is, transpose inside kernel
[m appendFormat:@" tensor<int32, [4]> rg = const()[name=string(\"rg\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> g2 = reshape(shape=rg,x=gate)[name=string(\"g2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> gt = transpose(perm=pm,x=g2)[name=string(\"gtt\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W22 = reshape(shape=rw,x=W2r)[name=string(\"W22\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W2t = transpose(perm=pm,x=W22)[name=string(\"W2t\")];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> fm = matmul(transpose_x=bF,transpose_y=bF,x=gt,y=W2t)[name=string(\"fm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ft = transpose(perm=pm,x=fm)[name=string(\"ft\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> rd2 = const()[name=string(\"rd2\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> ffn_out = reshape(shape=rd2,x=ft)[name=string(\"ffn_out\")];\n", DIM, SEQ];
// Residual: x_next = x2 + alpha * ffn_out
float alpha = 1.0f / sqrtf(2.0f * NLAYERS);
[m appendFormat:@" fp16 res_alpha = const()[name=string(\"res_alpha\"), val=fp16(%g)];\n", alpha];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> ffn_scaled = mul(x=ffn_out,y=res_alpha)[name=string(\"ffn_sc\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x_next = add(x=x2,y=ffn_scaled)[name=string(\"x_next\")];\n", DIM, SEQ];
// Output: concat(x_next, h1, h3, gate)
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(h1,h3,gate))[name=string(\"cat\")];\n", 2*HIDDEN+HIDDEN, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 3*HIDDEN, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(x_next,h1,h3,gate))[name=string(\"cat\")];\n", out_ch, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// FFN part 2: gate @ W2 (HIDDEN→DIM)
// Input: [1, HIDDEN, 1, SEQ + DIM] fp32
// sp[0:SEQ] = gate [HIDDEN,SEQ]
// sp[SEQ:SEQ+DIM] = W2[HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32
static NSString *gen_ffn_w2_dynamic(void) {
int sp_in = SEQ + DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
// ===== Backward kernels =====
[m appendString:@" tensor<int32, [4]> ba = const()[name=string(\"ba\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sa = const()[name=string(\"sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> bw = const()[name=string(\"bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2 = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"W2\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> at = transpose(perm=pm,x=a2)[name=string(\"at\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W22 = reshape(shape=rw,x=W2)[name=string(\"W22\")];\n", HIDDEN, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ym = matmul(transpose_x=bF,transpose_y=bF,x=at,y=W22)[name=string(\"ym\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=ym)[name=string(\"yt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=yr)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== FFN backward (dynamic weights) =====
// Input: [1, DIM+2*HIDDEN, 1, SEQ + HIDDEN + DIM + DIM] fp32
// Actually simpler to split into separate backward kernels like forward.
// FFN backward part 1: dffn @ W2^T → dsilu (HIDDEN), then SiLU derivative
// Input: [1, DIM, 1, SEQ + HIDDEN] fp32
// sp[0:SEQ] = dffn [DIM, SEQ]
// sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN]
// Output: [1, HIDDEN, 1, SEQ] fp32 = dsilu_raw
// ffnBwdW2t: dffn @ W2 → dsilu_raw (IC=DIM, OC=HIDDEN)
static NSString *gen_ffn_bwd_w2t_dynamic(void) {
return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ);
}
// FFN backward part 2: dh1 @ W1^T + dh3 @ W3^T → dx
// We need h1,h3 for SiLU derivative, but those are on CPU.
// Actually the SiLU derivative + gating is element-wise, do on CPU.
// Then: dh1 @ W1^T and dh3 @ W3^T are two separate matmuls (HIDDEN→DIM).
// Combine into one kernel:
// Input: [1, HIDDEN, 1, SEQ + SEQ + DIM + DIM] fp32
// sp[0:SEQ] = dh1 [HIDDEN,SEQ]
// sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ]
// sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM]
// sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dx1 + dx3
// ffnBwdW13t: dh1 @ W1 + dh3 @ W3 → dx_ffn (IC=HIDDEN, two matmuls added)
static NSString *gen_ffn_bwd_w13t_dynamic(void) {
int sp_in = 2*SEQ + 2*DIM;
int sp_in = FFN_BWD_W13T_SP;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
// Slice dh1 [HIDDEN, SEQ]
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = slice_by_size(x=xh,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
// Slice dh3
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = slice_by_size(x=xh,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
// Slice W1^T [HIDDEN, DIM]
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1t = slice_by_size(x=xh,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM];
// Slice W3^T
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1t = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3t = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3t = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
// dh1 matmul: [S,H] @ [H,D] → [S,D]
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh12 = reshape(shape=ra,x=dh1)[name=string(\"dh12\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh1t = transpose(perm=pm,x=dh12)[name=string(\"dh1t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh32 = reshape(shape=ra,x=dh3)[name=string(\"dh32\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh3t = transpose(perm=pm,x=dh32)[name=string(\"dh3t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W1t2 = reshape(shape=rw,x=W1t)[name=string(\"W1t2\")];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W3t2 = reshape(shape=rw,x=W3t)[name=string(\"W3t2\")];\n", HIDDEN, DIM];
@ -363,53 +328,88 @@ static NSString *gen_ffn_bwd_w13t_dynamic(void) {
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx1m = matmul(transpose_x=bF,transpose_y=bF,x=dh1t,y=W1t2)[name=string(\"dx1m\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx3m = matmul(transpose_x=bF,transpose_y=bF,x=dh3t,y=W3t2)[name=string(\"dx3m\")];\n", SEQ, DIM];
// Add
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxm = add(x=dx1m,y=dx3m)[name=string(\"dxm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
[m appendString:@" } -> (dx);\n}\n"];
return m;
}
// ===== SDPA backward part 1 (dynamic Wo^T) =====
// Same as original gen_sdpa_bwd1 but Wo^T comes from input instead of const
// Input: [1, 4*DIM, 1, SEQ + DIM] fp32 — Q,K,V,dx2 in channels, Wo^T in spatial
// Wait — channels must match for all data. Q,K,V are [DIM,SEQ], dx2 is [DIM,SEQ].
// Total input channels = 4*DIM. But Wo^T is [DIM,DIM] = DIM channels of DIM spatial.
// Problem: can't mix 4*DIM channels for data with DIM channels for Wo^T.
// Solution: Wo^T matmul as separate kernel, then SDPA part purely element-wise on ANE.
// Wo^T matmul: dx2 @ Wo^T → da (DIM→DIM)
// wotBwd: dy @ Wo → da (IC=DIM, OC=Q_DIM)
static NSString *gen_wot_dynamic(void) {
return gen_dyn_matmul_mil(DIM, DIM, SEQ);
return gen_dyn_matmul_mil(DIM, Q_DIM, SEQ);
}
// SDPA backward part 1 (no weights, all data): Q,K,V,da → dV,probs,dp
// Same as original but without Wo^T conv (already done)
// Input: [1, 4*DIM, 1, SEQ] fp16
static NSString *gen_sdpa_bwd1_noweight(void) {
float sc = 1.0f/sqrtf((float)HD);
// qBwd: dq @ Wq → dx_q (IC=Q_DIM, OC=DIM)
static NSString *gen_q_bwd_dynamic(void) {
return gen_dyn_matmul_mil(Q_DIM, DIM, SEQ);
}
// kvBwd: dk @ Wk + dv @ Wv → dx_kv (IC=KV_DIM)
// Input: [1, KV_DIM, 1, 2*SEQ+2*DIM] fp16
// Same pattern as ffnBwdW13t but with KV_DIM channels
static NSString *gen_kv_bwd_dynamic(void) {
int sp_in = KV_BWD_SP;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 4*DIM, SEQ];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", KV_DIM, sp_in];
// Slice Q,K,V,da
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", KV_DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dk\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dv\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"Wkt\")];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"Wvt\")];\n", KV_DIM, DIM];
// Reshape to heads
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dk2 = reshape(shape=ra,x=dk)[name=string(\"dk2\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, KV_DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dv2 = reshape(shape=ra,x=dv)[name=string(\"dv2\")];\n", KV_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, KV_DIM];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", KV_DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", KV_DIM, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxm = add(x=dxk,y=dxv)[name=string(\"dxm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" } -> (dx);\n}\n"];
return m;
}
// SDPA backward part 1: recompute attention + dV, dp
// Uses tiled K,V at HEADS dimension (CPU pre-tiles)
// Input: [1, 2*Q_DIM+2*Q_DIM, 1, SEQ] fp16 = (Q, K_tiled, V_tiled, da)
// Output: [1, Q_DIM+2*SCORE_CH, 1, SEQ] fp16 = (dV_full, probs, dp)
static NSString *gen_sdpa_bwd1_noweight(void) {
float sc = 1.0f/sqrtf((float)HD);
int in_ch = 4*Q_DIM; // Q + K_tiled + V_tiled + da, all at Q_DIM (HEADS*HD)
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", in_ch, SEQ];
// Slice Q,K_tiled,V_tiled,da — all [Q_DIM, SEQ]
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", Q_DIM, SEQ];
// Reshape to heads [1,HEADS,HD,SEQ] → [1,HEADS,SEQ,HD]
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
@ -421,7 +421,7 @@ static NSString *gen_sdpa_bwd1_noweight(void) {
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dr = reshape(shape=rsh,x=da)[name=string(\"rd\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dat = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS, SEQ, HD];
// Forward attention scores (recompute)
// Recompute attention scores
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
@ -436,49 +436,57 @@ static NSString *gen_sdpa_bwd1_noweight(void) {
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=dat)[name=string(\"dv\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp4 = matmul(transpose_x=bF,transpose_y=bT,x=dat,y=v)[name=string(\"dp\")];\n", HEADS, SEQ, SEQ];
// Reshape dV back
// Reshape dV to [Q_DIM, SEQ] (will be reduced to KV_DIM on CPU)
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", Q_DIM, SEQ];
// Flatten probs and dp for output
// Flatten probs and dp
[m appendFormat:@" tensor<int32, [4]> scs = const()[name=string(\"scs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", Q_DIM+2*SCORE_CH, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 2: same as original (no weights, pure computation)
// SDPA backward part 2: probs, dp, Q, K_tiled → dQ, dK_full
// Input: [1, 2*SCORE_CH + 2*Q_DIM, 1, SEQ]
// Output: [1, 2*Q_DIM, 1, SEQ] = (dQ, dK_full)
static NSString *gen_sdpa_bwd2(void) {
float sc = 1.0f/sqrtf((float)HD);
int bwd2_in = 2*SCORE_CH + 2*DIM;
int bwd2_in = 2*SCORE_CH + 2*Q_DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", bwd2_in, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_sc = const()[name=string(\"szsc\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_d = const()[name=string(\"szd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_q = const()[name=string(\"szq\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_q)[name=string(\"s2\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+Q_DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_q)[name=string(\"s3\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ssh = const()[name=string(\"ssh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
// Softmax backward: ds = (dp - sum(dp*probs)) * probs * scale
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([-1])];\n"];
[m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
@ -487,96 +495,26 @@ static NSString *gen_sdpa_bwd2(void) {
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", Q_DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", Q_DIM, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*Q_DIM, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// QKV backward (dynamic): dq @ Wq^T + dk @ Wk^T + dv @ Wv^T → dx
// Input: [1, DIM, 1, 3*SEQ + 3*DIM] fp32
// sp[0:SEQ] = dq [DIM,SEQ]
// sp[SEQ:2*SEQ] = dk [DIM,SEQ]
// sp[2*SEQ:3*SEQ] = dv [DIM,SEQ]
// sp[3*SEQ:3*SEQ+DIM] = Wq^T [DIM,DIM]
// sp[3*SEQ+DIM:3*SEQ+2D] = Wk^T [DIM,DIM]
// sp[3*SEQ+2D:3*SEQ+3D] = Wv^T [DIM,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dxq + dxk + dxv
static NSString *gen_qkvb_dynamic(void) {
int sp_in = 3*SEQ + 3*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice dq, dk, dv
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=xh,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=xh,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=xh,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ];
// Slice Wq^T, Wk^T, Wv^T
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wqt = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b4 = const()[name=string(\"b4\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=xh,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b5 = const()[name=string(\"b5\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=xh,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Reshape and matmul for each
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
// dq @ Wq^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dq2 = reshape(shape=rd,x=dq)[name=string(\"dq2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dqt = transpose(perm=pm,x=dq2)[name=string(\"dqt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wqt2 = reshape(shape=rw,x=Wqt)[name=string(\"Wqt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxq = matmul(transpose_x=bF,transpose_y=bF,x=dqt,y=Wqt2)[name=string(\"dxq\")];\n", SEQ, DIM];
// dk @ Wk^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dk2 = reshape(shape=rd,x=dk)[name=string(\"dk2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM];
// dv @ Wv^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dv2 = reshape(shape=rd,x=dv)[name=string(\"dv2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM];
// Sum: dxq + dxk + dxv
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxall = add(x=dxqk,y=dxv)[name=string(\"aall\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// Causal mask blob (used by sdpa_fwd and sdpa_bwd1)
// Causal mask blob
static NSData *g_mask_blob = nil;
static NSData *get_mask_blob(void) {
if (!g_mask_blob) {
@ -588,3 +526,39 @@ static NSData *get_mask_blob(void) {
}
return g_mask_blob;
}
// RoPE cos/sin blobs [1, 1, SEQ, HD]
static NSData *g_rope_cos_blob = nil;
static NSData *g_rope_sin_blob = nil;
static NSData *get_rope_cos_blob(void) {
if (!g_rope_cos_blob) {
_Float16 *buf = (_Float16*)calloc(SEQ * HD, sizeof(_Float16));
for (int p = 0; p < SEQ; p++)
for (int i = 0; i < HD/2; i++) {
float theta = p / powf(10000.0f, 2.0f * i / (float)HD);
_Float16 cv = (_Float16)cosf(theta);
buf[p * HD + 2*i] = cv;
buf[p * HD + 2*i + 1] = cv;
}
g_rope_cos_blob = build_blob_fp16(buf, SEQ * HD);
free(buf);
}
return g_rope_cos_blob;
}
static NSData *get_rope_sin_blob(void) {
if (!g_rope_sin_blob) {
_Float16 *buf = (_Float16*)calloc(SEQ * HD, sizeof(_Float16));
for (int p = 0; p < SEQ; p++)
for (int i = 0; i < HD/2; i++) {
float theta = p / powf(10000.0f, 2.0f * i / (float)HD);
_Float16 sv = (_Float16)sinf(theta);
buf[p * HD + 2*i] = sv;
buf[p * HD + 2*i + 1] = sv;
}
g_rope_sin_blob = build_blob_fp16(buf, SEQ * HD);
free(buf);
}
return g_rope_sin_blob;
}

View File

@ -0,0 +1,19 @@
// qwen3_06b.h — Qwen3-0.6B (28 layers, GQA 16q/8kv, head_dim=128)
#pragma once
#define MODEL_NAME "Qwen3-0.6B"
#define DIM 1024
#define HIDDEN 3072
#define HEADS 16
#define KV_HEADS 8
#define HD 128 // explicit head_dim (NOT DIM/HEADS)
#define GQA_RATIO (HEADS / KV_HEADS) // = 2
#define Q_DIM (HEADS * HD) // = 2048
#define KV_DIM (KV_HEADS * HD) // = 1024 (= DIM for this model)
#define SEQ 256
#define NLAYERS 28
#define VOCAB 151936
#define CKPT_PATH "ane_qwen3_06b_dyn_ckpt.bin"
#define DEFAULT_DATA_PATH "../tinystories_data00.bin"

View File

@ -0,0 +1,19 @@
// stories110m.h — Stories110M (Llama2-style, 12 layers, MHA)
#pragma once
#define MODEL_NAME "Stories110M"
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define KV_HEADS 12
#define HD (DIM/HEADS) // = 64
#define GQA_RATIO 1 // MHA: no GQA
#define Q_DIM (HEADS * HD) // = 768 = DIM
#define KV_DIM (KV_HEADS * HD) // = 768 = DIM
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin"
#define DEFAULT_DATA_PATH "../tinystories_data00.bin"

File diff suppressed because it is too large Load Diff