diff --git a/training/training_dynamic/config.h b/training/training_dynamic/config.h index d66d045..d22b6f1 100644 --- a/training/training_dynamic/config.h +++ b/training/training_dynamic/config.h @@ -62,6 +62,18 @@ typedef struct { // ANE kernel handle typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern; +// Per-layer IOSurfaces for pre-staged weights +typedef struct { + IOSurfaceRef sdpaFwd_in, ffnFused_in; + IOSurfaceRef ffnBwdW2t_in, ffnBwdW13t_in, wotBwd_in, qkvBwd_in; +} PerLayerSurfaces; + +// Per-layer ANE requests (bound to per-layer IOSurfaces) +typedef struct { + void *sdpaFwd, *ffnFused; + void *ffnBwdW2t, *ffnBwdW13t, *wotBwd, *qkvBwd; +} PerLayerRequests; + // Checkpoint header typedef struct { int magic, version, step, total_steps; diff --git a/training/training_dynamic/cpu_ops.h b/training/training_dynamic/cpu_ops.h index aed7e6f..5c446e5 100644 --- a/training/training_dynamic/cpu_ops.h +++ b/training/training_dynamic/cpu_ops.h @@ -53,13 +53,13 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c free(ss); free(rrms); free(dot); } -static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) { +static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps, float wd) { float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t); for (size_t i=0; in; i++) { s->m[i] = b1*s->m[i] + (1-b1)*g[i]; s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i]; float mh = s->m[i]/bc1, vh = s->v[i]/bc2; - w[i] -= lr * mh / (sqrtf(vh) + eps); + w[i] -= lr * (mh / (sqrtf(vh) + eps) + wd * w[i]); } } diff --git a/training/training_dynamic/io.h b/training/training_dynamic/io.h index 0a6969e..776e4b7 100644 --- a/training/training_dynamic/io.h +++ b/training/training_dynamic/io.h @@ -74,17 +74,17 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int IOSurfaceUnlock(s, 0, NULL); } -// fp32 IOSurface I/O (for dynamic matmul kernels that use fp32 input/output) +// fp16 IOSurface I/O (for dynamic matmul kernels with fp16 input/output) // Layout: [1, IC, 1, SP] where SP = SEQ + OC // Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC] static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq, const float *W, int oc) { int sp = seq + oc; IOSurfaceLock(s, 0, NULL); - float *buf = (float*)IOSurfaceGetBaseAddress(s); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); for (int d = 0; d < ic; d++) { - memcpy(buf + d*sp, act + d*seq, seq*4); - memcpy(buf + d*sp + seq, W + d*oc, oc*4); + cvt_f32_f16(buf + d*sp, act + d*seq, seq); + cvt_f32_f16(buf + d*sp + seq, W + d*oc, oc); } IOSurfaceUnlock(s, 0, NULL); } @@ -92,7 +92,7 @@ static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq, // Read output from dynamic matmul kernel: [1, OC, 1, SEQ] static void io_read_dyn(IOSurfaceRef s, float *out, int oc, int seq) { IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL); - memcpy(out, (float*)IOSurfaceGetBaseAddress(s), oc * seq * 4); + cvt_f16_f32(out, (_Float16*)IOSurfaceGetBaseAddress(s), oc * seq); IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL); } @@ -145,3 +145,201 @@ static void ane_eval(Kern *k) { id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil; ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); } + +// Evaluate with a per-layer request (different ioIn, same model) +static void ane_eval_req(Kern *k, void *request) { + id mdl = (__bridge id)k->model; id req = (__bridge id)request; NSError *e = nil; + ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); +} + +// Create an ANE request binding a custom ioIn to a kernel's model+ioOut +static void *make_request(Kern *k, IOSurfaceRef ioIn) { + id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), ioIn); + id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut); + id req = ((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR, + @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), + @[wI], @[@0], @[wO], @[@0], nil, nil, @0); + return (void*)CFBridgingRetain(req); +} + +// ===== Per-layer weight staging (write once, reuse across steps) ===== +// All surfaces are now fp16 — staging converts fp32 weights to fp16 + +// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] fp16 — weights at sp[SEQ:] +static void stage_sdpa_fwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, + const float *Wv, const float *Wo) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + 4*DIM; + for (int d = 0; d < DIM; d++) { + cvt_f32_f16(buf + d*sp + SEQ, Wq + d*DIM, DIM); + cvt_f32_f16(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM); + cvt_f32_f16(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM); + cvt_f32_f16(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM); + } + IOSurfaceUnlock(s, 0, NULL); +} +static void write_sdpa_fwd_acts(IOSurfaceRef s, const float *xnorm) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + 4*DIM; + for (int d = 0; d < DIM; d++) + cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ); + IOSurfaceUnlock(s, 0, NULL); +} + +// ffnFused: [1, DIM, 1, 2*SEQ+3*HIDDEN] fp16 +static void stage_ffn_fused_weights(IOSurfaceRef s, + const float *W1t, const float *W3t, const float *W2_orig) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = 2*SEQ + 3*HIDDEN; + for (int d = 0; d < DIM; d++) { + cvt_f32_f16(buf + d*sp + 2*SEQ, W1t + d*HIDDEN, HIDDEN); + cvt_f32_f16(buf + d*sp + 2*SEQ+HIDDEN, W3t + d*HIDDEN, HIDDEN); + cvt_f32_f16(buf + d*sp + 2*SEQ+2*HIDDEN, W2_orig + d*HIDDEN, HIDDEN); + } + IOSurfaceUnlock(s, 0, NULL); +} +static void write_ffn_fused_acts(IOSurfaceRef s, const float *x2norm, const float *x2) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = 2*SEQ + 3*HIDDEN; + for (int d = 0; d < DIM; d++) { + cvt_f32_f16(buf + d*sp, x2norm + d*SEQ, SEQ); + cvt_f32_f16(buf + d*sp + SEQ, x2 + d*SEQ, SEQ); + } + IOSurfaceUnlock(s, 0, NULL); +} + +// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] fp16 +static void stage_ffn_w13_weights(IOSurfaceRef s, const float *W1, const float *W3) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + 2*HIDDEN; + for (int d = 0; d < DIM; d++) { + cvt_f32_f16(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN); + cvt_f32_f16(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN); + } + IOSurfaceUnlock(s, 0, NULL); +} +static void write_ffn_w13_acts(IOSurfaceRef s, const float *xnorm) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + 2*HIDDEN; + for (int d = 0; d < DIM; d++) + cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ); + IOSurfaceUnlock(s, 0, NULL); +} + +// ffnW2: [1, HIDDEN, 1, SEQ+DIM] fp16 +static void stage_ffn_w2_weights(IOSurfaceRef s, const float *W2) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + DIM; + for (int d = 0; d < HIDDEN; d++) + cvt_f32_f16(buf + d*sp + SEQ, W2 + d*DIM, DIM); + IOSurfaceUnlock(s, 0, NULL); +} +static void write_ffn_w2_acts(IOSurfaceRef s, const float *gate) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + DIM; + for (int d = 0; d < HIDDEN; d++) + cvt_f32_f16(buf + d*sp, gate + d*SEQ, SEQ); + IOSurfaceUnlock(s, 0, NULL); +} + +// ffnBwdW2t: [1, DIM, 1, SEQ+HIDDEN] fp16 +static void stage_ffn_bwd_w2t_weights(IOSurfaceRef s, const float *W2) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + HIDDEN; + for (int d = 0; d < DIM; d++) + cvt_f32_f16(buf + d*sp + SEQ, W2 + d*HIDDEN, HIDDEN); + IOSurfaceUnlock(s, 0, NULL); +} +static void write_ffn_bwd_w2t_acts(IOSurfaceRef s, const float *dffn) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + HIDDEN; + for (int d = 0; d < DIM; d++) + cvt_f32_f16(buf + d*sp, dffn + d*SEQ, SEQ); + IOSurfaceUnlock(s, 0, NULL); +} + +// ffnBwdW13t: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16 +static void stage_ffn_bwd_w13t_weights(IOSurfaceRef s, const float *W1, const float *W3) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = 2*SEQ + 2*DIM; + for (int d = 0; d < HIDDEN; d++) { + cvt_f32_f16(buf + d*sp + 2*SEQ, W1 + d*DIM, DIM); + cvt_f32_f16(buf + d*sp + 2*SEQ + DIM, W3 + d*DIM, DIM); + } + IOSurfaceUnlock(s, 0, NULL); +} +static void write_ffn_bwd_w13t_acts(IOSurfaceRef s, const float *dh1, const float *dh3) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = 2*SEQ + 2*DIM; + for (int d = 0; d < HIDDEN; d++) { + cvt_f32_f16(buf + d*sp, dh1 + d*SEQ, SEQ); + cvt_f32_f16(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ); + } + IOSurfaceUnlock(s, 0, NULL); +} + +// wotBwd: [1, DIM, 1, SEQ+DIM] fp16 +static void stage_wot_bwd_weights(IOSurfaceRef s, const float *Wo) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + DIM; + for (int d = 0; d < DIM; d++) + cvt_f32_f16(buf + d*sp + SEQ, Wo + d*DIM, DIM); + IOSurfaceUnlock(s, 0, NULL); +} +static void write_wot_bwd_acts(IOSurfaceRef s, const float *dx2) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = SEQ + DIM; + for (int d = 0; d < DIM; d++) + cvt_f32_f16(buf + d*sp, dx2 + d*SEQ, SEQ); + IOSurfaceUnlock(s, 0, NULL); +} + +// qkvBwd: [1, DIM, 1, 3*SEQ+3*DIM] fp16 +static void stage_qkv_bwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, const float *Wv) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = 3*SEQ + 3*DIM; + for (int d = 0; d < DIM; d++) { + cvt_f32_f16(buf + d*sp + 3*SEQ, Wq + d*DIM, DIM); + cvt_f32_f16(buf + d*sp + 3*SEQ + DIM, Wk + d*DIM, DIM); + cvt_f32_f16(buf + d*sp + 3*SEQ + 2*DIM, Wv + d*DIM, DIM); + } + IOSurfaceUnlock(s, 0, NULL); +} +static void write_qkv_bwd_acts(IOSurfaceRef s, const float *dq, const float *dk, const float *dv) { + IOSurfaceLock(s, 0, NULL); + _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s); + int sp = 3*SEQ + 3*DIM; + for (int d = 0; d < DIM; d++) { + cvt_f32_f16(buf + d*sp, dq + d*SEQ, SEQ); + cvt_f32_f16(buf + d*sp + SEQ, dk + d*SEQ, SEQ); + cvt_f32_f16(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ); + } + IOSurfaceUnlock(s, 0, NULL); +} + +// Free per-layer surfaces and requests +static void free_per_layer(PerLayerSurfaces *pls, PerLayerRequests *plr) { + for (int L = 0; L < NLAYERS; L++) { + CFRelease(pls[L].sdpaFwd_in); CFRelease(pls[L].ffnFused_in); + CFRelease(pls[L].ffnBwdW2t_in); CFRelease(pls[L].ffnBwdW13t_in); + CFRelease(pls[L].wotBwd_in); CFRelease(pls[L].qkvBwd_in); + CFRelease(plr[L].sdpaFwd); CFRelease(plr[L].ffnFused); + CFRelease(plr[L].ffnBwdW2t); CFRelease(plr[L].ffnBwdW13t); + CFRelease(plr[L].wotBwd); CFRelease(plr[L].qkvBwd); + } +} diff --git a/training/training_dynamic/mil_dynamic.h b/training/training_dynamic/mil_dynamic.h index e6c5798..20746a2 100644 --- a/training/training_dynamic/mil_dynamic.h +++ b/training/training_dynamic/mil_dynamic.h @@ -45,25 +45,21 @@ static void gen_dyn_matmul(NSMutableString *m, const char *prefix, } // ===== Dynamic matmul kernel: y = x @ W ===== -// Input: [1, IC, 1, SEQ+OC] fp32 — act[0:SEQ] + W[SEQ:SEQ+OC] -// Output: [1, OC, 1, SEQ] fp32 +// Input: [1, IC, 1, SEQ+OC] fp16 — act[0:SEQ] + W[SEQ:SEQ+OC] +// Output: [1, OC, 1, SEQ] fp16 static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) { NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; int sp = seq + oc; - [m appendFormat:@" func main(tensor x) {\n", ic, sp]; - [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; - [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", ic, sp]; - gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "xh"); - [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; - [m appendFormat:@" tensor y = cast(dtype=to32,x=mm_y)[name=string(\"cout\")];\n", oc, seq]; - [m appendString:@" } -> (y);\n}\n"]; + [m appendFormat:@" func main(tensor x) {\n", ic, sp]; + gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "x"); + [m appendString:@" } -> (mm_y);\n}\n"]; return m; } // ===== SDPA forward (dynamic weights) ===== // Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul -// Input: [1, DIM, 1, SEQ + 4*DIM] fp32 +// Input: [1, DIM, 1, SEQ + 4*DIM] fp16 // sp[0:SEQ] = xnorm (rmsnorm output, DIM channels) // sp[SEQ:SEQ+DIM] = Wq[DIM,DIM] // sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM] @@ -77,32 +73,29 @@ static NSString *gen_sdpa_fwd_dynamic(void) { int sp_in = SEQ + w_total; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; - // Cast to fp16 - [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; - [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in]; + [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; // Slice xnorm [1,DIM,1,SEQ] [m appendString:@" tensor bx = const()[name=string(\"bx\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor sx = const()[name=string(\"sx\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor xn = slice_by_size(x=x,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ]; // Slice Wq [1,DIM,1,DIM] [m appendFormat:@" tensor bq = const()[name=string(\"bq\"), val=tensor([0,0,0,%d])];\n", SEQ]; [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", DIM, DIM]; - [m appendFormat:@" tensor Wq = slice_by_size(x=xh,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wq = slice_by_size(x=x,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM]; // Slice Wk [m appendFormat:@" tensor bk = const()[name=string(\"bk\"), val=tensor([0,0,0,%d])];\n", SEQ+DIM]; - [m appendFormat:@" tensor Wk = slice_by_size(x=xh,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wk = slice_by_size(x=x,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM]; // Slice Wv [m appendFormat:@" tensor bv = const()[name=string(\"bv\"), val=tensor([0,0,0,%d])];\n", SEQ+2*DIM]; - [m appendFormat:@" tensor Wv = slice_by_size(x=xh,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wv = slice_by_size(x=x,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM]; // Slice Wo [m appendFormat:@" tensor bo = const()[name=string(\"bo\"), val=tensor([0,0,0,%d])];\n", SEQ+3*DIM]; - [m appendFormat:@" tensor Wo = slice_by_size(x=xh,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wo = slice_by_size(x=x,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM]; // Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D] [m appendFormat:@" tensor r2 = const()[name=string(\"r2\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; @@ -173,10 +166,7 @@ static NSString *gen_sdpa_fwd_dynamic(void) { [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ]; - // Cast to fp32 - [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; - [m appendFormat:@" tensor out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 6*DIM, SEQ]; - [m appendString:@" } -> (out32);\n}\n"]; + [m appendString:@" } -> (out);\n}\n"]; return m; } @@ -257,6 +247,101 @@ static NSString *gen_ffn_w13_dynamic(void) { return m; } +// ===== Fused FFN forward: W1,W3 + SiLU + W2 + residual ===== +// RMSNorm stays on CPU (ANE can't handle RMS + 3 matmuls without BNNS fallback) +// Replaces: ffnW13 + CPU gate read + ffnW2 + CPU residual +// Input: [1, DIM, 1, 2*SEQ + 3*HIDDEN] fp16 +// sp[0:SEQ] = x2norm (RMSNorm output, from CPU) +// sp[SEQ:2*SEQ] = x2 (residual, for x_next = x2 + ffn_out) +// sp[2*SEQ : 2*SEQ+HIDDEN] = W1t[DIM,HIDDEN] +// sp[2*SEQ+HIDDEN : 2*SEQ+2*HIDDEN] = W3t[DIM,HIDDEN] +// sp[2*SEQ+2*HIDDEN : 2*SEQ+3*HIDDEN] = W2_orig[DIM,HIDDEN] (transposed inside kernel) +// Output: [1, DIM + 3*HIDDEN, 1, SEQ] fp16 +// = concat(x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN]) +static NSString *gen_ffn_fused_dynamic(void) { + int sp_in = 2*SEQ + 3*HIDDEN; + int out_ch = DIM + 3*HIDDEN; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + + // Slice x2norm [DIM, SEQ] — RMSNorm output (computed on CPU) + [m appendString:@" tensor b_xn = const()[name=string(\"b_xn\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor s_ds = const()[name=string(\"s_ds\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor x2norm = slice_by_size(x=x,begin=b_xn,size=s_ds)[name=string(\"x2norm\")];\n", DIM, SEQ]; + + // Slice x2 [DIM, SEQ] — for residual: x_next = x2 + ffn_out + [m appendFormat:@" tensor b_x2 = const()[name=string(\"b_x2\"), val=tensor([0,0,0,%d])];\n", SEQ]; + [m appendFormat:@" tensor x2 = slice_by_size(x=x,begin=b_x2,size=s_ds)[name=string(\"x2\")];\n", DIM, SEQ]; + + // Slice W1 [DIM, HIDDEN] + [m appendFormat:@" tensor b_w1 = const()[name=string(\"b_w1\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; + [m appendFormat:@" tensor s_wh = const()[name=string(\"s_wh\"), val=tensor([1,%d,1,%d])];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W1 = slice_by_size(x=x,begin=b_w1,size=s_wh)[name=string(\"W1\")];\n", DIM, HIDDEN]; + + // Slice W3 [DIM, HIDDEN] + [m appendFormat:@" tensor b_w3 = const()[name=string(\"b_w3\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+HIDDEN]; + [m appendFormat:@" tensor W3 = slice_by_size(x=x,begin=b_w3,size=s_wh)[name=string(\"W3\")];\n", DIM, HIDDEN]; + + // Slice W2_orig [DIM, HIDDEN] (transposed inside kernel) + [m appendFormat:@" tensor b_w2 = const()[name=string(\"b_w2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+2*HIDDEN]; + [m appendFormat:@" tensor W2r = slice_by_size(x=x,begin=b_w2,size=s_wh)[name=string(\"W2r\")];\n", DIM, HIDDEN]; + + // Reshape for matmul: x2norm [1,DIM,1,SEQ] → [1,1,DIM,SEQ] → [1,1,SEQ,DIM] + [m appendFormat:@" tensor rd = const()[name=string(\"rd\"), val=tensor([1,1,%d,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor xn2 = reshape(shape=rd,x=x2norm)[name=string(\"xn2\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM]; + + // Reshape weights: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN] + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor([1,1,%d,%d])];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN]; + + // h1 = x2norm_t @ W1, h3 = x2norm_t @ W3 [SEQ,DIM] @ [DIM,HIDDEN] → [SEQ,HIDDEN] + [m appendFormat:@" tensor h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN]; + [m appendFormat:@" tensor h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN]; + + // Reshape back: [1,1,SEQ,HIDDEN] → [1,1,HIDDEN,SEQ] → [1,HIDDEN,1,SEQ] + [m appendFormat:@" tensor h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor rh = const()[name=string(\"rh\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor h1 = reshape(shape=rh,x=h1t)[name=string(\"h1\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor h3 = reshape(shape=rh,x=h3t)[name=string(\"h3\")];\n", HIDDEN, SEQ]; + + // SiLU + gate: gate = silu(h1) * h3 + [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ]; + + // gate @ W2: reshape gate [1,HIDDEN,1,SEQ] → [1,1,HIDDEN,SEQ] → [1,1,SEQ,HIDDEN] + [m appendFormat:@" tensor rg = const()[name=string(\"rg\"), val=tensor([1,1,%d,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor g2 = reshape(shape=rg,x=gate)[name=string(\"g2\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor gt = transpose(perm=pm,x=g2)[name=string(\"gtt\")];\n", SEQ, HIDDEN]; + + // W2: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN] → transpose → [1,1,HIDDEN,DIM] + [m appendFormat:@" tensor W22 = reshape(shape=rw,x=W2r)[name=string(\"W22\")];\n", DIM, HIDDEN]; + [m appendFormat:@" tensor W2t = transpose(perm=pm,x=W22)[name=string(\"W2t\")];\n", HIDDEN, DIM]; + + // matmul: [1,1,SEQ,HIDDEN] @ [1,1,HIDDEN,DIM] → [1,1,SEQ,DIM] + [m appendFormat:@" tensor fm = matmul(transpose_x=bF,transpose_y=bF,x=gt,y=W2t)[name=string(\"fm\")];\n", SEQ, DIM]; + // Reshape: [1,1,SEQ,DIM] → [1,1,DIM,SEQ] → [1,DIM,1,SEQ] + [m appendFormat:@" tensor ft = transpose(perm=pm,x=fm)[name=string(\"ft\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor rd2 = const()[name=string(\"rd2\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor ffn_out = reshape(shape=rd2,x=ft)[name=string(\"ffn_out\")];\n", DIM, SEQ]; + + // Residual: x_next = x2 + ffn_out + [m appendFormat:@" tensor x_next = add(x=x2,y=ffn_out)[name=string(\"x_next\")];\n", DIM, SEQ]; + + // Output: concat(x_next, h1, h3, gate) — gate=silu*h3 needed for dW2 + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(x_next,h1,h3,gate))[name=string(\"cat\")];\n", out_ch, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + // FFN part 2: gate @ W2 (HIDDEN→DIM) // Input: [1, HIDDEN, 1, SEQ + DIM] fp32 // sp[0:SEQ] = gate [HIDDEN,SEQ] @@ -305,7 +390,7 @@ static NSString *gen_ffn_w2_dynamic(void) { // Input: [1, DIM, 1, SEQ + HIDDEN] fp32 // sp[0:SEQ] = dffn [DIM, SEQ] // sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN] -// Output: [1, HIDDEN, 1, SEQ] fp32 = dsilu_raw +// Output: [1, HIDDEN, 1, SEQ] fp16 = dsilu_raw static NSString *gen_ffn_bwd_w2t_dynamic(void) { return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ); } @@ -320,32 +405,30 @@ static NSString *gen_ffn_bwd_w2t_dynamic(void) { // sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ] // sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM] // sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM] -// Output: [1, DIM, 1, SEQ] fp32 = dx1 + dx3 +// Output: [1, DIM, 1, SEQ] fp16 = dx1 + dx3 static NSString *gen_ffn_bwd_w13t_dynamic(void) { int sp_in = 2*SEQ + 2*DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", HIDDEN, sp_in]; - [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; - [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in]; + [m appendFormat:@" func main(tensor x) {\n", HIDDEN, sp_in]; // Slice dh1 [HIDDEN, SEQ] [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; [m appendFormat:@" tensor sh = const()[name=string(\"sh\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor dh1 = slice_by_size(x=xh,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh1 = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; // Slice dh3 [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; - [m appendFormat:@" tensor dh3 = slice_by_size(x=xh,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh3 = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; // Slice W1^T [HIDDEN, DIM] [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, DIM]; - [m appendFormat:@" tensor W1t = slice_by_size(x=xh,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM]; + [m appendFormat:@" tensor W1t = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM]; // Slice W3^T [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", 2*SEQ+DIM]; - [m appendFormat:@" tensor W3t = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM]; + [m appendFormat:@" tensor W3t = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; @@ -370,9 +453,7 @@ static NSString *gen_ffn_bwd_w13t_dynamic(void) { [m appendFormat:@" tensor dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ]; [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; - [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; - [m appendFormat:@" tensor y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ]; - [m appendString:@" } -> (y);\n}\n"]; + [m appendString:@" } -> (dx);\n}\n"]; return m; } @@ -516,27 +597,25 @@ static NSString *gen_qkvb_dynamic(void) { int sp_in = 3*SEQ + 3*DIM; NSMutableString *m = [NSMutableString string]; [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; - [m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; - [m appendFormat:@" tensor xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in]; + [m appendFormat:@" func main(tensor x) {\n", DIM, sp_in]; // Slice dq, dk, dv [m appendFormat:@" tensor sd = const()[name=string(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor dq = slice_by_size(x=xh,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dq = slice_by_size(x=x,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ]; [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,0,0,%d])];\n", SEQ]; - [m appendFormat:@" tensor dk = slice_by_size(x=xh,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ]; [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,0,0,%d])];\n", 2*SEQ]; - [m appendFormat:@" tensor dv = slice_by_size(x=xh,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ]; // Slice Wq^T, Wk^T, Wv^T [m appendFormat:@" tensor sw = const()[name=string(\"sw\"), val=tensor([1,%d,1,%d])];\n", DIM, DIM]; [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,0,0,%d])];\n", 3*SEQ]; - [m appendFormat:@" tensor Wqt = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wqt = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM]; [m appendFormat:@" tensor b4 = const()[name=string(\"b4\"), val=tensor([0,0,0,%d])];\n", 3*SEQ+DIM]; - [m appendFormat:@" tensor Wkt = slice_by_size(x=xh,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wkt = slice_by_size(x=x,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM]; [m appendFormat:@" tensor b5 = const()[name=string(\"b5\"), val=tensor([0,0,0,%d])];\n", 3*SEQ+2*DIM]; - [m appendFormat:@" tensor Wvt = slice_by_size(x=xh,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM]; + [m appendFormat:@" tensor Wvt = slice_by_size(x=x,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM]; [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; @@ -570,9 +649,7 @@ static NSString *gen_qkvb_dynamic(void) { [m appendFormat:@" tensor dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ]; [m appendFormat:@" tensor ro = const()[name=string(\"ro\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; - [m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; - [m appendFormat:@" tensor y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ]; - [m appendString:@" } -> (y);\n}\n"]; + [m appendString:@" } -> (dx);\n}\n"]; return m; } diff --git a/training/training_dynamic/train.m b/training/training_dynamic/train.m index 412c4d8..01362e5 100644 --- a/training/training_dynamic/train.m +++ b/training/training_dynamic/train.m @@ -11,8 +11,7 @@ // Dynamic kernel set per layer typedef struct { Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface) - Kern *ffnW13; // W1,W3 matmul (dynamic) - Kern *ffnW2; // W2 matmul (dynamic) + Kern *ffnFused; // residual + RMSNorm + W1,W3 + SiLU + W2 + residual (fused) Kern *ffnBwdW2t; // dffn @ W2^T (dynamic) Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic) Kern *wotBwd; // dx2 @ Wo^T (dynamic) @@ -60,40 +59,36 @@ static void transpose_weight(float *dst, const float *src, int rows, int cols) { static bool compile_dynamic_kernels(DynLayerKernels *dk) { NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}}; - // SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp32 → [1, 6*DIM, 1, SEQ] fp32 + // SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp16 → [1, 6*DIM, 1, SEQ] fp16 printf(" Compiling sdpaFwd...\n"); dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w, - DIM*(SEQ+4*DIM)*4, 6*DIM*SEQ*4); + DIM*(SEQ+4*DIM)*2, 6*DIM*SEQ*2); if (!dk->sdpaFwd) return false; - // FFN W1+W3: [1, DIM, 1, SEQ+2*HIDDEN] fp32 → [1, 3*HIDDEN, 1, SEQ] fp32 - printf(" Compiling ffnW13...\n"); - dk->ffnW13 = compile_kern_mil_w(gen_ffn_w13_dynamic(), @{}, - DIM*(SEQ+2*HIDDEN)*4, 3*HIDDEN*SEQ*4); - if (!dk->ffnW13) return false; + // Fused FFN: W1,W3 + SiLU + W2 + residual (RMSNorm on CPU) + printf(" Compiling ffnFused...\n"); + int ffn_fused_sp = 2*SEQ + 3*HIDDEN; + int ffn_fused_och = DIM + 3*HIDDEN; + dk->ffnFused = compile_kern_mil_w(gen_ffn_fused_dynamic(), @{}, + DIM*ffn_fused_sp*2, ffn_fused_och*SEQ*2); + if (!dk->ffnFused) return false; - // FFN W2: [1, HIDDEN, 1, SEQ+DIM] fp32 → [1, DIM, 1, SEQ] fp32 - printf(" Compiling ffnW2...\n"); - dk->ffnW2 = compile_kern_mil_w(gen_ffn_w2_dynamic(), @{}, - HIDDEN*(SEQ+DIM)*4, DIM*SEQ*4); - if (!dk->ffnW2) return false; - - // FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp32 → [1, HIDDEN, 1, SEQ] fp32 + // FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp16 → [1, HIDDEN, 1, SEQ] fp16 printf(" Compiling ffnBwdW2t...\n"); dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{}, - DIM*(SEQ+HIDDEN)*4, HIDDEN*SEQ*4); + DIM*(SEQ+HIDDEN)*2, HIDDEN*SEQ*2); if (!dk->ffnBwdW2t) return false; - // FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp32 → [1, DIM, 1, SEQ] fp32 + // FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16 → [1, DIM, 1, SEQ] fp16 printf(" Compiling ffnBwdW13t...\n"); dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{}, - HIDDEN*(2*SEQ+2*DIM)*4, DIM*SEQ*4); + HIDDEN*(2*SEQ+2*DIM)*2, DIM*SEQ*2); if (!dk->ffnBwdW13t) return false; - // Wo^T backward: [1, DIM, 1, SEQ+DIM] fp32 → [1, DIM, 1, SEQ] fp32 + // Wo^T backward: [1, DIM, 1, SEQ+DIM] fp16 → [1, DIM, 1, SEQ] fp16 printf(" Compiling wotBwd...\n"); dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{}, - DIM*(SEQ+DIM)*4, DIM*SEQ*4); + DIM*(SEQ+DIM)*2, DIM*SEQ*2); if (!dk->wotBwd) return false; // SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 → [1, DIM+2*SCORE_CH, 1, SEQ] fp16 @@ -108,10 +103,10 @@ static bool compile_dynamic_kernels(DynLayerKernels *dk) { (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); if (!dk->sdpaBwd2) return false; - // QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp32 → [1, DIM, 1, SEQ] fp32 + // QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp16 → [1, DIM, 1, SEQ] fp16 printf(" Compiling qkvBwd...\n"); dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{}, - DIM*(3*SEQ+3*DIM)*4, DIM*SEQ*4); + DIM*(3*SEQ+3*DIM)*2, DIM*SEQ*2); if (!dk->qkvBwd) return false; return true; @@ -134,32 +129,6 @@ static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm, IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL); } -// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] — xnorm at sp[0:S], W1,W3 at sp[S:] -static void write_ffn_w13_input(DynLayerKernels *dk, const float *xnorm, - const float *W1, const float *W3) { - IOSurfaceLock(dk->ffnW13->ioIn, 0, NULL); - float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW13->ioIn); - int sp = SEQ + 2*HIDDEN; - for (int d = 0; d < DIM; d++) { - memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4); - memcpy(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN*4); - memcpy(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN*4); - } - IOSurfaceUnlock(dk->ffnW13->ioIn, 0, NULL); -} - -// ffnW2: [1, HIDDEN, 1, SEQ+DIM] — gate at sp[0:S], W2 at sp[S:] -static void write_ffn_w2_input(DynLayerKernels *dk, const float *gate, const float *W2) { - IOSurfaceLock(dk->ffnW2->ioIn, 0, NULL); - float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW2->ioIn); - int sp = SEQ + DIM; - for (int d = 0; d < HIDDEN; d++) { - memcpy(buf + d*sp, gate + d*SEQ, SEQ*4); - memcpy(buf + d*sp + SEQ, W2 + d*DIM, DIM*4); - } - IOSurfaceUnlock(dk->ffnW2->ioIn, 0, NULL); -} - // ===== Checkpoint ===== static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, double ct, double cw, int cs, int adam_t, @@ -238,11 +207,13 @@ int main(int argc, char *argv[]) { int total_steps = 10000; float max_lr = 3e-4f; - float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; + float adam_b1=0.9f, adam_b2=0.95f, adam_eps=1e-8f, wd=0.1f; int adam_t = 0, start_step = 0; int accum_steps = 10; int warmup_steps = 100; float grad_clip = 1.0f; + float loss_scale = 256.0f; // fp16 loss scaling for ANE backward + float act_clip = 20.0f; float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1 bool do_resume = false, from_scratch = false; @@ -287,7 +258,7 @@ int main(int argc, char *argv[]) { double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6; double embed_m = (double)VOCAB*DIM / 1e6; printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m); - printf("Kernels: 9 compiled, 9 weight-bearing\n"); + printf("Kernels: 8 compiled (ffnFused replaces ffnW13+ffnW2, RMSNorm on CPU)\n"); printf("Accum %d steps, LR=%g\n", accum_steps, max_lr); // FLOPs estimate: 6*N*B*T for transformer (forward+backward ≈ 3x forward) double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ; @@ -353,14 +324,48 @@ int main(int argc, char *argv[]) { AdamState acembed = adam_alloc((size_t)CV*DIM); // ===== Compile all kernels ONCE ===== - printf("Compiling %d dynamic kernels (one-time)...\n", 9); + printf("Compiling %d dynamic kernels (one-time)...\n", 8); uint64_t tc = mach_absolute_time(); DynLayerKernels dk; if (!compile_dynamic_kernels(&dk)) { printf("Compilation failed!\n"); return 1; } double compile_ms = tb_ms(mach_absolute_time() - tc); - printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n\n", compile_ms, NLAYERS); + printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n", compile_ms, NLAYERS); + + // Allocate per-layer IOSurfaces + requests (pre-stage weights) + int per_layer_bytes = (DIM*(SEQ+4*DIM) + DIM*(2*SEQ+3*HIDDEN) + + DIM*(SEQ+HIDDEN) + HIDDEN*(2*SEQ+2*DIM) + DIM*(SEQ+DIM) + DIM*(3*SEQ+3*DIM)) * 2; + int total_surf_mb = (int)((long)per_layer_bytes * NLAYERS / (1024*1024)); + printf("Allocating per-layer IOSurfaces (%d surfaces, ~%dMB fp16)...\n", NLAYERS*6, total_surf_mb); + PerLayerSurfaces pls[NLAYERS]; + PerLayerRequests plr[NLAYERS]; + for (int L = 0; L < NLAYERS; L++) { + pls[L].sdpaFwd_in = make_surface(DIM*(SEQ+4*DIM)*2); + pls[L].ffnFused_in = make_surface(DIM*(2*SEQ+3*HIDDEN)*2); + pls[L].ffnBwdW2t_in = make_surface(DIM*(SEQ+HIDDEN)*2); + pls[L].ffnBwdW13t_in= make_surface(HIDDEN*(2*SEQ+2*DIM)*2); + pls[L].wotBwd_in = make_surface(DIM*(SEQ+DIM)*2); + pls[L].qkvBwd_in = make_surface(DIM*(3*SEQ+3*DIM)*2); + + plr[L].sdpaFwd = make_request(dk.sdpaFwd, pls[L].sdpaFwd_in); + plr[L].ffnFused = make_request(dk.ffnFused, pls[L].ffnFused_in); + plr[L].ffnBwdW2t = make_request(dk.ffnBwdW2t, pls[L].ffnBwdW2t_in); + plr[L].ffnBwdW13t= make_request(dk.ffnBwdW13t,pls[L].ffnBwdW13t_in); + plr[L].wotBwd = make_request(dk.wotBwd, pls[L].wotBwd_in); + plr[L].qkvBwd = make_request(dk.qkvBwd, pls[L].qkvBwd_in); + } + + // Stage weights into per-layer surfaces + for (int L = 0; L < NLAYERS; L++) { + stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]); + stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2); + stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, W2t_buf[L]); + stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, W1t_buf[L], W3t_buf[L]); + stage_wot_bwd_weights(pls[L].wotBwd_in, Wot_buf[L]); + stage_qkv_bwd_weights(pls[L].qkvBwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]); + } + printf("Per-layer weight staging complete\n\n"); // Gradient + work buffers float *dy = (float*)malloc(SEQ*DIM*4); @@ -428,70 +433,63 @@ int main(int argc, char *argv[]) { dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER); t_cblas_wait += tb_ms(mach_absolute_time() - t0); - // SDPA forward (ANE): xnorm + Wq,Wk,Wv,Wo → o_out,Q,K,V,attn_out,xnorm + // SDPA forward (ANE): xnorm + pre-staged Wq,Wk,Wv,Wo → o_out,Q,K,V,attn_out,xnorm t0 = mach_absolute_time(); - write_sdpa_fwd_input(&dk, xnorm_buf, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]); + write_sdpa_fwd_acts(pls[L].sdpaFwd_in, xnorm_buf); t_io_fwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - ane_eval(dk.sdpaFwd); + ane_eval_req(dk.sdpaFwd, plr[L].sdpaFwd); t_ane_fwd += tb_ms(mach_absolute_time() - t0); - // Read output: [1, 6*DIM, 1, SEQ] fp32 + // Read output: [1, 6*DIM, 1, SEQ] fp16 t0 = mach_absolute_time(); IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL); - float *fwd_out = (float*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut); - memcpy(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ*4); - memcpy(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ*4); - memcpy(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ*4); - memcpy(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ*4); - memcpy(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ*4); + _Float16 *fwd_out = (_Float16*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut); + cvt_f16_f32(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ); + cvt_f16_f32(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ); + cvt_f16_f32(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ); + cvt_f16_f32(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ); + cvt_f16_f32(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ); IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL); t_io_fwd += tb_ms(mach_absolute_time() - t0); - // Residual: x2 = x_cur + o_out - vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM)); - - // RMSNorm2 (CPU) + // CPU: residual + RMSNorm (ANE can't fuse RMS with 3 matmuls) t0 = mach_absolute_time(); - rmsnorm(xnorm_buf, ac->x2, lw[L].rms_ffn, DIM, SEQ); - memcpy(ac->x2norm, xnorm_buf, SEQ*DIM*4); + vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM)); + rmsnorm(ac->x2norm, ac->x2, lw[L].rms_ffn, DIM, SEQ); t_rms += tb_ms(mach_absolute_time() - t0); - // FFN W1+W3 (ANE): xnorm → h1, h3, gate + // Fused FFN (ANE): W1,W3 + SiLU + W2 + residual + // Input: x2norm + x2 (acts), W1t + W3t + W2 (pre-staged weights) + // Output: x_next, h1, h3, silu_out t0 = mach_absolute_time(); - write_ffn_w13_input(&dk, xnorm_buf, W1t_buf[L], W3t_buf[L]); + write_ffn_fused_acts(pls[L].ffnFused_in, ac->x2norm, ac->x2); t_io_fwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - ane_eval(dk.ffnW13); + ane_eval_req(dk.ffnFused, plr[L].ffnFused); t_ane_fwd += tb_ms(mach_absolute_time() - t0); - // Read h1, h3, gate from output [1, 3*HIDDEN, 1, SEQ] + // Read fused output: [1, DIM+3*HIDDEN, 1, SEQ] fp16 + // Layout: x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN] t0 = mach_absolute_time(); - IOSurfaceLock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL); - float *ffn13_out = (float*)IOSurfaceGetBaseAddress(dk.ffnW13->ioOut); - memcpy(ac->h1, ffn13_out, HIDDEN*SEQ*4); - memcpy(ac->h3, ffn13_out + HIDDEN*SEQ, HIDDEN*SEQ*4); - memcpy(gate_buf, ffn13_out + 2*HIDDEN*SEQ, HIDDEN*SEQ*4); - memcpy(ac->silu_out, gate_buf, HIDDEN*SEQ*4); - IOSurfaceUnlock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL); + IOSurfaceLock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL); + _Float16 *ffn_out = (_Float16*)IOSurfaceGetBaseAddress(dk.ffnFused->ioOut); + int off = 0; + cvt_f16_f32(x_cur, ffn_out + off, DIM*SEQ); off += DIM*SEQ; + cvt_f16_f32(ac->h1, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ; + cvt_f16_f32(ac->h3, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ; + cvt_f16_f32(ac->silu_out,ffn_out + off, HIDDEN*SEQ); + IOSurfaceUnlock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL); t_io_fwd += tb_ms(mach_absolute_time() - t0); - // FFN W2 (ANE): gate @ W2 → ffn_out - t0 = mach_absolute_time(); - write_ffn_w2_input(&dk, gate_buf, W2t_buf[L]); - t_io_fwd += tb_ms(mach_absolute_time() - t0); - t0 = mach_absolute_time(); - ane_eval(dk.ffnW2); - t_ane_fwd += tb_ms(mach_absolute_time() - t0); - - t0 = mach_absolute_time(); - IOSurfaceLock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL); - memcpy(ac->ffn_out, (float*)IOSurfaceGetBaseAddress(dk.ffnW2->ioOut), DIM*SEQ*4); - IOSurfaceUnlock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL); - t_io_fwd += tb_ms(mach_absolute_time() - t0); - - // Residual: x_cur = x2 + ffn_out - vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM)); + // Scale down residual stream if max magnitude exceeds threshold + { + float amx; vDSP_maxmgv(x_cur, 1, &amx, (vDSP_Length)(SEQ*DIM)); + if (amx > act_clip) { + float sc = act_clip / amx; + vDSP_vsmul(x_cur, 1, &sc, x_cur, 1, (vDSP_Length)(SEQ*DIM)); + } + } } // Final RMSNorm + classifier + loss (CPU) @@ -507,6 +505,10 @@ int main(int argc, char *argv[]) { last_loss = loss; // ===== BACKWARD ===== + // Loss scaling: scale dlogits to prevent fp16 underflow in ANE backward kernels + // All gradients flow scaled; weight grads divided by loss_scale before Adam + vDSP_vsmul(dlogits, 1, &loss_scale, dlogits, 1, (vDSP_Length)(SEQ*CV)); + // Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ] t0 = mach_absolute_time(); cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, @@ -531,12 +533,12 @@ int main(int argc, char *argv[]) { LayerGrads *gr = &grads[L]; memcpy(dffn, dy, SEQ*DIM*4); - // FFN backward: dffn @ W2^T → dsilu_raw + // FFN backward: dffn @ pre-staged W2^T → dsilu_raw t0 = mach_absolute_time(); - io_write_dyn(dk.ffnBwdW2t->ioIn, dffn, DIM, SEQ, lw[L].W2, HIDDEN); + write_ffn_bwd_w2t_acts(pls[L].ffnBwdW2t_in, dffn); t_io_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - ane_eval(dk.ffnBwdW2t); + ane_eval_req(dk.ffnBwdW2t, plr[L].ffnBwdW2t); t_ane_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ); @@ -569,23 +571,12 @@ int main(int argc, char *argv[]) { } t_silu += tb_ms(mach_absolute_time() - t0); - // dh1@W1^T + dh3@W3^T → dx_ffn (ANE) + // dh1@W1^T + dh3@W3^T → dx_ffn (ANE, pre-staged weights) t0 = mach_absolute_time(); - { - IOSurfaceLock(dk.ffnBwdW13t->ioIn, 0, NULL); - float *buf = (float*)IOSurfaceGetBaseAddress(dk.ffnBwdW13t->ioIn); - int sp = 2*SEQ + 2*DIM; - for (int d = 0; d < HIDDEN; d++) { - memcpy(buf + d*sp, dh1 + d*SEQ, SEQ*4); - memcpy(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ*4); - memcpy(buf + d*sp + 2*SEQ, lw[L].W1 + d*DIM, DIM*4); - memcpy(buf + d*sp + 2*SEQ + DIM, lw[L].W3 + d*DIM, DIM*4); - } - IOSurfaceUnlock(dk.ffnBwdW13t->ioIn, 0, NULL); - } + write_ffn_bwd_w13t_acts(pls[L].ffnBwdW13t_in, dh1, dh3); t_io_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - ane_eval(dk.ffnBwdW13t); + ane_eval_req(dk.ffnBwdW13t, plr[L].ffnBwdW13t); t_ane_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ); @@ -616,12 +607,12 @@ int main(int argc, char *argv[]) { for(int i=0;iioIn, dx2, DIM, SEQ, lw[L].Wo, DIM); + write_wot_bwd_acts(pls[L].wotBwd_in, dx2); t_io_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - ane_eval(dk.wotBwd); + ane_eval_req(dk.wotBwd, plr[L].wotBwd); t_ane_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); float *da_buf = (float*)malloc(SEQ*DIM*4); @@ -639,6 +630,19 @@ int main(int argc, char *argv[]) { free(capt_do); free(capt_attn); }); + if (L == 0 && step % 10 == 0) { + float damx, dx2mx, dx2mean; + vDSP_maxmgv(da_buf, 1, &damx, (vDSP_Length)(SEQ*DIM)); + vDSP_maxmgv(dx2, 1, &dx2mx, (vDSP_Length)(SEQ*DIM)); + vDSP_meamgv(dx2, 1, &dx2mean, (vDSP_Length)(SEQ*DIM)); + // Count how many dx2 values survive fp16 conversion + int nz = 0; + for (int i=0; iioIn, 0, ac->Q, DIM, SEQ); @@ -667,6 +671,15 @@ int main(int argc, char *argv[]) { io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); + // Debug: check SDPA backward output magnitudes + if (L == 0 && step % 10 == 0) { + float dqmx, dkmx, dvmx; + vDSP_maxmgv(dq, 1, &dqmx, (vDSP_Length)(SEQ*DIM)); + vDSP_maxmgv(dk_buf, 1, &dkmx, (vDSP_Length)(SEQ*DIM)); + vDSP_maxmgv(dv, 1, &dvmx, (vDSP_Length)(SEQ*DIM)); + printf(" L0 sdpa_bwd: |dq|=%.6f |dk|=%.6f |dv|=%.6f\n", dqmx, dkmx, dvmx); + } + // dWq/dWk/dWv async t0 = mach_absolute_time(); float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4); @@ -684,25 +697,12 @@ int main(int argc, char *argv[]) { free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn); }); - // QKV backward (ANE): dq,dk,dv @ Wq^T,Wk^T,Wv^T → dx_attn + // QKV backward (ANE): dq,dk,dv @ pre-staged Wq^T,Wk^T,Wv^T → dx_attn t0 = mach_absolute_time(); - { - IOSurfaceLock(dk.qkvBwd->ioIn, 0, NULL); - float *buf = (float*)IOSurfaceGetBaseAddress(dk.qkvBwd->ioIn); - int sp = 3*SEQ + 3*DIM; - for (int d = 0; d < DIM; d++) { - memcpy(buf + d*sp, dq + d*SEQ, SEQ*4); - memcpy(buf + d*sp + SEQ, dk_buf + d*SEQ, SEQ*4); - memcpy(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ*4); - memcpy(buf + d*sp + 3*SEQ, lw[L].Wq + d*DIM, DIM*4); - memcpy(buf + d*sp + 3*SEQ+DIM, lw[L].Wk + d*DIM, DIM*4); - memcpy(buf + d*sp + 3*SEQ+2*DIM, lw[L].Wv + d*DIM, DIM*4); - } - IOSurfaceUnlock(dk.qkvBwd->ioIn, 0, NULL); - } + write_qkv_bwd_acts(pls[L].qkvBwd_in, dq, dk_buf, dv); t_io_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); - ane_eval(dk.qkvBwd); + ane_eval_req(dk.qkvBwd, plr[L].qkvBwd); t_ane_bwd += tb_ms(mach_absolute_time() - t0); t0 = mach_absolute_time(); io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ); @@ -741,10 +741,10 @@ int main(int argc, char *argv[]) { // Adam update every accum_steps if ((step+1) % accum_steps == 0 || step == total_steps-1) { dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER); - float gsc = 1.0f / accum_steps; + float gsc = 1.0f / (accum_steps * loss_scale); adam_t++; - // Scale gradients by 1/accum_steps + // Scale gradients by 1/(accum_steps * loss_scale) for (int L=0; LWq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;} @@ -778,7 +778,25 @@ int main(int argc, char *argv[]) { vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); grad_norm_sq+=s; } float grad_norm = sqrtf(grad_norm_sq); - if ((step+1) % 10 == 0) printf(" grad_norm=%.4f\n", grad_norm); + if ((step+1) % 10 == 0) { + // Per-component gradient norms for diagnostics + float attn_sq=0, ffn_sq=0, embed_sq=0; + for (int L=0; LWq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s; + vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s; + vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s; + vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); attn_sq+=s; + vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); ffn_sq+=s; + vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); ffn_sq+=s; + vDSP_dotpr(g->W3,1,g->W3,1,&s,(vDSP_Length)W3_SZ); ffn_sq+=s; + } + { float s; + vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); embed_sq=s; + } + printf(" grad_norm=%.4f attn=%.4f ffn=%.4f embed=%.4f\n", + grad_norm, sqrtf(attn_sq), sqrtf(ffn_sq), sqrtf(embed_sq)); + } // Gradient clipping if (grad_clip > 0 && grad_norm > grad_clip) { @@ -811,15 +829,15 @@ int main(int argc, char *argv[]) { // Adam update for (int L=0; LWq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps); + adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); + adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); + adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); + adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); + adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); + adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); + adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); + adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f); + adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f); // Update transposed weight buffers transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM); @@ -829,9 +847,17 @@ int main(int argc, char *argv[]) { transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM); transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN); transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM); + + // Re-stage weights into per-layer IOSurfaces + stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]); + stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2); + stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, W2t_buf[L]); + stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, W1t_buf[L], W3t_buf[L]); + stage_wot_bwd_weights(pls[L].wotBwd_in, Wot_buf[L]); + stage_qkv_bwd_weights(pls[L].qkvBwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]); } - adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps); - adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps); + adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f); + adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps, wd); // Re-extract compact embed from updated full embed free(cembed); cembed = vocab_compact_embed(embed, &vm, DIM); @@ -867,7 +893,8 @@ int main(int argc, char *argv[]) { free(Wqt_buf[L]); free(Wkt_buf[L]); free(Wvt_buf[L]); free(Wot_buf[L]); free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]); } - free_kern(dk.sdpaFwd); free_kern(dk.ffnW13); free_kern(dk.ffnW2); + free_per_layer(pls, plr); + free_kern(dk.sdpaFwd); free_kern(dk.ffnFused); free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd); free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd); munmap(token_data, data_len); close(data_fd);