Add dynamic weight training pipeline — 110ms/step without recompilation

Dynamic weight pipeline that eliminates the ~3.7s recompile-every-10-steps
bottleneck. Weights are passed via IOSurface spatial dimension instead of
baked as constants, so kernels compile once at startup (345ms) and run
indefinitely without exec() restart.

Key components:
- training_dynamic/ — full pipeline (config, IO, MIL generators, train loop)
  - 9 dynamic kernels shared across all 12 layers
  - Vocab compaction 32K→9.2K for faster classifier
  - Vectorized cross-entropy with vDSP/NEON
  - Adam optimizer with gradient clipping + cosine LR schedule
  - Checkpoint save/resume

- test_dynamic_matmul.m — validates dynamic weight matmul vs cblas
- test_weight_patch.m — tests weight update via IOSurface

- dashboard.py — updated with --dynamic flag for v2 pipeline support,
  improved step regex parsing, --scratch/--lr/--accum CLI args

Performance: 110ms/step steady-state (no recompile overhead)
  ane_fwd=21 ane_bwd=28 io_fwd=12 io_bwd=15 silu=10 cls=13 rms=5 ms
This commit is contained in:
maderix 2026-03-02 23:49:55 -08:00
parent c33077430e
commit cb474e1537
9 changed files with 2749 additions and 5 deletions

View File

@ -279,7 +279,7 @@ RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
RE_ACCUM = re.compile(r'Accum (\d+).*LR=([\d.e+-]+)')
RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)')
RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)(?:\s+lr=([\d.e+-]+))?(?:\s+([\d.]+)ms/step)?')
RE_BATCH = re.compile(r'\[batch (\d+): compile=([\d.]+)ms train=([\d.]+)ms \(([\d.]+)ms/step\) compiles=(\d+)\]')
RE_TIMING = re.compile(r'ane=([\d.]+) io=([\d.]+) cls=([\d.]+) elem=([\d.]+) rms=([\d.]+) cblas_wait=([\d.]+)')
RE_RESTART = re.compile(r'\[exec\(\) restart step (\d+)')
@ -323,6 +323,10 @@ def parse_line(line):
m = RE_STEP.search(line)
if m:
S.step, S.loss = int(m[1]), float(m[2])
if m[3]:
S.training['lr'] = m[3]
if m[4]:
S.ms_per_step = float(m[4])
S.loss_history.append((S.step, S.loss))
S.best_loss = min(S.best_loss, S.loss)
return
@ -659,10 +663,19 @@ def set_nonblock(fd):
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
def spawn_training(resume=False, steps=10000):
def spawn_training(resume=False, steps=10000, dynamic=False, scratch=False, lr=None, accum=None):
if dynamic:
cmd = 'cd training_dynamic && make 2>&1 && ./train'
else:
cmd = 'make train_large 2>&1 && ./train_large'
if resume:
cmd += ' --resume'
if scratch and dynamic:
cmd += ' --scratch'
if lr is not None:
cmd += f' --lr {lr}'
if accum is not None:
cmd += f' --accum {accum}'
cmd += f' --steps {steps}'
proc = subprocess.Popen(
['bash', '-c', cmd],
@ -684,6 +697,10 @@ def spawn_powermetrics():
def main():
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
parser.add_argument('--dynamic', action='store_true', help='Use v2 dynamic weight pipeline (training_dynamic/)')
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
parser.add_argument('--lr', type=float, default=None, help='Learning rate')
parser.add_argument('--accum', type=int, default=None, help='Gradient accumulation steps')
parser.add_argument('--infinite', action='store_true', help='Train indefinitely')
parser.add_argument('--no-powermetrics', action='store_true')
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
@ -697,7 +714,8 @@ def main():
term = Terminal()
procs = []
train_proc = spawn_training(resume=args.resume, steps=args.steps)
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
scratch=args.scratch, lr=args.lr, accum=args.accum)
S.train_pid = train_proc.pid
procs.append(train_proc)
@ -837,7 +855,8 @@ def main():
if train_proc:
train_proc.terminate()
train_proc.wait()
train_proc = spawn_training(resume=True, steps=args.steps)
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
lr=args.lr, accum=args.accum)
S.train_pid = train_proc.pid
procs = [p for p in procs if p.poll() is None]
procs.append(train_proc)

View File

