// stories_mil.h — MIL program generators for ANE kernels (Weights-as-Tensors version) #pragma once #include "stories_io.h" #define MIL_HDR \ @"program(1.3)\n[buildInfo = dict({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \ "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \ "{\"coremltools-version\", \"9.0\"}})]\n{\n" #define CONV_CONST \ " string pt = const()[name=string(\"pt\"), val=string(\"valid\")];\n" \ " tensor st = const()[name=string(\"st\"), val=tensor([1,1])];\n" \ " tensor pd = const()[name=string(\"pd\"), val=tensor([0,0,0,0])];\n" \ " tensor dl = const()[name=string(\"dl\"), val=tensor([1,1])];\n" \ " int32 gr = const()[name=string(\"gr\"), val=int32(1)];\n" // SDPA forward flex: x, rw, Wq, Wk, Wv, Wo, cm static NSString *gen_sdpa_fwd_flex(void) { float sc = 1.0f/sqrtf((float)HD); float invd = 1.0f/(float)DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x, " "tensor rw, " "tensor Wq, " "tensor Wk, " "tensor Wv, " "tensor Wo, " "tensor cm) {\n", DIM, SEQ, DIM, DIM, DIM, DIM, DIM, DIM, DIM, DIM, DIM, SEQ, SEQ]; [m appendFormat:@" tensor sq = mul(x=x,y=x);\n", DIM, SEQ]; [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd);\n", SEQ]; [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd);\n", SEQ]; [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps);\n", SEQ]; [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf);\n", SEQ]; [m appendFormat:@" tensor xr = mul(x=x,y=rrms);\n", DIM, SEQ]; [m appendFormat:@" tensor xn = mul(x=xr,y=rw);\n", DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn);\n", DIM,SEQ]; [m appendFormat:@" tensor kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn);\n", DIM,SEQ]; [m appendFormat:@" tensor vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn);\n", DIM,SEQ]; [m appendFormat:@" tensor qsh = const()[name=string(\"qsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor q4 = reshape(shape=qsh,x=qf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=q4);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor k4 = reshape(shape=qsh,x=kf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=k4);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor v4 = reshape(shape=qsh,x=vf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor v = transpose(perm=pm,x=v4);\n", HEADS,SEQ,HD]; [m appendString:@" bool tx = const()[name=string(\"tx\"), val=bool(false)];\n"]; [m appendString:@" bool ty = const()[name=string(\"ty\"), val=bool(true)];\n"]; [m appendFormat:@" tensor sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor ms = add(x=sc2,y=cm);\n", HEADS,SEQ,SEQ]; [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; [m appendFormat:@" tensor aw = softmax(axis=sax,x=ms);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor at = transpose(perm=pm,x=a4);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor os = const()[name=string(\"os\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor af = reshape(shape=os,x=at);\n", DIM,SEQ]; [m appendFormat:@" tensor oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af);\n", DIM,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn));\n", 6*DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // FFN forward flex: x, rw, W1, W2, W3 static NSString *gen_ffn_fwd_flex(void) { float invd = 1.0f/(float)DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x, " "tensor rw, " "tensor W1, " "tensor W2, " "tensor W3) {\n", DIM, SEQ, DIM, HIDDEN, DIM, DIM, HIDDEN, HIDDEN, DIM]; [m appendFormat:@" tensor sq = mul(x=x,y=x);\n", DIM, SEQ]; [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd);\n", SEQ]; [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd);\n", SEQ]; [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps);\n", SEQ]; [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf);\n", SEQ]; [m appendFormat:@" tensor xr = mul(x=x,y=rrms);\n", DIM, SEQ]; [m appendFormat:@" tensor xn = mul(x=xr,y=rw);\n", DIM, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn);\n", HIDDEN,SEQ]; [m appendFormat:@" tensor h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn);\n", HIDDEN,SEQ]; [m appendFormat:@" tensor sig = sigmoid(x=h1);\n", HIDDEN,SEQ]; [m appendFormat:@" tensor silu = mul(x=h1,y=sig);\n", HIDDEN,SEQ]; [m appendFormat:@" tensor gate = mul(x=silu,y=h3);\n", HIDDEN,SEQ]; [m appendFormat:@" tensor y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate);\n", DIM,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn));\n", 2*DIM+3*HIDDEN,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // FFN backward flex: x, W1t, W2t, W3t static NSString *gen_ffn_bwd_flex(void) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x, " "tensor W1t, " "tensor W2t, " "tensor W3t) {\n", DIM+2*HIDDEN, SEQ, DIM, HIDDEN, HIDDEN, DIM, DIM, HIDDEN]; [m appendString:@CONV_CONST]; [m appendString:@" tensor bd = const()[name=string(\"bd\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor sd = const()[name=string(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor dffn = slice_by_size(x=x,begin=bd,size=sd);\n", DIM, SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor s1 = const()[name=string(\"s1\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor h1 = slice_by_size(x=x,begin=b1,size=s1);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", DIM+HIDDEN]; [m appendFormat:@" tensor h3 = slice_by_size(x=x,begin=b3,size=s1);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor sig = sigmoid(x=h1);\n", HIDDEN, SEQ]; [m appendString:@" fp16 one = const()[name=string(\"one\"), val=fp16(1.0)];\n"]; [m appendFormat:@" tensor oms = sub(x=one,y=sig);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor homs = mul(x=h1,y=oms);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor brk = add(x=one,y=homs);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dsd = mul(x=sig,y=brk);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor t1 = mul(x=dsilu,y=h3);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh1 = mul(x=t1,y=dsd);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor slh = mul(x=h1,y=sig);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dh3 = mul(x=dsilu,y=slh);\n", HIDDEN, SEQ]; [m appendFormat:@" tensor dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1);\n", DIM, SEQ]; [m appendFormat:@" tensor dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3);\n", DIM, SEQ]; [m appendFormat:@" tensor dx = add(x=dx1,y=dx3);\n", DIM, SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3));\n", DIM+2*HIDDEN, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // QKV backward flex: x, Wqt, Wkt, Wvt static NSString *gen_qkvb_flex(void) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x, " "tensor Wqt, " "tensor Wkt, " "tensor Wvt) {\n", 3*DIM, SEQ, DIM, DIM, DIM, DIM, DIM, DIM]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor dq = slice_by_size(x=x,begin=b0,size=sz);\n", DIM,SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b1,size=sz);\n", DIM,SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b2,size=sz);\n", DIM,SEQ]; [m appendFormat:@" tensor dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq);\n", DIM, SEQ]; [m appendFormat:@" tensor dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk);\n", DIM, SEQ]; [m appendFormat:@" tensor dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv);\n", DIM, SEQ]; [m appendFormat:@" tensor dxqk = add(x=dxq,y=dxk);\n", DIM,SEQ]; [m appendFormat:@" tensor out = add(x=dxqk,y=dxv);\n", DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // SDPA backward part 1 flex: x, Wot, cm static NSString *gen_sdpa_bwd1_flex(void) { float sc = 1.0f/sqrtf((float)HD); NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x, " "tensor Wot, " "tensor cm) {\n", 4*DIM, SEQ, DIM, DIM, SEQ, SEQ]; [m appendString:@CONV_CONST]; [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz);\n", DIM,SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz);\n", DIM,SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz);\n", DIM,SEQ]; [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*DIM]; [m appendFormat:@" tensor dx2f = slice_by_size(x=x,begin=b3,size=sz);\n", DIM,SEQ]; [m appendFormat:@" tensor df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f);\n", DIM, SEQ]; [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=qr);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=kr);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor vr = reshape(shape=rsh,x=vf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor v = transpose(perm=pm,x=vr);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dr = reshape(shape=rsh,x=df);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor da = transpose(perm=pm,x=dr);\n", HEADS,SEQ,HD]; [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor ms = add(x=sc2,y=cm);\n", HEADS,SEQ,SEQ]; [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; [m appendFormat:@" tensor probs = softmax(axis=sax,x=ms);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv4);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor dvs = const()[name=string(\"dvs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt);\n", DIM,SEQ]; [m appendFormat:@" tensor scs = const()[name=string(\"scs\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor pf = reshape(shape=scs,x=probs);\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor dpf = reshape(shape=scs,x=dp4);\n", SCORE_CH,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf));\n", DIM+2*SCORE_CH,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // SDPA backward part 2 (no weights, stays the same but renamed) static NSString *gen_sdpa_bwd2_flex(void) { float sc = 1.0f/sqrtf((float)HD); int bwd2_in = 2*SCORE_CH + 2*DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; [m appendFormat:@" func main(tensor x) {\n", bwd2_in, SEQ]; [m appendFormat:@" tensor sz_sc = const()[name=string(\"szsc\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor pf = slice_by_size(x=x,begin=b0,size=sz_sc);\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", SCORE_CH]; [m appendFormat:@" tensor dpf = slice_by_size(x=x,begin=b1,size=sz_sc);\n", SCORE_CH,SEQ]; [m appendFormat:@" tensor sz_d = const()[name=string(\"szd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH]; [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_d);\n", DIM,SEQ]; [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+DIM]; [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_d);\n", DIM,SEQ]; [m appendFormat:@" tensor ssh = const()[name=string(\"ssh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor probs = reshape(shape=ssh,x=pf);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor dp = reshape(shape=ssh,x=dpf);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor q = transpose(perm=pm,x=qr);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor k = transpose(perm=pm,x=kr);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor pdp = mul(x=probs,y=dp);\n", HEADS,SEQ,SEQ]; [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([-1])];\n"]; [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; [m appendFormat:@" tensor spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd);\n", HEADS,SEQ]; [m appendFormat:@" tensor dps = sub(x=dp,y=spdp);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" tensor ds0 = mul(x=probs,y=dps);\n", HEADS,SEQ,SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor ds = mul(x=ds0,y=scv);\n", HEADS,SEQ,SEQ]; [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; [m appendFormat:@" tensor dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q);\n", HEADS,SEQ,HD]; [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq4);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk4);\n", HEADS,HD,SEQ]; [m appendFormat:@" tensor fs = const()[name=string(\"fs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt);\n", DIM,SEQ]; [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt);\n", DIM,SEQ]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf));\n", 2*DIM,SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } // Mask blob helper static NSData *get_mask_blob(void) { _Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16)); for(int t=0;t