// ane_rmsnorm_bwd.h — MIL generator for RMSNorm backward on ANE // Replaces CPU rmsnorm_bwd() from stories_cpu_ops.h // // RMSNorm forward: xn = x * rrms * w, where rrms = 1/sqrt(mean(x²) + eps) // RMSNorm backward: dx = w * rrms * (dy - x * sum(dy*w*x) * invd * rrms²) // // Input: concat(dy, x) as [1, 2*DIM, 1, SEQ] // Baked: RMSNorm weights w [1, DIM, 1, 1] as BLOBFILE // Output: dx [1, DIM, 1, SEQ] // // Note: dw (weight gradient) stays on CPU — it requires reduce_sum over SEQ // and accumulation across steps, which is cheap and better done on CPU. #pragma once #include "stories_mil.h" // Generate MIL for RMSNorm backward // Input: concat(dy, x) [1, 2*DIM, 1, SEQ] // Baked weights: rms_w [DIM] — the RMSNorm scale weights // Output: dx [1, DIM, 1, SEQ] static NSString *gen_rmsnorm_bwd(void) { float invd = 1.0f / (float)DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; // Input: concat of dy and x along channel dimension [m appendFormat:@" func main(tensor inp) {\n", 2*DIM, SEQ]; // Slice out dy [1, DIM, 1, SEQ] and x [1, DIM, 1, SEQ] [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor dy = slice_by_size(x=inp,begin=b0,size=sz)[name=string(\"sdy\")];\n", DIM, SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; [m appendFormat:@" tensor x = slice_by_size(x=inp,begin=b1,size=sz)[name=string(\"sx\")];\n", DIM, SEQ]; // Step 1: Compute rrms = 1/sqrt(mean(x²) + eps) // sq = x * x [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; // ss = sum(sq, axis=1, keepdims=true) → [1,1,1,SEQ] [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; // ss2 = ss * invd + eps [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; // rrms = pow(ss3, -0.5) → [1,1,1,SEQ] [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; // Step 2: Load RMSNorm weights w [1, DIM, 1, 1] [m appendFormat:@" tensor w = const()[name=string(\"w\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms_w.bin\"), offset=uint64(64)))];\n", DIM, DIM]; // Step 3: Compute dot = sum(dy * w * x, axis=1) * invd * rrms² // dyw = dy * w → [1, DIM, 1, SEQ] [m appendFormat:@" tensor dyw = mul(x=dy,y=w)[name=string(\"dyw\")];\n", DIM, SEQ]; // dywx = dyw * x → [1, DIM, 1, SEQ] [m appendFormat:@" tensor dywx = mul(x=dyw,y=x)[name=string(\"dywx\")];\n", DIM, SEQ]; // dot_sum = sum(dywx, axis=1, keepdims=true) → [1,1,1,SEQ] [m appendFormat:@" tensor dot_sum = reduce_sum(x=dywx,axes=rax,keep_dims=kd)[name=string(\"ds\")];\n", SEQ]; // dot_scaled = dot_sum * invd → [1,1,1,SEQ] [m appendFormat:@" tensor dot_sc = mul(x=dot_sum,y=invd)[name=string(\"dsc\")];\n", SEQ]; // rrms_sq = rrms * rrms → [1,1,1,SEQ] [m appendFormat:@" tensor rrms2 = mul(x=rrms,y=rrms)[name=string(\"rr2\")];\n", SEQ]; // coeff = dot_scaled * rrms_sq → [1,1,1,SEQ] [m appendFormat:@" tensor coeff = mul(x=dot_sc,y=rrms2)[name=string(\"cof\")];\n", SEQ]; // Step 4: dx = (dy * w - x * coeff) * rrms // x_coeff = x * coeff → [1, DIM, 1, SEQ] [m appendFormat:@" tensor xc = mul(x=x,y=coeff)[name=string(\"xc\")];\n", DIM, SEQ]; // diff = dyw - xc → [1, DIM, 1, SEQ] [m appendFormat:@" tensor diff = sub(x=dyw,y=xc)[name=string(\"dif\")];\n", DIM, SEQ]; // dx = diff * rrms → [1, DIM, 1, SEQ] [m appendFormat:@" tensor out = mul(x=diff,y=rrms)[name=string(\"out\")];\n", DIM, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; }