@ -0,0 +1,333 @@
// test_dynamic_matmul.m Benchmark dynamic matmul on ANE (no recompile)
// Layout: input [1, D, 1, S+D] activations in sp[0:S], weight rows in sp[S:S+D]
// MIL: slice reshape matmul reshape output
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach_time.h>
#include <arm_neon.h>
#include <Accelerate/Accelerate.h>
#include "stories_io.h"
// Generate MIL for y = x @ W where both come from input IOSurface
// Input: [1, IC, 1, SEQ+OC] fp32
// sp[0:SEQ] = activations x[IC, SEQ]
// sp[SEQ:SEQ+OC] = weight W[IC, OC] (each channel d holds W[d, :])
// Output: [1, OC, 1, SEQ] fp32
static NSString *gen_dynamic_matmul_mil(int ic, int oc, int seq) {
NSMutableString *m = [NSMutableString string];
[m appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
int sp_total = seq + oc;
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp_total];
// Cast to fp16
[m appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", ic, sp_total];
// Slice activations [1, IC, 1, SEQ]
[m appendString:@" tensor<int32, [4]> ba = const()[name = string(\"ba\"), val = tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sa = const()[name = string(\"sa\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", ic, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", ic, seq];
// Slice weight [1, IC, 1, OC]
[m appendFormat:@" tensor<int32, [4]> bw = const()[name = string(\"bw\"), val = tensor<int32, [4]>([0,0,0,%d])];\n", seq];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name = string(\"sw\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", ic, oc];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> wt = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"wt\")];\n", ic, oc];
// Reshape act: [1,IC,1,SEQ] [1,1,IC,SEQ] transpose [1,1,SEQ,IC]
[m appendFormat:@" tensor<int32, [4]> ra = const()[name = string(\"ra\"), val = tensor<int32, [4]>([1,1,%d,%d])];\n", ic, seq];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", ic, seq];
[m appendString:@" tensor<int32, [4]> pm = const()[name = string(\"pm\"), val = tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a3 = transpose(perm=pm,x=a2)[name=string(\"a3\")];\n", seq, ic];
// Reshape weight: [1,IC,1,OC] [1,1,IC,OC]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name = string(\"rw\"), val = tensor<int32, [4]>([1,1,%d,%d])];\n", ic, oc];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W = reshape(shape=rw,x=wt)[name=string(\"W\")];\n", ic, oc];
// matmul: [1,1,SEQ,IC] @ [1,1,IC,OC] [1,1,SEQ,OC]
[m appendString:@" bool bF = const()[name = string(\"bF\"), val = bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yh = matmul(transpose_x=bF,transpose_y=bF,x=a3,y=W)[name=string(\"mm\")];\n", seq, oc];
// Reshape+transpose back: [1,1,SEQ,OC] transpose [1,1,OC,SEQ] reshape [1,OC,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=yh)[name=string(\"yt\")];\n", oc, seq];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name = string(\"ro\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", oc, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", oc, seq];
// Cast back to fp32
[m appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype = to32, x = yr)[name = string(\"cout\")];\n", oc, seq];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// Tiled version: splits OC into tiles, each tile is a separate kernel
// For W[IC, OC], tile along OC: each tile handles W[:, t*T:(t+1)*T]
// Input per tile: [1, IC, 1, SEQ+T]
// Output per tile: [1, T, 1, SEQ]
typedef struct {
Kern **tiles;
int n_tiles, tile_oc, ic, oc, seq;
} TiledMatmul;
static TiledMatmul *compile_tiled_matmul(int ic, int oc, int tile_oc, int seq) {
TiledMatmul *tm = (TiledMatmul*)calloc(1, sizeof(TiledMatmul));
tm->ic = ic; tm->oc = oc; tm->seq = seq; tm->tile_oc = tile_oc;
tm->n_tiles = (oc + tile_oc - 1) / tile_oc;
tm->tiles = (Kern**)calloc(tm->n_tiles, sizeof(Kern*));
for (int t = 0; t < tm->n_tiles; t++) {
int this_oc = (t == tm->n_tiles-1 && oc % tile_oc) ? (oc % tile_oc) : tile_oc;
NSString *mil = gen_dynamic_matmul_mil(ic, this_oc, seq);
int in_bytes = ic * (seq + this_oc) * 4;
int out_bytes = this_oc * seq * 4;
tm->tiles[t] = compile_kern_mil_w(mil, @{}, in_bytes, out_bytes);
if (!tm->tiles[t]) { printf("Tile %d compile FAIL\n", t); return NULL; }
}
return tm;
}
// Write activations + weight tile into IOSurface
// act: [IC, SEQ] column-major (channel-first)
// W: [IC, OC] full weight matrix, we extract the tile
static void write_tile_input(TiledMatmul *tm, int tile_idx, const float *act, const float *W) {
Kern *k = tm->tiles[tile_idx];
int ic = tm->ic, seq = tm->seq, toc = tm->tile_oc;
int oc_off = tile_idx * toc;
int this_oc = (tile_idx == tm->n_tiles-1 && tm->oc % toc) ? (tm->oc % toc) : toc;
IOSurfaceLock(k->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(k->ioIn);
// Activations: buf[d * (seq+this_oc) + t] = act[d * seq + t]
for (int d = 0; d < ic; d++) {
memcpy(buf + d*(seq+this_oc), act + d*seq, seq*sizeof(float));
// Weight: buf[d * (seq+this_oc) + seq + c] = W[d * oc + oc_off + c]
for (int c = 0; c < this_oc; c++)
buf[d*(seq+this_oc) + seq + c] = W[d*tm->oc + oc_off + c];
}
IOSurfaceUnlock(k->ioIn, 0, NULL);
}
// Read tile output into full output buffer
static void read_tile_output(TiledMatmul *tm, int tile_idx, float *out) {
Kern *k = tm->tiles[tile_idx];
int seq = tm->seq, toc = tm->tile_oc;
int oc_off = tile_idx * toc;
int this_oc = (tile_idx == tm->n_tiles-1 && tm->oc % toc) ? (tm->oc % toc) : toc;
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float *obuf = (float*)IOSurfaceGetBaseAddress(k->ioOut);
for (int c = 0; c < this_oc; c++)
memcpy(out + (oc_off+c)*seq, obuf + c*seq, seq*sizeof(float));
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
}
int main(int argc, char **argv) {
@autoreleasepool {
mach_timebase_info(&g_tb);
ane_init();
// === Test 1: Single 64×64 dynamic matmul (correctness) ===
printf("=== Test 1: 64×64 dynamic matmul correctness ===\n");
{
int D = 64, S = 64;
NSString *mil = gen_dynamic_matmul_mil(D, D, S);
int in_b = D * (S+D) * 4, out_b = D * S * 4;
Kern *k = compile_kern_mil_w(mil, @{}, in_b, out_b);
if (!k) { printf("FAIL\n"); return 1; }
// Identity test
IOSurfaceLock(k->ioIn, 0, NULL);
float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn);
memset(inp, 0, in_b);
for (int d = 0; d < D; d++)
for (int s = 0; s < S; s++)
inp[d*(S+D) + s] = (float)(d*S + s) * 0.001f;
for (int d = 0; d < D; d++)
for (int c = 0; c < D; c++)
inp[d*(S+D) + S + c] = (d == c) ? 1.0f : 0.0f;
IOSurfaceUnlock(k->ioIn, 0, NULL);
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out = (float*)IOSurfaceGetBaseAddress(k->ioOut);
float me = 0;
for (int d = 0; d < D; d++)
for (int s = 0; s < S; s++) {
float e = fabsf(out[d*S+s] - inp[d*(S+D)+s]);
if (e > me) me = e;
}
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Identity: max_err=%.6f %s\n", me, me < 0.01 ? "PASS" : "FAIL");
// 2× test
IOSurfaceLock(k->ioIn, 0, NULL);
for (int d = 0; d < D; d++)
for (int c = 0; c < D; c++)
inp[d*(S+D) + S + c] = (d == c) ? 2.0f : 0.0f;
IOSurfaceUnlock(k->ioIn, 0, NULL);
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float sr = 0; int cnt = 0;
for (int i = 0; i < D*S; i++)
if (fabsf(inp[i/(S)*((S)+D) + i%S]) > 0.001f) { sr += out[i]/inp[i/S*(S+D)+i%S]; cnt++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("2× W: ratio=%.3f %s\n\n", cnt?sr/cnt:0, fabsf(sr/cnt-2.0f)<0.1?"PASS":"FAIL");
free_kern(k);
}
// === Test 2: 768×768 single kernel (if it compiles) ===
printf("=== Test 2: 768×768 single dynamic matmul ===\n");
{
int D = 768, S = 256;
int sp_total = S + D; // 256 + 768 = 1024
int in_b = D * sp_total * 4; // 768 * 1024 * 4 = 3.1MB
int out_b = D * S * 4; // 768 * 256 * 4 = 786KB
printf("IOSurface: in=%.1fMB out=%.1fKB\n", in_b/1e6, out_b/1e3);
NSString *mil = gen_dynamic_matmul_mil(D, D, S);
uint64_t t0 = mach_absolute_time();
Kern *k = compile_kern_mil_w(mil, @{}, in_b, out_b);
double compile_ms = tb_ms(mach_absolute_time() - t0);
if (!k) { printf("768×768 compile FAIL\n"); }
else {
printf("Compile: %.1fms\n", compile_ms);
// Random weights
float *act = (float*)calloc(D*S, sizeof(float));
float *W = (float*)calloc(D*D, sizeof(float));
for (int i = 0; i < D*S; i++) act[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.1f;
for (int i = 0; i < D*D; i++) W[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.01f;
// Write to IOSurface
IOSurfaceLock(k->ioIn, 0, NULL);
float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn);
for (int d = 0; d < D; d++) {
memcpy(inp + d*(S+D), act + d*S, S*4);
memcpy(inp + d*(S+D) + S, W + d*D, D*4);
}
IOSurfaceUnlock(k->ioIn, 0, NULL);
// Warmup
for (int i = 0; i < 3; i++) ane_eval(k);
// Benchmark
int iters = 50;
t0 = mach_absolute_time();
for (int i = 0; i < iters; i++) ane_eval(k);
double total_ms = tb_ms(mach_absolute_time() - t0);
double per_eval = total_ms / iters;
double flops = 2.0 * D * D * S; // matmul FLOPs
double gflops = flops / (per_eval * 1e6);
printf("768×768×256 matmul: %.3fms/eval %.1f GFLOP/s\n", per_eval, gflops);
// Benchmark with IO write (simulating weight update)
t0 = mach_absolute_time();
for (int i = 0; i < iters; i++) {
IOSurfaceLock(k->ioIn, 0, NULL);
float *p = (float*)IOSurfaceGetBaseAddress(k->ioIn);
for (int d = 0; d < D; d++)
memcpy(p + d*(S+D) + S, W + d*D, D*4);
IOSurfaceUnlock(k->ioIn, 0, NULL);
ane_eval(k);
}
total_ms = tb_ms(mach_absolute_time() - t0);
per_eval = total_ms / iters;
gflops = flops / (per_eval * 1e6);
printf("With weight IO: %.3fms/eval %.1f GFLOP/s\n", per_eval, gflops);
free(act); free(W); free_kern(k);
}
}
// === Test 3: Tiled matmul benchmark ===
int tile_sizes[] = {64, 128, 256, 384, 768};
int n_tiles_test = sizeof(tile_sizes)/sizeof(tile_sizes[0]);
printf("\n=== Test 3: Tiled 768×768 matmul (varying tile_oc) ===\n");
printf("%-10s %-8s %-10s %-12s %-10s\n", "tile_oc", "tiles", "compile", "eval/ms", "GFLOP/s");
{
int D = 768, S = 256;
float *act = (float*)calloc(D*S, sizeof(float));
float *W = (float*)calloc(D*D, sizeof(float));
float *out_full = (float*)calloc(D*S, sizeof(float));
for (int i = 0; i < D*S; i++) act[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.1f;
for (int i = 0; i < D*D; i++) W[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.01f;
for (int ti = 0; ti < n_tiles_test; ti++) {
int T = tile_sizes[ti];
if (T > D) continue;
uint64_t t0 = mach_absolute_time();
TiledMatmul *tm = compile_tiled_matmul(D, D, T, S);
double compile_ms = tb_ms(mach_absolute_time() - t0);
if (!tm) { printf("%-10d FAIL\n", T); continue; }
// Warmup
for (int w = 0; w < 2; w++) {
for (int t = 0; t < tm->n_tiles; t++) {
write_tile_input(tm, t, act, W);
ane_eval(tm->tiles[t]);
}
}
// Benchmark (with IO)
int iters = 20;
t0 = mach_absolute_time();
for (int i = 0; i < iters; i++) {
for (int t = 0; t < tm->n_tiles; t++) {
write_tile_input(tm, t, act, W);
ane_eval(tm->tiles[t]);
read_tile_output(tm, t, out_full);
}
}
double total_ms = tb_ms(mach_absolute_time() - t0);
double per_matmul = total_ms / iters;
double flops = 2.0 * D * D * S;
double gflops = flops / (per_matmul * 1e6);
printf("%-10d %-8d %-10.0fms %-12.3fms %-10.1f\n",
T, tm->n_tiles, compile_ms, per_matmul, gflops);
for (int t = 0; t < tm->n_tiles; t++) free_kern(tm->tiles[t]);
free(tm->tiles); free(tm);
}
// === Correctness check: compare with cblas ===
printf("\n=== Correctness: dynamic matmul vs cblas_sgemm ===\n");
{
int T = 768; // full, no tiling
TiledMatmul *tm = compile_tiled_matmul(D, D, T, S);
if (tm) {
write_tile_input(tm, 0, act, W);
ane_eval(tm->tiles[0]);
read_tile_output(tm, 0, out_full);
// Reference: cblas y = act^T @ W y[s,oc] = sum_d act[d,s]*W[d,oc]
// act is [D,S] col-major, W is [D,D] row-major
// We want out[oc,s] = sum_d act[d,s] * W[d,oc]
// = W^T @ act where W^T is [D,D] and act is [D,S] out is [D,S]
float *ref = (float*)calloc(D*S, sizeof(float));
// out[oc*S+s] = sum_d W[d*D+oc] * act[d*S+s]
// This is: (W^T) @ act in column-major: M=D,N=S,K=D
// cblas: C = alpha*A*B + beta*C
// A=W^T [D×D], B=act [D×S], C=ref [D×S]
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
D, S, D, 1.0f, W, D, act, D, 0.0f, ref, D);
float me = 0;
for (int i = 0; i < D*S; i++) {
float e = fabsf(out_full[i] - ref[i]);
if (e > me) me = e;
}
printf("vs cblas: max_err=%.6f %s\n", me, me < 1.0 ? "PASS" : "FAIL");
free(ref);
for (int t = 0; t < tm->n_tiles; t++) free_kern(tm->tiles[t]);
free(tm->tiles); free(tm);
}
}
free(act); free(W); free(out_full);
}
// === Summary for training ===
printf("\n=== Summary ===\n");
printf("Stories110M: 12 layers × 10 matmuls/layer = 120 matmuls/step\n");
printf("Sizes: Wq/Wk/Wv/Wo [768,768], W1/W3 [2048,768], W2 [768,2048]\n");
printf("With dynamic weights: compile once, update IOSurface every step\n");
printf("\nDone.\n");
}
return 0;
}

View File

@ -0,0 +1,450 @@
// test_weight_patch.m Test whether ANE weights can be patched after compile
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach.h>
#import <mach/mach_time.h>
#import <mach/vm_map.h>
#include <arm_neon.h>
#include <Accelerate/Accelerate.h>
#include "stories_io.h"
// MIL: fp32 in cast fp16 conv cast fp32 out (matches inmem_peak.m pattern)
static NSString *gen_conv_mil(int ic, int oc, int sp) {
NSMutableString *m = [NSMutableString string];
[m appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp];
[m appendString:
@" string pt = const()[name = string(\"pt\"), val = string(\"valid\")];\n"
" tensor<int32, [2]> st = const()[name = string(\"st\"), val = tensor<int32, [2]>([1, 1])];\n"
" tensor<int32, [4]> pd = const()[name = string(\"pd\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
" tensor<int32, [2]> dl = const()[name = string(\"dl\"), val = tensor<int32, [2]>([1, 1])];\n"
" int32 gr = const()[name = string(\"gr\"), val = int32(1)];\n"
" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cast_in\")];\n", ic, sp];
[m appendFormat:@" tensor<fp16, [%d, %d, 1, 1]> W = const()[name = string(\"W\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = string(\"@model_path/weights/w.bin\"), offset = uint64(64)))];\n",
oc, ic, oc, ic];
[m appendFormat:@" tensor<fp16, [1, %d, 1, %d]> yh = conv(dilations = dl, groups = gr, pad = pd, pad_type = pt, strides = st, weight = W, x = xh)"
"[name = string(\"conv\")];\n", oc, sp];
[m appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1, %d, 1, %d]> y = cast(dtype = to32, x = yh)[name = string(\"cast_out\")];\n", oc, sp];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
int main(int argc, char **argv) {
@autoreleasepool {
mach_timebase_info(&g_tb);
ane_init();
int IC = 256, OC = 256, SP = 64;
int io_bytes = IC * SP * 4; // fp32
// Identity weight
float *W_id = (float*)calloc(OC*IC, sizeof(float));
for (int i = 0; i < IC; i++) W_id[i*IC+i] = 1.0f;
NSString *mil = gen_conv_mil(IC, OC, SP);
NSDictionary *wd = @{@"@model_path/weights/w.bin": @{@"offset":@0, @"data":build_blob(W_id, OC, IC)}};
printf("=== Compiling conv %dx%d sp=%d ===\n", OC, IC, SP);
Kern *k = compile_kern_mil_w(mil, wd, io_bytes, io_bytes);
if (!k) { printf("COMPILE FAILED\n"); free(W_id); return 1; }
printf("Compile OK!\n");
// Write fp32 input
IOSurfaceLock(k->ioIn, 0, NULL);
float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn);
for (int i = 0; i < IC*SP; i++) inp[i] = (i % 100) * 0.01f;
IOSurfaceUnlock(k->ioIn, 0, NULL);
// Eval with identity
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out = (float*)IOSurfaceGetBaseAddress(k->ioOut);
printf("In: [%.3f, %.3f, %.3f, %.3f]\n", inp[0], inp[1], inp[2], inp[3]);
printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float max_err = 0;
for (int i = 0; i < OC*SP; i++) {
float err = fabsf(out[i] - inp[i]);
if (err > max_err) max_err = err;
}
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Identity max_err=%.6f %s\n\n", max_err, max_err < 0.1 ? "PASS" : "FAIL");
// === Approach 1: Patch weight on disk, unload+reload ===
printf("=== Approach 1: Disk patch + unload/reload ===\n");
float *W_2x = (float*)calloc(OC*IC, sizeof(float));
for (int i = 0; i < IC; i++) W_2x[i*IC+i] = 2.0f;
[build_blob(W_2x, OC, IC) writeToFile:
[(__bridge NSString*)k->tmpDir stringByAppendingPathComponent:@"weights/w.bin"] atomically:YES];
id mdl = (__bridge id)k->model;
NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
e = nil;
BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e);
printf("Reload: %s\n", ok?"OK":"FAIL");
if (ok) {
// Re-create request after reload
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
CFRelease(k->request);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0));
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr = 0; int cnt = 0;
for (int i = 0; i < OC*SP; i++)
if (fabsf(inp[i]) > 0.01f) { sr += out[i]/inp[i]; cnt++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Ratio: %.3f (2.0=patched, 1.0=cached)\n\n", cnt>0?sr/cnt:0);
}
// === Approach 2: Memory scan ===
printf("=== Approach 2: Memory scan ===\n");
uint16_t pat1[8] = {0x3C00, 0, 0, 0, 0, 0, 0, 0};
uint16_t pat2[8] = {0x4000, 0, 0, 0, 0, 0, 0, 0};
mach_port_t task = mach_task_self();
vm_address_t addr = 0; vm_size_t sz; natural_t depth = 1;
int f1 = 0, f2 = 0;
while (1) {
struct vm_region_submap_info_64 info;
mach_msg_type_number_t count = VM_REGION_SUBMAP_INFO_COUNT_64;
if (vm_region_recurse_64(task, &addr, &sz, &depth, (vm_region_recurse_info_t)&info, &count) != KERN_SUCCESS) break;
if (info.is_submap) { depth++; continue; }
if (!(info.protection & VM_PROT_READ) || sz < (size_t)(OC*IC*2)) { addr += sz; continue; }
uint8_t *base = (uint8_t*)addr;
for (size_t off = 0; off + OC*IC*2 <= sz; off += 2) {
int w = 0;
if (memcmp(base+off, pat1, 16) == 0) w = 1;
else if (memcmp(base+off, pat2, 16) == 0) w = 2;
if (!w) continue;
uint16_t *p = (uint16_t*)(base+off), diag = (w==1)?0x3C00:0x4000;
int ok2 = 1;
for (int r = 0; r < OC && ok2; r++)
for (int c = 0; c < IC && ok2; c++)
if (p[r*IC+c] != ((r==c)?diag:0)) ok2 = 0;
if (!ok2) continue;
if (w==1) f1++; else f2++;
printf(" FOUND %dx @%p prot=%d/%d %s\n", w, (void*)(addr+off),
info.protection, info.max_protection, (info.protection&VM_PROT_WRITE)?"WR":"RO");
}
addr += sz;
}
printf("Found: 1x=%d 2x=%d\n", f1, f2);
// Now patch ALL found weight patterns to 3× and re-eval
if (f1 > 0 || f2 > 0) {
printf("Patching all found patterns to 3x identity...\n");
addr = 0; depth = 1;
while (1) {
struct vm_region_submap_info_64 info2;
mach_msg_type_number_t count2 = VM_REGION_SUBMAP_INFO_COUNT_64;
if (vm_region_recurse_64(task, &addr, &sz, &depth, (vm_region_recurse_info_t)&info2, &count2) != KERN_SUCCESS) break;
if (info2.is_submap) { depth++; continue; }
if (!(info2.protection & VM_PROT_READ) || sz < (size_t)(OC*IC*2)) { addr += sz; continue; }
uint8_t *base2 = (uint8_t*)addr;
for (size_t off = 0; off + OC*IC*2 <= sz; off += 2) {
int w2 = 0;
if (memcmp(base2+off, pat1, 16) == 0) w2 = 1;
else if (memcmp(base2+off, pat2, 16) == 0) w2 = 2;
if (!w2) continue;
uint16_t *p2 = (uint16_t*)(base2+off), diag2 = (w2==1)?0x3C00:0x4000;
int ok3 = 1;
for (int r = 0; r < OC && ok3; r++)
for (int c = 0; c < IC && ok3; c++)
if (p2[r*IC+c] != ((r==c)?diag2:0)) ok3 = 0;
if (!ok3) continue;
if (info2.protection & VM_PROT_WRITE) {
printf(" Patching %dx @%p to 3x\n", w2, (void*)(addr+off));
for (int r = 0; r < OC; r++)
for (int c = 0; c < IC; c++)
p2[r*IC+c] = (r==c) ? 0x4200 : 0; // fp16(3.0)
}
}
addr += sz;
}
printf("\n=== Eval after memory patch (expect 3x) ===\n");
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr2 = 0; int cnt2 = 0;
for (int i = 0; i < OC*SP; i++)
if (fabsf(inp[i]) > 0.01f) { sr2 += out[i]/inp[i]; cnt2++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Ratio: %.3f (3.0=mem patch works!, 1.0=ANE uses SRAM copy)\n", cnt2>0?sr2/cnt2:0);
}
printf("\n");
// === Approach 3: Explore classes ===
printf("=== ANE classes ===\n");
const char *cn[] = {"_ANEWeight", "_ANEProgramForEvaluation", "_ANEChainingRequest", NULL};
for (int i = 0; cn[i]; i++) {
Class cls = NSClassFromString([NSString stringWithUTF8String:cn[i]]);
if (!cls) { printf("%s: NOT FOUND\n", cn[i]); continue; }
printf("%s:\n", cn[i]);
unsigned int mc = 0; Method *ms = class_copyMethodList(cls, &mc);
for (unsigned j = 0; j < mc; j++) printf(" - %s\n", sel_getName(method_getName(ms[j])));
free(ms);
mc = 0; ms = class_copyMethodList(object_getClass(cls), &mc);
for (unsigned j = 0; j < mc; j++) printf(" + %s\n", sel_getName(method_getName(ms[j])));
free(ms); printf("\n");
}
@try { printf("programHandle: %s\n", [[[mdl valueForKey:@"programHandle"] description] UTF8String]); } @catch(id x) {}
@try { printf("intermediateBufferHandle: %s\n", [[[mdl valueForKey:@"intermediateBufferHandle"] description] UTF8String]); } @catch(id x) {}
// === Approach 4: _ANEWeight + updateWeightURL ===
printf("\n=== Approach 4: _ANEWeight API ===\n");
Class AW = NSClassFromString(@"_ANEWeight");
if (AW) {
// Write 5× identity weights to a new file
float *W_5x = (float*)calloc(OC*IC, sizeof(float));
for (int i = 0; i < IC; i++) W_5x[i*IC+i] = 5.0f;
NSString *wpath = [NSTemporaryDirectory() stringByAppendingPathComponent:@"patched_w.bin"];
[build_blob(W_5x, OC, IC) writeToFile:wpath atomically:YES];
free(W_5x);
NSURL *wurl = [NSURL fileURLWithPath:wpath];
id wobj = ((id(*)(Class,SEL,id,id))objc_msgSend)(AW,
@selector(weightWithSymbolAndURL:weightURL:), @"W", wurl);
printf(" _ANEWeight: %s\n", wobj ? [[wobj description] UTF8String] : "nil");
if (wobj) {
printf(" weightSymbol: %s\n", [((id(*)(id,SEL))objc_msgSend)(wobj, @selector(weightSymbol)) UTF8String]);
printf(" weightURL: %s\n", [[((id(*)(id,SEL))objc_msgSend)(wobj, @selector(weightURL)) description] UTF8String]);
}
// Try to pass as weightsBuffer in request
printf("\n Trying weightsBuffer in request...\n");
id wI2 = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO2 = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
// Try passing weight array as weightsBuffer
if (wobj) {
CFRelease(k->request);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI2], @[@0], @[wO2], @[@0], @[wobj], nil, @0));
printf(" Request with weightsBuffer created\n");
@try {
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr3 = 0; int cnt3 = 0;
for (int i2 = 0; i2 < OC*SP; i2++)
if (fabsf(inp[i2]) > 0.01f) { sr3 += out[i2]/inp[i2]; cnt3++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Ratio: %.3f (5.0=weightsBuffer works!)\n", cnt3>0?sr3/cnt3:0);
} @catch(NSException *ex) {
printf(" Eval exception: %s\n", [[ex description] UTF8String]);
}
}
// Also try IOSurface as weightsBuffer
printf("\n Trying IOSurface as weightsBuffer...\n");
IOSurfaceRef wSurf = make_surface(OC*IC*2); // fp16 weights
IOSurfaceLock(wSurf, 0, NULL);
_Float16 *wfp16 = (_Float16*)IOSurfaceGetBaseAddress(wSurf);
for (int r = 0; r < OC; r++)
for (int c2 = 0; c2 < IC; c2++)
wfp16[r*IC+c2] = (r==c2) ? (_Float16)7.0f : (_Float16)0.0f; // 7× identity
IOSurfaceUnlock(wSurf, 0, NULL);
id wSurfObj = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), wSurf);
CFRelease(k->request);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI2], @[@0], @[wO2], @[@0], wSurfObj, nil, @0));
@try {
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr4 = 0; int cnt4 = 0;
for (int i3 = 0; i3 < OC*SP; i3++)
if (fabsf(inp[i3]) > 0.01f) { sr4 += out[i3]/inp[i3]; cnt4++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Ratio: %.3f (7.0=IOSurface weights work!)\n", cnt4>0?sr4/cnt4:0);
} @catch(NSException *ex) {
printf(" Eval exception: %s\n", [[ex description] UTF8String]);
}
CFRelease(wSurf);
}
// === Approach 5: Weights packed into input IOSurface (fp16 with cast) ===
printf("\n=== Approach 5: Dynamic weights via input IOSurface ===\n");
// Element-wise mul: x * w where both come from input
// Input [1, IC*2, 1, SP] fp32 cast fp16 slice mul cast fp32
{
int C5 = IC;
NSMutableString *m5 = [NSMutableString string];
[m5 appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
[m5 appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", C5*2, SP];
[m5 appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m5 appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", C5*2, SP];
[m5 appendFormat:@" tensor<int32, [4]> b0 = const()[name = string(\"b0\"), val = tensor<int32, [4]>([0,0,0,0])];\n"];
[m5 appendFormat:@" tensor<int32, [4]> s0 = const()[name = string(\"s0\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", C5, SP];
[m5 appendFormat:@" tensor<fp16, [1,%d,1,%d]> data = slice_by_size(x=xh,begin=b0,size=s0)[name=string(\"data\")];\n", C5, SP];
[m5 appendFormat:@" tensor<int32, [4]> b1 = const()[name = string(\"b1\"), val = tensor<int32, [4]>([0,%d,0,0])];\n", C5];
[m5 appendFormat:@" tensor<fp16, [1,%d,1,%d]> wt = slice_by_size(x=xh,begin=b1,size=s0)[name=string(\"wt\")];\n", C5, SP];
[m5 appendFormat:@" tensor<fp16, [1,%d,1,%d]> yh = mul(x=data,y=wt)[name=string(\"mul\")];\n", C5, SP];
[m5 appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m5 appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype = to32, x = yh)[name = string(\"cout\")];\n", C5, SP];
[m5 appendString:@" } -> (y);\n}\n"];
int io5_in = C5*2*SP*4;
int io5_out = C5*SP*4;
Kern *k5 = compile_kern_mil_w(m5, @{}, io5_in, io5_out);
if (k5) {
printf("Compile OK!\n");
IOSurfaceLock(k5->ioIn, 0, NULL);
float *in5 = (float*)IOSurfaceGetBaseAddress(k5->ioIn);
for (int i = 0; i < C5*SP; i++) in5[i] = (i%100)*0.01f;
for (int i = 0; i < C5*SP; i++) in5[C5*SP+i] = 2.0f;
IOSurfaceUnlock(k5->ioIn, 0, NULL);
ane_eval(k5);
IOSurfaceLock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out5 = (float*)IOSurfaceGetBaseAddress(k5->ioOut);
printf("data=[%.3f,%.3f,%.3f], w=2.0 → out=[%.3f,%.3f,%.3f]\n",
in5[0],in5[1],in5[2], out5[0],out5[1],out5[2]);
IOSurfaceUnlock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
// Change weight dynamically NO recompile!
IOSurfaceLock(k5->ioIn, 0, NULL);
for (int i = 0; i < C5*SP; i++) in5[C5*SP+i] = 5.0f;
IOSurfaceUnlock(k5->ioIn, 0, NULL);
ane_eval(k5);
IOSurfaceLock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("w=5.0 → out=[%.3f,%.3f,%.3f] (expect 5×)\n", out5[0],out5[1],out5[2]);
IOSurfaceUnlock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
free_kern(k5);
} else printf("Compile FAILED\n");
}
// === Approach 6: matmul with dynamic weights from input ===
printf("\n=== Approach 6: matmul with dynamic W from input ===\n");
// Pack x[1,D,S,1] and W[1,D,1,D] into input, then reshape+matmul
// Input shape: [1, D+D*D, 1, S] first D channels=activations, rest=weight matrix flattened
// Actually, matmul needs [1,H,S,D] shapes. Let's try:
// Input: [1, D*(S+D), 1, 1] reshaped as needed
// Simpler: just test matmul with two sliced inputs
{
int D6 = 64, S6 = 64; // small for test
// Input: [1, D6+D6, S6, D6] but that's 4D...
// Actually ANE matmul works on [1,H,M,K] @ [1,H,K,N] [1,H,M,N]
// Let's pack x[1,1,S6,D6] and W[1,1,D6,D6] into [1,2,S6,D6]
// Then slice matmul
NSMutableString *m6 = [NSMutableString string];
[m6 appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
// Input: [1, D6+D6, 1, S6*D6] flatten everything, then reshape
// Actually simplest: two separate regions in channel dim
// x_data: [1, D6, 1, S6] and W: [1, D6*D6, 1, 1]
// Total input channels: D6 + D6*D6
int total_ch = D6 + D6*D6;
[m6 appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", total_ch, S6];
[m6 appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m6 appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", total_ch, S6];
// Slice activations: [1, D6, 1, S6]
[m6 appendFormat:@" tensor<int32, [4]> b0 = const()[name = string(\"b0\"), val = tensor<int32, [4]>([0,0,0,0])];\n"];
[m6 appendFormat:@" tensor<int32, [4]> sa = const()[name = string(\"sa\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=b0,size=sa)[name=string(\"act\")];\n", D6, S6];
// Slice weight: [1, D6*D6, 1, S6] but we only need [D6, D6] reshape
[m6 appendFormat:@" tensor<int32, [4]> bw = const()[name = string(\"bw\"), val = tensor<int32, [4]>([0,%d,0,0])];\n", D6];
[m6 appendFormat:@" tensor<int32, [4]> sw = const()[name = string(\"sw\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", D6*D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,%d]> wf = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"wf\")];\n", D6*D6, S6];
// Reshape weight to [1, D6, D6, S6] for matmul-like operation
// Actually for conv: weight needs to be [OC, IC, 1, 1] const. Can't use dynamic weight with conv.
// For matmul: need [1, 1, D6, D6] or similar
// Let's try: reshape wf to [1, D6, D6, S6], take first slice [:,:,:,0] no, that's hard
// Simpler: reshape to [D6, D6] and use matmul
// But matmul expects specific ranks... let me try:
[m6 appendFormat:@" tensor<int32, [4]> ws = const()[name = string(\"ws\"), val = tensor<int32, [4]>([1, 1, %d, %d])];\n", D6, D6];
// Only take first column of wf to get [1, D6*D6, 1, 1]
[m6 appendFormat:@" tensor<int32, [4]> sw1 = const()[name = string(\"sw1\"), val = tensor<int32, [4]>([1,%d,1,1])];\n", D6*D6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,1]> wf1 = slice_by_size(x=wf,begin=b0,size=sw1)[name=string(\"wf1\")];\n", D6*D6];
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> W = reshape(shape=ws,x=wf1)[name=string(\"W\")];\n", D6, D6];
// Reshape act to [1, 1, S6, D6] for matmul
[m6 appendFormat:@" tensor<int32, [4]> as2 = const()[name = string(\"as2\"), val = tensor<int32, [4]>([1, 1, %d, %d])];\n", D6, S6];
[m6 appendFormat:@" tensor<int32, [4]> pm = const()[name = string(\"pm\"), val = tensor<int32, [4]>([0, 1, 3, 2])];\n"];
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=as2,x=act)[name=string(\"a2\")];\n", D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> a3 = transpose(perm=pm,x=a2)[name=string(\"a3\")];\n", S6, D6];
// matmul: [1,1,S6,D6] @ [1,1,D6,D6] [1,1,S6,D6]
[m6 appendString:@" bool bF = const()[name = string(\"bF\"), val = bool(false)];\n"];
[m6 appendFormat:@" tensor<fp16, [1, 1, %d, %d]> yh = matmul(transpose_x = bF, transpose_y = bF, x = a3, y = W)[name = string(\"mm\")];\n", S6, D6];
// Reshape back to [1, D6, 1, S6]
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=yh)[name=string(\"yt\")];\n", D6, S6];
[m6 appendFormat:@" tensor<int32, [4]> os = const()[name = string(\"os\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=os,x=yt)[name=string(\"yr\")];\n", D6, S6];
[m6 appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m6 appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype = to32, x = yr)[name = string(\"cout\")];\n", D6, S6];
[m6 appendString:@" } -> (y);\n}\n"];
int io6_in = total_ch * S6 * 4;
int io6_out = D6 * S6 * 4;
Kern *k6 = compile_kern_mil_w(m6, @{}, io6_in, io6_out);
if (k6) {
printf("Dynamic matmul compile OK!\n");
// Set up: identity W, ramp input
IOSurfaceLock(k6->ioIn, 0, NULL);
float *in6 = (float*)IOSurfaceGetBaseAddress(k6->ioIn);
memset(in6, 0, io6_in);
// Activations: [D6, S6] in channel-first layout
for (int d = 0; d < D6; d++)
for (int s = 0; s < S6; s++)
in6[d*S6+s] = (d*S6+s) * 0.001f;
// Weight: identity matrix [D6, D6] packed in channels D6..D6+D6*D6, only col 0
float *wbase = in6 + D6*S6;
for (int r = 0; r < D6; r++)
for (int c = 0; c < D6; c++)
wbase[(r*D6+c)*S6] = (r==c) ? 1.0f : 0.0f; // only sp=0 matters
IOSurfaceUnlock(k6->ioIn, 0, NULL);
ane_eval(k6);
IOSurfaceLock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out6 = (float*)IOSurfaceGetBaseAddress(k6->ioOut);
printf("Identity W: in=[%.4f,%.4f,%.4f] out=[%.4f,%.4f,%.4f]\n",
in6[0],in6[1],in6[2], out6[0],out6[1],out6[2]);
// Check
float me6 = 0;
for (int i = 0; i < D6*S6; i++) {
float e6 = fabsf(out6[i] - in6[i]);
if (e6 > me6) me6 = e6;
}
IOSurfaceUnlock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("max_err=%.6f %s\n", me6, me6 < 0.1 ? "PASS" : "FAIL");
// Now: 2× identity just change the IOSurface weight, no recompile!
IOSurfaceLock(k6->ioIn, 0, NULL);
for (int r = 0; r < D6; r++)
for (int c = 0; c < D6; c++)
wbase[(r*D6+c)*S6] = (r==c) ? 2.0f : 0.0f;
IOSurfaceUnlock(k6->ioIn, 0, NULL);
ane_eval(k6);
IOSurfaceLock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("2× W: in=[%.4f,%.4f] out=[%.4f,%.4f] (expect 2×)\n",
in6[0],in6[1], out6[0],out6[1]);
IOSurfaceUnlock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
free_kern(k6);
} else printf("Dynamic matmul compile FAILED\n");
}
free_kern(k); free(W_id); free(W_2x);
printf("\nDone.\n");
}
return 0;
}

View File

@ -0,0 +1,9 @@
CC = xcrun clang
CFLAGS = -O2 -framework Foundation -framework IOSurface -framework Accelerate \
-isysroot $(shell xcrun --show-sdk-path) -fobjc-arc
train: train.m config.h io.h cpu_ops.h mil_dynamic.h
$(CC) $(CFLAGS) -o train train.m
clean:
rm -f train

View File

@ -0,0 +1,156 @@
// config.h — Stories110M model config, structs, ANE init
#pragma once
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach_time.h>
#import <Accelerate/Accelerate.h>
#include <math.h>
#include <unistd.h>
#include <dispatch/dispatch.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <arm_neon.h>
// Stories110M config
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define HD (DIM/HEADS)
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
// Weight sizes per layer
#define WQ_SZ (DIM*DIM)
#define WO_SZ (DIM*DIM)
#define W1_SZ (HIDDEN*DIM)
#define W2_SZ (DIM*HIDDEN)
#define W3_SZ (HIDDEN*DIM)
#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
// Attention score channels for SDPA backward
#define SCORE_CH (HEADS*SEQ)
// Per-layer weights
typedef struct {
float *Wq, *Wk, *Wv, *Wo;
float *W1, *W2, *W3;
float *rms_att, *rms_ffn;
} LayerWeights;
// Adam optimizer state
typedef struct { float *m, *v; size_t n; } AdamState;
typedef struct {
AdamState Wq, Wk, Wv, Wo, W1, W2, W3, rms_att, rms_ffn;
} LayerAdam;
// Per-layer activations (saved for backward)
typedef struct {
float *layer_in, *xnorm, *Q, *K, *V, *attn_out, *o_out;
float *x2, *x2norm, *h1, *h3, *silu_out, *ffn_out;
} LayerActs;
// Per-layer gradients
typedef struct {
float *Wq, *Wk, *Wv, *Wo, *W1, *W2, *W3, *rms_att, *rms_ffn;
} LayerGrads;
// ANE kernel handle
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
// Checkpoint header
typedef struct {
int magic, version, step, total_steps;
int n_layers, vocab_size, dim, hidden_dim, n_heads, seq_len;
float lr, loss;
double cum_compile, cum_train, cum_wall;
int cum_steps, cum_batches, adam_t;
int pad[3];
} CkptHdr;
// llama2.c model file header
typedef struct {
int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len;
} Llama2Config;
// Globals
static Class g_D, g_I, g_AR, g_AIO;
static mach_timebase_info_data_t g_tb;
static int g_compile_count = 0;
static void ane_init(void) {
dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW);
g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor");
g_I = NSClassFromString(@"_ANEInMemoryModel");
g_AR = NSClassFromString(@"_ANERequest");
g_AIO= NSClassFromString(@"_ANEIOSurfaceObject");
}
static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; }
// Alloc helpers
static AdamState adam_alloc(size_t n) { AdamState s; s.m=(float*)calloc(n,4); s.v=(float*)calloc(n,4); s.n=n; return s; }
static void adam_free(AdamState *s) { free(s->m); free(s->v); }
static LayerWeights layer_weights_alloc(void) {
LayerWeights w;
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4);
w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4);
w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4);
return w;
}
static void layer_weights_free(LayerWeights *w) {
free(w->Wq);free(w->Wk);free(w->Wv);free(w->Wo);
free(w->W1);free(w->W2);free(w->W3);free(w->rms_att);free(w->rms_ffn);
}
static LayerAdam layer_adam_alloc(void) {
LayerAdam a;
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ);
a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ);
a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM);
return a;
}
static void layer_adam_free(LayerAdam *a) {
adam_free(&a->Wq);adam_free(&a->Wk);adam_free(&a->Wv);adam_free(&a->Wo);
adam_free(&a->W1);adam_free(&a->W2);adam_free(&a->W3);
adam_free(&a->rms_att);adam_free(&a->rms_ffn);
}
static LayerActs layer_acts_alloc(void) {
LayerActs a;
a.layer_in=(float*)malloc(SEQ*DIM*4);
a.xnorm=(float*)malloc(SEQ*DIM*4);
a.Q=(float*)malloc(SEQ*DIM*4); a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4);
a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4);
a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4);
a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4);
return a;
}
static void layer_acts_free(LayerActs *a) {
free(a->layer_in);free(a->xnorm);
free(a->Q);free(a->K);free(a->V);
free(a->attn_out);free(a->o_out);free(a->x2);free(a->x2norm);
free(a->h1);free(a->h3);free(a->silu_out);free(a->ffn_out);
}
static LayerGrads layer_grads_alloc(void) {
LayerGrads g;
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4);
g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4);
g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4);
return g;
}
static void layer_grads_zero(LayerGrads *g) {
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4);
memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4);
memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4);
memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4);
}
static void layer_grads_free(LayerGrads *g) {
free(g->Wq);free(g->Wk);free(g->Wv);free(g->Wo);
free(g->W1);free(g->W2);free(g->W3);free(g->rms_att);free(g->rms_ffn);
}

