From cb474e15378f75de37b7fc99e97af3c7c00a2afa Mon Sep 17 00:00:00 2001 From: maderix Date: Mon, 2 Mar 2026 23:49:55 -0800 Subject: [PATCH 1/3] =?UTF-8?q?Add=20dynamic=20weight=20training=20pipelin?= =?UTF-8?q?e=20=E2=80=94=20110ms/step=20without=20recompilation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dynamic weight pipeline that eliminates the ~3.7s recompile-every-10-steps bottleneck. Weights are passed via IOSurface spatial dimension instead of baked as constants, so kernels compile once at startup (345ms) and run indefinitely without exec() restart. Key components: - training_dynamic/ — full pipeline (config, IO, MIL generators, train loop) - 9 dynamic kernels shared across all 12 layers - Vocab compaction 32K→9.2K for faster classifier - Vectorized cross-entropy with vDSP/NEON - Adam optimizer with gradient clipping + cosine LR schedule - Checkpoint save/resume - test_dynamic_matmul.m — validates dynamic weight matmul vs cblas - test_weight_patch.m — tests weight update via IOSurface - dashboard.py — updated with --dynamic flag for v2 pipeline support, improved step regex parsing, --scratch/--lr/--accum CLI args Performance: 110ms/step steady-state (no recompile overhead) ane_fwd=21 ane_bwd=28 io_fwd=12 io_bwd=15 silu=10 cls=13 rms=5 ms --- training/dashboard.py | 29 +- training/test_dynamic_matmul.m | 333 +++++++++ training/test_weight_patch.m | 450 ++++++++++++ training/training_dynamic/Makefile | 9 + training/training_dynamic/config.h | 156 +++++ training/training_dynamic/cpu_ops.h | 164 +++++ training/training_dynamic/io.h | 147 ++++ training/training_dynamic/mil_dynamic.h | 590 ++++++++++++++++ training/training_dynamic/train.m | 876 ++++++++++++++++++++++++ 9 files changed, 2749 insertions(+), 5 deletions(-) create mode 100644 training/test_dynamic_matmul.m create mode 100644 training/test_weight_patch.m create mode 100644 training/training_dynamic/Makefile create mode 100644 training/training_dynamic/config.h create mode 100644 training/training_dynamic/cpu_ops.h create mode 100644 training/training_dynamic/io.h create mode 100644 training/training_dynamic/mil_dynamic.h create mode 100644 training/training_dynamic/train.m diff --git a/training/dashboard.py b/training/dashboard.py index a3a1503..06d46a2 100644 --- a/training/dashboard.py +++ b/training/dashboard.py @@ -279,7 +279,7 @@ RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+ RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)') RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing') RE_ACCUM = re.compile(r'Accum (\d+).*LR=([\d.e+-]+)') -RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)') +RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)(?:\s+lr=([\d.e+-]+))?(?:\s+([\d.]+)ms/step)?') RE_BATCH = re.compile(r'\[batch (\d+): compile=([\d.]+)ms train=([\d.]+)ms \(([\d.]+)ms/step\) compiles=(\d+)\]') RE_TIMING = re.compile(r'ane=([\d.]+) io=([\d.]+) cls=([\d.]+) elem=([\d.]+) rms=([\d.]+) cblas_wait=([\d.]+)') RE_RESTART = re.compile(r'\[exec\(\) restart step (\d+)') @@ -323,6 +323,10 @@ def parse_line(line): m = RE_STEP.search(line) if m: S.step, S.loss = int(m[1]), float(m[2]) + if m[3]: + S.training['lr'] = m[3] + if m[4]: + S.ms_per_step = float(m[4]) S.loss_history.append((S.step, S.loss)) S.best_loss = min(S.best_loss, S.loss) return @@ -659,10 +663,19 @@ def set_nonblock(fd): fl = fcntl.fcntl(fd, fcntl.F_GETFL) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) -def spawn_training(resume=False, steps=10000): - cmd = 'make train_large 2>&1 && ./train_large' +def spawn_training(resume=False, steps=10000, dynamic=False, scratch=False, lr=None, accum=None): + if dynamic: + cmd = 'cd training_dynamic && make 2>&1 && ./train' + else: + cmd = 'make train_large 2>&1 && ./train_large' if resume: cmd += ' --resume' + if scratch and dynamic: + cmd += ' --scratch' + if lr is not None: + cmd += f' --lr {lr}' + if accum is not None: + cmd += f' --accum {accum}' cmd += f' --steps {steps}' proc = subprocess.Popen( ['bash', '-c', cmd], @@ -684,6 +697,10 @@ def spawn_powermetrics(): def main(): parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)') parser.add_argument('--resume', action='store_true', help='Resume from checkpoint') + parser.add_argument('--dynamic', action='store_true', help='Use v2 dynamic weight pipeline (training_dynamic/)') + parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)') + parser.add_argument('--lr', type=float, default=None, help='Learning rate') + parser.add_argument('--accum', type=int, default=None, help='Gradient accumulation steps') parser.add_argument('--infinite', action='store_true', help='Train indefinitely') parser.add_argument('--no-powermetrics', action='store_true') parser.add_argument('--no-generate', action='store_true', help='Disable text generation') @@ -697,7 +714,8 @@ def main(): term = Terminal() procs = [] - train_proc = spawn_training(resume=args.resume, steps=args.steps) + train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic, + scratch=args.scratch, lr=args.lr, accum=args.accum) S.train_pid = train_proc.pid procs.append(train_proc) @@ -837,7 +855,8 @@ def main(): if train_proc: train_proc.terminate() train_proc.wait() - train_proc = spawn_training(resume=True, steps=args.steps) + train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic, + lr=args.lr, accum=args.accum) S.train_pid = train_proc.pid procs = [p for p in procs if p.poll() is None] procs.append(train_proc) diff --git a/training/test_dynamic_matmul.m b/training/test_dynamic_matmul.m new file mode 100644 index 0000000..72addbd --- /dev/null +++ b/training/test_dynamic_matmul.m @@ -0,0 +1,333 @@ +// test_dynamic_matmul.m — Benchmark dynamic matmul on ANE (no recompile) +// Layout: input [1, D, 1, S+D] — activations in sp[0:S], weight rows in sp[S:S+D] +// MIL: slice → reshape → matmul → reshape → output +#import +#import +#import +#import +#import +#import +#include +#include + +#include "stories_io.h" + +// Generate MIL for y = x @ W where both come from input IOSurface +// Input: [1, IC, 1, SEQ+OC] fp32 +// sp[0:SEQ] = activations x[IC, SEQ] +// sp[SEQ:SEQ+OC] = weight W[IC, OC] (each channel d holds W[d, :]) +// Output: [1, OC, 1, SEQ] fp32 +static NSString *gen_dynamic_matmul_mil(int ic, int oc, int seq) { + NSMutableString *m = [NSMutableString string]; + [m appendString:@"program(1.3)\n" + "[buildInfo = dict({{\"coremlc-component-MIL\", \"3510.2.1\"}, " + "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " + "{\"coremltools-version\", \"9.0\"}})]\n{\n"]; + int sp_total = seq + oc; + [m appendFormat:@" func main(tensor x) {\n", ic, sp_total]; + // Cast to fp16 + [m appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", ic, sp_total]; + // Slice activations [1, IC, 1, SEQ] + [m appendString:@" tensor ba = const()[name = string(\"ba\"), val = tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor sa = const()[name = string(\"sa\"), val = tensor([1,%d,1,%d])];\n", ic, seq]; + [m appendFormat:@" tensor act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", ic, seq]; + // Slice weight [1, IC, 1, OC] + [m appendFormat:@" tensor bw = const()[name = string(\"bw\"), val = tensor([0,0,0,%d])];\n", seq]; + [m appendFormat:@" tensor sw = const()[name = string(\"sw\"), val = tensor([1,%d,1,%d])];\n", ic, oc]; + [m appendFormat:@" tensor wt = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"wt\")];\n", ic, oc]; + // Reshape act: [1,IC,1,SEQ] → [1,1,IC,SEQ] → transpose → [1,1,SEQ,IC] + [m appendFormat:@" tensor ra = const()[name = string(\"ra\"), val = tensor([1,1,%d,%d])];\n", ic, seq]; + [m appendFormat:@" tensor a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", ic, seq]; + [m appendString:@" tensor pm = const()[name = string(\"pm\"), val = tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor a3 = transpose(perm=pm,x=a2)[name=string(\"a3\")];\n", seq, ic]; + // Reshape weight: [1,IC,1,OC] → [1,1,IC,OC] + [m appendFormat:@" tensor rw = const()[name = string(\"rw\"), val = tensor([1,1,%d,%d])];\n", ic, oc]; + [m appendFormat:@" tensor W = reshape(shape=rw,x=wt)[name=string(\"W\")];\n", ic, oc]; + // matmul: [1,1,SEQ,IC] @ [1,1,IC,OC] → [1,1,SEQ,OC] + [m appendString:@" bool bF = const()[name = string(\"bF\"), val = bool(false)];\n"]; + [m appendFormat:@" tensor yh = matmul(transpose_x=bF,transpose_y=bF,x=a3,y=W)[name=string(\"mm\")];\n", seq, oc]; + // Reshape+transpose back: [1,1,SEQ,OC] → transpose → [1,1,OC,SEQ] → reshape → [1,OC,1,SEQ] + [m appendFormat:@" tensor yt = transpose(perm=pm,x=yh)[name=string(\"yt\")];\n", oc, seq]; + [m appendFormat:@" tensor ro = const()[name = string(\"ro\"), val = tensor([1,%d,1,%d])];\n", oc, seq]; + [m appendFormat:@" tensor yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", oc, seq]; + // Cast back to fp32 + [m appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"]; + [m appendFormat:@" tensor y = cast(dtype = to32, x = yr)[name = string(\"cout\")];\n", oc, seq]; + [m appendString:@" } -> (y);\n}\n"]; + return m; +} + +// Tiled version: splits OC into tiles, each tile is a separate kernel +// For W[IC, OC], tile along OC: each tile handles W[:, t*T:(t+1)*T] +// Input per tile: [1, IC, 1, SEQ+T] +// Output per tile: [1, T, 1, SEQ] +typedef struct { + Kern **tiles; + int n_tiles, tile_oc, ic, oc, seq; +} TiledMatmul; + +static TiledMatmul *compile_tiled_matmul(int ic, int oc, int tile_oc, int seq) { + TiledMatmul *tm = (TiledMatmul*)calloc(1, sizeof(TiledMatmul)); + tm->ic = ic; tm->oc = oc; tm->seq = seq; tm->tile_oc = tile_oc; + tm->n_tiles = (oc + tile_oc - 1) / tile_oc; + tm->tiles = (Kern**)calloc(tm->n_tiles, sizeof(Kern*)); + for (int t = 0; t < tm->n_tiles; t++) { + int this_oc = (t == tm->n_tiles-1 && oc % tile_oc) ? (oc % tile_oc) : tile_oc; + NSString *mil = gen_dynamic_matmul_mil(ic, this_oc, seq); + int in_bytes = ic * (seq + this_oc) * 4; + int out_bytes = this_oc * seq * 4; + tm->tiles[t] = compile_kern_mil_w(mil, @{}, in_bytes, out_bytes); + if (!tm->tiles[t]) { printf("Tile %d compile FAIL\n", t); return NULL; } + } + return tm; +} + +// Write activations + weight tile into IOSurface +// act: [IC, SEQ] column-major (channel-first) +// W: [IC, OC] — full weight matrix, we extract the tile +static void write_tile_input(TiledMatmul *tm, int tile_idx, const float *act, const float *W) { + Kern *k = tm->tiles[tile_idx]; + int ic = tm->ic, seq = tm->seq, toc = tm->tile_oc; + int oc_off = tile_idx * toc; + int this_oc = (tile_idx == tm->n_tiles-1 && tm->oc % toc) ? (tm->oc % toc) : toc; + + IOSurfaceLock(k->ioIn, 0, NULL); + float *buf = (float*)IOSurfaceGetBaseAddress(k->ioIn); + // Activations: buf[d * (seq+this_oc) + t] = act[d * seq + t] + for (int d = 0; d < ic; d++) { + memcpy(buf + d*(seq+this_oc), act + d*seq, seq*sizeof(float)); + // Weight: buf[d * (seq+this_oc) + seq + c] = W[d * oc + oc_off + c] + for (int c = 0; c < this_oc; c++) + buf[d*(seq+this_oc) + seq + c] = W[d*tm->oc + oc_off + c]; + } + IOSurfaceUnlock(k->ioIn, 0, NULL); +} + +// Read tile output into full output buffer +static void read_tile_output(TiledMatmul *tm, int tile_idx, float *out) { + Kern *k = tm->tiles[tile_idx]; + int seq = tm->seq, toc = tm->tile_oc; + int oc_off = tile_idx * toc; + int this_oc = (tile_idx == tm->n_tiles-1 && tm->oc % toc) ? (tm->oc % toc) : toc; + + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + float *obuf = (float*)IOSurfaceGetBaseAddress(k->ioOut); + for (int c = 0; c < this_oc; c++) + memcpy(out + (oc_off+c)*seq, obuf + c*seq, seq*sizeof(float)); + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); +} + +int main(int argc, char **argv) { + @autoreleasepool { + mach_timebase_info(&g_tb); + ane_init(); + + // === Test 1: Single 64×64 dynamic matmul (correctness) === + printf("=== Test 1: 64×64 dynamic matmul correctness ===\n"); + { + int D = 64, S = 64; + NSString *mil = gen_dynamic_matmul_mil(D, D, S); + int in_b = D * (S+D) * 4, out_b = D * S * 4; + Kern *k = compile_kern_mil_w(mil, @{}, in_b, out_b); + if (!k) { printf("FAIL\n"); return 1; } + + // Identity test + IOSurfaceLock(k->ioIn, 0, NULL); + float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn); + memset(inp, 0, in_b); + for (int d = 0; d < D; d++) + for (int s = 0; s < S; s++) + inp[d*(S+D) + s] = (float)(d*S + s) * 0.001f; + for (int d = 0; d < D; d++) + for (int c = 0; c < D; c++) + inp[d*(S+D) + S + c] = (d == c) ? 1.0f : 0.0f; + IOSurfaceUnlock(k->ioIn, 0, NULL); + + ane_eval(k); + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + float *out = (float*)IOSurfaceGetBaseAddress(k->ioOut); + float me = 0; + for (int d = 0; d < D; d++) + for (int s = 0; s < S; s++) { + float e = fabsf(out[d*S+s] - inp[d*(S+D)+s]); + if (e > me) me = e; + } + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("Identity: max_err=%.6f %s\n", me, me < 0.01 ? "PASS" : "FAIL"); + + // 2× test + IOSurfaceLock(k->ioIn, 0, NULL); + for (int d = 0; d < D; d++) + for (int c = 0; c < D; c++) + inp[d*(S+D) + S + c] = (d == c) ? 2.0f : 0.0f; + IOSurfaceUnlock(k->ioIn, 0, NULL); + ane_eval(k); + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + float sr = 0; int cnt = 0; + for (int i = 0; i < D*S; i++) + if (fabsf(inp[i/(S)*((S)+D) + i%S]) > 0.001f) { sr += out[i]/inp[i/S*(S+D)+i%S]; cnt++; } + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("2× W: ratio=%.3f %s\n\n", cnt?sr/cnt:0, fabsf(sr/cnt-2.0f)<0.1?"PASS":"FAIL"); + free_kern(k); + } + + // === Test 2: 768×768 single kernel (if it compiles) === + printf("=== Test 2: 768×768 single dynamic matmul ===\n"); + { + int D = 768, S = 256; + int sp_total = S + D; // 256 + 768 = 1024 + int in_b = D * sp_total * 4; // 768 * 1024 * 4 = 3.1MB + int out_b = D * S * 4; // 768 * 256 * 4 = 786KB + printf("IOSurface: in=%.1fMB out=%.1fKB\n", in_b/1e6, out_b/1e3); + + NSString *mil = gen_dynamic_matmul_mil(D, D, S); + uint64_t t0 = mach_absolute_time(); + Kern *k = compile_kern_mil_w(mil, @{}, in_b, out_b); + double compile_ms = tb_ms(mach_absolute_time() - t0); + if (!k) { printf("768×768 compile FAIL\n"); } + else { + printf("Compile: %.1fms\n", compile_ms); + // Random weights + float *act = (float*)calloc(D*S, sizeof(float)); + float *W = (float*)calloc(D*D, sizeof(float)); + for (int i = 0; i < D*S; i++) act[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.1f; + for (int i = 0; i < D*D; i++) W[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.01f; + + // Write to IOSurface + IOSurfaceLock(k->ioIn, 0, NULL); + float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn); + for (int d = 0; d < D; d++) { + memcpy(inp + d*(S+D), act + d*S, S*4); + memcpy(inp + d*(S+D) + S, W + d*D, D*4); + } + IOSurfaceUnlock(k->ioIn, 0, NULL); + + // Warmup + for (int i = 0; i < 3; i++) ane_eval(k); + + // Benchmark + int iters = 50; + t0 = mach_absolute_time(); + for (int i = 0; i < iters; i++) ane_eval(k); + double total_ms = tb_ms(mach_absolute_time() - t0); + double per_eval = total_ms / iters; + double flops = 2.0 * D * D * S; // matmul FLOPs + double gflops = flops / (per_eval * 1e6); + printf("768×768×256 matmul: %.3fms/eval %.1f GFLOP/s\n", per_eval, gflops); + + // Benchmark with IO write (simulating weight update) + t0 = mach_absolute_time(); + for (int i = 0; i < iters; i++) { + IOSurfaceLock(k->ioIn, 0, NULL); + float *p = (float*)IOSurfaceGetBaseAddress(k->ioIn); + for (int d = 0; d < D; d++) + memcpy(p + d*(S+D) + S, W + d*D, D*4); + IOSurfaceUnlock(k->ioIn, 0, NULL); + ane_eval(k); + } + total_ms = tb_ms(mach_absolute_time() - t0); + per_eval = total_ms / iters; + gflops = flops / (per_eval * 1e6); + printf("With weight IO: %.3fms/eval %.1f GFLOP/s\n", per_eval, gflops); + + free(act); free(W); free_kern(k); + } + } + + // === Test 3: Tiled matmul benchmark === + int tile_sizes[] = {64, 128, 256, 384, 768}; + int n_tiles_test = sizeof(tile_sizes)/sizeof(tile_sizes[0]); + printf("\n=== Test 3: Tiled 768×768 matmul (varying tile_oc) ===\n"); + printf("%-10s %-8s %-10s %-12s %-10s\n", "tile_oc", "tiles", "compile", "eval/ms", "GFLOP/s"); + { + int D = 768, S = 256; + float *act = (float*)calloc(D*S, sizeof(float)); + float *W = (float*)calloc(D*D, sizeof(float)); + float *out_full = (float*)calloc(D*S, sizeof(float)); + for (int i = 0; i < D*S; i++) act[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.1f; + for (int i = 0; i < D*D; i++) W[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.01f; + + for (int ti = 0; ti < n_tiles_test; ti++) { + int T = tile_sizes[ti]; + if (T > D) continue; + uint64_t t0 = mach_absolute_time(); + TiledMatmul *tm = compile_tiled_matmul(D, D, T, S); + double compile_ms = tb_ms(mach_absolute_time() - t0); + if (!tm) { printf("%-10d FAIL\n", T); continue; } + + // Warmup + for (int w = 0; w < 2; w++) { + for (int t = 0; t < tm->n_tiles; t++) { + write_tile_input(tm, t, act, W); + ane_eval(tm->tiles[t]); + } + } + + // Benchmark (with IO) + int iters = 20; + t0 = mach_absolute_time(); + for (int i = 0; i < iters; i++) { + for (int t = 0; t < tm->n_tiles; t++) { + write_tile_input(tm, t, act, W); + ane_eval(tm->tiles[t]); + read_tile_output(tm, t, out_full); + } + } + double total_ms = tb_ms(mach_absolute_time() - t0); + double per_matmul = total_ms / iters; + double flops = 2.0 * D * D * S; + double gflops = flops / (per_matmul * 1e6); + printf("%-10d %-8d %-10.0fms %-12.3fms %-10.1f\n", + T, tm->n_tiles, compile_ms, per_matmul, gflops); + + for (int t = 0; t < tm->n_tiles; t++) free_kern(tm->tiles[t]); + free(tm->tiles); free(tm); + } + + // === Correctness check: compare with cblas === + printf("\n=== Correctness: dynamic matmul vs cblas_sgemm ===\n"); + { + int T = 768; // full, no tiling + TiledMatmul *tm = compile_tiled_matmul(D, D, T, S); + if (tm) { + write_tile_input(tm, 0, act, W); + ane_eval(tm->tiles[0]); + read_tile_output(tm, 0, out_full); + + // Reference: cblas y = act^T @ W → y[s,oc] = sum_d act[d,s]*W[d,oc] + // act is [D,S] col-major, W is [D,D] row-major + // We want out[oc,s] = sum_d act[d,s] * W[d,oc] + // = W^T @ act where W^T is [D,D] and act is [D,S] → out is [D,S] + float *ref = (float*)calloc(D*S, sizeof(float)); + // out[oc*S+s] = sum_d W[d*D+oc] * act[d*S+s] + // This is: (W^T) @ act in column-major: M=D,N=S,K=D + // cblas: C = alpha*A*B + beta*C + // A=W^T [D×D], B=act [D×S], C=ref [D×S] + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + D, S, D, 1.0f, W, D, act, D, 0.0f, ref, D); + float me = 0; + for (int i = 0; i < D*S; i++) { + float e = fabsf(out_full[i] - ref[i]); + if (e > me) me = e; + } + printf("vs cblas: max_err=%.6f %s\n", me, me < 1.0 ? "PASS" : "FAIL"); + free(ref); + for (int t = 0; t < tm->n_tiles; t++) free_kern(tm->tiles[t]); + free(tm->tiles); free(tm); + } + } + + free(act); free(W); free(out_full); + } + + // === Summary for training === + printf("\n=== Summary ===\n"); + printf("Stories110M: 12 layers × 10 matmuls/layer = 120 matmuls/step\n"); + printf("Sizes: Wq/Wk/Wv/Wo [768,768], W1/W3 [2048,768], W2 [768,2048]\n"); + printf("With dynamic weights: compile once, update IOSurface every step\n"); + + printf("\nDone.\n"); + } + return 0; +} diff --git a/training/test_weight_patch.m b/training/test_weight_patch.m new file mode 100644 index 0000000..13473b7 --- /dev/null +++ b/training/test_weight_patch.m @@ -0,0 +1,450 @@ +// test_weight_patch.m — Test whether ANE weights can be patched after compile +#import +#import +#import +#import +#import +#import +#import +#import +#include +#include + +#include "stories_io.h" + +// MIL: fp32 in → cast fp16 → conv → cast fp32 out (matches inmem_peak.m pattern) +static NSString *gen_conv_mil(int ic, int oc, int sp) { + NSMutableString *m = [NSMutableString string]; + [m appendString:@"program(1.3)\n" + "[buildInfo = dict({{\"coremlc-component-MIL\", \"3510.2.1\"}, " + "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " + "{\"coremltools-version\", \"9.0\"}})]\n{\n"]; + [m appendFormat:@" func main(tensor x) {\n", ic, sp]; + [m appendString: + @" string pt = const()[name = string(\"pt\"), val = string(\"valid\")];\n" + " tensor st = const()[name = string(\"st\"), val = tensor([1, 1])];\n" + " tensor pd = const()[name = string(\"pd\"), val = tensor([0, 0, 0, 0])];\n" + " tensor dl = const()[name = string(\"dl\"), val = tensor([1, 1])];\n" + " int32 gr = const()[name = string(\"gr\"), val = int32(1)];\n" + " string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype = to16, x = x)[name = string(\"cast_in\")];\n", ic, sp]; + [m appendFormat:@" tensor W = const()[name = string(\"W\"), " + "val = tensor(BLOBFILE(path = string(\"@model_path/weights/w.bin\"), offset = uint64(64)))];\n", + oc, ic, oc, ic]; + [m appendFormat:@" tensor yh = conv(dilations = dl, groups = gr, pad = pd, pad_type = pt, strides = st, weight = W, x = xh)" + "[name = string(\"conv\")];\n", oc, sp]; + [m appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"]; + [m appendFormat:@" tensor y = cast(dtype = to32, x = yh)[name = string(\"cast_out\")];\n", oc, sp]; + [m appendString:@" } -> (y);\n}\n"]; + return m; +} + +int main(int argc, char **argv) { + @autoreleasepool { + mach_timebase_info(&g_tb); + ane_init(); + + int IC = 256, OC = 256, SP = 64; + int io_bytes = IC * SP * 4; // fp32 + + // Identity weight + float *W_id = (float*)calloc(OC*IC, sizeof(float)); + for (int i = 0; i < IC; i++) W_id[i*IC+i] = 1.0f; + + NSString *mil = gen_conv_mil(IC, OC, SP); + NSDictionary *wd = @{@"@model_path/weights/w.bin": @{@"offset":@0, @"data":build_blob(W_id, OC, IC)}}; + + printf("=== Compiling conv %dx%d sp=%d ===\n", OC, IC, SP); + Kern *k = compile_kern_mil_w(mil, wd, io_bytes, io_bytes); + if (!k) { printf("COMPILE FAILED\n"); free(W_id); return 1; } + printf("Compile OK!\n"); + + // Write fp32 input + IOSurfaceLock(k->ioIn, 0, NULL); + float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn); + for (int i = 0; i < IC*SP; i++) inp[i] = (i % 100) * 0.01f; + IOSurfaceUnlock(k->ioIn, 0, NULL); + + // Eval with identity + ane_eval(k); + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + float *out = (float*)IOSurfaceGetBaseAddress(k->ioOut); + printf("In: [%.3f, %.3f, %.3f, %.3f]\n", inp[0], inp[1], inp[2], inp[3]); + printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]); + float max_err = 0; + for (int i = 0; i < OC*SP; i++) { + float err = fabsf(out[i] - inp[i]); + if (err > max_err) max_err = err; + } + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("Identity max_err=%.6f %s\n\n", max_err, max_err < 0.1 ? "PASS" : "FAIL"); + + // === Approach 1: Patch weight on disk, unload+reload === + printf("=== Approach 1: Disk patch + unload/reload ===\n"); + float *W_2x = (float*)calloc(OC*IC, sizeof(float)); + for (int i = 0; i < IC; i++) W_2x[i*IC+i] = 2.0f; + [build_blob(W_2x, OC, IC) writeToFile: + [(__bridge NSString*)k->tmpDir stringByAppendingPathComponent:@"weights/w.bin"] atomically:YES]; + + id mdl = (__bridge id)k->model; + NSError *e = nil; + ((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e); + e = nil; + BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e); + printf("Reload: %s\n", ok?"OK":"FAIL"); + if (ok) { + // Re-create request after reload + id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn); + id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut); + CFRelease(k->request); + k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR, + @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), + @[wI], @[@0], @[wO], @[@0], nil, nil, @0)); + ane_eval(k); + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]); + float sr = 0; int cnt = 0; + for (int i = 0; i < OC*SP; i++) + if (fabsf(inp[i]) > 0.01f) { sr += out[i]/inp[i]; cnt++; } + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("Ratio: %.3f (2.0=patched, 1.0=cached)\n\n", cnt>0?sr/cnt:0); + } + + // === Approach 2: Memory scan === + printf("=== Approach 2: Memory scan ===\n"); + uint16_t pat1[8] = {0x3C00, 0, 0, 0, 0, 0, 0, 0}; + uint16_t pat2[8] = {0x4000, 0, 0, 0, 0, 0, 0, 0}; + mach_port_t task = mach_task_self(); + vm_address_t addr = 0; vm_size_t sz; natural_t depth = 1; + int f1 = 0, f2 = 0; + while (1) { + struct vm_region_submap_info_64 info; + mach_msg_type_number_t count = VM_REGION_SUBMAP_INFO_COUNT_64; + if (vm_region_recurse_64(task, &addr, &sz, &depth, (vm_region_recurse_info_t)&info, &count) != KERN_SUCCESS) break; + if (info.is_submap) { depth++; continue; } + if (!(info.protection & VM_PROT_READ) || sz < (size_t)(OC*IC*2)) { addr += sz; continue; } + uint8_t *base = (uint8_t*)addr; + for (size_t off = 0; off + OC*IC*2 <= sz; off += 2) { + int w = 0; + if (memcmp(base+off, pat1, 16) == 0) w = 1; + else if (memcmp(base+off, pat2, 16) == 0) w = 2; + if (!w) continue; + uint16_t *p = (uint16_t*)(base+off), diag = (w==1)?0x3C00:0x4000; + int ok2 = 1; + for (int r = 0; r < OC && ok2; r++) + for (int c = 0; c < IC && ok2; c++) + if (p[r*IC+c] != ((r==c)?diag:0)) ok2 = 0; + if (!ok2) continue; + if (w==1) f1++; else f2++; + printf(" FOUND %dx @%p prot=%d/%d %s\n", w, (void*)(addr+off), + info.protection, info.max_protection, (info.protection&VM_PROT_WRITE)?"WR":"RO"); + } + addr += sz; + } + printf("Found: 1x=%d 2x=%d\n", f1, f2); + + // Now patch ALL found weight patterns to 3× and re-eval + if (f1 > 0 || f2 > 0) { + printf("Patching all found patterns to 3x identity...\n"); + addr = 0; depth = 1; + while (1) { + struct vm_region_submap_info_64 info2; + mach_msg_type_number_t count2 = VM_REGION_SUBMAP_INFO_COUNT_64; + if (vm_region_recurse_64(task, &addr, &sz, &depth, (vm_region_recurse_info_t)&info2, &count2) != KERN_SUCCESS) break; + if (info2.is_submap) { depth++; continue; } + if (!(info2.protection & VM_PROT_READ) || sz < (size_t)(OC*IC*2)) { addr += sz; continue; } + uint8_t *base2 = (uint8_t*)addr; + for (size_t off = 0; off + OC*IC*2 <= sz; off += 2) { + int w2 = 0; + if (memcmp(base2+off, pat1, 16) == 0) w2 = 1; + else if (memcmp(base2+off, pat2, 16) == 0) w2 = 2; + if (!w2) continue; + uint16_t *p2 = (uint16_t*)(base2+off), diag2 = (w2==1)?0x3C00:0x4000; + int ok3 = 1; + for (int r = 0; r < OC && ok3; r++) + for (int c = 0; c < IC && ok3; c++) + if (p2[r*IC+c] != ((r==c)?diag2:0)) ok3 = 0; + if (!ok3) continue; + if (info2.protection & VM_PROT_WRITE) { + printf(" Patching %dx @%p to 3x\n", w2, (void*)(addr+off)); + for (int r = 0; r < OC; r++) + for (int c = 0; c < IC; c++) + p2[r*IC+c] = (r==c) ? 0x4200 : 0; // fp16(3.0) + } + } + addr += sz; + } + + printf("\n=== Eval after memory patch (expect 3x) ===\n"); + ane_eval(k); + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]); + float sr2 = 0; int cnt2 = 0; + for (int i = 0; i < OC*SP; i++) + if (fabsf(inp[i]) > 0.01f) { sr2 += out[i]/inp[i]; cnt2++; } + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("Ratio: %.3f (3.0=mem patch works!, 1.0=ANE uses SRAM copy)\n", cnt2>0?sr2/cnt2:0); + } + printf("\n"); + + // === Approach 3: Explore classes === + printf("=== ANE classes ===\n"); + const char *cn[] = {"_ANEWeight", "_ANEProgramForEvaluation", "_ANEChainingRequest", NULL}; + for (int i = 0; cn[i]; i++) { + Class cls = NSClassFromString([NSString stringWithUTF8String:cn[i]]); + if (!cls) { printf("%s: NOT FOUND\n", cn[i]); continue; } + printf("%s:\n", cn[i]); + unsigned int mc = 0; Method *ms = class_copyMethodList(cls, &mc); + for (unsigned j = 0; j < mc; j++) printf(" - %s\n", sel_getName(method_getName(ms[j]))); + free(ms); + mc = 0; ms = class_copyMethodList(object_getClass(cls), &mc); + for (unsigned j = 0; j < mc; j++) printf(" + %s\n", sel_getName(method_getName(ms[j]))); + free(ms); printf("\n"); + } + @try { printf("programHandle: %s\n", [[[mdl valueForKey:@"programHandle"] description] UTF8String]); } @catch(id x) {} + @try { printf("intermediateBufferHandle: %s\n", [[[mdl valueForKey:@"intermediateBufferHandle"] description] UTF8String]); } @catch(id x) {} + + // === Approach 4: _ANEWeight + updateWeightURL === + printf("\n=== Approach 4: _ANEWeight API ===\n"); + Class AW = NSClassFromString(@"_ANEWeight"); + if (AW) { + // Write 5× identity weights to a new file + float *W_5x = (float*)calloc(OC*IC, sizeof(float)); + for (int i = 0; i < IC; i++) W_5x[i*IC+i] = 5.0f; + NSString *wpath = [NSTemporaryDirectory() stringByAppendingPathComponent:@"patched_w.bin"]; + [build_blob(W_5x, OC, IC) writeToFile:wpath atomically:YES]; + free(W_5x); + + NSURL *wurl = [NSURL fileURLWithPath:wpath]; + id wobj = ((id(*)(Class,SEL,id,id))objc_msgSend)(AW, + @selector(weightWithSymbolAndURL:weightURL:), @"W", wurl); + printf(" _ANEWeight: %s\n", wobj ? [[wobj description] UTF8String] : "nil"); + if (wobj) { + printf(" weightSymbol: %s\n", [((id(*)(id,SEL))objc_msgSend)(wobj, @selector(weightSymbol)) UTF8String]); + printf(" weightURL: %s\n", [[((id(*)(id,SEL))objc_msgSend)(wobj, @selector(weightURL)) description] UTF8String]); + } + + // Try to pass as weightsBuffer in request + printf("\n Trying weightsBuffer in request...\n"); + id wI2 = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn); + id wO2 = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut); + + // Try passing weight array as weightsBuffer + if (wobj) { + CFRelease(k->request); + k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR, + @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), + @[wI2], @[@0], @[wO2], @[@0], @[wobj], nil, @0)); + printf(" Request with weightsBuffer created\n"); + @try { + ane_eval(k); + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf(" Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]); + float sr3 = 0; int cnt3 = 0; + for (int i2 = 0; i2 < OC*SP; i2++) + if (fabsf(inp[i2]) > 0.01f) { sr3 += out[i2]/inp[i2]; cnt3++; } + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf(" Ratio: %.3f (5.0=weightsBuffer works!)\n", cnt3>0?sr3/cnt3:0); + } @catch(NSException *ex) { + printf(" Eval exception: %s\n", [[ex description] UTF8String]); + } + } + + // Also try IOSurface as weightsBuffer + printf("\n Trying IOSurface as weightsBuffer...\n"); + IOSurfaceRef wSurf = make_surface(OC*IC*2); // fp16 weights + IOSurfaceLock(wSurf, 0, NULL); + _Float16 *wfp16 = (_Float16*)IOSurfaceGetBaseAddress(wSurf); + for (int r = 0; r < OC; r++) + for (int c2 = 0; c2 < IC; c2++) + wfp16[r*IC+c2] = (r==c2) ? (_Float16)7.0f : (_Float16)0.0f; // 7× identity + IOSurfaceUnlock(wSurf, 0, NULL); + id wSurfObj = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), wSurf); + CFRelease(k->request); + k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR, + @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), + @[wI2], @[@0], @[wO2], @[@0], wSurfObj, nil, @0)); + @try { + ane_eval(k); + IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf(" Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]); + float sr4 = 0; int cnt4 = 0; + for (int i3 = 0; i3 < OC*SP; i3++) + if (fabsf(inp[i3]) > 0.01f) { sr4 += out[i3]/inp[i3]; cnt4++; } + IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL); + printf(" Ratio: %.3f (7.0=IOSurface weights work!)\n", cnt4>0?sr4/cnt4:0); + } @catch(NSException *ex) { + printf(" Eval exception: %s\n", [[ex description] UTF8String]); + } + CFRelease(wSurf); + } + + // === Approach 5: Weights packed into input IOSurface (fp16 with cast) === + printf("\n=== Approach 5: Dynamic weights via input IOSurface ===\n"); + // Element-wise mul: x * w where both come from input + // Input [1, IC*2, 1, SP] fp32 → cast fp16 → slice → mul → cast fp32 + { + int C5 = IC; + NSMutableString *m5 = [NSMutableString string]; + [m5 appendString:@"program(1.3)\n" + "[buildInfo = dict({{\"coremlc-component-MIL\", \"3510.2.1\"}, " + "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " + "{\"coremltools-version\", \"9.0\"}})]\n{\n"]; + [m5 appendFormat:@" func main(tensor x) {\n", C5*2, SP]; + [m5 appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"]; + [m5 appendFormat:@" tensor xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", C5*2, SP]; + [m5 appendFormat:@" tensor b0 = const()[name = string(\"b0\"), val = tensor([0,0,0,0])];\n"]; + [m5 appendFormat:@" tensor s0 = const()[name = string(\"s0\"), val = tensor([1,%d,1,%d])];\n", C5, SP]; + [m5 appendFormat:@" tensor data = slice_by_size(x=xh,begin=b0,size=s0)[name=string(\"data\")];\n", C5, SP]; + [m5 appendFormat:@" tensor b1 = const()[name = string(\"b1\"), val = tensor([0,%d,0,0])];\n", C5]; + [m5 appendFormat:@" tensor wt = slice_by_size(x=xh,begin=b1,size=s0)[name=string(\"wt\")];\n", C5, SP]; + [m5 appendFormat:@" tensor yh = mul(x=data,y=wt)[name=string(\"mul\")];\n", C5, SP]; + [m5 appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"]; + [m5 appendFormat:@" tensor y = cast(dtype = to32, x = yh)[name = string(\"cout\")];\n", C5, SP]; + [m5 appendString:@" } -> (y);\n}\n"]; + + int io5_in = C5*2*SP*4; + int io5_out = C5*SP*4; + Kern *k5 = compile_kern_mil_w(m5, @{}, io5_in, io5_out); + if (k5) { + printf("Compile OK!\n"); + IOSurfaceLock(k5->ioIn, 0, NULL); + float *in5 = (float*)IOSurfaceGetBaseAddress(k5->ioIn); + for (int i = 0; i < C5*SP; i++) in5[i] = (i%100)*0.01f; + for (int i = 0; i < C5*SP; i++) in5[C5*SP+i] = 2.0f; + IOSurfaceUnlock(k5->ioIn, 0, NULL); + ane_eval(k5); + IOSurfaceLock(k5->ioOut, kIOSurfaceLockReadOnly, NULL); + float *out5 = (float*)IOSurfaceGetBaseAddress(k5->ioOut); + printf("data=[%.3f,%.3f,%.3f], w=2.0 → out=[%.3f,%.3f,%.3f]\n", + in5[0],in5[1],in5[2], out5[0],out5[1],out5[2]); + IOSurfaceUnlock(k5->ioOut, kIOSurfaceLockReadOnly, NULL); + + // Change weight dynamically — NO recompile! + IOSurfaceLock(k5->ioIn, 0, NULL); + for (int i = 0; i < C5*SP; i++) in5[C5*SP+i] = 5.0f; + IOSurfaceUnlock(k5->ioIn, 0, NULL); + ane_eval(k5); + IOSurfaceLock(k5->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("w=5.0 → out=[%.3f,%.3f,%.3f] (expect 5×)\n", out5[0],out5[1],out5[2]); + IOSurfaceUnlock(k5->ioOut, kIOSurfaceLockReadOnly, NULL); + free_kern(k5); + } else printf("Compile FAILED\n"); + } + + // === Approach 6: matmul with dynamic weights from input === + printf("\n=== Approach 6: matmul with dynamic W from input ===\n"); + // Pack x[1,D,S,1] and W[1,D,1,D] into input, then reshape+matmul + // Input shape: [1, D+D*D, 1, S] — first D channels=activations, rest=weight matrix flattened + // Actually, matmul needs [1,H,S,D] shapes. Let's try: + // Input: [1, D*(S+D), 1, 1] reshaped as needed + // Simpler: just test matmul with two sliced inputs + { + int D6 = 64, S6 = 64; // small for test + // Input: [1, D6+D6, S6, D6] — but that's 4D... + // Actually ANE matmul works on [1,H,M,K] @ [1,H,K,N] → [1,H,M,N] + // Let's pack x[1,1,S6,D6] and W[1,1,D6,D6] into [1,2,S6,D6] + // Then slice → matmul + NSMutableString *m6 = [NSMutableString string]; + [m6 appendString:@"program(1.3)\n" + "[buildInfo = dict({{\"coremlc-component-MIL\", \"3510.2.1\"}, " + "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " + "{\"coremltools-version\", \"9.0\"}})]\n{\n"]; + // Input: [1, D6+D6, 1, S6*D6] — flatten everything, then reshape + // Actually simplest: two separate regions in channel dim + // x_data: [1, D6, 1, S6] and W: [1, D6*D6, 1, 1] + // Total input channels: D6 + D6*D6 + int total_ch = D6 + D6*D6; + [m6 appendFormat:@" func main(tensor x) {\n", total_ch, S6]; + [m6 appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"]; + [m6 appendFormat:@" tensor xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", total_ch, S6]; + // Slice activations: [1, D6, 1, S6] + [m6 appendFormat:@" tensor b0 = const()[name = string(\"b0\"), val = tensor([0,0,0,0])];\n"]; + [m6 appendFormat:@" tensor sa = const()[name = string(\"sa\"), val = tensor([1,%d,1,%d])];\n", D6, S6]; + [m6 appendFormat:@" tensor act = slice_by_size(x=xh,begin=b0,size=sa)[name=string(\"act\")];\n", D6, S6]; + // Slice weight: [1, D6*D6, 1, S6] but we only need [D6, D6] → reshape + [m6 appendFormat:@" tensor bw = const()[name = string(\"bw\"), val = tensor([0,%d,0,0])];\n", D6]; + [m6 appendFormat:@" tensor sw = const()[name = string(\"sw\"), val = tensor([1,%d,1,%d])];\n", D6*D6, S6]; + [m6 appendFormat:@" tensor wf = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"wf\")];\n", D6*D6, S6]; + // Reshape weight to [1, D6, D6, S6] for matmul-like operation + // Actually for conv: weight needs to be [OC, IC, 1, 1] const. Can't use dynamic weight with conv. + // For matmul: need [1, 1, D6, D6] or similar + // Let's try: reshape wf to [1, D6, D6, S6], take first slice [:,:,:,0] → no, that's hard + // Simpler: reshape to [D6, D6] and use matmul + // But matmul expects specific ranks... let me try: + [m6 appendFormat:@" tensor ws = const()[name = string(\"ws\"), val = tensor([1, 1, %d, %d])];\n", D6, D6]; + // Only take first column of wf to get [1, D6*D6, 1, 1] + [m6 appendFormat:@" tensor sw1 = const()[name = string(\"sw1\"), val = tensor([1,%d,1,1])];\n", D6*D6]; + [m6 appendFormat:@" tensor wf1 = slice_by_size(x=wf,begin=b0,size=sw1)[name=string(\"wf1\")];\n", D6*D6]; + [m6 appendFormat:@" tensor W = reshape(shape=ws,x=wf1)[name=string(\"W\")];\n", D6, D6]; + // Reshape act to [1, 1, S6, D6] for matmul + [m6 appendFormat:@" tensor as2 = const()[name = string(\"as2\"), val = tensor([1, 1, %d, %d])];\n", D6, S6]; + [m6 appendFormat:@" tensor pm = const()[name = string(\"pm\"), val = tensor([0, 1, 3, 2])];\n"]; + [m6 appendFormat:@" tensor a2 = reshape(shape=as2,x=act)[name=string(\"a2\")];\n", D6, S6]; + [m6 appendFormat:@" tensor a3 = transpose(perm=pm,x=a2)[name=string(\"a3\")];\n", S6, D6]; + // matmul: [1,1,S6,D6] @ [1,1,D6,D6] → [1,1,S6,D6] + [m6 appendString:@" bool bF = const()[name = string(\"bF\"), val = bool(false)];\n"]; + [m6 appendFormat:@" tensor yh = matmul(transpose_x = bF, transpose_y = bF, x = a3, y = W)[name = string(\"mm\")];\n", S6, D6]; + // Reshape back to [1, D6, 1, S6] + [m6 appendFormat:@" tensor yt = transpose(perm=pm,x=yh)[name=string(\"yt\")];\n", D6, S6]; + [m6 appendFormat:@" tensor os = const()[name = string(\"os\"), val = tensor([1,%d,1,%d])];\n", D6, S6]; + [m6 appendFormat:@" tensor yr = reshape(shape=os,x=yt)[name=string(\"yr\")];\n", D6, S6]; + [m6 appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"]; + [m6 appendFormat:@" tensor y = cast(dtype = to32, x = yr)[name = string(\"cout\")];\n", D6, S6]; + [m6 appendString:@" } -> (y);\n}\n"]; + + int io6_in = total_ch * S6 * 4; + int io6_out = D6 * S6 * 4; + Kern *k6 = compile_kern_mil_w(m6, @{}, io6_in, io6_out); + if (k6) { + printf("Dynamic matmul compile OK!\n"); + // Set up: identity W, ramp input + IOSurfaceLock(k6->ioIn, 0, NULL); + float *in6 = (float*)IOSurfaceGetBaseAddress(k6->ioIn); + memset(in6, 0, io6_in); + // Activations: [D6, S6] in channel-first layout + for (int d = 0; d < D6; d++) + for (int s = 0; s < S6; s++) + in6[d*S6+s] = (d*S6+s) * 0.001f; + // Weight: identity matrix [D6, D6] packed in channels D6..D6+D6*D6, only col 0 + float *wbase = in6 + D6*S6; + for (int r = 0; r < D6; r++) + for (int c = 0; c < D6; c++) + wbase[(r*D6+c)*S6] = (r==c) ? 1.0f : 0.0f; // only sp=0 matters + IOSurfaceUnlock(k6->ioIn, 0, NULL); + + ane_eval(k6); + IOSurfaceLock(k6->ioOut, kIOSurfaceLockReadOnly, NULL); + float *out6 = (float*)IOSurfaceGetBaseAddress(k6->ioOut); + printf("Identity W: in=[%.4f,%.4f,%.4f] out=[%.4f,%.4f,%.4f]\n", + in6[0],in6[1],in6[2], out6[0],out6[1],out6[2]); + + // Check + float me6 = 0; + for (int i = 0; i < D6*S6; i++) { + float e6 = fabsf(out6[i] - in6[i]); + if (e6 > me6) me6 = e6; + } + IOSurfaceUnlock(k6->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("max_err=%.6f %s\n", me6, me6 < 0.1 ? "PASS" : "FAIL"); + + // Now: 2× identity — just change the IOSurface weight, no recompile! + IOSurfaceLock(k6->ioIn, 0, NULL); + for (int r = 0; r < D6; r++) + for (int c = 0; c < D6; c++) + wbase[(r*D6+c)*S6] = (r==c) ? 2.0f : 0.0f; + IOSurfaceUnlock(k6->ioIn, 0, NULL); + ane_eval(k6); + IOSurfaceLock(k6->ioOut, kIOSurfaceLockReadOnly, NULL); + printf("2× W: in=[%.4f,%.4f] out=[%.4f,%.4f] (expect 2×)\n", + in6[0],in6[1], out6[0],out6[1]); + IOSurfaceUnlock(k6->ioOut, kIOSurfaceLockReadOnly, NULL); + free_kern(k6); + } else printf("Dynamic matmul compile FAILED\n"); + } + + free_kern(k); free(W_id); free(W_2x); + printf("\nDone.\n"); + } + return 0; +} diff --git a/training/training_dynamic/Makefile b/training/training_dynamic/Makefile new file mode 100644 index 0000000..8c02c11 --- /dev/null +++ b/training/training_dynamic/Makefile @@ -0,0 +1,9 @@ +CC = xcrun clang +CFLAGS = -O2 -framework Foundation -framework IOSurface -framework Accelerate \ + -isysroot $(shell xcrun --show-sdk-path) -fobjc-arc + +train: train.m config.h io.h cpu_ops.h mil_dynamic.h + $(CC) $(CFLAGS) -o train train.m + +clean: + rm -f train diff --git a/training/training_dynamic/config.h b/training/training_dynamic/config.h new file mode 100644 index 0000000..d66d045 --- /dev/null +++ b/training/training_dynamic/config.h @@ -0,0 +1,156 @@ +// config.h — Stories110M model config, structs, ANE init +#pragma once +#import +#import +#import +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include +#include + +// Stories110M config +#define DIM 768 +#define HIDDEN 2048 +#define HEADS 12 +#define HD (DIM/HEADS) +#define SEQ 256 +#define NLAYERS 12 +#define VOCAB 32000 + +// Weight sizes per layer +#define WQ_SZ (DIM*DIM) +#define WO_SZ (DIM*DIM) +#define W1_SZ (HIDDEN*DIM) +#define W2_SZ (DIM*HIDDEN) +#define W3_SZ (HIDDEN*DIM) +#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM) + +// Attention score channels for SDPA backward +#define SCORE_CH (HEADS*SEQ) + +// Per-layer weights +typedef struct { + float *Wq, *Wk, *Wv, *Wo; + float *W1, *W2, *W3; + float *rms_att, *rms_ffn; +} LayerWeights; + +// Adam optimizer state +typedef struct { float *m, *v; size_t n; } AdamState; +typedef struct { + AdamState Wq, Wk, Wv, Wo, W1, W2, W3, rms_att, rms_ffn; +} LayerAdam; + +// Per-layer activations (saved for backward) +typedef struct { + float *layer_in, *xnorm, *Q, *K, *V, *attn_out, *o_out; + float *x2, *x2norm, *h1, *h3, *silu_out, *ffn_out; +} LayerActs; + +// Per-layer gradients +typedef struct { + float *Wq, *Wk, *Wv, *Wo, *W1, *W2, *W3, *rms_att, *rms_ffn; +} LayerGrads; + +// ANE kernel handle +typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern; + +// Checkpoint header +typedef struct { + int magic, version, step, total_steps; + int n_layers, vocab_size, dim, hidden_dim, n_heads, seq_len; + float lr, loss; + double cum_compile, cum_train, cum_wall; + int cum_steps, cum_batches, adam_t; + int pad[3]; +} CkptHdr; + +// llama2.c model file header +typedef struct { + int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len; +} Llama2Config; + +// Globals +static Class g_D, g_I, g_AR, g_AIO; +static mach_timebase_info_data_t g_tb; +static int g_compile_count = 0; + +static void ane_init(void) { + dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW); + g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor"); + g_I = NSClassFromString(@"_ANEInMemoryModel"); + g_AR = NSClassFromString(@"_ANERequest"); + g_AIO= NSClassFromString(@"_ANEIOSurfaceObject"); +} +static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; } + +// Alloc helpers +static AdamState adam_alloc(size_t n) { AdamState s; s.m=(float*)calloc(n,4); s.v=(float*)calloc(n,4); s.n=n; return s; } +static void adam_free(AdamState *s) { free(s->m); free(s->v); } + +static LayerWeights layer_weights_alloc(void) { + LayerWeights w; + w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4); + w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4); + w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4); + w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4); + return w; +} +static void layer_weights_free(LayerWeights *w) { + free(w->Wq);free(w->Wk);free(w->Wv);free(w->Wo); + free(w->W1);free(w->W2);free(w->W3);free(w->rms_att);free(w->rms_ffn); +} +static LayerAdam layer_adam_alloc(void) { + LayerAdam a; + a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ); + a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ); + a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM); + return a; +} +static void layer_adam_free(LayerAdam *a) { + adam_free(&a->Wq);adam_free(&a->Wk);adam_free(&a->Wv);adam_free(&a->Wo); + adam_free(&a->W1);adam_free(&a->W2);adam_free(&a->W3); + adam_free(&a->rms_att);adam_free(&a->rms_ffn); +} +static LayerActs layer_acts_alloc(void) { + LayerActs a; + a.layer_in=(float*)malloc(SEQ*DIM*4); + a.xnorm=(float*)malloc(SEQ*DIM*4); + a.Q=(float*)malloc(SEQ*DIM*4); a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4); + a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4); + a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4); + a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4); + a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4); + return a; +} +static void layer_acts_free(LayerActs *a) { + free(a->layer_in);free(a->xnorm); + free(a->Q);free(a->K);free(a->V); + free(a->attn_out);free(a->o_out);free(a->x2);free(a->x2norm); + free(a->h1);free(a->h3);free(a->silu_out);free(a->ffn_out); +} +static LayerGrads layer_grads_alloc(void) { + LayerGrads g; + g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4); + g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4); + g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4); + g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4); + return g; +} +static void layer_grads_zero(LayerGrads *g) { + memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4); + memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4); + memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4); + memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4); +} +static void layer_grads_free(LayerGrads *g) { + free(g->Wq);free(g->Wk);free(g->Wv);free(g->Wo); + free(g->W1);free(g->W2);free(g->W3);free(g->rms_att);free(g->rms_ffn); +} diff --git a/training/training_dynamic/cpu_ops.h b/training/training_dynamic/cpu_ops.h new file mode 100644 index 0000000..aed7e6f --- /dev/null +++ b/training/training_dynamic/cpu_ops.h @@ -0,0 +1,164 @@ +// cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, embedding +#pragma once +#include "config.h" + +static float *g_rms_tmp = NULL; + +static void rmsnorm(float *out, const float *x, const float *w, int d, int S) { + if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4); + float *ss = (float*)calloc(S, sizeof(float)); + for (int i=0; in; i++) { + s->m[i] = b1*s->m[i] + (1-b1)*g[i]; + s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i]; + float mh = s->m[i]/bc1, vh = s->v[i]/bc2; + w[i] -= lr * mh / (sqrtf(vh) + eps); + } +} + +// Cross-entropy loss: operates on logits[V, S] column-major (each column = one token) +// Avoids transposing by using a per-token temp buffer +static float cross_entropy_loss(float *dlogits, const float *logits, const uint16_t *targets, int V, int S) { + float *col = (float*)malloc(V * 4); // single column buffer + float total_loss = 0; + float invS = 1.0f / S; + for (int t = 0; t < S; t++) { + // Gather column t: logits[v, t] = logits[v*S + t], stride=S + cblas_scopy(V, logits + t, S, col, 1); + // Softmax + float maxv; vDSP_maxv(col, 1, &maxv, (vDSP_Length)V); + float neg_max = -maxv; + vDSP_vsadd(col, 1, &neg_max, col, 1, (vDSP_Length)V); + int n = V; vvexpf(col, col, &n); + float sum; vDSP_sve(col, 1, &sum, (vDSP_Length)V); + float inv_sum = 1.0f / sum; + vDSP_vsmul(col, 1, &inv_sum, col, 1, (vDSP_Length)V); + // Loss + gradient + int tgt = targets[t]; + total_loss -= logf(col[tgt] + 1e-10f); + col[tgt] -= 1.0f; + vDSP_vsmul(col, 1, &invS, col, 1, (vDSP_Length)V); + // Scatter back: dlogits[v*S + t] = col[v] + cblas_scopy(V, col, 1, dlogits + t, S); + } + free(col); + return total_loss / S; +} + +// Vocab compaction: build mapping from full 32K vocab to compact vocab +typedef struct { + int compact_vocab; // number of active tokens + int *full_to_compact; // [VOCAB] → compact id (-1 if unused) + int *compact_to_full; // [compact_vocab] → full vocab id +} VocabMap; + +static VocabMap vocab_map_build(const uint16_t *data, size_t n_tokens, int full_vocab) { + VocabMap vm; + vm.full_to_compact = (int*)malloc(full_vocab * sizeof(int)); + memset(vm.full_to_compact, -1, full_vocab * sizeof(int)); + // Scan for used tokens + for (size_t i = 0; i < n_tokens; i++) { + vm.full_to_compact[data[i]] = 0; // mark as used + } + // Assign compact IDs + int cid = 0; + for (int v = 0; v < full_vocab; v++) { + if (vm.full_to_compact[v] == 0) + vm.full_to_compact[v] = cid++; + else + vm.full_to_compact[v] = -1; + } + vm.compact_vocab = cid; + vm.compact_to_full = (int*)malloc(cid * sizeof(int)); + for (int v = 0; v < full_vocab; v++) { + if (vm.full_to_compact[v] >= 0) + vm.compact_to_full[vm.full_to_compact[v]] = v; + } + return vm; +} + +// Create compact embedding from full embedding +static float *vocab_compact_embed(const float *full_embed, const VocabMap *vm, int dim) { + float *ce = (float*)malloc((size_t)vm->compact_vocab * dim * 4); + for (int c = 0; c < vm->compact_vocab; c++) + memcpy(ce + c*dim, full_embed + vm->compact_to_full[c]*dim, dim*4); + return ce; +} + +// Scatter compact embed gradients back to full embed +static void vocab_scatter_grads(float *full_gembed, const float *compact_gembed, const VocabMap *vm, int dim) { + for (int c = 0; c < vm->compact_vocab; c++) { + int fv = vm->compact_to_full[c]; + for (int d = 0; d < dim; d++) + full_gembed[fv*dim + d] += compact_gembed[c*dim + d]; + } +} + +// Update full embed from compact embed (after adam) +static void vocab_update_full(float *full_embed, const float *compact_embed, const VocabMap *vm, int dim) { + for (int c = 0; c < vm->compact_vocab; c++) + memcpy(full_embed + vm->compact_to_full[c]*dim, compact_embed + c*dim, dim*4); +} + +static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) { + for (int t = 0; t < seq; t++) { + int tok = tokens[t]; + for (int d = 0; d < dim; d++) + x[d*seq + t] = embed[tok*dim + d]; + } +} + +static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) { + for (int t = 0; t < seq; t++) { + int tok = tokens[t]; + for (int d = 0; d < dim; d++) + d_embed[tok*dim + d] += dx[d*seq + t]; + } +} diff --git a/training/training_dynamic/io.h b/training/training_dynamic/io.h new file mode 100644 index 0000000..0a6969e --- /dev/null +++ b/training/training_dynamic/io.h @@ -0,0 +1,147 @@ +// io.h — IOSurface helpers, NEON conversion, kernel compile/eval +#pragma once +#include "config.h" + +static IOSurfaceRef make_surface(size_t bytes) { + return IOSurfaceCreate((__bridge CFDictionaryRef)@{ + (id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1, + (id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes), + (id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0}); +} + +// Blob builders for const weights (mask, rms) +static NSData *build_blob(const float *w, int rows, int cols) { + int ws=rows*cols*2, tot=128+ws; + uint8_t *b=(uint8_t*)calloc(tot,1); + b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1; + *(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128; + _Float16 *fp16=(_Float16*)(b+128); + for(int i=0;imodel = (void*)CFBridgingRetain(mdl); + k->ioIn = make_surface(ic_bytes); + k->ioOut = make_surface(oc_bytes); + id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn); + id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut); + k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR, + @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), + @[wI], @[@0], @[wO], @[@0], nil, nil, @0)); + k->tmpDir = (void*)CFBridgingRetain(td); + return k; + } +} +static void free_kern(Kern *k) { + if (!k) return; + id mdl = (__bridge id)k->model; NSError *e = nil; + ((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e); + CFRelease(k->ioIn); CFRelease(k->ioOut); + [[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil]; + CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir); + free(k); +} +static void ane_eval(Kern *k) { + id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil; + ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); +} diff --git a/training/training_dynamic/mil_dynamic.h b/training/training_dynamic/mil_dynamic.h new file mode 100644 index 0000000..e6c5798 --- /dev/null +++ b/training/training_dynamic/mil_dynamic.h @@ -0,0 +1,590 @@ +// mil_dynamic.h — MIL generators using dynamic matmul (weights via IOSurface) +// Instead of conv(const_weight, x), we use matmul(x, W) where both come from input. +// Input layout: [1, IC, 1, SP] fp32, SP = SEQ + total_weight_cols +// Activations in sp[0:SEQ], weight matrices packed sequentially in sp[SEQ:] +#pragma once +#include "io.h" + +#define MIL_HDR \ + @"program(1.3)\n[buildInfo = dict({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \ + "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \ + "{\"coremltools-version\", \"9.0\"}})]\n{\n" + +// Helper: generate a dynamic matmul within a MIL function +// Slices activation [1,ic,1,seq] and weight [1,ic,1,oc] from input, does matmul +// act_sp_off: spatial offset for activations (usually 0) +// w_sp_off: spatial offset for weight block +// Returns variable name of result [1,oc,1,seq] in fp16 +static void gen_dyn_matmul(NSMutableString *m, const char *prefix, + int ic, int oc, int seq, + int act_sp_off, int w_sp_off, + const char *input_var) { + // Slice activations + [m appendFormat:@" tensor %s_ba = const()[name=string(\"%s_ba\"), val=tensor([0,0,0,%d])];\n", prefix, prefix, act_sp_off]; + [m appendFormat:@" tensor %s_sa = const()[name=string(\"%s_sa\"), val=tensor([1,%d,1,%d])];\n", prefix, prefix, ic, seq]; + [m appendFormat:@" tensor %s_act = slice_by_size(x=%s,begin=%s_ba,size=%s_sa)[name=string(\"%s_act\")];\n", ic, seq, prefix, input_var, prefix, prefix, prefix]; + // Slice weight + [m appendFormat:@" tensor %s_bw = const()[name=string(\"%s_bw\"), val=tensor([0,0,0,%d])];\n", prefix, prefix, w_sp_off]; + [m appendFormat:@" tensor %s_sw = const()[name=string(\"%s_sw\"), val=tensor([1,%d,1,%d])];\n", prefix, prefix, ic, oc]; + [m appendFormat:@" tensor %s_wt = slice_by_size(x=%s,begin=%s_bw,size=%s_sw)[name=string(\"%s_wt\")];\n", ic, oc, prefix, input_var, prefix, prefix, prefix]; + // Reshape act: [1,ic,1,seq] → [1,1,ic,seq] → transpose → [1,1,seq,ic] + [m appendFormat:@" tensor %s_ra = const()[name=string(\"%s_ra\"), val=tensor([1,1,%d,%d])];\n", prefix, prefix, ic, seq]; + [m appendFormat:@" tensor %s_a2 = reshape(shape=%s_ra,x=%s_act)[name=string(\"%s_a2\")];\n", ic, seq, prefix, prefix, prefix, prefix]; + [m appendFormat:@" tensor %s_pm = const()[name=string(\"%s_pm\"), val=tensor([0,1,3,2])];\n", prefix, prefix]; + [m appendFormat:@" tensor %s_a3 = transpose(perm=%s_pm,x=%s_a2)[name=string(\"%s_a3\")];\n", seq, ic, prefix, prefix, prefix, prefix]; + // Reshape weight: [1,ic,1,oc] → [1,1,ic,oc] + [m appendFormat:@" tensor %s_rw = const()[name=string(\"%s_rw\"), val=tensor([1,1,%d,%d])];\n", prefix, prefix, ic, oc]; + [m appendFormat:@" tensor %s_W = reshape(shape=%s_rw,x=%s_wt)[name=string(\"%s_W\")];\n", ic, oc, prefix, prefix, prefix, prefix]; + // matmul: [1,1,seq,ic] @ [1,1,ic,oc] → [1,1,seq,oc] + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor %s_yh = matmul(transpose_x=bF,transpose_y=bF,x=%s_a3,y=%s_W)[name=string(\"%s_yh\")];\n", seq, oc, prefix, prefix, prefix, prefix]; + // Transpose back + reshape: [1,1,seq,oc] → [1,1,oc,seq] → [1,oc,1,seq] + [m appendFormat:@" tensor %s_yt = transpose(perm=%s_pm,x=%s_yh)[name=string(\"%s_yt\")];\n", oc, seq, prefix, prefix, prefix, prefix]; + [m appendFormat:@" tensor %s_ro = const()[name=string(\"%s_ro\"), val=tensor([1,%d,1,%d])];\n", prefix, prefix, oc, seq]; + [m appendFormat:@" tensor %s_y = reshape(shape=%s_ro,x=%s_yt)[name=string(\"%s_y\")];\n", oc, seq, prefix, prefix, prefix, prefix]; +} + +// ===== Dynamic matmul kernel: y = x @ W ===== +// Input: [1, IC, 1, SEQ+OC] fp32 — act[0:SEQ] + W[SEQ:SEQ+OC] +// Output: [1, OC, 1, SEQ] fp32 +static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) { + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + int sp = seq + oc; + [m appendFormat:@" func main(tensor x) {\n", ic, sp]; + [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", ic, sp]; + gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "xh"); + [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; + [m appendFormat:@" tensor y = cast(dtype=to32,x=mm_y)[name=string(\"cout\")];\n", oc, seq]; + [m appendString:@" } -> (y);\n}\n"]; + return m; +} + +// ===== SDPA forward (dynamic weights) ===== +// Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul +// Input: [1, DIM, 1, SEQ + 4*DIM] fp32 +// sp[0:SEQ] = xnorm (rmsnorm output, DIM channels) +// sp[SEQ:SEQ+DIM] = Wq[DIM,DIM] +// sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM] +// sp[SEQ+2D:SEQ+3D] = Wv[DIM,DIM] +// sp[SEQ+3D:SEQ+4D] = Wo[DIM,DIM] +// Output: [1, 6*DIM, 1, SEQ] fp16 = concat(o_out, Q, K, V, attn_out, xnorm_pass) +// NOTE: mask is still a const weight (it doesn't change) +static NSString *gen_sdpa_fwd_dynamic(void) { + float sc = 1.0f/sqrtf((float)HD); + int w_total = 4*DIM; // Wq+Wk+Wv+Wo + int sp_in = SEQ + w_total; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; + // Cast to fp16 + [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in]; + + // Slice xnorm [1,DIM,1,SEQ] + [m appendString:@" tensor bx = const()[name=string(\"bx\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor sx = const()[name=string(\"sx\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ]; + + // Slice Wq [1,DIM,1,DIM] + [m appendFormat:@" tensor bq = const()[name=string(\"bq\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", DIM, DIM]; + [m appendFormat:@" tensor Wq = slice_by_size(x=xh,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM]; + + // Slice Wk + [m appendFormat:@" tensor bk = const()[name=string(\"bk\"), val=tensor([0,0,0,%d])];\n", SEQ+DIM]; + [m appendFormat:@" tensor Wk = slice_by_size(x=xh,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM]; + + // Slice Wv + [m appendFormat:@" tensor bv = const()[name=string(\"bv\"), val=tensor([0,0,0,%d])];\n", SEQ+2*DIM]; + [m appendFormat:@" tensor Wv = slice_by_size(x=xh,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM]; + + // Slice Wo + [m appendFormat:@" tensor bo = const()[name=string(\"bo\"), val=tensor([0,0,0,%d])];\n", SEQ+3*DIM]; + [m appendFormat:@" tensor Wo = slice_by_size(x=xh,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM]; + + // Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D] + [m appendFormat:@" tensor r2 = const()[name=string(\"r2\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor xn2 = reshape(shape=r2,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM]; + + // Reshape weights: [1,D,1,D] → [1,1,D,D] + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, DIM]; + [m appendFormat:@" tensor Wq2 = reshape(shape=rw,x=Wq)[name=string(\"Wq2\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wk2 = reshape(shape=rw,x=Wk)[name=string(\"Wk2\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wv2 = reshape(shape=rw,x=Wv)[name=string(\"Wv2\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wo2 = reshape(shape=rw,x=Wo)[name=string(\"Wo2\")];\n", DIM, DIM]; + + // QKV matmul: [1,1,S,D] @ [1,1,D,D] → [1,1,S,D] + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, DIM]; + + // Transpose back: [1,1,S,D] → [1,1,D,S] → reshape [1,D,1,S] + [m appendFormat:@" tensor qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor os = const()[name=string(\"os\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor qf = reshape(shape=os,x=qt)[name=string(\"qf\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor kf = reshape(shape=os,x=kt)[name=string(\"kf\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor vf = reshape(shape=os,x=vt)[name=string(\"vf\")];\n", DIM, SEQ]; + + // SDPA: reshape to heads, matmul, mask, softmax, matmul + [m appendFormat:@" tensor qsh = const()[name=string(\"qsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD]; + + // Q @ K^T + [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; + [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ]; + + // Causal mask (still const — doesn't change) + [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ]; + [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ]; + + // Softmax + [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; + [m appendFormat:@" tensor aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ]; + + // scores @ V + [m appendFormat:@" tensor a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS, SEQ, HD]; + + // Reshape back to [1,DIM,1,SEQ] + [m appendFormat:@" tensor at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM, SEQ]; + + // Wo matmul: af → [1,1,S,D] @ Wo[1,1,D,D] → [1,1,S,D] → [1,D,1,S] + [m appendFormat:@" tensor af2 = reshape(shape=r2,x=af)[name=string(\"af2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor aft = transpose(perm=pm,x=af2)[name=string(\"aft\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor om = matmul(transpose_x=bF,transpose_y=bF,x=aft,y=Wo2)[name=string(\"om\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ]; + + // Output: concat(o_out, qf, kf, vf, af, xn) — same as original for backward compatibility + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ]; + // Cast to fp32 + [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; + [m appendFormat:@" tensor out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 6*DIM, SEQ]; + [m appendString:@" } -> (out32);\n}\n"]; + return m; +} + +// ===== FFN forward (dynamic weights) ===== +// RMSNorm on CPU. This kernel: xnorm @ W1 → SiLU, xnorm @ W3 → gate, gate*silu @ W2 → out +// Input: [1, DIM, 1, SEQ + HIDDEN + HIDDEN + DIM] fp32 +// sp[0:SEQ] = xnorm [DIM,SEQ] +// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN] +// sp[SEQ+HIDDEN:SEQ+2*HIDDEN] = W3[DIM,HIDDEN] +// sp[SEQ+2*HIDDEN:SEQ+2*HIDDEN+DIM]= W2[HIDDEN→DIM] — but W2 is [DIM,HIDDEN], we need HIDDEN input channels +// PROBLEM: W2 has shape [DIM,HIDDEN] = HIDDEN input channels, but our kernel has DIM input channels. +// Solution: separate kernels for W1/W3 (DIM→HIDDEN) and W2 (HIDDEN→DIM) +// OR: do W1,W3 in one kernel, SiLU on CPU/ANE, W2 in another kernel. +// Simpler: 3 separate matmul kernels per FFN direction. But that's too many dispatches. +// Better: one kernel for W1+W3 (same input dim), CPU SiLU, one kernel for W2. + +// FFN part 1: xnorm @ W1, xnorm @ W3 (both DIM→HIDDEN) +// Input: [1, DIM, 1, SEQ + 2*HIDDEN] fp32 +// sp[0:SEQ] = xnorm +// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN] +// sp[SEQ+HIDDEN:SEQ+2*HIDDEN]= W3[DIM,HIDDEN] +// Output: [1, 2*HIDDEN, 1, SEQ] fp32 = concat(h1, h3) +static NSString *gen_ffn_w13_dynamic(void) { + int sp_in = SEQ + 2*HIDDEN; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; + [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in]; + + // Slice xnorm + [m appendString:@" tensor bx = const()[name=string(\"bx\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor sx = const()[name=string(\"sx\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ]; + + // Slice W1 + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor s1 = const()[name=string(\"s1\"), val=tensor([1,%d,1,%d])];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W1 = slice_by_size(x=xh,begin=b1,size=s1)[name=string(\"W1\")];\n", DIM, HIDDEN]; + + // Slice W3 + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", SEQ+HIDDEN]; + [m appendFormat:@" tensor W3 = slice_by_size(x=xh,begin=b3,size=s1)[name=string(\"W3\")];\n", DIM, HIDDEN]; + + // Reshape for matmul + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor rd = const()[name=string(\"rd\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor xn2 = reshape(shape=rd,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM]; + + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN]; + + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN]; + [m appendFormat:@" tensor h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN]; + + // Transpose back + [m appendFormat:@" tensor h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor rh = const()[name=string(\"rh\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor h1 = reshape(shape=rh,x=h1t)[name=string(\"h1\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor h3 = reshape(shape=rh,x=h3t)[name=string(\"h3\")];\n", HIDDEN, SEQ]; + + // SiLU + gate + [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ]; + + // Concat output: (h1, h3, gate) + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(h1,h3,gate))[name=string(\"cat\")];\n", 2*HIDDEN+HIDDEN, SEQ]; + [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; + [m appendFormat:@" tensor out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 3*HIDDEN, SEQ]; + [m appendString:@" } -> (out32);\n}\n"]; + return m; +} + +// FFN part 2: gate @ W2 (HIDDEN→DIM) +// Input: [1, HIDDEN, 1, SEQ + DIM] fp32 +// sp[0:SEQ] = gate [HIDDEN,SEQ] +// sp[SEQ:SEQ+DIM] = W2[HIDDEN,DIM] +// Output: [1, DIM, 1, SEQ] fp32 +static NSString *gen_ffn_w2_dynamic(void) { + int sp_in = SEQ + DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", HIDDEN, sp_in]; + [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in]; + + [m appendString:@" tensor ba = const()[name=string(\"ba\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor sa = const()[name=string(\"sa\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", HIDDEN, SEQ]; + + [m appendFormat:@" tensor bw = const()[name=string(\"bw\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, DIM]; + [m appendFormat:@" tensor W2 = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"W2\")];\n", HIDDEN, DIM]; + + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor ra = const()[name=string(\"ra\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor at = transpose(perm=pm,x=a2)[name=string(\"at\")];\n", SEQ, HIDDEN]; + + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, DIM]; + [m appendFormat:@" tensor W22 = reshape(shape=rw,x=W2)[name=string(\"W22\")];\n", HIDDEN, DIM]; + + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor ym = matmul(transpose_x=bF,transpose_y=bF,x=at,y=W22)[name=string(\"ym\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor yt = transpose(perm=pm,x=ym)[name=string(\"yt\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", DIM, SEQ]; + [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; + [m appendFormat:@" tensor y = cast(dtype=to32,x=yr)[name=string(\"cout\")];\n", DIM, SEQ]; + [m appendString:@" } -> (y);\n}\n"]; + return m; +} + +// ===== FFN backward (dynamic weights) ===== +// Input: [1, DIM+2*HIDDEN, 1, SEQ + HIDDEN + DIM + DIM] fp32 +// Actually simpler to split into separate backward kernels like forward. + +// FFN backward part 1: dffn @ W2^T → dsilu (HIDDEN), then SiLU derivative +// Input: [1, DIM, 1, SEQ + HIDDEN] fp32 +// sp[0:SEQ] = dffn [DIM, SEQ] +// sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN] +// Output: [1, HIDDEN, 1, SEQ] fp32 = dsilu_raw +static NSString *gen_ffn_bwd_w2t_dynamic(void) { + return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ); +} + +// FFN backward part 2: dh1 @ W1^T + dh3 @ W3^T → dx +// We need h1,h3 for SiLU derivative, but those are on CPU. +// Actually the SiLU derivative + gating is element-wise, do on CPU. +// Then: dh1 @ W1^T and dh3 @ W3^T are two separate matmuls (HIDDEN→DIM). +// Combine into one kernel: +// Input: [1, HIDDEN, 1, SEQ + SEQ + DIM + DIM] fp32 +// sp[0:SEQ] = dh1 [HIDDEN,SEQ] +// sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ] +// sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM] +// sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM] +// Output: [1, DIM, 1, SEQ] fp32 = dx1 + dx3 +static NSString *gen_ffn_bwd_w13t_dynamic(void) { + int sp_in = 2*SEQ + 2*DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", HIDDEN, sp_in]; + [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in]; + + // Slice dh1 [HIDDEN, SEQ] + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor sh = const()[name=string(\"sh\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh1 = slice_by_size(x=xh,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; + + // Slice dh3 + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor dh3 = slice_by_size(x=xh,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; + + // Slice W1^T [HIDDEN, DIM] + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; + [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, DIM]; + [m appendFormat:@" tensor W1t = slice_by_size(x=xh,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM]; + + // Slice W3^T + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+DIM]; + [m appendFormat:@" tensor W3t = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM]; + + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + + // dh1 matmul: [S,H] @ [H,D] → [S,D] + [m appendFormat:@" tensor ra = const()[name=string(\"ra\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh12 = reshape(shape=ra,x=dh1)[name=string(\"dh12\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh1t = transpose(perm=pm,x=dh12)[name=string(\"dh1t\")];\n", SEQ, HIDDEN]; + [m appendFormat:@" tensor dh32 = reshape(shape=ra,x=dh3)[name=string(\"dh32\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh3t = transpose(perm=pm,x=dh32)[name=string(\"dh3t\")];\n", SEQ, HIDDEN]; + + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, DIM]; + [m appendFormat:@" tensor W1t2 = reshape(shape=rw,x=W1t)[name=string(\"W1t2\")];\n", HIDDEN, DIM]; + [m appendFormat:@" tensor W3t2 = reshape(shape=rw,x=W3t)[name=string(\"W3t2\")];\n", HIDDEN, DIM]; + + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor dx1m = matmul(transpose_x=bF,transpose_y=bF,x=dh1t,y=W1t2)[name=string(\"dx1m\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor dx3m = matmul(transpose_x=bF,transpose_y=bF,x=dh3t,y=W3t2)[name=string(\"dx3m\")];\n", SEQ, DIM]; + + // Add + [m appendFormat:@" tensor dxm = add(x=dx1m,y=dx3m)[name=string(\"dxm\")];\n", SEQ, DIM]; + + [m appendFormat:@" tensor dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; + [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; + [m appendFormat:@" tensor y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ]; + [m appendString:@" } -> (y);\n}\n"]; + return m; +} + +// ===== SDPA backward part 1 (dynamic Wo^T) ===== +// Same as original gen_sdpa_bwd1 but Wo^T comes from input instead of const +// Input: [1, 4*DIM, 1, SEQ + DIM] fp32 — Q,K,V,dx2 in channels, Wo^T in spatial +// Wait — channels must match for all data. Q,K,V are [DIM,SEQ], dx2 is [DIM,SEQ]. +// Total input channels = 4*DIM. But Wo^T is [DIM,DIM] = DIM channels of DIM spatial. +// Problem: can't mix 4*DIM channels for data with DIM channels for Wo^T. +// Solution: Wo^T matmul as separate kernel, then SDPA part purely element-wise on ANE. + +// Wo^T matmul: dx2 @ Wo^T → da (DIM→DIM) +static NSString *gen_wot_dynamic(void) { + return gen_dyn_matmul_mil(DIM, DIM, SEQ); +} + +// SDPA backward part 1 (no weights, all data): Q,K,V,da → dV,probs,dp +// Same as original but without Wo^T conv (already done) +// Input: [1, 4*DIM, 1, SEQ] fp16 +static NSString *gen_sdpa_bwd1_noweight(void) { + float sc = 1.0f/sqrtf((float)HD); + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", 4*DIM, SEQ]; + + // Slice Q,K,V,da + [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; + [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; + [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*DIM]; + [m appendFormat:@" tensor da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM, SEQ]; + + // Reshape to heads + [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, HD, SEQ]; + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor dr = reshape(shape=rsh,x=da)[name=string(\"rd\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor dat = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS, SEQ, HD]; + + // Forward attention scores (recompute) + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; + [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ]; + [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ]; + [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; + [m appendFormat:@" tensor probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ]; + + // dV = probs^T @ da, dp = da @ V^T + [m appendFormat:@" tensor dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=dat)[name=string(\"dv\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor dp4 = matmul(transpose_x=bF,transpose_y=bT,x=dat,y=v)[name=string(\"dp\")];\n", HEADS, SEQ, SEQ]; + + // Reshape dV back + [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor dvs = const()[name=string(\"dvs\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM, SEQ]; + + // Flatten probs and dp for output + [m appendFormat:@" tensor scs = const()[name=string(\"scs\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; + [m appendFormat:@" tensor pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH, SEQ]; + [m appendFormat:@" tensor dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH, SEQ]; + + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// SDPA backward part 2: same as original (no weights, pure computation) +static NSString *gen_sdpa_bwd2(void) { + float sc = 1.0f/sqrtf((float)HD); + int bwd2_in = 2*SCORE_CH + 2*DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", bwd2_in, SEQ]; + [m appendFormat:@" tensor sz_sc = const()[name=string(\"szsc\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH, SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", SCORE_CH]; + [m appendFormat:@" tensor dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH, SEQ]; + [m appendFormat:@" tensor sz_d = const()[name=string(\"szd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH]; + [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+DIM]; + [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor ssh = const()[name=string(\"ssh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" tensor probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" tensor dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS, HD, SEQ]; + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS, SEQ, SEQ]; + [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([-1])];\n"]; + [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS, SEQ]; + [m appendFormat:@" tensor dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" tensor ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS, SEQ, SEQ]; + [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; + [m appendFormat:@" tensor ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS, SEQ, SEQ]; + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor fs = const()[name=string(\"fs\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM, SEQ]; + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// QKV backward (dynamic): dq @ Wq^T + dk @ Wk^T + dv @ Wv^T → dx +// Input: [1, DIM, 1, 3*SEQ + 3*DIM] fp32 +// sp[0:SEQ] = dq [DIM,SEQ] +// sp[SEQ:2*SEQ] = dk [DIM,SEQ] +// sp[2*SEQ:3*SEQ] = dv [DIM,SEQ] +// sp[3*SEQ:3*SEQ+DIM] = Wq^T [DIM,DIM] +// sp[3*SEQ+DIM:3*SEQ+2D] = Wk^T [DIM,DIM] +// sp[3*SEQ+2D:3*SEQ+3D] = Wv^T [DIM,DIM] +// Output: [1, DIM, 1, SEQ] fp32 = dxq + dxk + dxv +static NSString *gen_qkvb_dynamic(void) { + int sp_in = 3*SEQ + 3*DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; + [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; + [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in]; + + // Slice dq, dk, dv + [m appendFormat:@" tensor sd = const()[name=string(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor dq = slice_by_size(x=xh,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor dk = slice_by_size(x=xh,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; + [m appendFormat:@" tensor dv = slice_by_size(x=xh,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ]; + + // Slice Wq^T, Wk^T, Wv^T + [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", DIM, DIM]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", 3*SEQ]; + [m appendFormat:@" tensor Wqt = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM]; + [m appendFormat:@" tensor b4 = const()[name=string(\"b4\"), val=tensor([0,0,0,%d])];\n", 3*SEQ+DIM]; + [m appendFormat:@" tensor Wkt = slice_by_size(x=xh,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM]; + [m appendFormat:@" tensor b5 = const()[name=string(\"b5\"), val=tensor([0,0,0,%d])];\n", 3*SEQ+2*DIM]; + [m appendFormat:@" tensor Wvt = slice_by_size(x=xh,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM]; + + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + + // Reshape and matmul for each + [m appendFormat:@" tensor rd = const()[name=string(\"rd\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, DIM]; + + // dq @ Wq^T + [m appendFormat:@" tensor dq2 = reshape(shape=rd,x=dq)[name=string(\"dq2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq2)[name=string(\"dqt\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor Wqt2 = reshape(shape=rw,x=Wqt)[name=string(\"Wqt2\")];\n", DIM, DIM]; + [m appendFormat:@" tensor dxq = matmul(transpose_x=bF,transpose_y=bF,x=dqt,y=Wqt2)[name=string(\"dxq\")];\n", SEQ, DIM]; + + // dk @ Wk^T + [m appendFormat:@" tensor dk2 = reshape(shape=rd,x=dk)[name=string(\"dk2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", DIM, DIM]; + [m appendFormat:@" tensor dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM]; + + // dv @ Wv^T + [m appendFormat:@" tensor dv2 = reshape(shape=rd,x=dv)[name=string(\"dv2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", DIM, DIM]; + [m appendFormat:@" tensor dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM]; + + // Sum: dxq + dxk + dxv + [m appendFormat:@" tensor dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", SEQ, DIM]; + [m appendFormat:@" tensor dxall = add(x=dxqk,y=dxv)[name=string(\"aall\")];\n", SEQ, DIM]; + + [m appendFormat:@" tensor dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; + [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; + [m appendFormat:@" tensor y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ]; + [m appendString:@" } -> (y);\n}\n"]; + return m; +} + +// Causal mask blob (used by sdpa_fwd and sdpa_bwd1) +static NSData *g_mask_blob = nil; +static NSData *get_mask_blob(void) { + if (!g_mask_blob) { + _Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16)); + for(int t=0;t