// stories_mil.h — MIL program generators for ANE kernels // Same architecture as single-layer train_large.m but parameterized #pragma once #include "stories_io.h" #define MIL_HDR \ @"program(1.0)\n[buildInfo = dict, tensor>({{\"coremlc-version\", \"3505.4.1\"}})]\n{\n" #define CONV_CONST \ " tensor pt = const()[name=tensor(\"pt\"), val=tensor(\"valid\")];\n" \ " tensor st = const()[name=tensor(\"st\"), val=tensor([1,1])];\n" \ " tensor pd = const()[name=tensor(\"pd\"), val=tensor([0,0,0,0])];\n" \ " tensor dl = const()[name=tensor(\"dl\"), val=tensor([1,1])];\n" \ " tensor gr = const()[name=tensor(\"gr\"), val=tensor(1)];\n" // SDPA forward + taps: x_in → rmsnorm → QKV+SDPA+Wo → concat(o_out, Q, K, V, attn_out, xnorm) static NSString *gen_sdpa_fwd_taps(void) { float sc = 1.0f/sqrtf((float)HD); float invd = 1.0f/(float)DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=tensor(\"sq\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rax = const()[name=tensor(\"rax\"), val=tensor([1])];\n"]; [m appendFormat:@" tensor kd = const()[name=tensor(\"kd\"), val=tensor(true)];\n"]; [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=tensor(\"ss\")];\n", SEQ]; [m appendFormat:@" tensor invd = const()[name=tensor(\"invd\"), val=tensor(%f)];\n", invd]; [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=tensor(\"ss2\")];\n", SEQ]; [m appendFormat:@" tensor eps = const()[name=tensor(\"eps\"), val=tensor(0.00001)];\n"]; [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=tensor(\"ss3\")];\n", SEQ]; [m appendFormat:@" tensor nhalf = const()[name=tensor(\"nhalf\"), val=tensor(-0.5)];\n"]; [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=tensor(\"rrms\")];\n", SEQ]; [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=tensor(\"xr\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rw = const()[name=tensor(\"rw\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/rms1.bin\"), offset=tensor(64)))];\n", DIM, DIM]; [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=tensor(\"xn\")];\n", DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor Wq = const()[name=tensor(\"Wq\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wq.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wk = const()[name=tensor(\"Wk\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wk.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wv = const()[name=tensor(\"Wv\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wv.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wo = const()[name=tensor(\"Wo\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wo.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn)[name=tensor(\"cq\")];\n", DIM,SEQ]; [m appendFormat:@" tensor kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn)[name=tensor(\"ck\")];\n", DIM,SEQ]; [m appendFormat:@" tensor vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn)[name=tensor(\"cv\")];\n", DIM,SEQ]; [m appendFormat:@" tensor qsh = const()[name=tensor(\"qsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=tensor(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor q4 = reshape(shape=qsh,x=qf)[name=tensor(\"rq\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=q4)[name=tensor(\"tq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor k4 = reshape(shape=qsh,x=kf)[name=tensor(\"rk\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=k4)[name=tensor(\"tk\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor v4 = reshape(shape=qsh,x=vf)[name=tensor(\"rv\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=tensor(\"tv\")];\n", HEADS,SEQ,HD]; [m appendString:@" tensor tx = const()[name=tensor(\"tx\"), val=tensor(false)];\n"]; [m appendString:@" tensor ty = const()[name=tensor(\"ty\"), val=tensor(true)];\n"]; [m appendFormat:@" tensor sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k)[name=tensor(\"mm1\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor scv = const()[name=tensor(\"scv\"), val=tensor(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=tensor(\"scl\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor cm = const()[name=tensor(\"cm\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/mask.bin\"), offset=tensor(64)))];\n", SEQ,SEQ,SEQ,SEQ]; [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=tensor(\"msk\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" tensor sax = const()[name=tensor(\"sax\"), val=tensor(-1)];\n"]; [m appendFormat:@" tensor aw = softmax(axis=sax,x=ms)[name=tensor(\"sm\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v)[name=tensor(\"mm2\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor at = transpose(perm=pm,x=a4)[name=tensor(\"ta\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor os = const()[name=tensor(\"os\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor af = reshape(shape=os,x=at)[name=tensor(\"ra\")];\n", DIM,SEQ]; [m appendFormat:@" tensor oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af)[name=tensor(\"co\")];\n", DIM,SEQ]; [m appendString:@" tensor cax = const()[name=tensor(\"cax\"), val=tensor(1)];\n"]; [m appendString:@" tensor cid = const()[name=tensor(\"cid\"), val=tensor(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=tensor(\"cat\")];\n", 6*DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // FFN forward + taps: x2 → rmsnorm → FFN → concat(ffn_out, h1, h3, silu_out, x2norm) static NSString *gen_ffn_fwd_taps(void) { float invd = 1.0f/(float)DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=tensor(\"sq\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rax = const()[name=tensor(\"rax\"), val=tensor([1])];\n"]; [m appendFormat:@" tensor kd = const()[name=tensor(\"kd\"), val=tensor(true)];\n"]; [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=tensor(\"ss\")];\n", SEQ]; [m appendFormat:@" tensor invd = const()[name=tensor(\"invd\"), val=tensor(%f)];\n", invd]; [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=tensor(\"ss2\")];\n", SEQ]; [m appendFormat:@" tensor eps = const()[name=tensor(\"eps\"), val=tensor(0.00001)];\n"]; [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=tensor(\"ss3\")];\n", SEQ]; [m appendFormat:@" tensor nhalf = const()[name=tensor(\"nhalf\"), val=tensor(-0.5)];\n"]; [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=tensor(\"rrms\")];\n", SEQ]; [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=tensor(\"xr\")];\n", DIM, SEQ]; [m appendFormat:@" tensor rw = const()[name=tensor(\"rw\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/rms2.bin\"), offset=tensor(64)))];\n", DIM, DIM]; [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=tensor(\"xn\")];\n", DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor W1 = const()[name=tensor(\"W1\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/w1.bin\"), offset=tensor(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; [m appendFormat:@" tensor W3 = const()[name=tensor(\"W3\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/w3.bin\"), offset=tensor(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; [m appendFormat:@" tensor W2 = const()[name=tensor(\"W2\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/w2.bin\"), offset=tensor(64)))];\n", DIM,HIDDEN,DIM,HIDDEN]; [m appendFormat:@" tensor h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn)[name=tensor(\"c1\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn)[name=tensor(\"c3\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=tensor(\"sg\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=tensor(\"si\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=tensor(\"gt\")];\n", HIDDEN,SEQ]; [m appendFormat:@" tensor y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate)[name=tensor(\"c2\")];\n", DIM,SEQ]; [m appendString:@" tensor cax = const()[name=tensor(\"cax\"), val=tensor(1)];\n"]; [m appendString:@" tensor cid = const()[name=tensor(\"cid\"), val=tensor(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn))[name=tensor(\"cat\")];\n", 2*DIM+3*HIDDEN,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // FFN backward: concat(dffn,h1,h3) → concat(dx,dh1,dh3) static NSString *gen_ffn_bwd(void) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", DIM+2*HIDDEN, SEQ]; [m appendString:@CONV_CONST]; [m appendString:@" tensor bd = const()[name=tensor(\"bd\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor sd = const()[name=tensor(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor dffn = slice_by_size(x=x,begin=bd,size=sd)[name=tensor(\"s0\")];\n", DIM, SEQ]; [m appendFormat:@" tensor b1 = const()[name=tensor(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor s1 = const()[name=tensor(\"s1\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor h1 = slice_by_size(x=x,begin=b1,size=s1)[name=tensor(\"s1x\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor b3 = const()[name=tensor(\"b3\"), val=tensor([0,%d,0,0])];\n", DIM+HIDDEN]; [m appendFormat:@" tensor h3 = slice_by_size(x=x,begin=b3,size=s1)[name=tensor(\"s3x\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor W2t = const()[name=tensor(\"W2t\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/w2t.bin\"), offset=tensor(64)))];\n", HIDDEN, DIM, HIDDEN, DIM]; [m appendFormat:@" tensor dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn)[name=tensor(\"cw2\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=tensor(\"sg\")];\n", HIDDEN, SEQ]; [m appendString:@" tensor one = const()[name=tensor(\"one\"), val=tensor(1.0)];\n"]; [m appendFormat:@" tensor oms = sub(x=one,y=sig)[name=tensor(\"oms\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor homs = mul(x=h1,y=oms)[name=tensor(\"homs\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor brk = add(x=one,y=homs)[name=tensor(\"brk\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dsd = mul(x=sig,y=brk)[name=tensor(\"dsd\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor t1 = mul(x=dsilu,y=h3)[name=tensor(\"t1\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh1 = mul(x=t1,y=dsd)[name=tensor(\"dh1\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor slh = mul(x=h1,y=sig)[name=tensor(\"slh\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh3 = mul(x=dsilu,y=slh)[name=tensor(\"dh3\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor W1t = const()[name=tensor(\"W1t\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/w1t.bin\"), offset=tensor(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; [m appendFormat:@" tensor W3t = const()[name=tensor(\"W3t\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/w3t.bin\"), offset=tensor(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; [m appendFormat:@" tensor dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1)[name=tensor(\"cw1\")];\n", DIM, SEQ]; [m appendFormat:@" tensor dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3)[name=tensor(\"cw3\")];\n", DIM, SEQ]; [m appendFormat:@" tensor dx = add(x=dx1,y=dx3)[name=tensor(\"adx\")];\n", DIM, SEQ]; [m appendString:@" tensor cax = const()[name=tensor(\"cax\"), val=tensor(1)];\n"]; [m appendString:@" tensor cid = const()[name=tensor(\"cid\"), val=tensor(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3))[name=tensor(\"cat\")];\n", DIM+2*HIDDEN, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // QKV backward: concat(dq,dk,dv) → dx static NSString *gen_qkvb(void) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", 3*DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor sz = const()[name=tensor(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=tensor(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor dq = slice_by_size(x=x,begin=b0,size=sz)[name=tensor(\"s0\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b1 = const()[name=tensor(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b1,size=sz)[name=tensor(\"s1\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b2 = const()[name=tensor(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b2,size=sz)[name=tensor(\"s2\")];\n", DIM,SEQ]; [m appendFormat:@" tensor Wqt = const()[name=tensor(\"Wqt\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wqt.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wkt = const()[name=tensor(\"Wkt\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wkt.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor Wvt = const()[name=tensor(\"Wvt\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wvt.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq)[name=tensor(\"cq\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk)[name=tensor(\"ck\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv)[name=tensor(\"cv\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dxqk = add(x=dxq,y=dxk)[name=tensor(\"aqk\")];\n", DIM,SEQ]; [m appendFormat:@" tensor out = add(x=dxqk,y=dxv)[name=tensor(\"out\")];\n", DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // SDPA backward part 1 + Wo^T static NSString *gen_sdpa_bwd1(void) { float sc = 1.0f/sqrtf((float)HD); NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", 4*DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor sz = const()[name=tensor(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=tensor(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz)[name=tensor(\"s0\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b1 = const()[name=tensor(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz)[name=tensor(\"s1\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b2 = const()[name=tensor(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz)[name=tensor(\"s2\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b3 = const()[name=tensor(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*DIM]; [m appendFormat:@" tensor dx2f = slice_by_size(x=x,begin=b3,size=sz)[name=tensor(\"s3\")];\n", DIM,SEQ]; [m appendFormat:@" tensor Wot = const()[name=tensor(\"Wot\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/wot.bin\"), offset=tensor(64)))];\n", DIM,DIM,DIM,DIM]; [m appendFormat:@" tensor df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f)[name=tensor(\"cwo\")];\n", DIM,SEQ]; [m appendFormat:@" tensor rsh = const()[name=tensor(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=tensor(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=tensor(\"rq\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=tensor(\"tq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=tensor(\"rk\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=tensor(\"tk\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor vr = reshape(shape=rsh,x=vf)[name=tensor(\"rv\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor v = transpose(perm=pm,x=vr)[name=tensor(\"tv\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dr = reshape(shape=rsh,x=df)[name=tensor(\"rd\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor da = transpose(perm=pm,x=dr)[name=tensor(\"td\")];\n", HEADS,SEQ,HD]; [m appendString:@" tensor bF = const()[name=tensor(\"bF\"), val=tensor(false)];\n"]; [m appendString:@" tensor bT = const()[name=tensor(\"bT\"), val=tensor(true)];\n"]; [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=tensor(\"mm1\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor scv = const()[name=tensor(\"scv\"), val=tensor(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=tensor(\"scl\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor cm = const()[name=tensor(\"cm\"), val=tensor(BLOBFILE(path=tensor(\"@model_path/weights/mask.bin\"), offset=tensor(64)))];\n", SEQ,SEQ,SEQ,SEQ]; [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=tensor(\"msk\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" tensor sax = const()[name=tensor(\"sax\"), val=tensor(-1)];\n"]; [m appendFormat:@" tensor probs = softmax(axis=sax,x=ms)[name=tensor(\"sm\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da)[name=tensor(\"dv\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v)[name=tensor(\"dp\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv4)[name=tensor(\"dvt\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor dvs = const()[name=tensor(\"dvs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt)[name=tensor(\"dvf\")];\n", DIM,SEQ]; [m appendFormat:@" tensor scs = const()[name=tensor(\"scs\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor pf = reshape(shape=scs,x=probs)[name=tensor(\"pf\")];\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor dpf = reshape(shape=scs,x=dp4)[name=tensor(\"dpf\")];\n", SCORE_CH,SEQ]; [m appendString:@" tensor cax = const()[name=tensor(\"cax\"), val=tensor(1)];\n"]; [m appendString:@" tensor cid = const()[name=tensor(\"cid\"), val=tensor(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=tensor(\"cat\")];\n", DIM+2*SCORE_CH,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // SDPA backward part 2: concat(probs,dp,Q,K) → concat(dQ,dK) static NSString *gen_sdpa_bwd2(void) { float sc = 1.0f/sqrtf((float)HD); int bwd2_in = 2*SCORE_CH + 2*DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", bwd2_in, SEQ]; [m appendFormat:@" tensor sz_sc = const()[name=tensor(\"szsc\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; [m appendString:@" tensor b0 = const()[name=tensor(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=tensor(\"s0\")];\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor b1 = const()[name=tensor(\"b1\"), val=tensor([0,%d,0,0])];\n", SCORE_CH]; [m appendFormat:@" tensor dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=tensor(\"s1\")];\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor sz_d = const()[name=tensor(\"szd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor b2 = const()[name=tensor(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH]; [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=tensor(\"s2\")];\n", DIM,SEQ]; [m appendFormat:@" tensor b3 = const()[name=tensor(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+DIM]; [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=tensor(\"s3\")];\n", DIM,SEQ]; [m appendFormat:@" tensor ssh = const()[name=tensor(\"ssh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor probs = reshape(shape=ssh,x=pf)[name=tensor(\"rp\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dp = reshape(shape=ssh,x=dpf)[name=tensor(\"rdp\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor rsh = const()[name=tensor(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=tensor(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=tensor(\"rq\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=tensor(\"tq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=tensor(\"rk\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=tensor(\"tk\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor pdp = mul(x=probs,y=dp)[name=tensor(\"pdp\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" tensor rax = const()[name=tensor(\"rax\"), val=tensor([-1])];\n"]; [m appendString:@" tensor kd = const()[name=tensor(\"kd\"), val=tensor(true)];\n"]; [m appendFormat:@" tensor spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=tensor(\"rs\")];\n", HEADS,SEQ]; [m appendFormat:@" tensor dps = sub(x=dp,y=spdp)[name=tensor(\"dps\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor ds0 = mul(x=probs,y=dps)[name=tensor(\"ds0\")];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor scv = const()[name=tensor(\"scv\"), val=tensor(%f)];\n", sc]; [m appendFormat:@" tensor ds = mul(x=ds0,y=scv)[name=tensor(\"ds\")];\n", HEADS,SEQ,SEQ]; [m appendString:@" tensor bF = const()[name=tensor(\"bF\"), val=tensor(false)];\n"]; [m appendString:@" tensor bT = const()[name=tensor(\"bT\"), val=tensor(true)];\n"]; [m appendFormat:@" tensor dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=tensor(\"dq\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=tensor(\"dk\")];\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq4)[name=tensor(\"dqt\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk4)[name=tensor(\"dkt\")];\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor fs = const()[name=tensor(\"fs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt)[name=tensor(\"dqf\")];\n", DIM,SEQ]; [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt)[name=tensor(\"dkf\")];\n", DIM,SEQ]; [m appendString:@" tensor cax = const()[name=tensor(\"cax\"), val=tensor(1)];\n"]; [m appendString:@" tensor cid = const()[name=tensor(\"cid\"), val=tensor(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=tensor(\"cat\")];\n", 2*DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // Mask blob (causal mask [SEQ,SEQ]) static NSData *g_mask_blob = nil; static NSData *get_mask_blob(void) { if (!g_mask_blob) { _Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16)); for(int t=0;t