View File

@ -0,0 +1,164 @@
// cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, embedding
#pragma once
#include "config.h"
static float *g_rms_tmp = NULL;
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
int n = S; vvrsqrtf(ss, ss, &n);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
}
free(ss);
}
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
float *rrms = (float*)malloc(S*4);
int n = S; vvrsqrtf(rrms, ss, &n);
float *dot = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
}
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
dw[i] += s;
}
free(ss); free(rrms); free(dot);
}
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
for (size_t i=0; i<s->n; i++) {
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
}
}
// Cross-entropy loss: operates on logits[V, S] column-major (each column = one token)
// Avoids transposing by using a per-token temp buffer
static float cross_entropy_loss(float *dlogits, const float *logits, const uint16_t *targets, int V, int S) {
float *col = (float*)malloc(V * 4); // single column buffer
float total_loss = 0;
float invS = 1.0f / S;
for (int t = 0; t < S; t++) {
// Gather column t: logits[v, t] = logits[v*S + t], stride=S
cblas_scopy(V, logits + t, S, col, 1);
// Softmax
float maxv; vDSP_maxv(col, 1, &maxv, (vDSP_Length)V);
float neg_max = -maxv;
vDSP_vsadd(col, 1, &neg_max, col, 1, (vDSP_Length)V);
int n = V; vvexpf(col, col, &n);
float sum; vDSP_sve(col, 1, &sum, (vDSP_Length)V);
float inv_sum = 1.0f / sum;
vDSP_vsmul(col, 1, &inv_sum, col, 1, (vDSP_Length)V);
// Loss + gradient
int tgt = targets[t];
total_loss -= logf(col[tgt] + 1e-10f);
col[tgt] -= 1.0f;
vDSP_vsmul(col, 1, &invS, col, 1, (vDSP_Length)V);
// Scatter back: dlogits[v*S + t] = col[v]
cblas_scopy(V, col, 1, dlogits + t, S);
}
free(col);
return total_loss / S;
}
// Vocab compaction: build mapping from full 32K vocab to compact vocab
typedef struct {
int compact_vocab; // number of active tokens
int *full_to_compact; // [VOCAB] → compact id (-1 if unused)
int *compact_to_full; // [compact_vocab] → full vocab id
} VocabMap;
static VocabMap vocab_map_build(const uint16_t *data, size_t n_tokens, int full_vocab) {
VocabMap vm;
vm.full_to_compact = (int*)malloc(full_vocab * sizeof(int));
memset(vm.full_to_compact, -1, full_vocab * sizeof(int));
// Scan for used tokens
for (size_t i = 0; i < n_tokens; i++) {
vm.full_to_compact[data[i]] = 0; // mark as used
}
// Assign compact IDs
int cid = 0;
for (int v = 0; v < full_vocab; v++) {
if (vm.full_to_compact[v] == 0)
vm.full_to_compact[v] = cid++;
else
vm.full_to_compact[v] = -1;
}
vm.compact_vocab = cid;
vm.compact_to_full = (int*)malloc(cid * sizeof(int));
for (int v = 0; v < full_vocab; v++) {
if (vm.full_to_compact[v] >= 0)
vm.compact_to_full[vm.full_to_compact[v]] = v;
}
return vm;
}
// Create compact embedding from full embedding
static float *vocab_compact_embed(const float *full_embed, const VocabMap *vm, int dim) {
float *ce = (float*)malloc((size_t)vm->compact_vocab * dim * 4);
for (int c = 0; c < vm->compact_vocab; c++)
memcpy(ce + c*dim, full_embed + vm->compact_to_full[c]*dim, dim*4);
return ce;
}
// Scatter compact embed gradients back to full embed
static void vocab_scatter_grads(float *full_gembed, const float *compact_gembed, const VocabMap *vm, int dim) {
for (int c = 0; c < vm->compact_vocab; c++) {
int fv = vm->compact_to_full[c];
for (int d = 0; d < dim; d++)
full_gembed[fv*dim + d] += compact_gembed[c*dim + d];
}
}
// Update full embed from compact embed (after adam)
static void vocab_update_full(float *full_embed, const float *compact_embed, const VocabMap *vm, int dim) {
for (int c = 0; c < vm->compact_vocab; c++)
memcpy(full_embed + vm->compact_to_full[c]*dim, compact_embed + c*dim, dim*4);
}
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
for (int d = 0; d < dim; d++)
x[d*seq + t] = embed[tok*dim + d];
}
}
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
for (int d = 0; d < dim; d++)
d_embed[tok*dim + d] += dx[d*seq + t];
}
}

