ANE/training/train_large.m

1006 lines
67 KiB
Objective-C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// train_large.m — Train a single transformer layer FULLY on ANE
// 7 ANE kernels per step:
// Forward: kFwdAttn (QKV+SDPA+Wo, taps Q,K,V,attn_out) + kFwdFFN (W1+W3+SiLU+W2, taps h1,h3,silu_out)
// Backward: kFFNBwd (W2^T+SiLU_bwd+W1^T+W3^T) + kSdpaBwd1 (Wo^T+SDPA) + kSdpaBwd2 + kQKVb (Wq^T+Wk^T+Wv^T)
// CPU: RMSNorm (fwd+bwd), residuals, loss, dW accumulation (cblas), SGD update
// NO CPU recompute of Q,K,V,h1,h3 — all exposed via forward taps
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach_time.h>
#import <Accelerate/Accelerate.h>
#include <math.h>
#include <unistd.h>
#include <dispatch/dispatch.h>
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define HD (DIM/HEADS)
#define SEQ 512
#define ACCUM_STEPS 100
#define MAX_COMPILES 100
#define NUM_KERNELS 6
#define CKPT_PATH "/tmp/ane_large_ckpt.bin"
static Class g_D, g_I, g_AR, g_AIO;
static mach_timebase_info_data_t g_tb;
static int g_compile_count = 0;
static void ane_init(void) {
dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW);
g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor");
g_I = NSClassFromString(@"_ANEInMemoryModel");
g_AR = NSClassFromString(@"_ANERequest");
g_AIO= NSClassFromString(@"_ANEIOSurfaceObject");
}
static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; }
static IOSurfaceRef make_surface(size_t bytes) {
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1,
(id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes),
(id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0});
}
static NSData *build_blob(const float *w, int rows, int cols) {
int ws=rows*cols*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
_Float16 *fp16=(_Float16*)(b+128);
for(int i=0;i<rows*cols;i++) fp16[i]=(_Float16)w[i];
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
static NSData *build_blob_t(const float *w, int rows, int cols) {
int ws=cols*rows*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
_Float16 *fp16=(_Float16*)(b+128);
for(int i=0;i<rows;i++) for(int j=0;j<cols;j++) fp16[j*rows+i]=(_Float16)w[i*cols+j];
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
static NSData *build_blob_fp16(_Float16 *d, int cnt) {
int ws=cnt*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
memcpy(b+128,d,ws);
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
// ===== MIL generators =====
#define MIL_HDR \
@"program(1.3)\n[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
#define CONV_CONST \
" string pt = const()[name=string(\"pt\"), val=string(\"valid\")];\n" \
" tensor<int32, [2]> st = const()[name=string(\"st\"), val=tensor<int32, [2]>([1,1])];\n" \
" tensor<int32, [4]> pd = const()[name=string(\"pd\"), val=tensor<int32, [4]>([0,0,0,0])];\n" \
" tensor<int32, [2]> dl = const()[name=string(\"dl\"), val=tensor<int32, [2]>([1,1])];\n" \
" int32 gr = const()[name=string(\"gr\"), val=int32(1)];\n"
// SDPA forward + taps: x_in → rmsnorm → QKV+SDPA+Wo → concat(o_out, Q, K, V, attn_out, xnorm) fp16
static NSString *gen_sdpa_fwd_taps(void) {
float sc = 1.0f/sqrtf((float)HD);
float invd = 1.0f/(float)DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, SEQ];
// --- RMSNorm: x → xn ---
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([1])];\n"];
[m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ];
[m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ];
[m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ];
[m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,1]> rw = const()[name=string(\"rw\"), val=tensor<fp16, [1,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/rms1.bin\"), offset=uint64(64)))];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ];
// --- QKV + SDPA + Wo (operates on xn) ---
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wq = const()[name=string(\"Wq\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wq.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wk = const()[name=string(\"Wk\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wk.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wv = const()[name=string(\"Wv\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wv.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wo = const()[name=string(\"Wo\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wo.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn)[name=string(\"cq\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn)[name=string(\"ck\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn)[name=string(\"cv\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS,SEQ,HD];
[m appendString:@" bool tx = const()[name=string(\"tx\"), val=bool(false)];\n"];
[m appendString:@" bool ty = const()[name=string(\"ty\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<int32, [4]> os = const()[name=string(\"os\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af)[name=string(\"co\")];\n", DIM,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// FFN forward + taps: x2 → rmsnorm → FFN → concat(ffn_out, h1, h3, silu_out, x2norm) fp16
static NSString *gen_ffn_fwd_taps(void) {
float invd = 1.0f/(float)DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, SEQ];
// --- RMSNorm: x → xn ---
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([1])];\n"];
[m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ];
[m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ];
[m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ];
[m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,1]> rw = const()[name=string(\"rw\"), val=tensor<fp16, [1,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/rms2.bin\"), offset=uint64(64)))];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ];
// --- FFN (operates on xn) ---
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W1 = const()[name=string(\"W1\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w1.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W3 = const()[name=string(\"W3\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w3.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W2 = const()[name=string(\"W2\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w2.bin\"), offset=uint64(64)))];\n", DIM,HIDDEN,DIM,HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn)[name=string(\"c1\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn)[name=string(\"c3\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate)[name=string(\"c2\")];\n", DIM,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn))[name=string(\"cat\")];\n", 2*DIM+3*HIDDEN,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// Fused FFN backward: concat(dffn,h1,h3) → concat(dx,dh1,dh3) fp16
static NSString *gen_ffn_bwd(void) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM+2*HIDDEN, SEQ];
[m appendString:@CONV_CONST];
[m appendString:@" tensor<int32, [4]> bd = const()[name=string(\"bd\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dffn = slice_by_size(x=x,begin=bd,size=sd)[name=string(\"s0\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<int32, [4]> s1 = const()[name=string(\"s1\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = slice_by_size(x=x,begin=b1,size=s1)[name=string(\"s1x\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = slice_by_size(x=x,begin=b3,size=s1)[name=string(\"s3x\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W2t = const()[name=string(\"W2t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w2t.bin\"), offset=uint64(64)))];\n", HIDDEN, DIM, HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn)[name=string(\"cw2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ];
[m appendString:@" fp16 one = const()[name=string(\"one\"), val=fp16(1.0)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oms = sub(x=one,y=sig)[name=string(\"oms\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> homs = mul(x=h1,y=oms)[name=string(\"homs\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> brk = add(x=one,y=homs)[name=string(\"brk\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dsd = mul(x=sig,y=brk)[name=string(\"dsd\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> t1 = mul(x=dsilu,y=h3)[name=string(\"t1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = mul(x=t1,y=dsd)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> slh = mul(x=h1,y=sig)[name=string(\"slh\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = mul(x=dsilu,y=slh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W1t = const()[name=string(\"W1t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w1t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W3t = const()[name=string(\"W3t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w3t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1)[name=string(\"cw1\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3)[name=string(\"cw3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = add(x=dx1,y=dx3)[name=string(\"adx\")];\n", DIM, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3))[name=string(\"cat\")];\n", DIM+2*HIDDEN, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// Fused QKV backward: concat(dq,dk,dv) → dx fp16
static NSString *gen_qkvb(void) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 3*DIM, SEQ];
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wqt = const()[name=string(\"Wqt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wqt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wkt = const()[name=string(\"Wkt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wkt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wvt = const()[name=string(\"Wvt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wvt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq)[name=string(\"cq\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk)[name=string(\"ck\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv)[name=string(\"cv\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = add(x=dxqk,y=dxv)[name=string(\"out\")];\n", DIM,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 1 + Wo^T: concat(Q,K,V,dx2) → Wo^T(dx2) → concat(dV, probs_flat, dp_flat) fp16
// SCORE_CH: channels needed for flattened attention scores [HEADS,SEQ,SEQ] → [HEADS*SEQ, SEQ]
#define SCORE_CH (HEADS*SEQ)
static NSString *gen_sdpa_bwd1(void) {
float sc = 1.0f/sqrtf((float)HD);
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 4*DIM, SEQ];
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx2f = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM,SEQ];
// Wo^T backward: dx2 → dattn
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wot = const()[name=string(\"Wot\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wot.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f)[name=string(\"cwo\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dr = reshape(shape=rsh,x=df)[name=string(\"rd\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> da = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS,SEQ,HD];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da)[name=string(\"dv\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v)[name=string(\"dp\")];\n", HEADS,SEQ,SEQ];
// Flatten dv back to [1,DIM,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM,SEQ];
// Flatten probs [1,H,S,S] → [1,H*S,1,S] and dp [1,H,S,S] → [1,H*S,1,S]
[m appendFormat:@" tensor<int32, [4]> scs = const()[name=string(\"scs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 2: concat(probs[SCORE_CH],dp[SCORE_CH],Q[DIM],K[DIM]) → concat(dQ,dK) fp16
static NSString *gen_sdpa_bwd2(void) {
float sc = 1.0f/sqrtf((float)HD);
int bwd2_in = 2*SCORE_CH + 2*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", bwd2_in, SEQ];
// Slice probs
[m appendFormat:@" tensor<int32, [4]> sz_sc = const()[name=string(\"szsc\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH,SEQ];
// Slice dp
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH,SEQ];
// Slice Q
[m appendFormat:@" tensor<int32, [4]> sz_d = const()[name=string(\"szd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM,SEQ];
// Slice K
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM,SEQ];
// Reshape to multi-head
[m appendFormat:@" tensor<int32, [4]> ssh = const()[name=string(\"ssh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
// Softmax grad: ds = probs * (dp - sum(probs*dp)) * scale
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([-1])];\n"];
[m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// ===== Weight builders =====
static NSData *g_mask_blob = nil;
static NSData *get_mask_blob(void) {
if (!g_mask_blob) {
_Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16));
for(int t=0;t<SEQ;t++) for(int t2=0;t2<SEQ;t2++)
mask[t*SEQ+t2] = (t2<=t) ? (_Float16)0.0f : (_Float16)(-65504.0f);
g_mask_blob = build_blob_fp16(mask, SEQ*SEQ);
free(mask);
}
return g_mask_blob;
}
// ===== Kernel compilation and evaluation =====
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {
@autoreleasepool {
NSData *md = [mil dataUsingEncoding:NSUTF8StringEncoding];
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(g_D, @selector(modelWithMILText:weights:optionsPlist:), md, weights, nil);
if (!desc) { printf(" [compile] desc=NULL\n"); return NULL; }
id mdl = ((id(*)(Class,SEL,id))objc_msgSend)(g_I, @selector(inMemoryModelWithDescriptor:), desc);
id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier));
NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx];
[[NSFileManager defaultManager] createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"] withIntermediateDirectories:YES attributes:nil error:nil];
[md writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES];
for (NSString *path in weights) {
NSString *rel = [path stringByReplacingOccurrencesOfString:@"@model_path/" withString:@""];
[weights[path][@"data"] writeToFile:[td stringByAppendingPathComponent:rel] atomically:YES];
}
NSError *e = nil;
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(compileWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] FAIL: %s\n", e ? [[e description] UTF8String] : "no error"); return NULL;
}
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] load FAIL\n"); return NULL;
}
__sync_fetch_and_add(&g_compile_count, 1);
Kern *k = calloc(1, sizeof(Kern));
k->model = CFBridgingRetain(mdl);
k->ioIn = make_surface(ic_bytes);
k->ioOut = make_surface(oc_bytes);
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
k->request = CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0));
k->tmpDir = CFBridgingRetain(td);
return k;
}
}
static void free_kern(Kern *k) {
if (!k) return;
id mdl = (__bridge id)k->model; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
CFRelease(k->ioIn); CFRelease(k->ioOut);
[[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil];
CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir);
free(k);
}
static void ane_eval(Kern *k) {
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}
// ===== Vectorized conversion helpers (NEON) =====
#include <arm_neon.h>
static void cvt_f16_f32(float *dst, const _Float16 *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vld1q_f16((const __fp16*)(src+i));
vst1q_f32(dst+i, vcvt_f32_f16(vget_low_f16(h)));
vst1q_f32(dst+i+4, vcvt_f32_f16(vget_high_f16(h)));
}
for (; i < n; i++) dst[i] = (float)src[i];
}
static void cvt_f32_f16(_Float16 *dst, const float *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vcombine_f16(vcvt_f16_f32(vld1q_f32(src+i)),
vcvt_f16_f32(vld1q_f32(src+i+4)));
vst1q_f16((__fp16*)(dst+i), h);
}
for (; i < n; i++) dst[i] = (_Float16)src[i];
}
// ===== IOSurface I/O helpers (channel-first, no transpose) =====
// All CPU buffers are [C,S] channel-first matching IOSurface [1,C,1,S]
// Write fp32 [C,S] → fp16 [1,C,1,S] (just type conversion, no transpose)
static void io_write_fp16(IOSurfaceRef s, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
_Float16 *dst = (_Float16*)IOSurfaceGetBaseAddress(s);
cvt_f32_f16(dst, data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
// Write fp32 [C,S] → fp32 [1,C,1,S] (just memcpy)
static void io_write_fp32(IOSurfaceRef s, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
memcpy(IOSurfaceGetBaseAddress(s), data, channels * sp * sizeof(float));
IOSurfaceUnlock(s, 0, NULL);
}
// Read fp16 [1,C,1,S] → fp32 [C,S] at channel offset (just type conversion)
static void io_read_fp16(IOSurfaceRef s, float *data, int ch_off, int channels, int sp) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
_Float16 *src = (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp;
cvt_f16_f32(data, src, channels * sp);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
// Read fp32 [1,C,1,S] → fp32 [C,S] (just memcpy)
static void io_read_fp32(IOSurfaceRef s, float *data, int channels, int sp) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
memcpy(data, IOSurfaceGetBaseAddress(s), channels * sp * sizeof(float));
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
// Write multiple fp32 [C,S] arrays concatenated along channel dim as fp16
static void io_write_multi_fp16(IOSurfaceRef s, int sp, int n, ...) {
IOSurfaceLock(s, 0, NULL);
_Float16 *dst = (_Float16*)IOSurfaceGetBaseAddress(s);
va_list ap; va_start(ap, n);
int ch_off = 0;
for (int i=0; i<n; i++) {
const float *data = va_arg(ap, const float*);
int channels = va_arg(ap, int);
cvt_f32_f16(dst + ch_off*sp, data, channels * sp);
ch_off += channels;
}
va_end(ap);
IOSurfaceUnlock(s, 0, NULL);
}
// Direct copy between IOSurfaces (no format conversion — both fp16 channel-first)
static void io_copy(IOSurfaceRef dst, int dst_ch, IOSurfaceRef src, int src_ch, int channels, int sp) {
IOSurfaceLock(dst, 0, NULL);
IOSurfaceLock(src, kIOSurfaceLockReadOnly, NULL);
memcpy((_Float16*)IOSurfaceGetBaseAddress(dst) + dst_ch*sp,
(_Float16*)IOSurfaceGetBaseAddress(src) + src_ch*sp,
channels * sp * sizeof(_Float16));
IOSurfaceUnlock(src, kIOSurfaceLockReadOnly, NULL);
IOSurfaceUnlock(dst, 0, NULL);
}
// Write one fp32 [C,S] array at specific channel offset in IOSurface as fp16
static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
_Float16 *dst = (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp;
cvt_f32_f16(dst, data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
// ===== CPU ops (channel-first [C,S] layout) =====
// x[i*S+t] = channel i, position t
// Process all positions in parallel using vectorized column ops
static float *g_rms_tmp = NULL;
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = malloc(S*4);
float *ss = calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
int n = S; vvrsqrtf(ss, ss, &n);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
}
free(ss);
}
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = malloc(S*4);
float *ss = calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
float *rrms = malloc(S*4);
int n = S; vvrsqrtf(rrms, ss, &n);
// dot[t] = sum_i dy[i,t]*x[i,t]*w[i]
float *dot = calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
// dot += tmp * w[i] (vDSP_vsma: A*scalar+C→D)
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
}
// dot *= rrms^2/d
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
// dw[i] += sum_t dy[i,t]*x[i,t]*rrms[t]
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
dw[i] += s;
}
free(ss); free(rrms); free(dot);
}
// ===== Checkpoint =====
typedef struct {
int step, total_steps;
float lr, loss;
double cum_compile, cum_train, cum_wall;
int cum_steps, cum_batches;
int adam_t; // Adam timestep
} CkptHdr;
// Adam optimizer state
typedef struct {
float *m, *v; // first and second moment
size_t n;
} AdamState;
static AdamState adam_alloc(size_t n) { return (AdamState){calloc(n,4), calloc(n,4), n}; }
static void adam_free(AdamState *s) { free(s->m); free(s->v); }
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
for (size_t i=0; i<s->n; i++) {
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
}
}
int main(int argc, char *argv[]) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
int total_steps = 400;
float lr = 1e-3f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0;
int start_step = 0;
size_t wq_sz = DIM*DIM, wo_sz = DIM*DIM;
size_t w1_sz = HIDDEN*DIM, w2_sz = DIM*HIDDEN, w3_sz = HIDDEN*DIM;
size_t total_params = 4*wq_sz + w1_sz + w2_sz + w3_sz;
float *Wq=malloc(wq_sz*4), *Wk=malloc(wq_sz*4), *Wv=malloc(wq_sz*4), *Wo=malloc(wo_sz*4);
float *W1=malloc(w1_sz*4), *W2=malloc(w2_sz*4), *W3=malloc(w3_sz*4);
float *rms1_w=malloc(DIM*4), *rms2_w=malloc(DIM*4);
// Adam optimizer states (m and v for each weight)
AdamState aWq=adam_alloc(wq_sz), aWk=adam_alloc(wq_sz), aWv=adam_alloc(wq_sz), aWo=adam_alloc(wo_sz);
AdamState aW1=adam_alloc(w1_sz), aW2=adam_alloc(w2_sz), aW3=adam_alloc(w3_sz);
AdamState arms1=adam_alloc(DIM), arms2=adam_alloc(DIM);
double cum_compile=0, cum_train=0, cum_wall=0;
int cum_steps=0, cum_batches=0;
bool resuming = false;
if (argc > 1 && strcmp(argv[1], "--resume") == 0) {
FILE *f = fopen(CKPT_PATH, "rb");
if (f) {
CkptHdr h; fread(&h, sizeof(h), 1, f);
start_step=h.step; total_steps=h.total_steps; lr=h.lr;
cum_compile=h.cum_compile; cum_train=h.cum_train; cum_wall=h.cum_wall;
cum_steps=h.cum_steps; cum_batches=h.cum_batches; adam_t=h.adam_t;
fread(Wq,4,wq_sz,f); fread(Wk,4,wq_sz,f); fread(Wv,4,wq_sz,f); fread(Wo,4,wo_sz,f);
fread(W1,4,w1_sz,f); fread(W2,4,w2_sz,f); fread(W3,4,w3_sz,f);
fread(rms1_w,4,DIM,f); fread(rms2_w,4,DIM,f);
// Adam state
fread(aWq.m,4,wq_sz,f);fread(aWq.v,4,wq_sz,f);
fread(aWk.m,4,wq_sz,f);fread(aWk.v,4,wq_sz,f);
fread(aWv.m,4,wq_sz,f);fread(aWv.v,4,wq_sz,f);
fread(aWo.m,4,wo_sz,f);fread(aWo.v,4,wo_sz,f);
fread(aW1.m,4,w1_sz,f);fread(aW1.v,4,w1_sz,f);
fread(aW2.m,4,w2_sz,f);fread(aW2.v,4,w2_sz,f);
fread(aW3.m,4,w3_sz,f);fread(aW3.v,4,w3_sz,f);
fread(arms1.m,4,DIM,f);fread(arms1.v,4,DIM,f);
fread(arms2.m,4,DIM,f);fread(arms2.v,4,DIM,f);
fclose(f);
resuming = true;
printf("[RESUMED step %d, loss=%.6f]\n", start_step, h.loss);
}
}
if (!resuming) {
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
for(size_t i=0;i<wq_sz;i++){Wq[i]=scale_d*(2*drand48()-1);Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<wq_sz;i++){Wv[i]=scale_d*(2*drand48()-1);Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<w1_sz;i++) W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<w2_sz;i++) W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<w3_sz;i++) W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){rms1_w[i]=1.0f; rms2_w[i]=1.0f;}
}
if (!resuming) {
// FLOP accounting: 7 weight matrices, each 2*OC*IC*SEQ for forward
double fwd_flops = 4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ;
double bwd_dx_flops = fwd_flops; // same matmuls transposed
double bwd_dw_flops = fwd_flops; // dW = dy^T @ x, same FLOPs
double sdpa_flops = 2.0*HEADS*5*SEQ*SEQ*HD; // 5 SEQ×SEQ matmuls in backward
double total_flops = fwd_flops + bwd_dx_flops + bwd_dw_flops + sdpa_flops;
double ane_flops_step = fwd_flops + bwd_dx_flops + sdpa_flops;
printf("=== ANE Training: Fully-ANE Pipeline ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d\n", DIM, HIDDEN, HEADS, SEQ);
printf("Params: %.2fM | Weights: %.1fMB FP16\n", total_params/1e6, total_params*2.0/1e6);
printf("Kernels: %d (fwdAttn+fwdFFN+ffnBwd+sdpaBwd1+sdpaBwd2+qkvBwd)\n", NUM_KERNELS);
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
printf("FLOPs/step: fwd=%.0fM bwd_dx=%.0fM bwd_dW=%.0fM sdpa_bwd=%.0fM total=%.0fM\n",
fwd_flops/1e6, bwd_dx_flops/1e6, bwd_dw_flops/1e6, sdpa_flops/1e6, total_flops/1e6);
printf("ANE FLOPs/step: %.0fM (fwd+bwd_dx+sdpa_bwd) | CPU: dW (cblas)\n\n", ane_flops_step/1e6);
}
// Training data
float *x_in=malloc(SEQ*DIM*4), *y_tgt=malloc(SEQ*DIM*4);
// Training data in channel-first [C,S] layout
if (!resuming) srand48(42);
for(int c=0;c<DIM;c++) for(int t=0;t<SEQ;t++) {
int idx = c*SEQ+t;
x_in[idx]=0.1f*(2*drand48()-1);
y_tgt[idx]=0.1f*sinf(idx*0.03f+1.0f);
}
// Activation buffers (saved from forward for backward)
float *xnorm=malloc(SEQ*DIM*4);
float *Q=malloc(SEQ*DIM*4), *K=malloc(SEQ*DIM*4), *V=malloc(SEQ*DIM*4);
float *attn_out=malloc(SEQ*DIM*4), *o_out=malloc(SEQ*DIM*4);
float *x2=malloc(SEQ*DIM*4), *x2norm=malloc(SEQ*DIM*4);
float *h1=malloc(SEQ*HIDDEN*4), *h3=malloc(SEQ*HIDDEN*4), *silu_out=malloc(SEQ*HIDDEN*4);
float *ffn_out=malloc(SEQ*DIM*4), *y_out=malloc(SEQ*DIM*4);
// Gradient buffers
float *dy=malloc(SEQ*DIM*4), *dffn=malloc(SEQ*DIM*4);
float *dh1=malloc(SEQ*HIDDEN*4), *dh3=malloc(SEQ*HIDDEN*4);
float *dx_ffn=malloc(SEQ*DIM*4), *dx2=malloc(SEQ*DIM*4);
float *do_out_buf=malloc(SEQ*DIM*4), *dattn=malloc(SEQ*DIM*4);
float *dq=malloc(SEQ*DIM*4), *dk=malloc(SEQ*DIM*4), *dv=malloc(SEQ*DIM*4);
float *dx_attn=malloc(SEQ*DIM*4);
// SDPA bwd intermediates
float *probs_flat=malloc(SEQ*DIM*4), *dp_flat=malloc(SEQ*DIM*4);
// Gradient accumulators
float *gWq=calloc(wq_sz,4), *gWk=calloc(wq_sz,4), *gWv=calloc(wq_sz,4), *gWo=calloc(wo_sz,4);
float *gW1=calloc(w1_sz,4), *gW2=calloc(w2_sz,4), *gW3=calloc(w3_sz,4);
float *grms1=calloc(DIM,4), *grms2=calloc(DIM,4);
// 7 ANE kernels
Kern *kFwdAttn=NULL, *kFwdFFN=NULL, *kFFNBwd=NULL;
Kern *kSdpaBwd1=NULL, *kSdpaBwd2=NULL, *kQKVb=NULL;
// Compile static (weight-free) kernels ONCE
kSdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
if (!kSdpaBwd2) { printf("Static kernel compile failed\n"); return 1; }
// GCD queue for async dW cblas (overlaps with ANE evals)
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f;
double total_compile_ms=0, total_train_ms=0;
int total_steps_done=0, total_batches=0;
uint64_t t_wall_start = mach_absolute_time();
int step = start_step;
while (step < total_steps) {
// Check compile budget — 5 weight-bearing kernels per batch
if (g_compile_count + 5 > MAX_COMPILES) {
free_kern(kFwdAttn);free_kern(kFwdFFN);free_kern(kFFNBwd);
free_kern(kSdpaBwd1);free_kern(kQKVb);
free_kern(kSdpaBwd2);
double wall = tb_ms(mach_absolute_time() - t_wall_start);
FILE *f = fopen(CKPT_PATH, "wb");
CkptHdr h = {step,total_steps,lr,last_loss,
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
total_steps_done+cum_steps, total_batches+cum_batches, adam_t};
fwrite(&h,sizeof(h),1,f);
fwrite(Wq,4,wq_sz,f);fwrite(Wk,4,wq_sz,f);fwrite(Wv,4,wq_sz,f);fwrite(Wo,4,wo_sz,f);
fwrite(W1,4,w1_sz,f);fwrite(W2,4,w2_sz,f);fwrite(W3,4,w3_sz,f);
fwrite(rms1_w,4,DIM,f);fwrite(rms2_w,4,DIM,f);
// Adam state
fwrite(aWq.m,4,wq_sz,f);fwrite(aWq.v,4,wq_sz,f);
fwrite(aWk.m,4,wq_sz,f);fwrite(aWk.v,4,wq_sz,f);
fwrite(aWv.m,4,wq_sz,f);fwrite(aWv.v,4,wq_sz,f);
fwrite(aWo.m,4,wo_sz,f);fwrite(aWo.v,4,wo_sz,f);
fwrite(aW1.m,4,w1_sz,f);fwrite(aW1.v,4,w1_sz,f);
fwrite(aW2.m,4,w2_sz,f);fwrite(aW2.v,4,w2_sz,f);
fwrite(aW3.m,4,w3_sz,f);fwrite(aW3.v,4,w3_sz,f);
fwrite(arms1.m,4,DIM,f);fwrite(arms1.v,4,DIM,f);
fwrite(arms2.m,4,DIM,f);fwrite(arms2.v,4,DIM,f);
fclose(f);
printf("[exec() restart step %d, %d compiles, loss=%.6f]\n", step, g_compile_count, last_loss);
fflush(stdout);
execl(argv[0], argv[0], "--resume", NULL);
perror("execl"); return 1;
}
// Compile 5 weight-bearing kernels (sdpaBwd2 compiled once above)
uint64_t tc = mach_absolute_time();
free_kern(kFwdAttn);free_kern(kFwdFFN);free_kern(kFFNBwd);
free_kern(kSdpaBwd1);free_kern(kQKVb);
kFwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{
@"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(rms1_w,1,DIM)},
@"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(Wq,DIM,DIM)},
@"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(Wk,DIM,DIM)},
@"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(Wv,DIM,DIM)},
@"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(Wo,DIM,DIM)},
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
}), DIM*SEQ*2, 6*DIM*SEQ*2);
kFwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{
@"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(rms2_w,1,DIM)},
@"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(W1,HIDDEN,DIM)},
@"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(W3,HIDDEN,DIM)},
@"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(W2,DIM,HIDDEN)},
}), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2);
kFFNBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{
@"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(W2,DIM,HIDDEN)},
@"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(W1,HIDDEN,DIM)},
@"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(W3,HIDDEN,DIM)},
}), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2);
kSdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
@"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(Wo,DIM,DIM)},
}), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
kQKVb = compile_kern_mil_w(gen_qkvb(), (@{
@"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(Wq,DIM,DIM)},
@"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(Wk,DIM,DIM)},
@"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(Wv,DIM,DIM)},
}), 3*DIM*SEQ*2, DIM*SEQ*2);
double cms = tb_ms(mach_absolute_time() - tc);
total_compile_ms += cms;
if (!kFwdAttn||!kFwdFFN||!kFFNBwd||!kSdpaBwd1||!kQKVb) {
printf("Compile failed at step %d, restart\n", step);
g_compile_count = MAX_COMPILES; continue;
}
// === Training loop ===
memset(gWq,0,wq_sz*4);memset(gWk,0,wq_sz*4);memset(gWv,0,wq_sz*4);memset(gWo,0,wo_sz*4);
memset(gW1,0,w1_sz*4);memset(gW2,0,w2_sz*4);memset(gW3,0,w3_sz*4);
memset(grms1,0,DIM*4);memset(grms2,0,DIM*4);
int steps_batch = 0;
uint64_t tt = mach_absolute_time();
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0;
for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
uint64_t t0,t1;
// ===== FORWARD =====
// Attention fwd (ANE does rmsnorm internally): x_in → o_out,Q,K,V,attn_out,xnorm
t0=mach_absolute_time();
io_write_fp16(kFwdAttn->ioIn, x_in, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kFwdAttn);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// Wait for prev step's dW cblas before reading attn_out/xnorm
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t1=mach_absolute_time(); t_cblas_wait+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kFwdAttn->ioOut, o_out, 0, DIM, SEQ);
io_read_fp16(kFwdAttn->ioOut, attn_out, 4*DIM, DIM, SEQ);
io_read_fp16(kFwdAttn->ioOut, xnorm, 5*DIM, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
for(int i=0;i<SEQ*DIM;i++) x2[i] = x_in[i] + o_out[i];
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// FFN fwd (ANE does rmsnorm internally): x2 → ffn_out,h1,h3,silu_out,x2norm
io_write_fp16(kFwdFFN->ioIn, x2, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kFwdFFN);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kFwdFFN->ioOut, ffn_out, 0, DIM, SEQ);
io_read_fp16(kFwdFFN->ioOut, h1, DIM, HIDDEN, SEQ);
io_read_fp16(kFwdFFN->ioOut, h3, DIM+HIDDEN, HIDDEN, SEQ);
io_read_fp16(kFwdFFN->ioOut, silu_out, DIM+2*HIDDEN, HIDDEN, SEQ);
io_read_fp16(kFwdFFN->ioOut, x2norm, DIM+3*HIDDEN, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
// Residual + Loss
for(int i=0;i<SEQ*DIM;i++) y_out[i] = x2[i] + ffn_out[i];
float loss = 0;
for(int i=0;i<SEQ*DIM;i++){
float d = y_out[i]-y_tgt[i]; loss += d*d;
dy[i] = 2.0f*d/(SEQ*DIM);
}
loss /= (SEQ*DIM);
last_loss = loss;
memcpy(dffn, dy, SEQ*DIM*4);
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// ===== BACKWARD =====
// FFN backward (ANE)
io_write_fp16_at(kFFNBwd->ioIn, 0, dffn, DIM, SEQ);
io_copy(kFFNBwd->ioIn, DIM, kFwdFFN->ioOut, DIM, 2*HIDDEN, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kFFNBwd);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kFFNBwd->ioOut, dx_ffn, 0, DIM, SEQ);
io_read_fp16(kFFNBwd->ioOut, dh1, DIM, HIDDEN, SEQ);
io_read_fp16(kFFNBwd->ioOut, dh3, DIM+HIDDEN, HIDDEN, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
// dW FFN async (overlaps with rmsnorm2_bwd + SDPA)
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
1.0f, dffn, SEQ, silu_out, SEQ, 1.0f, gW2, HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, dh1, SEQ, x2norm, SEQ, 1.0f, gW1, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, dh3, SEQ, x2norm, SEQ, 1.0f, gW3, DIM);
});
// RMSNorm2 backward — runs in parallel with dW FFN
memset(dx2, 0, SEQ*DIM*4);
rmsnorm_bwd(dx2, grms2, dx_ffn, x2, rms2_w, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
// dWo async (overlaps with SDPA backward)
memcpy(do_out_buf, dx2, SEQ*DIM*4);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, do_out_buf, SEQ, attn_out, SEQ, 1.0f, gWo, DIM);
});
// SDPA backward (ANE) — includes Wo^T conv
io_copy(kSdpaBwd1->ioIn, 0, kFwdAttn->ioOut, DIM, 3*DIM, SEQ);
io_write_fp16_at(kSdpaBwd1->ioIn, 3*DIM, dx2, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kSdpaBwd1);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_copy(kSdpaBwd2->ioIn, 0, kSdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_copy(kSdpaBwd2->ioIn, 2*SCORE_CH, kFwdAttn->ioOut, DIM, 2*DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kSdpaBwd2);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// Read dq,dk,dv — dW FFN+dWo still running async on serial queue
io_read_fp16(kSdpaBwd2->ioOut, dq, 0, DIM, SEQ);
io_read_fp16(kSdpaBwd2->ioOut, dk, DIM, DIM, SEQ);
io_read_fp16(kSdpaBwd1->ioOut, dv, 0, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
// dWq/dWk/dWv queues after dWo on serial queue — no wait needed
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, dq, SEQ, xnorm, SEQ, 1.0f, gWq, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, dk, SEQ, xnorm, SEQ, 1.0f, gWk, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, dv, SEQ, xnorm, SEQ, 1.0f, gWv, DIM);
});
// QKV backward (ANE) — dWq/dWk/dWv runs async
io_copy(kQKVb->ioIn, 0, kSdpaBwd2->ioOut, 0, 2*DIM, SEQ);
io_copy(kQKVb->ioIn, 2*DIM, kSdpaBwd1->ioOut, 0, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kQKVb);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kQKVb->ioOut, dx_attn, 0, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
// RMSNorm1 backward (CPU) — doesn't touch cblas buffers
float *dx_rms = calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms, grms1, dx_attn, x_in, rms1_w, DIM, SEQ);
free(dx_rms);
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0);
steps_batch++;
if (step % 10 == 0 || step == start_step)
printf("step %-4d loss=%.6f\n", step, loss);
}
double tms = tb_ms(mach_absolute_time() - tt);
total_train_ms += tms;
total_steps_done += steps_batch;
total_batches++;
// Ensure all async dW finished before Adam
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
// Adam update (scale gradients by 1/steps_batch for averaging)
float gsc = 1.0f / steps_batch;
for(size_t i=0;i<wq_sz;i++){gWq[i]*=gsc;gWk[i]*=gsc;gWv[i]*=gsc;gWo[i]*=gsc;}
for(size_t i=0;i<w1_sz;i++) gW1[i]*=gsc;
for(size_t i=0;i<w2_sz;i++) gW2[i]*=gsc;
for(size_t i=0;i<w3_sz;i++) gW3[i]*=gsc;
for(int i=0;i<DIM;i++){grms1[i]*=gsc; grms2[i]*=gsc;}
adam_t++;
adam_update(Wq, gWq, &aWq, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(Wk, gWk, &aWk, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(Wv, gWv, &aWv, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(Wo, gWo, &aWo, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(W1, gW1, &aW1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(W2, gW2, &aW2, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(W3, gW3, &aW3, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(rms1_w, grms1, &arms1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(rms2_w, grms2, &arms2, adam_t, lr, adam_b1, adam_b2, adam_eps);
printf(" [batch %d: compile=%.0fms train=%.1fms (%.1fms/step) compiles=%d]\n",
steps_batch, cms, tms, tms/steps_batch, g_compile_count);
printf(" ane=%.1f io=%.1f elem=%.1f rms=%.1f cblas_wait=%.1f ms/step\n",
t_ane/steps_batch, t_io/steps_batch, t_elem/steps_batch,
t_rms/steps_batch, t_cblas_wait/steps_batch);
}
// === Efficiency Report ===
double wall = tb_ms(mach_absolute_time() - t_wall_start);
total_compile_ms += cum_compile; total_train_ms += cum_train;
wall += cum_wall; total_steps_done += cum_steps; total_batches += cum_batches;
double fwd_flops = 4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ;
double bwd_dx_flops = fwd_flops;
double sdpa_flops = 2.0*HEADS*5*SEQ*SEQ*HD;
double ane_flops = (fwd_flops + bwd_dx_flops + sdpa_flops) * total_steps_done;
double total_flops = (fwd_flops*3 + sdpa_flops) * total_steps_done; // fwd+bwd_dx+bwd_dw+sdpa
printf("\n=== Efficiency Report ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Wall time: %.0f ms (%.1f s)\n", wall, wall/1000);
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100*total_compile_ms/wall);
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100*total_train_ms/wall);
printf("Avg compile: %.0f ms per batch (5 kernels)\n", total_compile_ms/total_batches);
printf("Avg train: %.1f ms/step\n", total_train_ms/total_steps_done);
printf("ANE TFLOPS: %.2f sustained\n", ane_flops / (total_train_ms * 1e9));
printf("Total TFLOPS: %.2f (ANE+CPU)\n", total_flops / (total_train_ms * 1e9));
printf("ANE utilization: %.1f%% of 15.8 TFLOPS\n", 100*ane_flops/(total_train_ms*1e9)/15.8);
printf("Params: %.2fM Weights: %.1fMB FP16\n", total_params/1e6, total_params*2.0/1e6);
// Cleanup
free_kern(kFwdAttn);free_kern(kFwdFFN);free_kern(kFFNBwd);
free_kern(kSdpaBwd1);free_kern(kSdpaBwd2);free_kern(kQKVb);
free(Wq);free(Wk);free(Wv);free(Wo);free(W1);free(W2);free(W3);
free(rms1_w);free(rms2_w);free(x_in);free(y_tgt);
free(xnorm);free(Q);free(K);free(V);free(attn_out);free(o_out);
free(x2);free(x2norm);free(h1);free(h3);free(silu_out);free(ffn_out);free(y_out);
free(dy);free(dffn);free(dh1);free(dh3);free(dx_ffn);free(dx2);
free(do_out_buf);free(dattn);free(dq);free(dk);free(dv);free(dx_attn);
free(probs_flat);free(dp_flat);
free(gWq);free(gWk);free(gWv);free(gWo);free(gW1);free(gW2);free(gW3);
free(grms1);free(grms2);
adam_free(&aWq);adam_free(&aWk);adam_free(&aWv);adam_free(&aWo);
adam_free(&aW1);adam_free(&aW2);adam_free(&aW3);adam_free(&arms1);adam_free(&arms2);
unlink(CKPT_PATH);
}
return 0;
}