// train_large.m — Train a single transformer layer FULLY on ANE // 7 ANE kernels per step: // Forward: kFwdAttn (QKV+SDPA+Wo, taps Q,K,V,attn_out) + kFwdFFN (W1+W3+SiLU+W2, taps h1,h3,silu_out) // Backward: kFFNBwd (W2^T+SiLU_bwd+W1^T+W3^T) + kSdpaBwd1 (Wo^T+SDPA) + kSdpaBwd2 + kQKVb (Wq^T+Wk^T+Wv^T) // CPU: RMSNorm (fwd+bwd), residuals, loss, dW accumulation (cblas), SGD update // NO CPU recompute of Q,K,V,h1,h3 — all exposed via forward taps #import #import #import #import #import #import #import #include #include #include #define DIM 768 #define HIDDEN 2048 #define HEADS 12 #define HD (DIM/HEADS) #define SEQ 512 #define ACCUM_STEPS 100 #define MAX_COMPILES 100 #define NUM_KERNELS 6 #define CKPT_PATH "/tmp/ane_large_ckpt.bin" static Class g_D, g_I, g_AR, g_AIO; static mach_timebase_info_data_t g_tb; static int g_compile_count = 0; static void ane_init(void) { dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW); g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor"); g_I = NSClassFromString(@"_ANEInMemoryModel"); g_AR = NSClassFromString(@"_ANERequest"); g_AIO= NSClassFromString(@"_ANEIOSurfaceObject"); } static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; } static IOSurfaceRef make_surface(size_t bytes) { return IOSurfaceCreate((__bridge CFDictionaryRef)@{ (id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1, (id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes), (id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0}); } static NSData *build_blob(const float *w, int rows, int cols) { int ws=rows*cols*2, tot=128+ws; uint8_t *b=(uint8_t*)calloc(tot,1); b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1; *(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128; _Float16 *fp16=(_Float16*)(b+128); for(int i=0;i({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \ "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \ "{\"coremltools-version\", \"9.0\"}})]\n{\n" #define CONV_CONST \ " string pt = const()[name=string(\"pt\"), val=string(\"valid\")];\n" \ " tensor st = const()[name=string(\"st\"), val=tensor([1,1])];\n" \ " tensor pd = const()[name=string(\"pd\"), val=tensor([0,0,0,0])];\n" \ " tensor dl = const()[name=string(\"dl\"), val=tensor([1,1])];\n" \ " int32 gr = const()[name=string(\"gr\"), val=int32(1)];\n" // SDPA forward + taps: x_in → rmsnorm → QKV+SDPA+Wo → concat(o_out, Q, K, V, attn_out, xnorm) fp16 static NSString *gen_sdpa_fwd_taps(void) { float sc = 1.0f/sqrtf((float)HD); float invd = 1.0f/(float)DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; // --- RMSNorm: x → xn --- [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms1.bin\"), offset=uint64(64)))];\n", DIM, DIM]; [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ]; // --- QKV + SDPA + Wo (operates on xn) --- [m appendString:@CONV_CONST]; [m appendFormat:@" tensor Wq = const()[name=string(\"Wq\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wq.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wk = const()[name=string(\"Wk\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wk.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wv = const()[name=string(\"Wv\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wv.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wo = const()[name=string(\"Wo\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wo.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn)[name=string(\"cq\")];\n", DIM,SEQ]; [m appendFormat:@" tensor kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn)[name=string(\"ck\")];\n", DIM,SEQ]; [m appendFormat:@" tensor vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn)[name=string(\"cv\")];\n", DIM,SEQ]; [m appendFormat:@" tensor qsh = const()[name=string(\"qsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS,SEQ,HD]; [m appendString:@" bool tx = const()[name=string(\"tx\"), val=bool(false)];\n"]; [m appendString:@" bool ty = const()[name=string(\"ty\"), val=bool(true)];\n"]; [m appendFormat:@" tensor sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ]; [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; [m appendFormat:@" tensor aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor os = const()[name=string(\"os\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM,SEQ]; [m appendFormat:@" tensor oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af)[name=string(\"co\")];\n", DIM,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // FFN forward + taps: x2 → rmsnorm → FFN → concat(ffn_out, h1, h3, silu_out, x2norm) fp16 static NSString *gen_ffn_fwd_taps(void) { float invd = 1.0f/(float)DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; // --- RMSNorm: x → xn --- [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms2.bin\"), offset=uint64(64)))];\n", DIM, DIM]; [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ]; // --- FFN (operates on xn) --- [m appendString:@CONV_CONST]; [m appendFormat:@" tensor W1 = const()[name=string(\"W1\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w1.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; [m appendFormat:@" tensor W3 = const()[name=string(\"W3\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w3.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; [m appendFormat:@" tensor W2 = const()[name=string(\"W2\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w2.bin\"), offset=uint64(64)))];\n", DIM,HIDDEN,DIM,HIDDEN]; [m appendFormat:@" tensor h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn)[name=string(\"c1\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn)[name=string(\"c3\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate)[name=string(\"c2\")];\n", DIM,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn))[name=string(\"cat\")];\n", 2*DIM+3*HIDDEN,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // Fused FFN backward: concat(dffn,h1,h3) → concat(dx,dh1,dh3) fp16 static NSString *gen_ffn_bwd(void) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", DIM+2*HIDDEN, SEQ]; [m appendString:@CONV_CONST]; [m appendString:@" tensor bd = const()[name=string(\"bd\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor sd = const()[name=string(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor dffn = slice_by_size(x=x,begin=bd,size=sd)[name=string(\"s0\")];\n", DIM, SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor s1 = const()[name=string(\"s1\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor h1 = slice_by_size(x=x,begin=b1,size=s1)[name=string(\"s1x\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", DIM+HIDDEN]; [m appendFormat:@" tensor h3 = slice_by_size(x=x,begin=b3,size=s1)[name=string(\"s3x\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor W2t = const()[name=string(\"W2t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w2t.bin\"), offset=uint64(64)))];\n", HIDDEN, DIM, HIDDEN, DIM]; [m appendFormat:@" tensor dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn)[name=string(\"cw2\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ]; [m appendString:@" fp16 one = const()[name=string(\"one\"), val=fp16(1.0)];\n"]; [m appendFormat:@" tensor oms = sub(x=one,y=sig)[name=string(\"oms\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor homs = mul(x=h1,y=oms)[name=string(\"homs\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor brk = add(x=one,y=homs)[name=string(\"brk\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dsd = mul(x=sig,y=brk)[name=string(\"dsd\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor t1 = mul(x=dsilu,y=h3)[name=string(\"t1\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh1 = mul(x=t1,y=dsd)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor slh = mul(x=h1,y=sig)[name=string(\"slh\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh3 = mul(x=dsilu,y=slh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor W1t = const()[name=string(\"W1t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w1t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; [m appendFormat:@" tensor W3t = const()[name=string(\"W3t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w3t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; [m appendFormat:@" tensor dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1)[name=string(\"cw1\")];\n", DIM, SEQ]; [m appendFormat:@" tensor dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3)[name=string(\"cw3\")];\n", DIM, SEQ]; [m appendFormat:@" tensor dx = add(x=dx1,y=dx3)[name=string(\"adx\")];\n", DIM, SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3))[name=string(\"cat\")];\n", DIM+2*HIDDEN, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // Fused QKV backward: concat(dq,dk,dv) → dx fp16 static NSString *gen_qkvb(void) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", 3*DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor dq = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ]; [m appendFormat:@" tensor Wqt = const()[name=string(\"Wqt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wqt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wkt = const()[name=string(\"Wkt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wkt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wvt = const()[name=string(\"Wvt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wvt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq)[name=string(\"cq\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk)[name=string(\"ck\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv)[name=string(\"cv\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", DIM,SEQ]; [m appendFormat:@" tensor out = add(x=dxqk,y=dxv)[name=string(\"out\")];\n", DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // SDPA backward part 1 + Wo^T: concat(Q,K,V,dx2) → Wo^T(dx2) → concat(dV, probs_flat, dp_flat) fp16 // SCORE_CH: channels needed for flattened attention scores [HEADS,SEQ,SEQ] → [HEADS*SEQ, SEQ] #define SCORE_CH (HEADS*SEQ) static NSString *gen_sdpa_bwd1(void) { float sc = 1.0f/sqrtf((float)HD); NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", 4*DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*DIM]; [m appendFormat:@" tensor dx2f = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM,SEQ]; // Wo^T backward: dx2 → dattn [m appendFormat:@" tensor Wot = const()[name=string(\"Wot\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wot.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f)[name=string(\"cwo\")];\n", DIM,SEQ]; [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dr = reshape(shape=rsh,x=df)[name=string(\"rd\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor da = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS,SEQ,HD]; [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ]; [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; [m appendFormat:@" tensor probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da)[name=string(\"dv\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v)[name=string(\"dp\")];\n", HEADS,SEQ,SEQ]; // Flatten dv back to [1,DIM,1,SEQ] [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor dvs = const()[name=string(\"dvs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM,SEQ]; // Flatten probs [1,H,S,S] → [1,H*S,1,S] and dp [1,H,S,S] → [1,H*S,1,S] [m appendFormat:@" tensor scs = const()[name=string(\"scs\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // SDPA backward part 2: concat(probs[SCORE_CH],dp[SCORE_CH],Q[DIM],K[DIM]) → concat(dQ,dK) fp16 static NSString *gen_sdpa_bwd2(void) { float sc = 1.0f/sqrtf((float)HD); int bwd2_in = 2*SCORE_CH + 2*DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", bwd2_in, SEQ]; // Slice probs [m appendFormat:@" tensor sz_sc = const()[name=string(\"szsc\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH,SEQ]; // Slice dp [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", SCORE_CH]; [m appendFormat:@" tensor dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH,SEQ]; // Slice Q [m appendFormat:@" tensor sz_d = const()[name=string(\"szd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH]; [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM,SEQ]; // Slice K [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+DIM]; [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM,SEQ]; // Reshape to multi-head [m appendFormat:@" tensor ssh = const()[name=string(\"ssh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; // Softmax grad: ds = probs * (dp - sum(probs*dp)) * scale [m appendFormat:@" tensor pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([-1])];\n"]; [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; [m appendFormat:@" tensor spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS,SEQ]; [m appendFormat:@" tensor dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; [m appendFormat:@" tensor dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor fs = const()[name=string(\"fs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // ===== Weight builders ===== static NSData *g_mask_blob = nil; static NSData *get_mask_blob(void) { if (!g_mask_blob) { _Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16)); for(int t=0;tmodel = CFBridgingRetain(mdl); k->ioIn = make_surface(ic_bytes); k->ioOut = make_surface(oc_bytes); id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn); id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut); k->request = CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR, @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), @[wI], @[@0], @[wO], @[@0], nil, nil, @0)); k->tmpDir = CFBridgingRetain(td); return k; } } static void free_kern(Kern *k) { if (!k) return; id mdl = (__bridge id)k->model; NSError *e = nil; ((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e); CFRelease(k->ioIn); CFRelease(k->ioOut); [[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil]; CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir); free(k); } static void ane_eval(Kern *k) { id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil; ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); } // ===== Vectorized conversion helpers (NEON) ===== #include static void cvt_f16_f32(float *dst, const _Float16 *src, int n) { int i = 0; for (; i+7 < n; i += 8) { float16x8_t h = vld1q_f16((const __fp16*)(src+i)); vst1q_f32(dst+i, vcvt_f32_f16(vget_low_f16(h))); vst1q_f32(dst+i+4, vcvt_f32_f16(vget_high_f16(h))); } for (; i < n; i++) dst[i] = (float)src[i]; } static void cvt_f32_f16(_Float16 *dst, const float *src, int n) { int i = 0; for (; i+7 < n; i += 8) { float16x8_t h = vcombine_f16(vcvt_f16_f32(vld1q_f32(src+i)), vcvt_f16_f32(vld1q_f32(src+i+4))); vst1q_f16((__fp16*)(dst+i), h); } for (; i < n; i++) dst[i] = (_Float16)src[i]; } // ===== IOSurface I/O helpers (channel-first, no transpose) ===== // All CPU buffers are [C,S] channel-first matching IOSurface [1,C,1,S] // Write fp32 [C,S] → fp16 [1,C,1,S] (just type conversion, no transpose) static void io_write_fp16(IOSurfaceRef s, const float *data, int channels, int sp) { IOSurfaceLock(s, 0, NULL); _Float16 *dst = (_Float16*)IOSurfaceGetBaseAddress(s); cvt_f32_f16(dst, data, channels * sp); IOSurfaceUnlock(s, 0, NULL); } // Write fp32 [C,S] → fp32 [1,C,1,S] (just memcpy) static void io_write_fp32(IOSurfaceRef s, const float *data, int channels, int sp) { IOSurfaceLock(s, 0, NULL); memcpy(IOSurfaceGetBaseAddress(s), data, channels * sp * sizeof(float)); IOSurfaceUnlock(s, 0, NULL); } // Read fp16 [1,C,1,S] → fp32 [C,S] at channel offset (just type conversion) static void io_read_fp16(IOSurfaceRef s, float *data, int ch_off, int channels, int sp) { IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL); _Float16 *src = (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp; cvt_f16_f32(data, src, channels * sp); IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL); } // Read fp32 [1,C,1,S] → fp32 [C,S] (just memcpy) static void io_read_fp32(IOSurfaceRef s, float *data, int channels, int sp) { IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL); memcpy(data, IOSurfaceGetBaseAddress(s), channels * sp * sizeof(float)); IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL); } // Write multiple fp32 [C,S] arrays concatenated along channel dim as fp16 static void io_write_multi_fp16(IOSurfaceRef s, int sp, int n, ...) { IOSurfaceLock(s, 0, NULL); _Float16 *dst = (_Float16*)IOSurfaceGetBaseAddress(s); va_list ap; va_start(ap, n); int ch_off = 0; for (int i=0; im); free(s->v); } static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) { float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t); for (size_t i=0; in; i++) { s->m[i] = b1*s->m[i] + (1-b1)*g[i]; s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i]; float mh = s->m[i]/bc1, vh = s->v[i]/bc2; w[i] -= lr * mh / (sqrtf(vh) + eps); } } int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); ane_init(); mach_timebase_info(&g_tb); int total_steps = 400; float lr = 1e-3f; float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; int adam_t = 0; int start_step = 0; size_t wq_sz = DIM*DIM, wo_sz = DIM*DIM; size_t w1_sz = HIDDEN*DIM, w2_sz = DIM*HIDDEN, w3_sz = HIDDEN*DIM; size_t total_params = 4*wq_sz + w1_sz + w2_sz + w3_sz; float *Wq=malloc(wq_sz*4), *Wk=malloc(wq_sz*4), *Wv=malloc(wq_sz*4), *Wo=malloc(wo_sz*4); float *W1=malloc(w1_sz*4), *W2=malloc(w2_sz*4), *W3=malloc(w3_sz*4); float *rms1_w=malloc(DIM*4), *rms2_w=malloc(DIM*4); // Adam optimizer states (m and v for each weight) AdamState aWq=adam_alloc(wq_sz), aWk=adam_alloc(wq_sz), aWv=adam_alloc(wq_sz), aWo=adam_alloc(wo_sz); AdamState aW1=adam_alloc(w1_sz), aW2=adam_alloc(w2_sz), aW3=adam_alloc(w3_sz); AdamState arms1=adam_alloc(DIM), arms2=adam_alloc(DIM); double cum_compile=0, cum_train=0, cum_wall=0; int cum_steps=0, cum_batches=0; bool resuming = false; if (argc > 1 && strcmp(argv[1], "--resume") == 0) { FILE *f = fopen(CKPT_PATH, "rb"); if (f) { CkptHdr h; fread(&h, sizeof(h), 1, f); start_step=h.step; total_steps=h.total_steps; lr=h.lr; cum_compile=h.cum_compile; cum_train=h.cum_train; cum_wall=h.cum_wall; cum_steps=h.cum_steps; cum_batches=h.cum_batches; adam_t=h.adam_t; fread(Wq,4,wq_sz,f); fread(Wk,4,wq_sz,f); fread(Wv,4,wq_sz,f); fread(Wo,4,wo_sz,f); fread(W1,4,w1_sz,f); fread(W2,4,w2_sz,f); fread(W3,4,w3_sz,f); fread(rms1_w,4,DIM,f); fread(rms2_w,4,DIM,f); // Adam state fread(aWq.m,4,wq_sz,f);fread(aWq.v,4,wq_sz,f); fread(aWk.m,4,wq_sz,f);fread(aWk.v,4,wq_sz,f); fread(aWv.m,4,wq_sz,f);fread(aWv.v,4,wq_sz,f); fread(aWo.m,4,wo_sz,f);fread(aWo.v,4,wo_sz,f); fread(aW1.m,4,w1_sz,f);fread(aW1.v,4,w1_sz,f); fread(aW2.m,4,w2_sz,f);fread(aW2.v,4,w2_sz,f); fread(aW3.m,4,w3_sz,f);fread(aW3.v,4,w3_sz,f); fread(arms1.m,4,DIM,f);fread(arms1.v,4,DIM,f); fread(arms2.m,4,DIM,f);fread(arms2.v,4,DIM,f); fclose(f); resuming = true; printf("[RESUMED step %d, loss=%.6f]\n", start_step, h.loss); } } if (!resuming) { srand48(42); float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN); for(size_t i=0;i