View File

@ -0,0 +1,147 @@
// io.h — IOSurface helpers, NEON conversion, kernel compile/eval
#pragma once
#include "config.h"
static IOSurfaceRef make_surface(size_t bytes) {
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1,
(id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes),
(id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0});
}
// Blob builders for const weights (mask, rms)
static NSData *build_blob(const float *w, int rows, int cols) {
int ws=rows*cols*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
_Float16 *fp16=(_Float16*)(b+128);
for(int i=0;i<rows*cols;i++) fp16[i]=(_Float16)w[i];
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
static NSData *build_blob_fp16(_Float16 *d, int cnt) {
int ws=cnt*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
memcpy(b+128,d,ws);
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
// NEON vectorized conversion
static void cvt_f16_f32(float *dst, const _Float16 *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vld1q_f16((const __fp16*)(src+i));
vst1q_f32(dst+i, vcvt_f32_f16(vget_low_f16(h)));
vst1q_f32(dst+i+4, vcvt_f32_f16(vget_high_f16(h)));
}
for (; i < n; i++) dst[i] = (float)src[i];
}
static void cvt_f32_f16(_Float16 *dst, const float *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vcombine_f16(vcvt_f16_f32(vld1q_f32(src+i)),
vcvt_f16_f32(vld1q_f32(src+i+4)));
vst1q_f16((__fp16*)(dst+i), h);
}
for (; i < n; i++) dst[i] = (_Float16)src[i];
}
// IOSurface I/O (channel-first [C,S] layout, fp16 on surface)
static void io_write_fp16(IOSurfaceRef s, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s), data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
static void io_read_fp16(IOSurfaceRef s, float *data, int ch_off, int channels, int sp) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
cvt_f16_f32(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
static void io_copy(IOSurfaceRef dst, int dst_ch, IOSurfaceRef src, int src_ch, int channels, int sp) {
IOSurfaceLock(dst, 0, NULL);
IOSurfaceLock(src, kIOSurfaceLockReadOnly, NULL);
memcpy((_Float16*)IOSurfaceGetBaseAddress(dst) + dst_ch*sp,
(_Float16*)IOSurfaceGetBaseAddress(src) + src_ch*sp,
channels * sp * sizeof(_Float16));
IOSurfaceUnlock(src, kIOSurfaceLockReadOnly, NULL);
IOSurfaceUnlock(dst, 0, NULL);
}
static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
// fp32 IOSurface I/O (for dynamic matmul kernels that use fp32 input/output)
// Layout: [1, IC, 1, SP] where SP = SEQ + OC
// Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC]
static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq,
const float *W, int oc) {
int sp = seq + oc;
IOSurfaceLock(s, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < ic; d++) {
memcpy(buf + d*sp, act + d*seq, seq*4);
memcpy(buf + d*sp + seq, W + d*oc, oc*4);
}
IOSurfaceUnlock(s, 0, NULL);
}
// Read output from dynamic matmul kernel: [1, OC, 1, SEQ]
static void io_read_dyn(IOSurfaceRef s, float *out, int oc, int seq) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
memcpy(out, (float*)IOSurfaceGetBaseAddress(s), oc * seq * 4);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
// Compile MIL to ANE kernel
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {
@autoreleasepool {
NSData *md = [mil dataUsingEncoding:NSUTF8StringEncoding];
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(g_D, @selector(modelWithMILText:weights:optionsPlist:), md, weights, nil);
if (!desc) { printf(" [compile] desc=NULL\n"); return NULL; }
id mdl = ((id(*)(Class,SEL,id))objc_msgSend)(g_I, @selector(inMemoryModelWithDescriptor:), desc);
id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier));
NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx];
[[NSFileManager defaultManager] createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"] withIntermediateDirectories:YES attributes:nil error:nil];
[md writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES];
for (NSString *path in weights) {
NSString *rel = [path stringByReplacingOccurrencesOfString:@"@model_path/" withString:@""];
[weights[path][@"data"] writeToFile:[td stringByAppendingPathComponent:rel] atomically:YES];
}
NSError *e = nil;
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(compileWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] FAIL: %s\n", e ? [[e description] UTF8String] : "no error"); return NULL;
}
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] load FAIL\n"); return NULL;
}
__sync_fetch_and_add(&g_compile_count, 1);
Kern *k = (Kern*)calloc(1, sizeof(Kern));
k->model = (void*)CFBridgingRetain(mdl);
k->ioIn = make_surface(ic_bytes);
k->ioOut = make_surface(oc_bytes);
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0));
k->tmpDir = (void*)CFBridgingRetain(td);
return k;
}
}
static void free_kern(Kern *k) {
if (!k) return;
id mdl = (__bridge id)k->model; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
CFRelease(k->ioIn); CFRelease(k->ioOut);
[[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil];
CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir);
free(k);
}
static void ane_eval(Kern *k) {
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}

View File

@ -0,0 +1,590 @@
// mil_dynamic.h — MIL generators using dynamic matmul (weights via IOSurface)
// Instead of conv(const_weight, x), we use matmul(x, W) where both come from input.
// Input layout: [1, IC, 1, SP] fp32, SP = SEQ + total_weight_cols
// Activations in sp[0:SEQ], weight matrices packed sequentially in sp[SEQ:]
#pragma once
#include "io.h"
#define MIL_HDR \
@"program(1.3)\n[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
// Helper: generate a dynamic matmul within a MIL function
// Slices activation [1,ic,1,seq] and weight [1,ic,1,oc] from input, does matmul
// act_sp_off: spatial offset for activations (usually 0)
// w_sp_off: spatial offset for weight block
// Returns variable name of result [1,oc,1,seq] in fp16
static void gen_dyn_matmul(NSMutableString *m, const char *prefix,
int ic, int oc, int seq,
int act_sp_off, int w_sp_off,
const char *input_var) {
// Slice activations
[m appendFormat:@" tensor<int32, [4]> %s_ba = const()[name=string(\"%s_ba\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, act_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sa = const()[name=string(\"%s_sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_act = slice_by_size(x=%s,begin=%s_ba,size=%s_sa)[name=string(\"%s_act\")];\n", ic, seq, prefix, input_var, prefix, prefix, prefix];
// Slice weight
[m appendFormat:@" tensor<int32, [4]> %s_bw = const()[name=string(\"%s_bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, w_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sw = const()[name=string(\"%s_sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_wt = slice_by_size(x=%s,begin=%s_bw,size=%s_sw)[name=string(\"%s_wt\")];\n", ic, oc, prefix, input_var, prefix, prefix, prefix];
// Reshape act: [1,ic,1,seq] → [1,1,ic,seq] → transpose → [1,1,seq,ic]
[m appendFormat:@" tensor<int32, [4]> %s_ra = const()[name=string(\"%s_ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a2 = reshape(shape=%s_ra,x=%s_act)[name=string(\"%s_a2\")];\n", ic, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_pm = const()[name=string(\"%s_pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n", prefix, prefix];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a3 = transpose(perm=%s_pm,x=%s_a2)[name=string(\"%s_a3\")];\n", seq, ic, prefix, prefix, prefix, prefix];
// Reshape weight: [1,ic,1,oc] → [1,1,ic,oc]
[m appendFormat:@" tensor<int32, [4]> %s_rw = const()[name=string(\"%s_rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_W = reshape(shape=%s_rw,x=%s_wt)[name=string(\"%s_W\")];\n", ic, oc, prefix, prefix, prefix, prefix];
// matmul: [1,1,seq,ic] @ [1,1,ic,oc] → [1,1,seq,oc]
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yh = matmul(transpose_x=bF,transpose_y=bF,x=%s_a3,y=%s_W)[name=string(\"%s_yh\")];\n", seq, oc, prefix, prefix, prefix, prefix];
// Transpose back + reshape: [1,1,seq,oc] → [1,1,oc,seq] → [1,oc,1,seq]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yt = transpose(perm=%s_pm,x=%s_yh)[name=string(\"%s_yt\")];\n", oc, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_ro = const()[name=string(\"%s_ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, oc, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_y = reshape(shape=%s_ro,x=%s_yt)[name=string(\"%s_y\")];\n", oc, seq, prefix, prefix, prefix, prefix];
}
// ===== Dynamic matmul kernel: y = x @ W =====
// Input: [1, IC, 1, SEQ+OC] fp32 — act[0:SEQ] + W[SEQ:SEQ+OC]
// Output: [1, OC, 1, SEQ] fp32
static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
int sp = seq + oc;
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", ic, sp];
gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "xh");
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=mm_y)[name=string(\"cout\")];\n", oc, seq];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== SDPA forward (dynamic weights) =====
// Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul
// Input: [1, DIM, 1, SEQ + 4*DIM] fp32
// sp[0:SEQ] = xnorm (rmsnorm output, DIM channels)
// sp[SEQ:SEQ+DIM] = Wq[DIM,DIM]
// sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM]
// sp[SEQ+2D:SEQ+3D] = Wv[DIM,DIM]
// sp[SEQ+3D:SEQ+4D] = Wo[DIM,DIM]
// Output: [1, 6*DIM, 1, SEQ] fp16 = concat(o_out, Q, K, V, attn_out, xnorm_pass)
// NOTE: mask is still a const weight (it doesn't change)
static NSString *gen_sdpa_fwd_dynamic(void) {
float sc = 1.0f/sqrtf((float)HD);
int w_total = 4*DIM; // Wq+Wk+Wv+Wo
int sp_in = SEQ + w_total;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
// Cast to fp16
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice xnorm [1,DIM,1,SEQ]
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice Wq [1,DIM,1,DIM]
[m appendFormat:@" tensor<int32, [4]> bq = const()[name=string(\"bq\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=xh,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM];
// Slice Wk
[m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=xh,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM];
// Slice Wv
[m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=xh,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM];
// Slice Wo
[m appendFormat:@" tensor<int32, [4]> bo = const()[name=string(\"bo\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wo = slice_by_size(x=xh,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM];
// Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D]
[m appendFormat:@" tensor<int32, [4]> r2 = const()[name=string(\"r2\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=r2,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
// Reshape weights: [1,D,1,D] → [1,1,D,D]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wq2 = reshape(shape=rw,x=Wq)[name=string(\"Wq2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wk2 = reshape(shape=rw,x=Wk)[name=string(\"Wk2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wv2 = reshape(shape=rw,x=Wv)[name=string(\"Wv2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wo2 = reshape(shape=rw,x=Wo)[name=string(\"Wo2\")];\n", DIM, DIM];
// QKV matmul: [1,1,S,D] @ [1,1,D,D] → [1,1,S,D]
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, DIM];
// Transpose back: [1,1,S,D] → [1,1,D,S] → reshape [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> os = const()[name=string(\"os\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = reshape(shape=os,x=qt)[name=string(\"qf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = reshape(shape=os,x=kt)[name=string(\"kf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = reshape(shape=os,x=vt)[name=string(\"vf\")];\n", DIM, SEQ];
// SDPA: reshape to heads, matmul, mask, softmax, matmul
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD];
// Q @ K^T
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ];
// Causal mask (still const — doesn't change)
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ];
// Softmax
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ];
// scores @ V
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS, SEQ, HD];
// Reshape back to [1,DIM,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM, SEQ];
// Wo matmul: af → [1,1,S,D] @ Wo[1,1,D,D] → [1,1,S,D] → [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> af2 = reshape(shape=r2,x=af)[name=string(\"af2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> aft = transpose(perm=pm,x=af2)[name=string(\"aft\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> om = matmul(transpose_x=bF,transpose_y=bF,x=aft,y=Wo2)[name=string(\"om\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ];
// Output: concat(o_out, qf, kf, vf, af, xn) — same as original for backward compatibility
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ];
// Cast to fp32
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 6*DIM, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
return m;
}
// ===== FFN forward (dynamic weights) =====
// RMSNorm on CPU. This kernel: xnorm @ W1 → SiLU, xnorm @ W3 → gate, gate*silu @ W2 → out
// Input: [1, DIM, 1, SEQ + HIDDEN + HIDDEN + DIM] fp32
// sp[0:SEQ] = xnorm [DIM,SEQ]
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN] = W3[DIM,HIDDEN]
// sp[SEQ+2*HIDDEN:SEQ+2*HIDDEN+DIM]= W2[HIDDEN→DIM] — but W2 is [DIM,HIDDEN], we need HIDDEN input channels
// PROBLEM: W2 has shape [DIM,HIDDEN] = HIDDEN input channels, but our kernel has DIM input channels.
// Solution: separate kernels for W1/W3 (DIM→HIDDEN) and W2 (HIDDEN→DIM)
// OR: do W1,W3 in one kernel, SiLU on CPU/ANE, W2 in another kernel.
// Simpler: 3 separate matmul kernels per FFN direction. But that's too many dispatches.
// Better: one kernel for W1+W3 (same input dim), CPU SiLU, one kernel for W2.
// FFN part 1: xnorm @ W1, xnorm @ W3 (both DIM→HIDDEN)
// Input: [1, DIM, 1, SEQ + 2*HIDDEN] fp32
// sp[0:SEQ] = xnorm
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN]= W3[DIM,HIDDEN]
// Output: [1, 2*HIDDEN, 1, SEQ] fp32 = concat(h1, h3)
static NSString *gen_ffn_w13_dynamic(void) {
int sp_in = SEQ + 2*HIDDEN;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice xnorm
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice W1
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> s1 = const()[name=string(\"s1\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=xh,begin=b1,size=s1)[name=string(\"W1\")];\n", DIM, HIDDEN];
// Slice W3
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=xh,begin=b3,size=s1)[name=string(\"W3\")];\n", DIM, HIDDEN];
// Reshape for matmul
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN];
// Transpose back
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> rh = const()[name=string(\"rh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = reshape(shape=rh,x=h1t)[name=string(\"h1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = reshape(shape=rh,x=h3t)[name=string(\"h3\")];\n", HIDDEN, SEQ];
// SiLU + gate
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ];
// Concat output: (h1, h3, gate)
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(h1,h3,gate))[name=string(\"cat\")];\n", 2*HIDDEN+HIDDEN, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 3*HIDDEN, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
return m;
}
// FFN part 2: gate @ W2 (HIDDEN→DIM)
// Input: [1, HIDDEN, 1, SEQ + DIM] fp32
// sp[0:SEQ] = gate [HIDDEN,SEQ]
// sp[SEQ:SEQ+DIM] = W2[HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32
static NSString *gen_ffn_w2_dynamic(void) {
int sp_in = SEQ + DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
[m appendString:@" tensor<int32, [4]> ba = const()[name=string(\"ba\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sa = const()[name=string(\"sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> bw = const()[name=string(\"bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2 = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"W2\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> at = transpose(perm=pm,x=a2)[name=string(\"at\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W22 = reshape(shape=rw,x=W2)[name=string(\"W22\")];\n", HIDDEN, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ym = matmul(transpose_x=bF,transpose_y=bF,x=at,y=W22)[name=string(\"ym\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=ym)[name=string(\"yt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=yr)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== FFN backward (dynamic weights) =====
// Input: [1, DIM+2*HIDDEN, 1, SEQ + HIDDEN + DIM + DIM] fp32
// Actually simpler to split into separate backward kernels like forward.
// FFN backward part 1: dffn @ W2^T → dsilu (HIDDEN), then SiLU derivative
// Input: [1, DIM, 1, SEQ + HIDDEN] fp32
// sp[0:SEQ] = dffn [DIM, SEQ]
// sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN]
// Output: [1, HIDDEN, 1, SEQ] fp32 = dsilu_raw
static NSString *gen_ffn_bwd_w2t_dynamic(void) {
return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ);
}
// FFN backward part 2: dh1 @ W1^T + dh3 @ W3^T → dx
// We need h1,h3 for SiLU derivative, but those are on CPU.
// Actually the SiLU derivative + gating is element-wise, do on CPU.
// Then: dh1 @ W1^T and dh3 @ W3^T are two separate matmuls (HIDDEN→DIM).
// Combine into one kernel:
// Input: [1, HIDDEN, 1, SEQ + SEQ + DIM + DIM] fp32
// sp[0:SEQ] = dh1 [HIDDEN,SEQ]
// sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ]
// sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM]
// sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dx1 + dx3
static NSString *gen_ffn_bwd_w13t_dynamic(void) {
int sp_in = 2*SEQ + 2*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
// Slice dh1 [HIDDEN, SEQ]
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = slice_by_size(x=xh,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
// Slice dh3
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = slice_by_size(x=xh,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
// Slice W1^T [HIDDEN, DIM]
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1t = slice_by_size(x=xh,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM];
// Slice W3^T
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3t = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
// dh1 matmul: [S,H] @ [H,D] → [S,D]
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh12 = reshape(shape=ra,x=dh1)[name=string(\"dh12\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh1t = transpose(perm=pm,x=dh12)[name=string(\"dh1t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh32 = reshape(shape=ra,x=dh3)[name=string(\"dh32\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh3t = transpose(perm=pm,x=dh32)[name=string(\"dh3t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W1t2 = reshape(shape=rw,x=W1t)[name=string(\"W1t2\")];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W3t2 = reshape(shape=rw,x=W3t)[name=string(\"W3t2\")];\n", HIDDEN, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx1m = matmul(transpose_x=bF,transpose_y=bF,x=dh1t,y=W1t2)[name=string(\"dx1m\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx3m = matmul(transpose_x=bF,transpose_y=bF,x=dh3t,y=W3t2)[name=string(\"dx3m\")];\n", SEQ, DIM];
// Add
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxm = add(x=dx1m,y=dx3m)[name=string(\"dxm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== SDPA backward part 1 (dynamic Wo^T) =====
// Same as original gen_sdpa_bwd1 but Wo^T comes from input instead of const
// Input: [1, 4*DIM, 1, SEQ + DIM] fp32 — Q,K,V,dx2 in channels, Wo^T in spatial
// Wait — channels must match for all data. Q,K,V are [DIM,SEQ], dx2 is [DIM,SEQ].
// Total input channels = 4*DIM. But Wo^T is [DIM,DIM] = DIM channels of DIM spatial.
// Problem: can't mix 4*DIM channels for data with DIM channels for Wo^T.
// Solution: Wo^T matmul as separate kernel, then SDPA part purely element-wise on ANE.
// Wo^T matmul: dx2 @ Wo^T → da (DIM→DIM)
static NSString *gen_wot_dynamic(void) {
return gen_dyn_matmul_mil(DIM, DIM, SEQ);
}
// SDPA backward part 1 (no weights, all data): Q,K,V,da → dV,probs,dp
// Same as original but without Wo^T conv (already done)
// Input: [1, 4*DIM, 1, SEQ] fp16
static NSString *gen_sdpa_bwd1_noweight(void) {
float sc = 1.0f/sqrtf((float)HD);
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 4*DIM, SEQ];
// Slice Q,K,V,da
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM, SEQ];
// Reshape to heads
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dr = reshape(shape=rsh,x=da)[name=string(\"rd\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dat = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS, SEQ, HD];
// Forward attention scores (recompute)
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ];
// dV = probs^T @ da, dp = da @ V^T
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=dat)[name=string(\"dv\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp4 = matmul(transpose_x=bF,transpose_y=bT,x=dat,y=v)[name=string(\"dp\")];\n", HEADS, SEQ, SEQ];
// Reshape dV back
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM, SEQ];
// Flatten probs and dp for output
[m appendFormat:@" tensor<int32, [4]> scs = const()[name=string(\"scs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 2: same as original (no weights, pure computation)
static NSString *gen_sdpa_bwd2(void) {
float sc = 1.0f/sqrtf((float)HD);
int bwd2_in = 2*SCORE_CH + 2*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", bwd2_in, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_sc = const()[name=string(\"szsc\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_d = const()[name=string(\"szd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ssh = const()[name=string(\"ssh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([-1])];\n"];
[m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// QKV backward (dynamic): dq @ Wq^T + dk @ Wk^T + dv @ Wv^T → dx
// Input: [1, DIM, 1, 3*SEQ + 3*DIM] fp32
// sp[0:SEQ] = dq [DIM,SEQ]
// sp[SEQ:2*SEQ] = dk [DIM,SEQ]
// sp[2*SEQ:3*SEQ] = dv [DIM,SEQ]
// sp[3*SEQ:3*SEQ+DIM] = Wq^T [DIM,DIM]
// sp[3*SEQ+DIM:3*SEQ+2D] = Wk^T [DIM,DIM]
// sp[3*SEQ+2D:3*SEQ+3D] = Wv^T [DIM,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dxq + dxk + dxv
static NSString *gen_qkvb_dynamic(void) {
int sp_in = 3*SEQ + 3*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice dq, dk, dv
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=xh,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=xh,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=xh,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ];
// Slice Wq^T, Wk^T, Wv^T
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wqt = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b4 = const()[name=string(\"b4\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=xh,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b5 = const()[name=string(\"b5\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=xh,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Reshape and matmul for each
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
// dq @ Wq^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dq2 = reshape(shape=rd,x=dq)[name=string(\"dq2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dqt = transpose(perm=pm,x=dq2)[name=string(\"dqt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wqt2 = reshape(shape=rw,x=Wqt)[name=string(\"Wqt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxq = matmul(transpose_x=bF,transpose_y=bF,x=dqt,y=Wqt2)[name=string(\"dxq\")];\n", SEQ, DIM];
// dk @ Wk^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dk2 = reshape(shape=rd,x=dk)[name=string(\"dk2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM];
// dv @ Wv^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dv2 = reshape(shape=rd,x=dv)[name=string(\"dv2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM];
// Sum: dxq + dxk + dxv
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxall = add(x=dxqk,y=dxv)[name=string(\"aall\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// Causal mask blob (used by sdpa_fwd and sdpa_bwd1)
static NSData *g_mask_blob = nil;
static NSData *get_mask_blob(void) {
if (!g_mask_blob) {
_Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16));
for(int t=0;t<SEQ;t++) for(int t2=0;t2<SEQ;t2++)
mask[t*SEQ+t2] = (t2<=t) ? (_Float16)0.0f : (_Float16)(-65504.0f);
g_mask_blob = build_blob_fp16(mask, SEQ*SEQ);
free(mask);
}
return g_mask_blob;
}

View File

@ -0,0 +1,876 @@
// train.m Dynamic weight ANE training for Stories110M
// Compile kernels ONCE at startup, update weights via IOSurface every step.
// No exec() restart needed eliminates 76% compile overhead.
#include "mil_dynamic.h"
#include "cpu_ops.h"
#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin"
#define MODEL_PATH "../../../assets/models/stories110M.bin"
#define DATA_PATH "../tinystories_data00.bin"
// Dynamic kernel set per layer
typedef struct {
Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface)
Kern *ffnW13; // W1,W3 matmul (dynamic)
Kern *ffnW2; // W2 matmul (dynamic)
Kern *ffnBwdW2t; // dffn @ W2^T (dynamic)
Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic)
Kern *wotBwd; // dx2 @ Wo^T (dynamic)
Kern *sdpaBwd1; // Q,K,V,da dV,probs,dp (weight-free, has mask const)
Kern *sdpaBwd2; // probs,dp,Q,K dQ,dK (weight-free)
Kern *qkvBwd; // dq@Wq^T + dk@Wk^T + dv@Wv^T (dynamic)
} DynLayerKernels;
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { printf("Cannot open %s\n", path); return false; }
Llama2Config cfg;
fread(&cfg, sizeof(cfg), 1, f);
printf(" Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
printf(" ERROR: Config mismatch!\n"); fclose(f); return false;
}
int V = abs(cfg.vocab_size);
fread(embed, 4, V * DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
fread(rms_final, 4, DIM, f);
fclose(f);
printf(" Loaded pretrained weights\n");
return true;
}
// Transpose W[rows,cols] W^T[cols,rows] stored as [cols channels, rows spatial]
static void transpose_weight(float *dst, const float *src, int rows, int cols) {
for (int r = 0; r < rows; r++)
for (int c = 0; c < cols; c++)
dst[c * rows + r] = src[r * cols + c];
}
// ===== Compile all dynamic kernels (ONCE) =====
static bool compile_dynamic_kernels(DynLayerKernels *dk) {
NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}};
// SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp32 [1, 6*DIM, 1, SEQ] fp32
printf(" Compiling sdpaFwd...\n");
dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w,
DIM*(SEQ+4*DIM)*4, 6*DIM*SEQ*4);
if (!dk->sdpaFwd) return false;
// FFN W1+W3: [1, DIM, 1, SEQ+2*HIDDEN] fp32 [1, 3*HIDDEN, 1, SEQ] fp32
printf(" Compiling ffnW13...\n");
dk->ffnW13 = compile_kern_mil_w(gen_ffn_w13_dynamic(), @{},
DIM*(SEQ+2*HIDDEN)*4, 3*HIDDEN*SEQ*4);
if (!dk->ffnW13) return false;
// FFN W2: [1, HIDDEN, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling ffnW2...\n");
dk->ffnW2 = compile_kern_mil_w(gen_ffn_w2_dynamic(), @{},
HIDDEN*(SEQ+DIM)*4, DIM*SEQ*4);
if (!dk->ffnW2) return false;
// FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp32 [1, HIDDEN, 1, SEQ] fp32
printf(" Compiling ffnBwdW2t...\n");
dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{},
DIM*(SEQ+HIDDEN)*4, HIDDEN*SEQ*4);
if (!dk->ffnBwdW2t) return false;
// FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling ffnBwdW13t...\n");
dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{},
HIDDEN*(2*SEQ+2*DIM)*4, DIM*SEQ*4);
if (!dk->ffnBwdW13t) return false;
// Wo^T backward: [1, DIM, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling wotBwd...\n");
dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{},
DIM*(SEQ+DIM)*4, DIM*SEQ*4);
if (!dk->wotBwd) return false;
// SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 [1, DIM+2*SCORE_CH, 1, SEQ] fp16
printf(" Compiling sdpaBwd1...\n");
dk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_noweight(), mask_w,
4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
if (!dk->sdpaBwd1) return false;
// SDPA bwd2 (no weights): [1, 2*SCORE_CH+2*DIM, 1, SEQ] fp16 [1, 2*DIM, 1, SEQ] fp16
printf(" Compiling sdpaBwd2...\n");
dk->sdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
if (!dk->sdpaBwd2) return false;
// QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling qkvBwd...\n");
dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{},
DIM*(3*SEQ+3*DIM)*4, DIM*SEQ*4);
if (!dk->qkvBwd) return false;
return true;
}
// ===== Write dynamic weights into IOSurface =====
// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] xnorm at sp[0:S], Wq/Wk/Wv/Wo at sp[S:]
static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm,
const float *Wq, const float *Wk, const float *Wv, const float *Wo) {
IOSurfaceLock(dk->sdpaFwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->sdpaFwd->ioIn);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL);
}
// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] xnorm at sp[0:S], W1,W3 at sp[S:]
static void write_ffn_w13_input(DynLayerKernels *dk, const float *xnorm,
const float *W1, const float *W3) {
IOSurfaceLock(dk->ffnW13->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW13->ioIn);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN*4);
memcpy(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN*4);
}
IOSurfaceUnlock(dk->ffnW13->ioIn, 0, NULL);
}
// ffnW2: [1, HIDDEN, 1, SEQ+DIM] gate at sp[0:S], W2 at sp[S:]
static void write_ffn_w2_input(DynLayerKernels *dk, const float *gate, const float *W2) {
IOSurfaceLock(dk->ffnW2->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW2->ioIn);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, gate + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W2 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->ffnW2->ioIn, 0, NULL);
}
// ===== Checkpoint =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double ct, double cw, int cs, int adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "wb");
CkptHdr h = {0};
h.magic = 0x424C5A54; h.version = 3;
h.step = step; h.total_steps = total_steps;
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
h.lr = lr; h.loss = loss;
h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.adam_t = adam_t;
fwrite(&h, sizeof(h), 1, f);
for (int L = 0; L < NLAYERS; L++) {
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
}
fwrite(rms_final,4,DIM,f);
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
fwrite(embed,4,VOCAB*DIM,f);
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
fclose(f);
}
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
double *ct, double *cw, int *cs, int *adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "rb");
if (!f) return false;
CkptHdr h;
fread(&h, sizeof(h), 1, f);
if (h.magic != 0x424C5A54 || h.version != 3) { fclose(f); return false; }
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
*ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *adam_t = h.adam_t;
for (int L = 0; L < NLAYERS; L++) {
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
}
fread(rms_final,4,DIM,f);
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
fread(embed,4,VOCAB*DIM,f);
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
fclose(f);
return true;
}
int main(int argc, char *argv[]) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
int total_steps = 10000;
float max_lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
int accum_steps = 10;
int warmup_steps = 100;
float grad_clip = 1.0f;
float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1
bool do_resume = false, from_scratch = false;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--scratch") == 0) from_scratch = true;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) max_lr = atof(argv[++i]);
else if (strcmp(argv[i], "--accum") == 0 && i+1<argc) accum_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--warmup") == 0 && i+1<argc) warmup_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--clip") == 0 && i+1<argc) grad_clip = atof(argv[++i]);
}
float lr = max_lr;
// Allocate per-layer state
LayerWeights lw[NLAYERS]; LayerAdam la[NLAYERS];
LayerActs acts[NLAYERS]; LayerGrads grads[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
lw[L] = layer_weights_alloc(); la[L] = layer_adam_alloc();
acts[L] = layer_acts_alloc(); grads[L] = layer_grads_alloc();
}
float *rms_final = (float*)malloc(DIM*4);
float *embed = (float*)malloc(VOCAB*DIM*4);
float *grms_final = (float*)calloc(DIM, 4);
float *gembed = (float*)calloc(VOCAB*DIM, 4);
AdamState arms_final = adam_alloc(DIM);
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
double cum_train=0, cum_wall=0; int cum_steps=0;
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
&cum_train, &cum_wall, &cum_steps, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
}
if (!resuming) {
printf("=== ANE Dynamic Training: Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
// Param counts for dashboard
double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6;
double embed_m = (double)VOCAB*DIM / 1e6;
printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m);
printf("Kernels: 9 compiled, 9 weight-bearing\n");
printf("Accum %d steps, LR=%g\n", accum_steps, max_lr);
// FLOPs estimate: 6*N*B*T for transformer (forward+backward 3x forward)
double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ;
double total_flops = 3.0 * fwd_flops; // fwd + bwd 3x fwd
printf("FLOPs/step: fwd=%.1fM bwd_dx=%.1fM bwd_dW=%.1fM sdpa_bwd=0.0M total=%.1fM\n",
fwd_flops/1e6, fwd_flops/1e6, fwd_flops/1e6, total_flops/1e6);
printf("ANE FLOPs/step: %.1fM\n", total_flops/1e6);
if (from_scratch || !load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
if (from_scratch) printf(" Training from scratch (random init)\n");
else printf(" Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
}
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
float escale = 0.02f;
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
}
}
// Precompute transposed weights (for backward pass kernels)
// These get updated after each Adam step
float *Wqt_buf[NLAYERS], *Wkt_buf[NLAYERS], *Wvt_buf[NLAYERS], *Wot_buf[NLAYERS];
float *W1t_buf[NLAYERS], *W2t_buf[NLAYERS], *W3t_buf[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
Wqt_buf[L]=(float*)malloc(WQ_SZ*4); Wkt_buf[L]=(float*)malloc(WQ_SZ*4);
Wvt_buf[L]=(float*)malloc(WQ_SZ*4); Wot_buf[L]=(float*)malloc(WO_SZ*4);
W1t_buf[L]=(float*)malloc(W1_SZ*4); W2t_buf[L]=(float*)malloc(W2_SZ*4);
W3t_buf[L]=(float*)malloc(W3_SZ*4);
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
}
// mmap token data
int data_fd = open(DATA_PATH, O_RDONLY);
if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; }
struct stat st; fstat(data_fd, &st);
size_t data_len = st.st_size;
uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
size_t n_tokens = data_len / 2;
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
// Vocab compaction: map 32K sparse vocab ~9K compact
VocabMap vm = vocab_map_build(token_data, n_tokens, VOCAB);
int CV = vm.compact_vocab;
printf("Vocab compaction: %d → %d active tokens (%.1fx reduction)\n", VOCAB, CV, (float)VOCAB/CV);
// Create compact embedding + adam state
float *cembed = vocab_compact_embed(embed, &vm, DIM);
float *gcembed = (float*)calloc((size_t)CV*DIM, 4);
AdamState acembed = adam_alloc((size_t)CV*DIM);
// ===== Compile all kernels ONCE =====
printf("Compiling %d dynamic kernels (one-time)...\n", 9);
uint64_t tc = mach_absolute_time();
DynLayerKernels dk;
if (!compile_dynamic_kernels(&dk)) {
printf("Compilation failed!\n"); return 1;
}
double compile_ms = tb_ms(mach_absolute_time() - tc);
printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n\n", compile_ms, NLAYERS);
// Gradient + work buffers
float *dy = (float*)malloc(SEQ*DIM*4);
float *dffn = (float*)malloc(SEQ*DIM*4);
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
float *dx2 = (float*)malloc(SEQ*DIM*4);
float *dx_attn = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*DIM*4);
float *dk_buf = (float*)malloc(SEQ*DIM*4);
float *dv = (float*)malloc(SEQ*DIM*4);
float *x_cur = (float*)malloc(SEQ*DIM*4);
float *x_final = (float*)malloc(SEQ*DIM*4);
float *xnorm_buf = (float*)malloc(SEQ*DIM*4);
float *logits = (float*)malloc(SEQ*CV*4);
float *dlogits = (float*)malloc(SEQ*CV*4);
float *gate_buf = (float*)malloc(SEQ*HIDDEN*4);
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
float *dsilu = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp2 = (float*)malloc(SEQ*HIDDEN*4);
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f;
double total_train_ms = 0;
int total_steps_done = 0;
uint64_t t_wall_start = mach_absolute_time();
srand48(42 + start_step);
for (int step = start_step; step < total_steps; step++) {
uint64_t t0, t1, t_step = mach_absolute_time();
// Sample data
size_t max_pos = n_tokens - SEQ - 1;
size_t pos = (size_t)(drand48() * max_pos);
uint16_t *input_tokens = token_data + pos;
uint16_t *target_tokens_raw = token_data + pos + 1;
// Map targets to compact vocab IDs
uint16_t ctargets[SEQ];
for (int t = 0; t < SEQ; t++) ctargets[t] = (uint16_t)vm.full_to_compact[target_tokens_raw[t]];
// Embedding lookup (uses full embed for now input tokens are full IDs)
embed_lookup(x_cur, embed, input_tokens, DIM, SEQ);
// Timing accumulators (reset each step)
double t_rms=0, t_ane_fwd=0, t_io_fwd=0, t_cblas_wait=0;
double t_ane_bwd=0, t_io_bwd=0, t_silu=0, t_rms_bwd=0, t_cls=0, t_dw_copy=0;
// ===== FORWARD (12 layers) =====
for (int L=0; L<NLAYERS; L++) {
LayerActs *ac = &acts[L];
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
// RMSNorm1 (CPU)
t0 = mach_absolute_time();
rmsnorm(xnorm_buf, x_cur, lw[L].rms_att, DIM, SEQ);
memcpy(ac->xnorm, xnorm_buf, SEQ*DIM*4);
t_rms += tb_ms(mach_absolute_time() - t0);
// Wait for any pending dW cblas
t0 = mach_absolute_time();
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t_cblas_wait += tb_ms(mach_absolute_time() - t0);
// SDPA forward (ANE): xnorm + Wq,Wk,Wv,Wo o_out,Q,K,V,attn_out,xnorm
t0 = mach_absolute_time();
write_sdpa_fwd_input(&dk, xnorm_buf, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaFwd);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read output: [1, 6*DIM, 1, SEQ] fp32
t0 = mach_absolute_time();
IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
float *fwd_out = (float*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut);
memcpy(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ*4);
IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x2 = x_cur + o_out
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
// RMSNorm2 (CPU)
t0 = mach_absolute_time();
rmsnorm(xnorm_buf, ac->x2, lw[L].rms_ffn, DIM, SEQ);
memcpy(ac->x2norm, xnorm_buf, SEQ*DIM*4);
t_rms += tb_ms(mach_absolute_time() - t0);
// FFN W1+W3 (ANE): xnorm h1, h3, gate
t0 = mach_absolute_time();
write_ffn_w13_input(&dk, xnorm_buf, W1t_buf[L], W3t_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnW13);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read h1, h3, gate from output [1, 3*HIDDEN, 1, SEQ]
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL);
float *ffn13_out = (float*)IOSurfaceGetBaseAddress(dk.ffnW13->ioOut);
memcpy(ac->h1, ffn13_out, HIDDEN*SEQ*4);
memcpy(ac->h3, ffn13_out + HIDDEN*SEQ, HIDDEN*SEQ*4);
memcpy(gate_buf, ffn13_out + 2*HIDDEN*SEQ, HIDDEN*SEQ*4);
memcpy(ac->silu_out, gate_buf, HIDDEN*SEQ*4);
IOSurfaceUnlock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// FFN W2 (ANE): gate @ W2 ffn_out
t0 = mach_absolute_time();
write_ffn_w2_input(&dk, gate_buf, W2t_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnW2);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
memcpy(ac->ffn_out, (float*)IOSurfaceGetBaseAddress(dk.ffnW2->ioOut), DIM*SEQ*4);
IOSurfaceUnlock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x_cur = x2 + ffn_out
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
}
// Final RMSNorm + classifier + loss (CPU)
t0 = mach_absolute_time();
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t_rms += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
// Classifier: logits[CV, SEQ] = cembed[CV, DIM] @ x_final[DIM, SEQ]
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
CV, SEQ, DIM, 1.0f, cembed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
float loss = cross_entropy_loss(dlogits, logits, ctargets, CV, SEQ);
t_cls += tb_ms(mach_absolute_time() - t0);
last_loss = loss;
// ===== BACKWARD =====
// Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ]
t0 = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, CV, 1.0f, cembed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
t_cls += tb_ms(mach_absolute_time() - t0);
// dEmbed async: gcembed[CV, DIM] += dlogits[CV, SEQ] @ x_final^T[SEQ, DIM]
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
CV, DIM, SEQ, 1.0f, dlogits, SEQ, x_final, SEQ, 1.0f, gcembed, DIM);
});
// Final RMSNorm backward
float *dx_rms_final = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms_final, grms_final, dy, x_cur, rms_final, DIM, SEQ);
memcpy(dy, dx_rms_final, SEQ*DIM*4);
free(dx_rms_final);
// ===== BACKWARD (12 layers, reverse) =====
for (int L=NLAYERS-1; L>=0; L--) {
LayerActs *ac = &acts[L];
LayerGrads *gr = &grads[L];
memcpy(dffn, dy, SEQ*DIM*4);
// FFN backward: dffn @ W2^T dsilu_raw
t0 = mach_absolute_time();
io_write_dyn(dk.ffnBwdW2t->ioIn, dffn, DIM, SEQ, lw[L].W2, HIDDEN);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW2t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// SiLU derivative (vectorized): dsilu dh1, dh3
// silu(h1) = h1*sig(h1), dsilu_dh1 = sig*(1+h1*(1-sig))
// dh1 = dsilu * h3 * dsilu_dh1, dh3 = dsilu * silu(h1)
t0 = mach_absolute_time();
{
int n = HIDDEN*SEQ;
// sig = 1/(1+exp(-h1))
float minus1 = -1.0f, one = 1.0f;
vDSP_vsmul(ac->h1, 1, &minus1, silu_tmp, 1, (vDSP_Length)n);
vvexpf(silu_tmp, silu_tmp, &n);
vDSP_vsadd(silu_tmp, 1, &one, silu_tmp, 1, (vDSP_Length)n);
vvrecf(silu_tmp, silu_tmp, &n); // silu_tmp = sig
// dh3 = dsilu * h1 * sig (= dsilu * silu(h1))
vDSP_vmul(ac->h1, 1, silu_tmp, 1, dh3, 1, (vDSP_Length)n);
vDSP_vmul(dsilu, 1, dh3, 1, dh3, 1, (vDSP_Length)n);
// dsilu_dh1 = sig*(1+h1*(1-sig)), store in silu_tmp2
vDSP_vsadd(silu_tmp, 1, &minus1, silu_tmp2, 1, (vDSP_Length)n); // sig-1
vDSP_vneg(silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // 1-sig
vDSP_vmul(ac->h1, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // h1*(1-sig)
vDSP_vsadd(silu_tmp2, 1, &one, silu_tmp2, 1, (vDSP_Length)n); // 1+h1*(1-sig)
vDSP_vmul(silu_tmp, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // full dsilu_dh1
// dh1 = dsilu * h3 * dsilu_dh1
vDSP_vmul(dsilu, 1, ac->h3, 1, dh1, 1, (vDSP_Length)n);
vDSP_vmul(dh1, 1, silu_tmp2, 1, dh1, 1, (vDSP_Length)n);
}
t_silu += tb_ms(mach_absolute_time() - t0);
// dh1@W1^T + dh3@W3^T dx_ffn (ANE)
t0 = mach_absolute_time();
{
IOSurfaceLock(dk.ffnBwdW13t->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.ffnBwdW13t->ioIn);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, dh1 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, lw[L].W1 + d*DIM, DIM*4);
memcpy(buf + d*sp + 2*SEQ + DIM, lw[L].W3 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.ffnBwdW13t->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW13t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dW FFN async (cblas)
t0 = mach_absolute_time();
float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4);
float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4);
float *capt_dh1 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh1, dh1, SEQ*HIDDEN*4);
float *capt_dh3 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh3, dh3, SEQ*HIDDEN*4);
float *capt_x2n = (float*)malloc(SEQ*DIM*4); memcpy(capt_x2n, ac->x2norm, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
1.0f, capt_dffn, SEQ, capt_silu, SEQ, 1.0f, gr->W2, HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh1, SEQ, capt_x2n, SEQ, 1.0f, gr->W1, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh3, SEQ, capt_x2n, SEQ, 1.0f, gr->W3, DIM);
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
});
// RMSNorm2 backward
t0 = mach_absolute_time();
memset(dx2, 0, SEQ*DIM*4);
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
// Wo^T backward (ANE): dx2 @ Wo^T da
t0 = mach_absolute_time();
io_write_dyn(dk.wotBwd->ioIn, dx2, DIM, SEQ, lw[L].Wo, DIM);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.wotBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
float *da_buf = (float*)malloc(SEQ*DIM*4);
io_read_dyn(dk.wotBwd->ioOut, da_buf, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dWo async
t0 = mach_absolute_time();
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2, SEQ*DIM*4);
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM);
free(capt_do); free(capt_attn);
});
// SDPA backward part 1 (ANE, fp16): Q,K,V,da dV,probs,dp
t0 = mach_absolute_time();
io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, DIM, ac->K, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 2*DIM, ac->V, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 3*DIM, da_buf, DIM, SEQ);
free(da_buf);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd1);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
// SDPA backward part 2: probs,dp,Q,K dQ,dK
t0 = mach_absolute_time();
io_copy(dk.sdpaBwd2->ioIn, 0, dk.sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH+DIM, ac->K, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd2);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_fp16(dk.sdpaBwd2->ioOut, dq, 0, DIM, SEQ);
io_read_fp16(dk.sdpaBwd2->ioOut, dk_buf, DIM, DIM, SEQ);
io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dWq/dWk/dWv async
t0 = mach_absolute_time();
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk_buf, SEQ*DIM*4);
float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4);
float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM);
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
});
// QKV backward (ANE): dq,dk,dv @ Wq^T,Wk^T,Wv^T dx_attn
t0 = mach_absolute_time();
{
IOSurfaceLock(dk.qkvBwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.qkvBwd->ioIn);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, dq + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dk_buf + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 3*SEQ, lw[L].Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+DIM, lw[L].Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+2*DIM, lw[L].Wv + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.qkvBwd->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.qkvBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// RMSNorm1 backward
t0 = mach_absolute_time();
float *dx_rms1 = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms1, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms1[i] + dx2[i];
free(dx_rms1);
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
}
// Embedding backward
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
embed_backward(gembed, dy, input_tokens, DIM, SEQ);
double step_ms = tb_ms(mach_absolute_time() - t_step);
total_train_ms += step_ms;
total_steps_done++;
if (step % 10 == 0 || step == start_step) {
printf(" timing: ane_fwd=%.1f io_fwd=%.1f rms=%.1f ane_bwd=%.1f io_bwd=%.1f silu=%.1f rms_bwd=%.1f cls=%.1f cblas_wait=%.1f dw_copy=%.1f\n",
t_ane_fwd, t_io_fwd, t_rms, t_ane_bwd, t_io_bwd, t_silu, t_rms_bwd, t_cls, t_cblas_wait, t_dw_copy);
float xmx, xmn;
vDSP_maxv(x_cur,1,&xmx,(vDSP_Length)(SEQ*DIM));
vDSP_minv(x_cur,1,&xmn,(vDSP_Length)(SEQ*DIM));
float dmx, dmn;
vDSP_maxv(dy,1,&dmx,(vDSP_Length)(SEQ*DIM));
vDSP_minv(dy,1,&dmn,(vDSP_Length)(SEQ*DIM));
printf("step %-4d loss=%.4f lr=%.2e %.1fms/step x[%.2f,%.2f] dy[%.3e,%.3e]\n",
step, loss, lr, step_ms, xmn, xmx, dmn, dmx);
}
// Adam update every accum_steps
if ((step+1) % accum_steps == 0 || step == total_steps-1) {
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
float gsc = 1.0f / accum_steps;
adam_t++;
// Scale gradients by 1/accum_steps
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
}
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
// Merge compact classifier grads into full embed grads
vocab_scatter_grads(gembed, gcembed, &vm, DIM);
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
// Global gradient norm
float grad_norm_sq = 0;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
float s;
vDSP_dotpr(g->Wq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W3,1,g->W3,1,&s,(vDSP_Length)W3_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->rms_att,1,g->rms_att,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
vDSP_dotpr(g->rms_ffn,1,g->rms_ffn,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
}
{ float s;
vDSP_dotpr(grms_final,1,grms_final,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); grad_norm_sq+=s;
}
float grad_norm = sqrtf(grad_norm_sq);
if ((step+1) % 10 == 0) printf(" grad_norm=%.4f\n", grad_norm);
// Gradient clipping
if (grad_clip > 0 && grad_norm > grad_clip) {
float clip_scale = grad_clip / grad_norm;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
vDSP_vsmul(g->Wq,1,&clip_scale,g->Wq,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wk,1,&clip_scale,g->Wk,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wv,1,&clip_scale,g->Wv,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wo,1,&clip_scale,g->Wo,1,(vDSP_Length)WO_SZ);
vDSP_vsmul(g->W1,1,&clip_scale,g->W1,1,(vDSP_Length)W1_SZ);
vDSP_vsmul(g->W2,1,&clip_scale,g->W2,1,(vDSP_Length)W2_SZ);
vDSP_vsmul(g->W3,1,&clip_scale,g->W3,1,(vDSP_Length)W3_SZ);
vDSP_vsmul(g->rms_att,1,&clip_scale,g->rms_att,1,(vDSP_Length)DIM);
vDSP_vsmul(g->rms_ffn,1,&clip_scale,g->rms_ffn,1,(vDSP_Length)DIM);
}
vDSP_vsmul(grms_final,1,&clip_scale,grms_final,1,(vDSP_Length)DIM);
vDSP_vsmul(gembed,1,&clip_scale,gembed,1,(vDSP_Length)(VOCAB*DIM));
}
// Cosine LR schedule with warmup
if (step < warmup_steps) {
lr = max_lr * ((float)(step + 1)) / warmup_steps;
} else {
float decay_ratio = (float)(step - warmup_steps) / (float)(total_steps - warmup_steps);
float min_lr = max_lr * min_lr_frac;
lr = min_lr + 0.5f * (1.0f + cosf(M_PI * decay_ratio)) * (max_lr - min_lr);
}
// Adam update
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Update transposed weight buffers
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
}
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Re-extract compact embed from updated full embed
free(cembed);
cembed = vocab_compact_embed(embed, &vm, DIM);
// Zero grads
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
memset(grms_final, 0, DIM*4);
memset(gembed, 0, (size_t)VOCAB*DIM*4);
memset(gcembed, 0, (size_t)CV*DIM*4);
// Checkpoint
if ((step+1) % 100 == 0) {
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss,
total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
}
}
}
// Report
double wall = tb_ms(mach_absolute_time() - t_wall_start);
printf("\n=== Efficiency Report ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Compile: %.0fms (one-time, %.1f%%)\n", compile_ms, 100*compile_ms/(wall+cum_wall));
printf("Train time: %.0fms (%.1fms/step)\n", total_train_ms, total_train_ms/total_steps_done);
printf("Wall time: %.1fs\n", (wall+cum_wall)/1000);
// Cleanup
for (int L=0; L<NLAYERS; L++) {
layer_weights_free(&lw[L]); layer_adam_free(&la[L]);
layer_acts_free(&acts[L]); layer_grads_free(&grads[L]);
free(Wqt_buf[L]); free(Wkt_buf[L]); free(Wvt_buf[L]); free(Wot_buf[L]);
free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]);
}
free_kern(dk.sdpaFwd); free_kern(dk.ffnW13); free_kern(dk.ffnW2);
free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd);
free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd);
munmap(token_data, data_len); close(data_fd);
}
return 0;
}